diff options
Diffstat (limited to 'pkg/tcpip/transport')
53 files changed, 28604 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD new file mode 100644 index 000000000..7e5c79776 --- /dev/null +++ b/pkg/tcpip/transport/icmp/BUILD @@ -0,0 +1,40 @@ +load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "icmp_packet_list", + out = "icmp_packet_list.go", + package = "icmp", + prefix = "icmpPacket", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*icmpPacket", + "Linker": "*icmpPacket", + }, +) + +go_library( + name = "icmp", + srcs = [ + "endpoint.go", + "endpoint_state.go", + "icmp_packet_list.go", + "protocol.go", + ], + imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/sleep", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/ports", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/raw", + "//pkg/tcpip/transport/tcp", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go new file mode 100644 index 000000000..62d1acad4 --- /dev/null +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -0,0 +1,831 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package icmp + +import ( + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// +stateify savable +type icmpPacket struct { + icmpPacketEntry + senderAddress tcpip.FullAddress + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + timestamp int64 +} + +type endpointState int + +const ( + stateInitial endpointState = iota + stateBound + stateConnected + stateClosed +) + +// endpoint represents an ICMP endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. +// +// +stateify savable +type endpoint struct { + stack.TransportEndpointInfo + + // The following fields are initialized at creation time and are + // immutable. + stack *stack.Stack `state:"manual"` + waiterQueue *waiter.Queue + uniqueID uint64 + + // The following fields are used to manage the receive queue, and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList icmpPacketList + rcvBufSizeMax int `state:".(int)"` + rcvBufSize int + rcvClosed bool + + // The following fields are protected by the mu mutex. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + state endpointState + route stack.Route `state:"manual"` + ttl uint8 + stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner +} + +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return &endpoint{ + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: transProto, + }, + waiterQueue: waiterQueue, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + state: stateInitial, + uniqueID: s.UniqueID(), + }, nil +} + +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + +// Close puts the endpoint in a closed state and frees all resources +// associated with it. +func (e *endpoint) Close() { + e.mu.Lock() + e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite + switch e.state { + case stateBound, stateConnected: + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.TransProto, e.ID, e, ports.Flags{}, 0 /* bindToDevice */) + } + + // Close the receive list and drain it. + e.rcvMu.Lock() + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } + e.rcvMu.Unlock() + + e.route.Release() + + // Update the state. + e.state = stateClosed + + e.mu.Unlock() + + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. +func (e *endpoint) ModerateRecvBuf(copied int) {} + +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + +// Read reads data from the endpoint. This method does not block if +// there is no data pending. +func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + e.rcvMu.Lock() + + if e.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() + err = tcpip.ErrClosedForReceive + } + e.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + p := e.rcvList.Front() + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() + + e.rcvMu.Unlock() + + if addr != nil { + *addr = p.senderAddress + } + + return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil +} + +// prepareForWrite prepares the endpoint for sending data. In particular, it +// binds it if it's still in the initial state. To do so, it must first +// reacquire the mutex in exclusive mode. +// +// Returns true for retry if preparation should be retried. +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { + switch e.state { + case stateInitial: + case stateConnected: + return false, nil + + case stateBound: + if to == nil { + return false, tcpip.ErrDestinationRequired + } + return false, nil + default: + return false, tcpip.ErrInvalidEndpointState + } + + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // The state changed when we released the shared locked and re-acquired + // it in exclusive mode. Try again. + if e.state != stateInitial { + return true, nil + } + + // The state is still 'initial', so try to bind the endpoint. + if err := e.bindLocked(tcpip.FullAddress{}); err != nil { + return false, err + } + + return true, nil +} + +// Write writes data to the endpoint's peer. This method does not block +// if the data cannot be written. +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) + if opts.More { + return 0, nil, tcpip.ErrInvalidOptionValue + } + + to := opts.To + + e.mu.RLock() + defer e.mu.RUnlock() + + // If we've shutdown with SHUT_WR we are in an invalid state for sending. + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { + return 0, nil, tcpip.ErrClosedForSend + } + + // Prepare for write. + for { + retry, err := e.prepareForWrite(to) + if err != nil { + return 0, nil, err + } + + if !retry { + break + } + } + + var route *stack.Route + if to == nil { + route = &e.route + + if route.IsResolutionRequired() { + // Promote lock to exclusive if using a shared route, + // given that it may need to change in Route.Resolve() + // call below. + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // Recheck state after lock was re-acquired. + if e.state != stateConnected { + return 0, nil, tcpip.ErrInvalidEndpointState + } + } + } else { + // Reject destination address if it goes through a different + // NIC than the endpoint was bound to. + nicID := to.NIC + if e.BindNICID != 0 { + if nicID != 0 && nicID != e.BindNICID { + return 0, nil, tcpip.ErrNoRoute + } + + nicID = e.BindNICID + } + + dst, netProto, err := e.checkV4MappedLocked(*to) + if err != nil { + return 0, nil, err + } + + // Find the endpoint. + r, err := e.stack.FindRoute(nicID, e.BindAddr, dst.Addr, netProto, false /* multicastLoop */) + if err != nil { + return 0, nil, err + } + defer r.Release() + + route = &r + } + + if route.IsResolutionRequired() { + if ch, err := route.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + return 0, ch, tcpip.ErrNoLinkAddress + } + return 0, nil, err + } + } + + v, err := p.FullPayload() + if err != nil { + return 0, nil, err + } + + switch e.NetProto { + case header.IPv4ProtocolNumber: + err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) + + case header.IPv6ProtocolNumber: + err = send6(route, e.ID.LocalPort, v, e.ttl) + } + + if err != nil { + return 0, nil, err + } + + return int64(len(v)), nil, nil +} + +// Peek only returns data from a single datagram, so do nothing here. +func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return nil +} + +// SetSockOptBool sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + return nil +} + +// SetSockOptInt sets a socket option. Currently not supported. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + + } + return nil +} + +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 + e.rcvMu.Lock() + if !e.rcvList.Empty() { + p := e.rcvList.Front() + v = p.data.Size() + } + e.rcvMu.Unlock() + return v, nil + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSize + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil + + case tcpip.TTLOption: + e.rcvMu.Lock() + v := int(e.ttl) + e.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.ErrorOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { + if len(data) < header.ICMPv4MinimumSize { + return tcpip.ErrInvalidEndpointState + } + + hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength())) + + icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + copy(icmpv4, data) + // Set the ident to the user-specified port. Sequence number should + // already be set by the user. + icmpv4.SetIdent(ident) + data = data[header.ICMPv4MinimumSize:] + + // Linux performs these basic checks. + if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 { + return tcpip.ErrInvalidEndpointState + } + + icmpv4.SetChecksum(0) + icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) + + if ttl == 0 { + ttl = r.DefaultTTL() + } + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ + Header: hdr, + Data: data.ToVectorisedView(), + TransportHeader: buffer.View(icmpv4), + Owner: owner, + }) +} + +func send6(r *stack.Route, ident uint16, data buffer.View, ttl uint8) *tcpip.Error { + if len(data) < header.ICMPv6EchoMinimumSize { + return tcpip.ErrInvalidEndpointState + } + + hdr := buffer.NewPrependable(header.ICMPv6MinimumSize + int(r.MaxHeaderLength())) + + icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6MinimumSize)) + copy(icmpv6, data) + // Set the ident. Sequence number is provided by the user. + icmpv6.SetIdent(ident) + data = data[header.ICMPv6MinimumSize:] + + if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { + return tcpip.ErrInvalidEndpointState + } + + dataVV := data.ToVectorisedView() + icmpv6.SetChecksum(header.ICMPv6Checksum(icmpv6, r.LocalAddress, r.RemoteAddress, dataVV)) + + if ttl == 0 { + ttl = r.DefaultTTL() + } + return r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS}, &stack.PacketBuffer{ + Header: hdr, + Data: dataVV, + TransportHeader: buffer.View(icmpv6), + }) +} + +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, false /* v6only */) + if err != nil { + return tcpip.FullAddress{}, 0, err + } + return unwrapped, netProto, nil +} + +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Connect connects the endpoint to its peer. Specifying a NIC is optional. +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + nicID := addr.NIC + localPort := uint16(0) + switch e.state { + case stateInitial: + case stateBound, stateConnected: + localPort = e.ID.LocalPort + if e.BindNICID == 0 { + break + } + + if nicID != 0 && nicID != e.BindNICID { + return tcpip.ErrInvalidEndpointState + } + + nicID = e.BindNICID + default: + return tcpip.ErrInvalidEndpointState + } + + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + + // Find a route to the desired destination. + r, err := e.stack.FindRoute(nicID, e.BindAddr, addr.Addr, netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + id := stack.TransportEndpointID{ + LocalAddress: r.LocalAddress, + LocalPort: localPort, + RemoteAddress: r.RemoteAddress, + } + + // Even if we're connected, this endpoint can still be used to send + // packets on a different network protocol, so we register both even if + // v6only is set to false and this is an ipv6 endpoint. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + + id, err = e.registerWithStack(nicID, netProtos, id) + if err != nil { + return err + } + + e.ID = id + e.route = r.Clone() + e.RegisterNICID = nicID + + e.state = stateConnected + + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// ConnectEndpoint is not supported. +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + +// Shutdown closes the read and/or write end of the endpoint connection +// to its peer. +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + e.shutdownFlags |= flags + + if e.state != stateConnected { + return tcpip.ErrNotConnected + } + + if flags&tcpip.ShutdownRead != 0 { + e.rcvMu.Lock() + wasClosed := e.rcvClosed + e.rcvClosed = true + e.rcvMu.Unlock() + + if !wasClosed { + e.waiterQueue.Notify(waiter.EventIn) + } + } + + return nil +} + +// Listen is not supported by UDP, it just fails. +func (*endpoint) Listen(int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept is not supported by UDP, it just fails. +func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { + if id.LocalPort != 0 { + // The endpoint already has a local port, just attempt to + // register it. + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindToDevice */) + return id, err + } + + // We need to find a port for the endpoint. + _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + id.LocalPort = p + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, e.TransProto, id, e, ports.Flags{}, 0 /* bindtodevice */) + switch err { + case nil: + return true, nil + case tcpip.ErrPortInUse: + return false, nil + default: + return false, err + } + }) + + return id, err +} + +func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { + // Don't allow binding once endpoint is not in the initial state + // anymore. + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + + if len(addr.Addr) != 0 { + // A local address was specified, verify that it's valid. + if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { + return tcpip.ErrBadLocalAddress + } + } + + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: addr.Addr, + } + id, err = e.registerWithStack(addr.NIC, netProtos, id) + if err != nil { + return err + } + + e.ID = id + e.RegisterNICID = addr.NIC + + // Mark endpoint as bound. + e.state = stateBound + + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// Bind binds the endpoint to a specific local address and port. +// Specifying a NIC is optional. +func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + err := e.bindLocked(addr) + if err != nil { + return err + } + + e.BindNICID = addr.NIC + e.BindAddr = addr.Addr + + return nil +} + +// GetLocalAddress returns the address to which the endpoint is bound. +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + return tcpip.FullAddress{ + NIC: e.RegisterNICID, + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + }, nil +} + +// GetRemoteAddress returns the address to which the endpoint is connected. +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state != stateConnected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + NIC: e.RegisterNICID, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, + }, nil +} + +// Readiness returns the current readiness of the endpoint. For example, if +// waiter.EventIn is set, the endpoint is immediately readable. +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine if the endpoint is readable if requested. + if (mask & waiter.EventIn) != 0 { + e.rcvMu.Lock() + if !e.rcvList.Empty() || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvMu.Unlock() + } + + return result +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + // Only accept echo replies. + switch e.NetProto { + case header.IPv4ProtocolNumber: + h := header.ICMPv4(pkt.TransportHeader) + if len(h) < header.ICMPv4MinimumSize || h.Type() != header.ICMPv4EchoReply { + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + return + } + case header.IPv6ProtocolNumber: + h := header.ICMPv6(pkt.TransportHeader) + if len(h) < header.ICMPv6MinimumSize || h.Type() != header.ICMPv6EchoReply { + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + return + } + } + + e.rcvMu.Lock() + + // Drop the packet if our buffer is currently full. + if !e.rcvReady || e.rcvClosed { + e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { + e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() + return + } + + wasEmpty := e.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + packet := &icmpPacket{ + senderAddress: tcpip.FullAddress{ + NIC: r.NICID(), + Addr: id.RemoteAddress, + }, + } + + // ICMP socket's data includes ICMP header. + packet.data = pkt.TransportHeader.ToVectorisedView() + packet.data.Append(pkt.Data) + + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() + + packet.timestamp = e.stack.NowNanoseconds() + + e.rcvMu.Unlock() + e.stats.PacketsReceived.Increment() + // Notify any waiters that there's data to be read now. + if wasEmpty { + e.waiterQueue.Notify(waiter.EventIn) + } +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { +} + +// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't +// expose internal socket state. +func (e *endpoint) State() uint32 { + return 0 +} + +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go new file mode 100644 index 000000000..9d263c0ec --- /dev/null +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -0,0 +1,95 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package icmp + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// saveData saves icmpPacket.data field. +func (p *icmpPacket) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, + // which is not allowed by state framework (in-struct pointer). + return p.data.Clone(nil) +} + +// loadData loads icmpPacket.data field. +func (p *icmpPacket) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing p.views for data.views. + p.data = data +} + +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after savercvBufSizeMax(), which would have + // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + e.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (e *endpoint) saveRcvBufSizeMax() int { + max := e.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + e.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + e.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (e *endpoint) loadRcvBufSizeMax(max int) { + e.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (e *endpoint) afterLoad() { + stack.StackFromEnv.RegisterRestoredEndpoint(e) +} + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + + if e.state != stateBound && e.state != stateConnected { + return + } + + var err *tcpip.Error + if e.state == stateConnected { + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.ID.RemoteAddress, e.NetProto, false /* multicastLoop */) + if err != nil { + panic(err) + } + + e.ID.LocalAddress = e.route.LocalAddress + } else if len(e.ID.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.ID.LocalAddress) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + e.ID, err = e.registerWithStack(e.RegisterNICID, []tcpip.NetworkProtocolNumber{e.NetProto}, e.ID) + if err != nil { + panic(err) + } +} diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go new file mode 100644 index 000000000..74ef6541e --- /dev/null +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -0,0 +1,145 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport +// protocols for use in ping. To use it in the networking stack, this package +// must be added to the project, and activated on the stack by passing +// icmp.NewProtocol4() and/or icmp.NewProtocol6() as one of the transport +// protocols when calling stack.New(). Then endpoints can be created by passing +// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number +// when calling Stack.NewEndpoint(). +package icmp + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/raw" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // ProtocolNumber4 is the ICMP protocol number. + ProtocolNumber4 = header.ICMPv4ProtocolNumber + + // ProtocolNumber6 is the IPv6-ICMP protocol number. + ProtocolNumber6 = header.ICMPv6ProtocolNumber +) + +// protocol implements stack.TransportProtocol. +type protocol struct { + number tcpip.TransportProtocolNumber +} + +// Number returns the ICMP protocol number. +func (p *protocol) Number() tcpip.TransportProtocolNumber { + return p.number +} + +func (p *protocol) netProto() tcpip.NetworkProtocolNumber { + switch p.number { + case ProtocolNumber4: + return header.IPv4ProtocolNumber + case ProtocolNumber6: + return header.IPv6ProtocolNumber + } + panic(fmt.Sprint("unknown protocol number: ", p.number)) +} + +// NewEndpoint creates a new icmp endpoint. It implements +// stack.TransportProtocol.NewEndpoint. +func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if netProto != p.netProto() { + return nil, tcpip.ErrUnknownProtocol + } + return newEndpoint(stack, netProto, p.number, waiterQueue) +} + +// NewRawEndpoint creates a new raw icmp endpoint. It implements +// stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if netProto != p.netProto() { + return nil, tcpip.ErrUnknownProtocol + } + return raw.NewEndpoint(stack, netProto, p.number, waiterQueue) +} + +// MinimumPacketSize returns the minimum valid icmp packet size. +func (p *protocol) MinimumPacketSize() int { + switch p.number { + case ProtocolNumber4: + return header.ICMPv4MinimumSize + case ProtocolNumber6: + return header.ICMPv6MinimumSize + } + panic(fmt.Sprint("unknown protocol number: ", p.number)) +} + +// ParsePorts in case of ICMP sets src to 0, dst to ICMP ID, and err to nil. +func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + switch p.number { + case ProtocolNumber4: + hdr := header.ICMPv4(v) + return 0, hdr.Ident(), nil + case ProtocolNumber6: + hdr := header.ICMPv6(v) + return 0, hdr.Ident(), nil + } + panic(fmt.Sprint("unknown protocol number: ", p.number)) +} + +// HandleUnknownDestinationPacket handles packets targeted at this protocol but +// that don't match any existing endpoint. +func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) bool { + return true +} + +// SetOption implements stack.TransportProtocol.SetOption. +func (*protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements stack.TransportProtocol.Option. +func (*protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + // TODO(gvisor.dev/issue/170): Implement parsing of ICMP. + // + // Right now, the Parse() method is tied to enabled protocols passed into + // stack.New. This works for UDP and TCP, but we handle ICMP traffic even + // when netstack users don't pass ICMP as a supported protocol. + return false +} + +// NewProtocol4 returns an ICMPv4 transport protocol. +func NewProtocol4() stack.TransportProtocol { + return &protocol{ProtocolNumber4} +} + +// NewProtocol6 returns an ICMPv6 transport protocol. +func NewProtocol6() stack.TransportProtocol { + return &protocol{ProtocolNumber6} +} diff --git a/pkg/tcpip/transport/packet/BUILD b/pkg/tcpip/transport/packet/BUILD new file mode 100644 index 000000000..b989b1209 --- /dev/null +++ b/pkg/tcpip/transport/packet/BUILD @@ -0,0 +1,37 @@ +load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "packet_list", + out = "packet_list.go", + package = "packet", + prefix = "packet", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*packet", + "Linker": "*packet", + }, +) + +go_library( + name = "packet", + srcs = [ + "endpoint.go", + "endpoint_state.go", + "packet_list.go", + ], + imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/log", + "//pkg/sleep", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go new file mode 100644 index 000000000..a8f8454dd --- /dev/null +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -0,0 +1,469 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package packet provides the implementation of packet sockets (see +// packet(7)). Packet sockets allow applications to: +// +// * manually write and inspect link, network, and transport headers +// * receive all traffic of a given network protocol, or all protocols +// +// Packet sockets are similar to raw sockets, but provide even more power to +// users, letting them effectively talk directly to the network device. +// +// Packet sockets skip the input and output iptables chains. +package packet + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// +stateify savable +type packet struct { + packetEntry + // data holds the actual packet data, including any headers and + // payload. + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + // timestampNS is the unix time at which the packet was received. + timestampNS int64 + // senderAddr is the network address of the sender. + senderAddr tcpip.FullAddress +} + +// endpoint is the packet socket implementation of tcpip.Endpoint. It is legal +// to have goroutines make concurrent calls into the endpoint. +// +// Lock order: +// endpoint.mu +// endpoint.rcvMu +// +// +stateify savable +type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and are + // immutable. + stack *stack.Stack `state:"manual"` + netProto tcpip.NetworkProtocolNumber + waiterQueue *waiter.Queue + cooked bool + + // The following fields are used to manage the receive queue and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvList packetList + rcvBufSizeMax int `state:".(int)"` + rcvBufSize int + rcvClosed bool + + // The following fields are protected by mu. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + closed bool + stats tcpip.TransportEndpointStats `state:"nosave"` + bound bool +} + +// NewEndpoint returns a new packet endpoint. +func NewEndpoint(s *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + ep := &endpoint{ + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + }, + cooked: cooked, + netProto: netProto, + waiterQueue: waiterQueue, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + } + + // Override with stack defaults. + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + ep.sndBufSizeMax = ss.Default + } + + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + ep.rcvBufSizeMax = rs.Default + } + + if err := s.RegisterPacketEndpoint(0, netProto, ep); err != nil { + return nil, err + } + return ep, nil +} + +// Abort implements stack.TransportEndpoint.Abort. +func (ep *endpoint) Abort() { + ep.Close() +} + +// Close implements tcpip.Endpoint.Close. +func (ep *endpoint) Close() { + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.closed { + return + } + + ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + + ep.rcvMu.Lock() + defer ep.rcvMu.Unlock() + + // Clear the receive list. + ep.rcvClosed = true + ep.rcvBufSize = 0 + for !ep.rcvList.Empty() { + ep.rcvList.Remove(ep.rcvList.Front()) + } + + ep.closed = true + ep.bound = false + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. +func (ep *endpoint) ModerateRecvBuf(copied int) {} + +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + ep.rcvMu.Lock() + + // If there's no data to read, return that read would block or that the + // endpoint is closed. + if ep.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if ep.rcvClosed { + ep.stats.ReadErrors.ReadClosed.Increment() + err = tcpip.ErrClosedForReceive + } + ep.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + packet := ep.rcvList.Front() + ep.rcvList.Remove(packet) + ep.rcvBufSize -= packet.data.Size() + + ep.rcvMu.Unlock() + + if addr != nil { + *addr = packet.senderAddr + } + + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil +} + +func (ep *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // TODO(b/129292371): Implement. + return 0, nil, tcpip.ErrInvalidOptionValue +} + +// Peek implements tcpip.Endpoint.Peek. +func (ep *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be +// disconnected, and this function always returns tpcip.ErrNotSupported. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Connect implements tcpip.Endpoint.Connect. Packet sockets cannot be +// connected, and this function always returnes tcpip.ErrNotSupported. +func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Shutdown implements tcpip.Endpoint.Shutdown. Packet sockets cannot be used +// with Shutdown, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Listen implements tcpip.Endpoint.Listen. Packet sockets cannot be used with +// Listen, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Listen(backlog int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept implements tcpip.Endpoint.Accept. Packet sockets cannot be used with +// Accept, and this function always returns tcpip.ErrNotSupported. +func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +// Bind implements tcpip.Endpoint.Bind. +func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + // TODO(gvisor.dev/issue/173): Add Bind support. + + // "By default, all packets of the specified protocol type are passed + // to a packet socket. To get packets only from a specific interface + // use bind(2) specifying an address in a struct sockaddr_ll to bind + // the packet socket to an interface. Fields used for binding are + // sll_family (should be AF_PACKET), sll_protocol, and sll_ifindex." + // - packet(7). + + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.bound { + return tcpip.ErrAlreadyBound + } + + // Unregister endpoint with all the nics. + ep.stack.UnregisterPacketEndpoint(0, ep.netProto, ep) + + // Bind endpoint to receive packets from specific interface. + if err := ep.stack.RegisterPacketEndpoint(addr.NIC, ep.netProto, ep); err != nil { + return err + } + + ep.bound = true + + return nil +} + +// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. +func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{}, tcpip.ErrNotSupported +} + +// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. +func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + // Even a connected socket doesn't return a remote address. + return tcpip.FullAddress{}, tcpip.ErrNotConnected +} + +// Readiness implements tcpip.Endpoint.Readiness. +func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine whether the endpoint is readable. + if (mask & waiter.EventIn) != 0 { + ep.rcvMu.Lock() + if !ep.rcvList.Empty() || ep.rcvClosed { + result |= waiter.EventIn + } + ep.rcvMu.Unlock() + } + + return result +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. Packet sockets cannot be +// used with SetSockOpt, and this function always returns +// tcpip.ErrNotSupported. +func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (ep *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss stack.SendBufferSizeOption + if err := ep.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + if v > ss.Max { + v = ss.Max + } + if v < ss.Min { + v = ss.Min + } + ep.mu.Lock() + ep.sndBufSizeMax = v + ep.mu.Unlock() + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs stack.ReceiveBufferSizeOption + if err := ep.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) + } + if v > rs.Max { + v = rs.Max + } + if v < rs.Min { + v = rs.Min + } + ep.rcvMu.Lock() + ep.rcvBufSizeMax = v + ep.rcvMu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (ep *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + return false, tcpip.ErrNotSupported +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (ep *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 + ep.rcvMu.Lock() + if !ep.rcvList.Empty() { + p := ep.rcvList.Front() + v = p.data.Size() + } + ep.rcvMu.Unlock() + return v, nil + + case tcpip.SendBufferSizeOption: + ep.mu.Lock() + v := ep.sndBufSizeMax + ep.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + v := ep.rcvBufSizeMax + ep.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// HandlePacket implements stack.PacketEndpoint.HandlePacket. +func (ep *endpoint) HandlePacket(nicID tcpip.NICID, localAddr tcpip.LinkAddress, netProto tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) { + ep.rcvMu.Lock() + + // Drop the packet if our buffer is currently full. + if ep.rcvClosed { + ep.rcvMu.Unlock() + ep.stack.Stats().DroppedPackets.Increment() + ep.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if ep.rcvBufSize >= ep.rcvBufSizeMax { + ep.rcvMu.Unlock() + ep.stack.Stats().DroppedPackets.Increment() + ep.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() + return + } + + wasEmpty := ep.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + var packet packet + // TODO(b/129292371): Return network protocol. + if len(pkt.LinkHeader) > 0 { + // Get info directly from the ethernet header. + hdr := header.Ethernet(pkt.LinkHeader) + packet.senderAddr = tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(hdr.SourceAddress()), + } + } else { + // Guess the would-be ethernet header. + packet.senderAddr = tcpip.FullAddress{ + NIC: nicID, + Addr: tcpip.Address(localAddr), + } + } + + if ep.cooked { + // Cooked packets can simply be queued. + packet.data = pkt.Data + } else { + // Raw packets need their ethernet headers prepended before + // queueing. + var linkHeader buffer.View + if len(pkt.LinkHeader) == 0 { + // We weren't provided with an actual ethernet header, + // so fake one. + ethFields := header.EthernetFields{ + SrcAddr: tcpip.LinkAddress([]byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00}), + DstAddr: localAddr, + Type: netProto, + } + fakeHeader := make(header.Ethernet, header.EthernetMinimumSize) + fakeHeader.Encode(ðFields) + linkHeader = buffer.View(fakeHeader) + } else { + linkHeader = append(buffer.View(nil), pkt.LinkHeader...) + } + combinedVV := linkHeader.ToVectorisedView() + combinedVV.Append(pkt.Data) + packet.data = combinedVV + } + packet.timestampNS = ep.stack.NowNanoseconds() + + ep.rcvList.PushBack(&packet) + ep.rcvBufSize += packet.data.Size() + + ep.rcvMu.Unlock() + ep.stats.PacketsReceived.Increment() + // Notify waiters that there's data to be read. + if wasEmpty { + ep.waiterQueue.Notify(waiter.EventIn) + } +} + +// State implements socket.Socket.State. +func (ep *endpoint) State() uint32 { + return 0 +} + +// Info returns a copy of the endpoint info. +func (ep *endpoint) Info() tcpip.EndpointInfo { + ep.mu.RLock() + // Make a copy of the endpoint info. + ret := ep.TransportEndpointInfo + ep.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (ep *endpoint) Stats() tcpip.EndpointStats { + return &ep.stats +} + +func (ep *endpoint) SetOwner(owner tcpip.PacketOwner) {} diff --git a/pkg/tcpip/transport/packet/endpoint_state.go b/pkg/tcpip/transport/packet/endpoint_state.go new file mode 100644 index 000000000..9b88f17e4 --- /dev/null +++ b/pkg/tcpip/transport/packet/endpoint_state.go @@ -0,0 +1,72 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package packet + +import ( + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// saveData saves packet.data field. +func (p *packet) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, + // which is not allowed by state framework (in-struct pointer). + return p.data.Clone(nil) +} + +// loadData loads packet.data field. +func (p *packet) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing p.views for data.views. + p.data = data +} + +// beforeSave is invoked by stateify. +func (ep *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after saveRcvBufSizeMax(), which would have + // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + ep.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) saveRcvBufSizeMax() int { + max := ep.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + ep.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + ep.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) loadRcvBufSizeMax(max int) { + ep.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (ep *endpoint) afterLoad() { + // StackFromEnv is a stack used specifically for save/restore. + ep.stack = stack.StackFromEnv + + // TODO(gvisor.dev/173): Once bind is supported, choose the right NIC. + if err := ep.stack.RegisterPacketEndpoint(0, ep.netProto, ep); err != nil { + panic(*err) + } +} diff --git a/pkg/tcpip/transport/raw/BUILD b/pkg/tcpip/transport/raw/BUILD new file mode 100644 index 000000000..2eab09088 --- /dev/null +++ b/pkg/tcpip/transport/raw/BUILD @@ -0,0 +1,39 @@ +load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "raw_packet_list", + out = "raw_packet_list.go", + package = "raw", + prefix = "rawPacket", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*rawPacket", + "Linker": "*rawPacket", + }, +) + +go_library( + name = "raw", + srcs = [ + "endpoint.go", + "endpoint_state.go", + "protocol.go", + "raw_packet_list.go", + ], + imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/log", + "//pkg/sleep", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/packet", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go new file mode 100644 index 000000000..5b6e7d102 --- /dev/null +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -0,0 +1,729 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package raw provides the implementation of raw sockets (see raw(7)). Raw +// sockets allow applications to: +// +// * manually write and inspect transport layer headers and payloads +// * receive all traffic of a given transport protocol (e.g. ICMP or UDP) +// * optionally write and inspect network layer headers of packets +// +// Raw sockets don't have any notion of ports, and incoming packets are +// demultiplexed solely by protocol number. Thus, a raw UDP endpoint will +// receive every UDP packet received by netstack. bind(2) and connect(2) can be +// used to filter incoming packets by source and destination. +package raw + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// +stateify savable +type rawPacket struct { + rawPacketEntry + // data holds the actual packet data, including any headers and + // payload. + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + // timestampNS is the unix time at which the packet was received. + timestampNS int64 + // senderAddr is the network address of the sender. + senderAddr tcpip.FullAddress +} + +// endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to +// have goroutines make concurrent calls into the endpoint. +// +// Lock order: +// endpoint.mu +// endpoint.rcvMu +// +// +stateify savable +type endpoint struct { + stack.TransportEndpointInfo + // The following fields are initialized at creation time and are + // immutable. + stack *stack.Stack `state:"manual"` + waiterQueue *waiter.Queue + associated bool + hdrIncluded bool + + // The following fields are used to manage the receive queue and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvList rawPacketList + rcvBufSize int + rcvBufSizeMax int `state:".(int)"` + rcvClosed bool + + // The following fields are protected by mu. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + closed bool + connected bool + bound bool + // route is the route to a remote network endpoint. It is set via + // Connect(), and is valid only when conneted is true. + route stack.Route `state:"manual"` + stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner +} + +// NewEndpoint returns a raw endpoint for the given protocols. +func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, transProto, waiterQueue, true /* associated */) +} + +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { + if netProto != header.IPv4ProtocolNumber && netProto != header.IPv6ProtocolNumber { + return nil, tcpip.ErrUnknownProtocol + } + + e := &endpoint{ + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: transProto, + }, + waiterQueue: waiterQueue, + rcvBufSizeMax: 32 * 1024, + sndBufSizeMax: 32 * 1024, + associated: associated, + hdrIncluded: !associated, + } + + // Override with stack defaults. + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + e.sndBufSizeMax = ss.Default + } + + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + e.rcvBufSizeMax = rs.Default + } + + // Unassociated endpoints are write-only and users call Write() with IP + // headers included. Because they're write-only, We don't need to + // register with the stack. + if !associated { + e.rcvBufSizeMax = 0 + e.waiterQueue = nil + return e, nil + } + + if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { + return nil, err + } + + return e, nil +} + +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + +// Close implements tcpip.Endpoint.Close. +func (e *endpoint) Close() { + e.mu.Lock() + defer e.mu.Unlock() + + if e.closed || !e.associated { + return + } + + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + + e.rcvMu.Lock() + defer e.rcvMu.Unlock() + + // Clear the receive list. + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + e.rcvList.Remove(e.rcvList.Front()) + } + + if e.connected { + e.route.Release() + e.connected = false + } + + e.closed = true + + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. +func (e *endpoint) ModerateRecvBuf(copied int) {} + +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + +// Read implements tcpip.Endpoint.Read. +func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + e.rcvMu.Lock() + + // If there's no data to read, return that read would block or that the + // endpoint is closed. + if e.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() + err = tcpip.ErrClosedForReceive + } + e.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + pkt := e.rcvList.Front() + e.rcvList.Remove(pkt) + e.rcvBufSize -= pkt.data.Size() + + e.rcvMu.Unlock() + + if addr != nil { + *addr = pkt.senderAddr + } + + return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil +} + +// Write implements tcpip.Endpoint.Write. +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // We can create, but not write to, unassociated IPv6 endpoints. + if !e.associated && e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber { + return 0, nil, tcpip.ErrInvalidOptionValue + } + + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. + if opts.More { + return 0, nil, tcpip.ErrInvalidOptionValue + } + + e.mu.RLock() + + if e.closed { + e.mu.RUnlock() + return 0, nil, tcpip.ErrInvalidEndpointState + } + + payloadBytes, err := p.FullPayload() + if err != nil { + e.mu.RUnlock() + return 0, nil, err + } + + // If this is an unassociated socket and callee provided a nonzero + // destination address, route using that address. + if e.hdrIncluded { + ip := header.IPv4(payloadBytes) + if !ip.IsValid(len(payloadBytes)) { + e.mu.RUnlock() + return 0, nil, tcpip.ErrInvalidOptionValue + } + dstAddr := ip.DestinationAddress() + // Update dstAddr with the address in the IP header, unless + // opts.To is set (e.g. if sendto specifies a specific + // address). + if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil { + opts.To = &tcpip.FullAddress{ + NIC: 0, // NIC is unset. + Addr: dstAddr, // The address from the payload. + Port: 0, // There are no ports here. + } + } + } + + // Did the user caller provide a destination? If not, use the connected + // destination. + if opts.To == nil { + // If the user doesn't specify a destination, they should have + // connected to another address. + if !e.connected { + e.mu.RUnlock() + return 0, nil, tcpip.ErrDestinationRequired + } + + if e.route.IsResolutionRequired() { + savedRoute := &e.route + // Promote lock to exclusive if using a shared route, + // given that it may need to change in finishWrite. + e.mu.RUnlock() + e.mu.Lock() + + // Make sure that the route didn't change during the + // time we didn't hold the lock. + if !e.connected || savedRoute != &e.route { + e.mu.Unlock() + return 0, nil, tcpip.ErrInvalidEndpointState + } + + n, ch, err := e.finishWrite(payloadBytes, savedRoute) + e.mu.Unlock() + return n, ch, err + } + + n, ch, err := e.finishWrite(payloadBytes, &e.route) + e.mu.RUnlock() + return n, ch, err + } + + // The caller provided a destination. Reject destination address if it + // goes through a different NIC than the endpoint was bound to. + nic := opts.To.NIC + if e.bound && nic != 0 && nic != e.BindNICID { + e.mu.RUnlock() + return 0, nil, tcpip.ErrNoRoute + } + + // Find the route to the destination. If BindAddress is 0, + // FindRoute will choose an appropriate source address. + route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) + if err != nil { + e.mu.RUnlock() + return 0, nil, err + } + + n, ch, err := e.finishWrite(payloadBytes, &route) + route.Release() + e.mu.RUnlock() + return n, ch, err +} + +// finishWrite writes the payload to a route. It resolves the route if +// necessary. It's really just a helper to make defer unnecessary in Write. +func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, <-chan struct{}, *tcpip.Error) { + // We may need to resolve the route (match a link layer address to the + // network address). If that requires blocking (e.g. to use ARP), + // return a channel on which the caller can wait. + if route.IsResolutionRequired() { + if ch, err := route.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + return 0, ch, tcpip.ErrNoLinkAddress + } + return 0, nil, err + } + } + + if e.hdrIncluded { + if err := route.WriteHeaderIncludedPacket(&stack.PacketBuffer{ + Data: buffer.View(payloadBytes).ToVectorisedView(), + }); err != nil { + return 0, nil, err + } + } else { + hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) + if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: e.TransProto, TTL: route.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ + Header: hdr, + Data: buffer.View(payloadBytes).ToVectorisedView(), + Owner: e.owner, + }); err != nil { + return 0, nil, err + } + } + + return int64(len(payloadBytes)), nil, nil +} + +// Peek implements tcpip.Endpoint.Peek. +func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Connect implements tcpip.Endpoint.Connect. +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + if e.closed { + return tcpip.ErrInvalidEndpointState + } + + nic := addr.NIC + if e.bound { + if e.BindNICID == 0 { + // If we're bound, but not to a specific NIC, the NIC + // in addr will be used. Nothing to do here. + } else if addr.NIC == 0 { + // If we're bound to a specific NIC, but addr doesn't + // specify a NIC, use the bound NIC. + nic = e.BindNICID + } else if addr.NIC != e.BindNICID { + // We're bound and addr specifies a NIC. They must be + // the same. + return tcpip.ErrInvalidEndpointState + } + } + + // Find a route to the destination. + route, err := e.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, e.NetProto, false) + if err != nil { + return err + } + defer route.Release() + + if e.associated { + // Re-register the endpoint with the appropriate NIC. + if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + return err + } + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.RegisterNICID = nic + } + + // Save the route we've connected via. + e.route = route.Clone() + e.connected = true + + return nil +} + +// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + if !e.connected { + return tcpip.ErrNotConnected + } + return nil +} + +// Listen implements tcpip.Endpoint.Listen. +func (e *endpoint) Listen(backlog int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept implements tcpip.Endpoint.Accept. +func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +// Bind implements tcpip.Endpoint.Bind. +func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // If a local address was specified, verify that it's valid. + if e.stack.CheckLocalAddress(addr.NIC, e.NetProto, addr.Addr) == 0 { + return tcpip.ErrBadLocalAddress + } + + if e.associated { + // Re-register the endpoint with the appropriate NIC. + if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + return err + } + e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) + e.RegisterNICID = addr.NIC + e.BindNICID = addr.NIC + } + + e.BindAddr = addr.Addr + e.bound = true + + return nil +} + +// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{}, tcpip.ErrNotSupported +} + +// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + // Even a connected socket doesn't return a remote address. + return tcpip.FullAddress{}, tcpip.ErrNotConnected +} + +// Readiness implements tcpip.Endpoint.Readiness. +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine whether the endpoint is readable. + if (mask & waiter.EventIn) != 0 { + e.rcvMu.Lock() + if !e.rcvList.Empty() || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvMu.Unlock() + } + + return result +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + e.hdrIncluded = v + e.mu.Unlock() + return nil + } + return tcpip.ErrUnknownProtocolOption +} + +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", ss, err)) + } + if v > ss.Max { + v = ss.Max + } + if v < ss.Min { + v = ss.Min + } + e.mu.Lock() + e.sndBufSizeMax = v + e.mu.Unlock() + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("s.Option(%#v) = %s", rs, err)) + } + if v > rs.Max { + v = rs.Max + } + if v < rs.Min { + v = rs.Min + } + e.rcvMu.Lock() + e.rcvBufSizeMax = v + e.rcvMu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.ErrorOption: + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.IPHdrIncludedOption: + e.mu.Lock() + v := e.hdrIncluded + e.mu.Unlock() + return v, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { + switch opt { + case tcpip.ReceiveQueueSizeOption: + v := 0 + e.rcvMu.Lock() + if !e.rcvList.Empty() { + p := e.rcvList.Front() + v = p.data.Size() + } + e.rcvMu.Unlock() + return v, nil + + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSizeMax + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// HandlePacket implements stack.RawTransportEndpoint.HandlePacket. +func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { + e.rcvMu.Lock() + + // Drop the packet if our buffer is currently full or if this is an unassociated + // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only + // See: https://man7.org/linux/man-pages/man7/raw.7.html + // + // An IPPROTO_RAW socket is send only. If you really want to receive + // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. + // Note that packet sockets don't reassemble IP fragments, unlike raw + // sockets. + if e.rcvClosed || !e.associated { + e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { + e.rcvMu.Unlock() + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() + return + } + + if e.bound { + // If bound to a NIC, only accept data for that NIC. + if e.BindNICID != 0 && e.BindNICID != route.NICID() { + e.rcvMu.Unlock() + return + } + // If bound to an address, only accept data for that address. + if e.BindAddr != "" && e.BindAddr != route.RemoteAddress { + e.rcvMu.Unlock() + return + } + } + + // If connected, only accept packets from the remote address we + // connected to. + if e.connected && e.route.RemoteAddress != route.RemoteAddress { + e.rcvMu.Unlock() + return + } + + wasEmpty := e.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + packet := &rawPacket{ + senderAddr: tcpip.FullAddress{ + NIC: route.NICID(), + Addr: route.RemoteAddress, + }, + } + + // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. + // We copy headers' underlying bytes because pkt.*Header may point to + // the middle of a slice, and another struct may point to the "outer" + // slice. Save/restore doesn't support overlapping slices and will fail. + var combinedVV buffer.VectorisedView + if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { + headers := make(buffer.View, 0, len(pkt.NetworkHeader)+len(pkt.TransportHeader)) + headers = append(headers, pkt.NetworkHeader...) + headers = append(headers, pkt.TransportHeader...) + combinedVV = headers.ToVectorisedView() + } else { + combinedVV = append(buffer.View(nil), pkt.TransportHeader...).ToVectorisedView() + } + combinedVV.Append(pkt.Data) + packet.data = combinedVV + packet.timestampNS = e.stack.NowNanoseconds() + + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() + e.rcvMu.Unlock() + e.stats.PacketsReceived.Increment() + // Notify waiters that there's data to be read. + if wasEmpty { + e.waiterQueue.Notify(waiter.EventIn) + } +} + +// State implements socket.Socket.State. +func (e *endpoint) State() uint32 { + return 0 +} + +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + +// Wait implements stack.TransportEndpoint.Wait. +func (*endpoint) Wait() {} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go new file mode 100644 index 000000000..33bfb56cd --- /dev/null +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -0,0 +1,94 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package raw + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// saveData saves rawPacket.data field. +func (p *rawPacket) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, + // which is not allowed by state framework (in-struct pointer). + return p.data.Clone(nil) +} + +// loadData loads rawPacket.data field. +func (p *rawPacket) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing p.views for data.views. + p.data = data +} + +// beforeSave is invoked by stateify. +func (ep *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after saveRcvBufSizeMax(), which would have + // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + ep.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) saveRcvBufSizeMax() int { + max := ep.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + ep.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + ep.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) loadRcvBufSizeMax(max int) { + ep.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (ep *endpoint) afterLoad() { + stack.StackFromEnv.RegisterRestoredEndpoint(ep) +} + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (ep *endpoint) Resume(s *stack.Stack) { + ep.stack = s + + // If the endpoint is connected, re-connect. + if ep.connected { + var err *tcpip.Error + ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false) + if err != nil { + panic(err) + } + } + + // If the endpoint is bound, re-bind. + if ep.bound { + if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + if ep.associated { + if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { + panic(err) + } + } +} diff --git a/pkg/tcpip/transport/raw/protocol.go b/pkg/tcpip/transport/raw/protocol.go new file mode 100644 index 000000000..f30aa2a4a --- /dev/null +++ b/pkg/tcpip/transport/raw/protocol.go @@ -0,0 +1,35 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package raw + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/packet" + "gvisor.dev/gvisor/pkg/waiter" +) + +// EndpointFactory implements stack.RawFactory. +type EndpointFactory struct{} + +// NewUnassociatedEndpoint implements stack.RawFactory.NewUnassociatedEndpoint. +func (EndpointFactory) NewUnassociatedEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, transProto, waiterQueue, false /* associated */) +} + +// NewPacketEndpoint implements stack.RawFactory.NewPacketEndpoint. +func (EndpointFactory) NewPacketEndpoint(stack *stack.Stack, cooked bool, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return packet.NewEndpoint(stack, cooked, netProto, waiterQueue) +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD new file mode 100644 index 000000000..18ff89ffc --- /dev/null +++ b/pkg/tcpip/transport/tcp/BUILD @@ -0,0 +1,126 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "tcp_segment_list", + out = "tcp_segment_list.go", + package = "tcp", + prefix = "segment", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*segment", + "Linker": "*segment", + }, +) + +go_template_instance( + name = "tcp_endpoint_list", + out = "tcp_endpoint_list.go", + package = "tcp", + prefix = "endpoint", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*endpoint", + "Linker": "*endpoint", + }, +) + +go_library( + name = "tcp", + srcs = [ + "accept.go", + "connect.go", + "connect_unsafe.go", + "cubic.go", + "cubic_state.go", + "dispatcher.go", + "endpoint.go", + "endpoint_state.go", + "forwarder.go", + "protocol.go", + "rcv.go", + "rcv_state.go", + "reno.go", + "sack.go", + "sack_scoreboard.go", + "segment.go", + "segment_heap.go", + "segment_queue.go", + "segment_state.go", + "snd.go", + "snd_state.go", + "tcp_endpoint_list.go", + "tcp_segment_list.go", + "timer.go", + ], + imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/log", + "//pkg/rand", + "//pkg/sleep", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/hash/jenkins", + "//pkg/tcpip/header", + "//pkg/tcpip/ports", + "//pkg/tcpip/seqnum", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/raw", + "//pkg/waiter", + "@com_github_google_btree//:go_default_library", + ], +) + +go_test( + name = "tcp_x_test", + size = "medium", + srcs = [ + "dual_stack_test.go", + "sack_scoreboard_test.go", + "tcp_noracedetector_test.go", + "tcp_sack_test.go", + "tcp_test.go", + "tcp_timestamp_test.go", + ], + shard_count = 10, + deps = [ + ":tcp", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/ports", + "//pkg/tcpip/seqnum", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp/testing/context", + "//pkg/test/testutil", + "//pkg/waiter", + ], +) + +go_test( + name = "rcv_test", + size = "small", + srcs = ["rcv_test.go"], + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + ], +) + +go_test( + name = "tcp_test", + size = "small", + srcs = ["timer_test.go"], + library = ":tcp", + deps = ["//pkg/sleep"], +) diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go new file mode 100644 index 000000000..6e00e5526 --- /dev/null +++ b/pkg/tcpip/transport/tcp/accept.go @@ -0,0 +1,752 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "crypto/sha1" + "encoding/binary" + "fmt" + "hash" + "io" + "time" + + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // tsLen is the length, in bits, of the timestamp in the SYN cookie. + tsLen = 8 + + // tsMask is a mask for timestamp values (i.e., tsLen bits). + tsMask = (1 << tsLen) - 1 + + // tsOffset is the offset, in bits, of the timestamp in the SYN cookie. + tsOffset = 24 + + // hashMask is the mask for hash values (i.e., tsOffset bits). + hashMask = (1 << tsOffset) - 1 + + // maxTSDiff is the maximum allowed difference between a received cookie + // timestamp and the current timestamp. If the difference is greater + // than maxTSDiff, the cookie is expired. + maxTSDiff = 2 + + // SynRcvdCountThreshold is the default global maximum number of + // connections that are allowed to be in SYN-RCVD state before TCP + // starts using SYN cookies to accept connections. + SynRcvdCountThreshold uint64 = 1000 +) + +var ( + // mssTable is a slice containing the possible MSS values that we + // encode in the SYN cookie with two bits. + mssTable = []uint16{536, 1300, 1440, 1460} +) + +func encodeMSS(mss uint16) uint32 { + for i := len(mssTable) - 1; i > 0; i-- { + if mss >= mssTable[i] { + return uint32(i) + } + } + return 0 +} + +// listenContext is used by a listening endpoint to store state used while +// listening for connections. This struct is allocated by the listen goroutine +// and must not be accessed or have its methods called concurrently as they +// may mutate the stored objects. +type listenContext struct { + stack *stack.Stack + + // synRcvdCount is a reference to the stack level synRcvdCount. + synRcvdCount *synRcvdCounter + + // rcvWnd is the receive window that is sent by this listening context + // in the initial SYN-ACK. + rcvWnd seqnum.Size + + // nonce are random bytes that are initialized once when the context + // is created and used to seed the hash function when generating + // the SYN cookie. + nonce [2][sha1.BlockSize]byte + + // listenEP is a reference to the listening endpoint associated with + // this context. Can be nil if the context is created by the forwarder. + listenEP *endpoint + + // hasherMu protects hasher. + hasherMu sync.Mutex + // hasher is the hash function used to generate a SYN cookie. + hasher hash.Hash + + // v6Only is true if listenEP is a dual stack socket and has the + // IPV6_V6ONLY option set. + v6Only bool + + // netProto indicates the network protocol(IPv4/v6) for the listening + // endpoint. + netProto tcpip.NetworkProtocolNumber + + // pendingMu protects pendingEndpoints. This should only be accessed + // by the listening endpoint's worker goroutine. + // + // Lock Ordering: listenEP.workerMu -> pendingMu + pendingMu sync.Mutex + // pending is used to wait for all pendingEndpoints to finish when + // a socket is closed. + pending sync.WaitGroup + // pendingEndpoints is a map of all endpoints for which a handshake is + // in progress. + pendingEndpoints map[stack.TransportEndpointID]*endpoint +} + +// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. +func timeStamp() uint32 { + return uint32(time.Now().Unix()>>6) & tsMask +} + +// newListenContext creates a new listen context. +func newListenContext(stk *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6Only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { + l := &listenContext{ + stack: stk, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6Only: v6Only, + netProto: netProto, + listenEP: listenEP, + pendingEndpoints: make(map[stack.TransportEndpointID]*endpoint), + } + p, ok := stk.TransportProtocolInstance(ProtocolNumber).(*protocol) + if !ok { + panic(fmt.Sprintf("unable to get TCP protocol instance from stack: %+v", stk)) + } + l.synRcvdCount = p.SynRcvdCounter() + + rand.Read(l.nonce[0][:]) + rand.Read(l.nonce[1][:]) + + return l +} + +// cookieHash calculates the cookieHash for the given id, timestamp and nonce +// index. The hash is used to create and validate cookies. +func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 { + + // Initialize block with fixed-size data: local ports and v. + var payload [8]byte + binary.BigEndian.PutUint16(payload[0:], id.LocalPort) + binary.BigEndian.PutUint16(payload[2:], id.RemotePort) + binary.BigEndian.PutUint32(payload[4:], ts) + + // Feed everything to the hasher. + l.hasherMu.Lock() + l.hasher.Reset() + l.hasher.Write(payload[:]) + l.hasher.Write(l.nonce[nonceIndex][:]) + io.WriteString(l.hasher, string(id.LocalAddress)) + io.WriteString(l.hasher, string(id.RemoteAddress)) + + // Finalize the calculation of the hash and return the first 4 bytes. + h := make([]byte, 0, sha1.Size) + h = l.hasher.Sum(h) + l.hasherMu.Unlock() + + return binary.BigEndian.Uint32(h[:]) +} + +// createCookie creates a SYN cookie for the given id and incoming sequence +// number. +func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value { + ts := timeStamp() + v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset) + v += (l.cookieHash(id, ts, 1) + data) & hashMask + return seqnum.Value(v) +} + +// isCookieValid checks if the supplied cookie is valid for the given id and +// sequence number. If it is, it also returns the data originally encoded in the +// cookie when createCookie was called. +func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) { + ts := timeStamp() + v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq) + cookieTS := v >> tsOffset + if ((ts - cookieTS) & tsMask) > maxTSDiff { + return 0, false + } + + return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true +} + +// createConnectingEndpoint creates a new endpoint in a connecting state, with +// the connection parameters given by the arguments. +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint { + // Create a new endpoint. + netProto := l.netProto + if netProto == 0 { + netProto = s.route.NetProto + } + n := newEndpoint(l.stack, netProto, queue) + n.v6only = l.v6Only + n.ID = s.id + n.boundNICID = s.route.NICID() + n.route = s.route.Clone() + n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} + n.rcvBufSize = int(l.rcvWnd) + n.amss = mssForRoute(&n.route) + n.setEndpointState(StateConnecting) + + n.maybeEnableTimestamp(rcvdSynOpts) + n.maybeEnableSACKPermitted(rcvdSynOpts) + + n.initGSO() + + // Bootstrap the auto tuning algorithm. Starting at zero will result in + // a large step function on the first window adjustment causing the + // window to grow to a really large value. + n.rcvAutoParams.prevCopied = n.initialReceiveWindow() + + return n +} + +// createEndpointAndPerformHandshake creates a new endpoint in connected state +// and then performs the TCP 3-way handshake. +// +// The new endpoint is returned with e.mu held. +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions, queue *waiter.Queue, owner tcpip.PacketOwner) (*endpoint, *tcpip.Error) { + // Create new endpoint. + irs := s.sequenceNumber + isn := generateSecureISN(s.id, l.stack.Seed()) + ep := l.createConnectingEndpoint(s, isn, irs, opts, queue) + + // Lock the endpoint before registering to ensure that no out of + // band changes are possible due to incoming packets etc till + // the endpoint is done initializing. + ep.mu.Lock() + ep.owner = owner + + // listenEP is nil when listenContext is used by tcp.Forwarder. + deferAccept := time.Duration(0) + if l.listenEP != nil { + l.listenEP.mu.Lock() + if l.listenEP.EndpointState() != StateListen { + + l.listenEP.mu.Unlock() + // Ensure we release any registrations done by the newly + // created endpoint. + ep.mu.Unlock() + ep.Close() + + return nil, tcpip.ErrConnectionAborted + } + l.addPendingEndpoint(ep) + + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + l.listenEP.propagateInheritableOptionsLocked(ep) + + if !ep.reserveTupleLocked() { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + l.listenEP.mu.Unlock() + } + + return nil, tcpip.ErrConnectionAborted + } + + deferAccept = l.listenEP.deferAccept + l.listenEP.mu.Unlock() + } + + // Register new endpoint so that packets are routed to it. + if err := ep.stack.RegisterTransportEndpoint(ep.boundNICID, ep.effectiveNetProtos, ProtocolNumber, ep.ID, ep, ep.boundPortFlags, ep.boundBindToDevice); err != nil { + ep.mu.Unlock() + ep.Close() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } + + ep.drainClosingSegmentQueue() + + return nil, err + } + + ep.isRegistered = true + + // Perform the 3-way handshake. + h := newPassiveHandshake(ep, seqnum.Size(ep.initialReceiveWindow()), isn, irs, opts, deferAccept) + if err := h.execute(); err != nil { + ep.mu.Unlock() + ep.Close() + ep.notifyAborted() + + if l.listenEP != nil { + l.removePendingEndpoint(ep) + } + + ep.drainClosingSegmentQueue() + + return nil, err + } + ep.isConnectNotified = true + + // Update the receive window scaling. We can't do it before the + // handshake because it's possible that the peer doesn't support window + // scaling. + ep.rcv.rcvWndScale = h.effectiveRcvWndScale() + + return ep, nil +} + +func (l *listenContext) addPendingEndpoint(n *endpoint) { + l.pendingMu.Lock() + l.pendingEndpoints[n.ID] = n + l.pending.Add(1) + l.pendingMu.Unlock() +} + +func (l *listenContext) removePendingEndpoint(n *endpoint) { + l.pendingMu.Lock() + delete(l.pendingEndpoints, n.ID) + l.pending.Done() + l.pendingMu.Unlock() +} + +func (l *listenContext) closeAllPendingEndpoints() { + l.pendingMu.Lock() + for _, n := range l.pendingEndpoints { + n.notifyProtocolGoroutine(notifyClose) + } + l.pendingMu.Unlock() + l.pending.Wait() +} + +// deliverAccepted delivers the newly-accepted endpoint to the listener. If the +// endpoint has transitioned out of the listen state (acceptedChan is nil), +// the new endpoint is closed instead. +func (e *endpoint) deliverAccepted(n *endpoint) { + e.mu.Lock() + e.pendingAccepted.Add(1) + e.mu.Unlock() + defer e.pendingAccepted.Done() + + e.acceptMu.Lock() + for { + if e.acceptedChan == nil { + e.acceptMu.Unlock() + n.notifyProtocolGoroutine(notifyReset) + return + } + select { + case e.acceptedChan <- n: + e.acceptMu.Unlock() + e.waiterQueue.Notify(waiter.EventIn) + return + default: + e.acceptCond.Wait() + } + } +} + +// propagateInheritableOptionsLocked propagates any options set on the listening +// endpoint to the newly created endpoint. +// +// Precondition: e.mu and n.mu must be held. +func (e *endpoint) propagateInheritableOptionsLocked(n *endpoint) { + n.userTimeout = e.userTimeout + n.portFlags = e.portFlags + n.boundBindToDevice = e.boundBindToDevice + n.boundPortFlags = e.boundPortFlags +} + +// reserveTupleLocked reserves an accepted endpoint's tuple. +// +// Preconditions: +// * propagateInheritableOptionsLocked has been called. +// * e.mu is held. +func (e *endpoint) reserveTupleLocked() bool { + dest := tcpip.FullAddress{Addr: e.ID.RemoteAddress, Port: e.ID.RemotePort} + if !e.stack.ReserveTuple( + e.effectiveNetProtos, + ProtocolNumber, + e.ID.LocalAddress, + e.ID.LocalPort, + e.boundPortFlags, + e.boundBindToDevice, + dest, + ) { + return false + } + + e.isPortReserved = true + e.boundDest = dest + return true +} + +// notifyAborted wakes up any waiters on registered, but not accepted +// endpoints. +// +// This is strictly not required normally as a socket that was never accepted +// can't really have any registered waiters except when stack.Wait() is called +// which waits for all registered endpoints to stop and expects an EventHUp. +func (e *endpoint) notifyAborted() { + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// handleSynSegment is called in its own goroutine once the listening endpoint +// receives a SYN segment. It is responsible for completing the handshake and +// queueing the new endpoint for acceptance. +// +// A limited number of these goroutines are allowed before TCP starts using SYN +// cookies to accept connections. +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { + defer ctx.synRcvdCount.dec() + defer func() { + e.mu.Lock() + e.decSynRcvdCount() + e.mu.Unlock() + }() + defer s.decRef() + + n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) + if err != nil { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + return + } + ctx.removePendingEndpoint(n) + n.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + + e.deliverAccepted(n) +} + +func (e *endpoint) incSynRcvdCount() bool { + e.acceptMu.Lock() + canInc := e.synRcvdCount < cap(e.acceptedChan) + e.acceptMu.Unlock() + if canInc { + e.synRcvdCount++ + } + return canInc +} + +func (e *endpoint) decSynRcvdCount() { + e.synRcvdCount-- +} + +func (e *endpoint) acceptQueueIsFull() bool { + e.acceptMu.Lock() + full := len(e.acceptedChan)+e.synRcvdCount >= cap(e.acceptedChan) + e.acceptMu.Unlock() + return full +} + +// handleListenSegment is called when a listening endpoint receives a segment +// and needs to handle it. +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { + e.rcvListMu.Lock() + rcvClosed := e.rcvClosed + e.rcvListMu.Unlock() + if rcvClosed || s.flagsAreSet(header.TCPFlagSyn|header.TCPFlagAck) { + // If the endpoint is shutdown, reply with reset. + // + // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST + // must be sent in response to a SYN-ACK while in the listen + // state to prevent completing a handshake from an old SYN. + replyWithReset(s, e.sendTOS, e.ttl) + return + } + + // TODO(b/143300739): Use the userMSS of the listening socket + // for accepted sockets. + + switch { + case s.flags == header.TCPFlagSyn: + opts := parseSynSegmentOptions(s) + if ctx.synRcvdCount.inc() { + // Only handle the syn if the following conditions hold + // - accept queue is not full. + // - number of connections in synRcvd state is less than the + // backlog. + if !e.acceptQueueIsFull() && e.incSynRcvdCount() { + s.incRef() + go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. + return + } + ctx.synRcvdCount.dec() + e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() + e.stack.Stats().DroppedPackets.Increment() + return + } else { + // If cookies are in use but the endpoint accept queue + // is full then drop the syn. + if e.acceptQueueIsFull() { + e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() + e.stack.Stats().DroppedPackets.Increment() + return + } + cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + + // Send SYN without window scaling because we currently + // dont't encode this information in the cookie. + // + // Enable Timestamp option if the original syn did have + // the timestamp option specified. + synOpts := header.TCPSynOptions{ + WS: -1, + TS: opts.TS, + TSVal: tcpTimeStamp(timeStampOffset()), + TSEcr: opts.TSVal, + MSS: mssForRoute(&s.route), + } + e.sendSynTCP(&s.route, tcpFields{ + id: s.id, + ttl: e.ttl, + tos: e.sendTOS, + flags: header.TCPFlagSyn | header.TCPFlagAck, + seq: cookie, + ack: s.sequenceNumber + 1, + rcvWnd: ctx.rcvWnd, + }, synOpts) + e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() + } + + case (s.flags & header.TCPFlagAck) != 0: + if e.acceptQueueIsFull() { + // Silently drop the ack as the application can't accept + // the connection at this point. The ack will be + // retransmitted by the sender anyway and we can + // complete the connection at the time of retransmit if + // the backlog has space. + e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() + e.stack.Stats().DroppedPackets.Increment() + return + } + + if !ctx.synRcvdCount.synCookiesInUse() { + // When not using SYN cookies, as per RFC 793, section 3.9, page 64: + // Any acknowledgment is bad if it arrives on a connection still in + // the LISTEN state. An acceptable reset segment should be formed + // for any arriving ACK-bearing segment. The RST should be + // formatted as follows: + // + // <SEQ=SEG.ACK><CTL=RST> + // + // Send a reset as this is an ACK for which there is no + // half open connections and we are not using cookies + // yet. + // + // The only time we should reach here when a connection + // was opened and closed really quickly and a delayed + // ACK was received from the sender. + replyWithReset(s, e.sendTOS, e.ttl) + return + } + + iss := s.ackNumber - 1 + irs := s.sequenceNumber - 1 + + // Since SYN cookies are in use this is potentially an ACK to a + // SYN-ACK we sent but don't have a half open connection state + // as cookies are being used to protect against a potential SYN + // flood. In such cases validate the cookie and if valid create + // a fully connected endpoint and deliver to the accept queue. + // + // If not, silently drop the ACK to avoid leaking information + // when under a potential syn flood attack. + // + // Validate the cookie. + data, ok := ctx.isCookieValid(s.id, iss, irs) + if !ok || int(data) >= len(mssTable) { + e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() + e.stack.Stats().DroppedPackets.Increment() + return + } + e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() + // Create newly accepted endpoint and deliver it. + rcvdSynOptions := &header.TCPSynOptions{ + MSS: mssTable[data], + // Disable Window scaling as original SYN is + // lost. + WS: -1, + } + + // When syn cookies are in use we enable timestamp only + // if the ack specifies the timestamp option assuming + // that the other end did in fact negotiate the + // timestamp option in the original SYN. + if s.parsedOptions.TS { + rcvdSynOptions.TS = true + rcvdSynOptions.TSVal = s.parsedOptions.TSVal + rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr + } + + n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + + n.mu.Lock() + + // Propagate any inheritable options from the listening endpoint + // to the newly created endpoint. + e.propagateInheritableOptionsLocked(n) + + if !n.reserveTupleLocked() { + n.mu.Unlock() + n.Close() + + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + return + } + + // Register new endpoint so that packets are routed to it. + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.ID, n, n.boundPortFlags, n.boundBindToDevice); err != nil { + n.mu.Unlock() + n.Close() + + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + return + } + + n.isRegistered = true + + // clear the tsOffset for the newly created + // endpoint as the Timestamp was already + // randomly offset when the original SYN-ACK was + // sent above. + n.tsOffset = 0 + + // Switch state to connected. + n.isConnectNotified = true + n.transitionToStateEstablishedLocked(&handshake{ + ep: n, + iss: iss, + ackNum: irs + 1, + rcvWnd: seqnum.Size(n.initialReceiveWindow()), + sndWnd: s.window, + rcvWndScale: e.rcvWndScaleForHandshake(), + sndWndScale: rcvdSynOptions.WS, + mss: rcvdSynOptions.MSS, + }) + + // Do the delivery in a separate goroutine so + // that we don't block the listen loop in case + // the application is slow to accept or stops + // accepting. + // + // NOTE: This won't result in an unbounded + // number of goroutines as we do check before + // entering here that there was at least some + // space available in the backlog. + + // Start the protocol goroutine. + n.startAcceptedLoop() + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + go e.deliverAccepted(n) + } +} + +// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in +// its own goroutine and is responsible for handling connection requests. +func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { + e.mu.Lock() + v6Only := e.v6only + ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) + + defer func() { + // Mark endpoint as closed. This will prevent goroutines running + // handleSynSegment() from attempting to queue new connections + // to the endpoint. + e.setEndpointState(StateClose) + + // close any endpoints in SYN-RCVD state. + ctx.closeAllPendingEndpoints() + + // Do cleanup if needed. + e.completeWorkerLocked() + + if e.drainDone != nil { + close(e.drainDone) + } + e.mu.Unlock() + + e.drainClosingSegmentQueue() + + // Notify waiters that the endpoint is shutdown. + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) + }() + + s := sleep.Sleeper{} + s.AddWaker(&e.notificationWaker, wakerForNotification) + s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) + for { + e.mu.Unlock() + index, _ := s.Fetch(true) + e.mu.Lock() + switch index { + case wakerForNotification: + n := e.fetchNotifications() + if n¬ifyClose != 0 { + return nil + } + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + s := e.segmentQueue.dequeue() + e.handleListenSegment(ctx, s) + s.decRef() + } + close(e.drainDone) + e.mu.Unlock() + <-e.undrain + e.mu.Lock() + } + + case wakerForNewSegment: + // Process at most maxSegmentsPerWake segments. + mayRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + mayRequeue = false + break + } + + e.handleListenSegment(ctx, s) + s.decRef() + } + + // If the queue is not empty, make sure we'll wake up + // in the next iteration. + if mayRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + } + } +} diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go new file mode 100644 index 000000000..81b740115 --- /dev/null +++ b/pkg/tcpip/transport/tcp/connect.go @@ -0,0 +1,1713 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "encoding/binary" + "time" + + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// maxSegmentsPerWake is the maximum number of segments to process in the main +// protocol goroutine per wake-up. Yielding [after this number of segments are +// processed] allows other events to be processed as well (e.g., timeouts, +// resets, etc.). +const maxSegmentsPerWake = 100 + +type handshakeState int + +// The following are the possible states of the TCP connection during a 3-way +// handshake. A depiction of the states and transitions can be found in RFC 793, +// page 23. +const ( + handshakeSynSent handshakeState = iota + handshakeSynRcvd + handshakeCompleted +) + +// The following are used to set up sleepers. +const ( + wakerForNotification = iota + wakerForNewSegment + wakerForResend + wakerForResolution +) + +const ( + // Maximum space available for options. + maxOptionSize = 40 +) + +// handshake holds the state used during a TCP 3-way handshake. +// +// NOTE: handshake.ep.mu is held during handshake processing. It is released if +// we are going to block and reacquired when we start processing an event. +type handshake struct { + ep *endpoint + state handshakeState + active bool + flags uint8 + ackNum seqnum.Value + + // iss is the initial send sequence number, as defined in RFC 793. + iss seqnum.Value + + // rcvWnd is the receive window, as defined in RFC 793. + rcvWnd seqnum.Size + + // sndWnd is the send window, as defined in RFC 793. + sndWnd seqnum.Size + + // mss is the maximum segment size received from the peer. + mss uint16 + + // sndWndScale is the send window scale, as defined in RFC 1323. A + // negative value means no scaling is supported by the peer. + sndWndScale int + + // rcvWndScale is the receive window scale, as defined in RFC 1323. + rcvWndScale int + + // startTime is the time at which the first SYN/SYN-ACK was sent. + startTime time.Time + + // deferAccept if non-zero will drop the final ACK for a passive + // handshake till an ACK segment with data is received or the timeout is + // hit. + deferAccept time.Duration + + // acked is true if the the final ACK for a 3-way handshake has + // been received. This is required to stop retransmitting the + // original SYN-ACK when deferAccept is enabled. + acked bool +} + +func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { + h := handshake{ + ep: ep, + active: true, + rcvWnd: rcvWnd, + rcvWndScale: ep.rcvWndScaleForHandshake(), + } + h.resetState() + return h +} + +func newPassiveHandshake(ep *endpoint, rcvWnd seqnum.Size, isn, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) handshake { + h := newHandshake(ep, rcvWnd) + h.resetToSynRcvd(isn, irs, opts, deferAccept) + return h +} + +// FindWndScale determines the window scale to use for the given maximum window +// size. +func FindWndScale(wnd seqnum.Size) int { + if wnd < 0x10000 { + return 0 + } + + max := seqnum.Size(0xffff) + s := 0 + for wnd > max && s < header.MaxWndScale { + s++ + max <<= 1 + } + + return s +} + +// resetState resets the state of the handshake object such that it becomes +// ready for a new 3-way handshake. +func (h *handshake) resetState() { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + + h.state = handshakeSynSent + h.flags = header.TCPFlagSyn + h.ackNum = 0 + h.mss = 0 + h.iss = generateSecureISN(h.ep.ID, h.ep.stack.Seed()) +} + +// generateSecureISN generates a secure Initial Sequence number based on the +// recommendation here https://tools.ietf.org/html/rfc6528#page-3. +func generateSecureISN(id stack.TransportEndpointID, seed uint32) seqnum.Value { + isnHasher := jenkins.Sum32(seed) + isnHasher.Write([]byte(id.LocalAddress)) + isnHasher.Write([]byte(id.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, id.LocalPort) + isnHasher.Write(portBuf) + binary.LittleEndian.PutUint16(portBuf, id.RemotePort) + isnHasher.Write(portBuf) + // The time period here is 64ns. This is similar to what linux uses + // generate a sequence number that overlaps less than one + // time per MSL (2 minutes). + // + // A 64ns clock ticks 10^9/64 = 15625000) times in a second. + // To wrap the whole 32 bit space would require + // 2^32/1562500 ~ 274 seconds. + // + // Which sort of guarantees that we won't reuse the ISN for a new + // connection for the same tuple for at least 274s. + isn := isnHasher.Sum32() + uint32(time.Now().UnixNano()>>6) + return seqnum.Value(isn) +} + +// effectiveRcvWndScale returns the effective receive window scale to be used. +// If the peer doesn't support window scaling, the effective rcv wnd scale is +// zero; otherwise it's the value calculated based on the initial rcv wnd. +func (h *handshake) effectiveRcvWndScale() uint8 { + if h.sndWndScale < 0 { + return 0 + } + return uint8(h.rcvWndScale) +} + +// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD +// state. +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, deferAccept time.Duration) { + h.active = false + h.state = handshakeSynRcvd + h.flags = header.TCPFlagSyn | header.TCPFlagAck + h.iss = iss + h.ackNum = irs + 1 + h.mss = opts.MSS + h.sndWndScale = opts.WS + h.deferAccept = deferAccept + h.ep.setEndpointState(StateSynRecv) +} + +// checkAck checks if the ACK number, if present, of a segment received during +// a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in +// response. +func (h *handshake) checkAck(s *segment) bool { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber != h.iss+1 { + // RFC 793, page 36, states that a reset must be generated when + // the connection is in any non-synchronized state and an + // incoming segment acknowledges something not yet sent. The + // connection remains in the same state. + ack := s.sequenceNumber.Add(s.logicalLen()) + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, s.ackNumber, ack, 0) + return false + } + + return true +} + +// synSentState handles a segment received when the TCP 3-way handshake is in +// the SYN-SENT state. +func (h *handshake) synSentState(s *segment) *tcpip.Error { + // RFC 793, page 37, states that in the SYN-SENT state, a reset is + // acceptable if the ack field acknowledges the SYN. + if s.flagIsSet(header.TCPFlagRst) { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 { + // RFC 793, page 67, states that "If the RST bit is set [and] If the ACK + // was acceptable then signal the user "error: connection reset", drop + // the segment, enter CLOSED state, delete TCB, and return." + h.ep.workerCleanup = true + // Although the RFC above calls out ECONNRESET, Linux actually returns + // ECONNREFUSED here so we do as well. + return tcpip.ErrConnectionRefused + } + return nil + } + + if !h.checkAck(s) { + return nil + } + + // We are in the SYN-SENT state. We only care about segments that have + // the SYN flag. + if !s.flagIsSet(header.TCPFlagSyn) { + return nil + } + + // Parse the SYN options. + rcvSynOpts := parseSynSegmentOptions(s) + + // Remember if the Timestamp option was negotiated. + h.ep.maybeEnableTimestamp(&rcvSynOpts) + + // Remember if the SACKPermitted option was negotiated. + h.ep.maybeEnableSACKPermitted(&rcvSynOpts) + + // Remember the sequence we'll ack from now on. + h.ackNum = s.sequenceNumber + 1 + h.flags |= header.TCPFlagAck + h.mss = rcvSynOpts.MSS + h.sndWndScale = rcvSynOpts.WS + + // If this is a SYN ACK response, we only need to acknowledge the SYN + // and the handshake is completed. + if s.flagIsSet(header.TCPFlagAck) { + h.state = handshakeCompleted + + h.ep.transitionToStateEstablishedLocked(h) + + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) + return nil + } + + // A SYN segment was received, but no ACK in it. We acknowledge the SYN + // but resend our own SYN and wait for it to be acknowledged in the + // SYN-RCVD state. + h.state = handshakeSynRcvd + ttl := h.ep.ttl + amss := h.ep.amss + h.ep.setEndpointState(StateSynRecv) + synOpts := header.TCPSynOptions{ + WS: int(h.effectiveRcvWndScale()), + TS: rcvSynOpts.TS, + TSVal: h.ep.timestamp(), + TSEcr: h.ep.recentTimestamp(), + + // We only send SACKPermitted if the other side indicated it + // permits SACK. This is not explicitly defined in the RFC but + // this is the behaviour implemented by Linux. + SACKPermitted: rcvSynOpts.SACKPermitted, + MSS: amss, + } + if ttl == 0 { + ttl = s.route.DefaultTTL() + } + h.ep.sendSynTCP(&s.route, tcpFields{ + id: h.ep.ID, + ttl: ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) + return nil +} + +// synRcvdState handles a segment received when the TCP 3-way handshake is in +// the SYN-RCVD state. +func (h *handshake) synRcvdState(s *segment) *tcpip.Error { + if s.flagIsSet(header.TCPFlagRst) { + // RFC 793, page 37, states that in the SYN-RCVD state, a reset + // is acceptable if the sequence number is in the window. + if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { + return tcpip.ErrConnectionRefused + } + return nil + } + + if !h.checkAck(s) { + return nil + } + + // RFC 793, Section 3.9, page 69, states that in the SYN-RCVD state, a + // sequence number outside of the window causes an ACK with the proper seq + // number and "After sending the acknowledgment, drop the unacceptable + // segment and return." + if !s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd) + return nil + } + + if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { + // We received two SYN segments with different sequence + // numbers, so we reset this and restart the whole + // process, except that we don't reset the timer. + ack := s.sequenceNumber.Add(s.logicalLen()) + seq := seqnum.Value(0) + if s.flagIsSet(header.TCPFlagAck) { + seq = s.ackNumber + } + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0) + + if !h.active { + return tcpip.ErrInvalidEndpointState + } + + h.resetState() + synOpts := header.TCPSynOptions{ + WS: h.rcvWndScale, + TS: h.ep.sendTSOk, + TSVal: h.ep.timestamp(), + TSEcr: h.ep.recentTimestamp(), + SACKPermitted: h.ep.sackPermitted, + MSS: h.ep.amss, + } + h.ep.sendSynTCP(&s.route, tcpFields{ + id: h.ep.ID, + ttl: h.ep.ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) + return nil + } + + // We have previously received (and acknowledged) the peer's SYN. If the + // peer acknowledges our SYN, the handshake is completed. + if s.flagIsSet(header.TCPFlagAck) { + // If deferAccept is not zero and this is a bare ACK and the + // timeout is not hit then drop the ACK. + if h.deferAccept != 0 && s.data.Size() == 0 && time.Since(h.startTime) < h.deferAccept { + h.acked = true + h.ep.stack.Stats().DroppedPackets.Increment() + return nil + } + + // If the timestamp option is negotiated and the segment does + // not carry a timestamp option then the segment must be dropped + // as per https://tools.ietf.org/html/rfc7323#section-3.2. + if h.ep.sendTSOk && !s.parsedOptions.TS { + h.ep.stack.Stats().DroppedPackets.Increment() + return nil + } + + // Update timestamp if required. See RFC7323, section-4.3. + if h.ep.sendTSOk && s.parsedOptions.TS { + h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) + } + h.state = handshakeCompleted + + h.ep.transitionToStateEstablishedLocked(h) + + // If the segment has data then requeue it for the receiver + // to process it again once main loop is started. + if s.data.Size() > 0 { + s.incRef() + h.ep.enqueueSegment(s) + } + return nil + } + + return nil +} + +func (h *handshake) handleSegment(s *segment) *tcpip.Error { + h.sndWnd = s.window + if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 { + h.sndWnd <<= uint8(h.sndWndScale) + } + + switch h.state { + case handshakeSynRcvd: + return h.synRcvdState(s) + case handshakeSynSent: + return h.synSentState(s) + } + return nil +} + +// processSegments goes through the segment queue and processes up to +// maxSegmentsPerWake (if they're available). +func (h *handshake) processSegments() *tcpip.Error { + for i := 0; i < maxSegmentsPerWake; i++ { + s := h.ep.segmentQueue.dequeue() + if s == nil { + return nil + } + + err := h.handleSegment(s) + s.decRef() + if err != nil { + return err + } + + // We stop processing packets once the handshake is completed, + // otherwise we may process packets meant to be processed by + // the main protocol goroutine. + if h.state == handshakeCompleted { + break + } + } + + // If the queue is not empty, make sure we'll wake up in the next + // iteration. + if !h.ep.segmentQueue.empty() { + h.ep.newSegmentWaker.Assert() + } + + return nil +} + +func (h *handshake) resolveRoute() *tcpip.Error { + // Set up the wakers. + s := sleep.Sleeper{} + resolutionWaker := &sleep.Waker{} + s.AddWaker(resolutionWaker, wakerForResolution) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + defer s.Done() + + // Initial action is to resolve route. + index := wakerForResolution + for { + switch index { + case wakerForResolution: + if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { + if err == tcpip.ErrNoLinkAddress { + h.ep.stats.SendErrors.NoLinkAddr.Increment() + } else if err != nil { + h.ep.stats.SendErrors.NoRoute.Increment() + } + // Either success (err == nil) or failure. + return err + } + // Resolution not completed. Keep trying... + + case wakerForNotification: + n := h.ep.fetchNotifications() + if n¬ifyClose != 0 { + h.ep.route.RemoveWaker(resolutionWaker) + return tcpip.ErrAborted + } + if n¬ifyDrain != 0 { + close(h.ep.drainDone) + h.ep.mu.Unlock() + <-h.ep.undrain + h.ep.mu.Lock() + } + } + + // Wait for notification. + index, _ = s.Fetch(true) + } +} + +// execute executes the TCP 3-way handshake. +func (h *handshake) execute() *tcpip.Error { + if h.ep.route.IsResolutionRequired() { + if err := h.resolveRoute(); err != nil { + return err + } + } + + h.startTime = time.Now() + // Initialize the resend timer. + resendWaker := sleep.Waker{} + timeOut := time.Duration(time.Second) + rt := time.AfterFunc(timeOut, resendWaker.Assert) + defer rt.Stop() + + // Set up the wakers. + s := sleep.Sleeper{} + s.AddWaker(&resendWaker, wakerForResend) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + defer s.Done() + + var sackEnabled SACKEnabled + if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { + // If stack returned an error when checking for SACKEnabled + // status then just default to switching off SACK negotiation. + sackEnabled = false + } + + // Send the initial SYN segment and loop until the handshake is + // completed. + h.ep.amss = calculateAdvertisedMSS(h.ep.userMSS, h.ep.route) + + synOpts := header.TCPSynOptions{ + WS: h.rcvWndScale, + TS: true, + TSVal: h.ep.timestamp(), + TSEcr: h.ep.recentTimestamp(), + SACKPermitted: bool(sackEnabled), + MSS: h.ep.amss, + } + + // Execute is also called in a listen context so we want to make sure we + // only send the TS/SACK option when we received the TS/SACK in the + // initial SYN. + if h.state == handshakeSynRcvd { + synOpts.TS = h.ep.sendTSOk + synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) + if h.sndWndScale < 0 { + // Disable window scaling if the peer did not send us + // the window scaling option. + synOpts.WS = -1 + } + } + + h.ep.sendSynTCP(&h.ep.route, tcpFields{ + id: h.ep.ID, + ttl: h.ep.ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) + + for h.state != handshakeCompleted { + h.ep.mu.Unlock() + index, _ := s.Fetch(true) + h.ep.mu.Lock() + switch index { + + case wakerForResend: + timeOut *= 2 + if timeOut > MaxRTO { + return tcpip.ErrTimeout + } + rt.Reset(timeOut) + // Resend the SYN/SYN-ACK only if the following conditions hold. + // - It's an active handshake (deferAccept does not apply) + // - It's a passive handshake and we have not yet got the final-ACK. + // - It's a passive handshake and we got an ACK but deferAccept is + // enabled and we are now past the deferAccept duration. + // The last is required to provide a way for the peer to complete + // the connection with another ACK or data (as ACKs are never + // retransmitted on their own). + if h.active || !h.acked || h.deferAccept != 0 && time.Since(h.startTime) > h.deferAccept { + h.ep.sendSynTCP(&h.ep.route, tcpFields{ + id: h.ep.ID, + ttl: h.ep.ttl, + tos: h.ep.sendTOS, + flags: h.flags, + seq: h.iss, + ack: h.ackNum, + rcvWnd: h.rcvWnd, + }, synOpts) + } + + case wakerForNotification: + n := h.ep.fetchNotifications() + if (n¬ifyClose)|(n¬ifyAbort) != 0 { + return tcpip.ErrAborted + } + if n¬ifyDrain != 0 { + for !h.ep.segmentQueue.empty() { + s := h.ep.segmentQueue.dequeue() + err := h.handleSegment(s) + s.decRef() + if err != nil { + return err + } + if h.state == handshakeCompleted { + return nil + } + } + close(h.ep.drainDone) + h.ep.mu.Unlock() + <-h.ep.undrain + h.ep.mu.Lock() + } + + case wakerForNewSegment: + if err := h.processSegments(); err != nil { + return err + } + } + } + + return nil +} + +func parseSynSegmentOptions(s *segment) header.TCPSynOptions { + synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck)) + if synOpts.TS { + s.parsedOptions.TSVal = synOpts.TSVal + s.parsedOptions.TSEcr = synOpts.TSEcr + } + return synOpts +} + +var optionPool = sync.Pool{ + New: func() interface{} { + return &[maxOptionSize]byte{} + }, +} + +func getOptions() []byte { + return (*optionPool.Get().(*[maxOptionSize]byte))[:] +} + +func putOptions(options []byte) { + // Reslice to full capacity. + optionPool.Put(optionsToArray(options)) +} + +func makeSynOptions(opts header.TCPSynOptions) []byte { + // Emulate linux option order. This is as follows: + // + // if md5: NOP NOP MD5SIG 18 md5sig(16) + // if mss: MSS 4 mss(2) + // if ts and sack_advertise: + // SACK 2 TIMESTAMP 2 timestamp(8) + // elif ts: NOP NOP TIMESTAMP 10 timestamp(8) + // elif sack: NOP NOP SACK 2 + // if wscale: NOP WINDOW 3 ws(1) + // if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8)) + // [for each block] start_seq(4) end_seq(4) + // if fastopen_cookie: + // if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2) + // else: FASTOPEN (2 + len(cookie)) + // cookie(variable) [padding to four bytes] + // + options := getOptions() + + // Always encode the mss. + offset := header.EncodeMSSOption(uint32(opts.MSS), options) + + // Special ordering is required here. If both TS and SACK are enabled, + // then the SACK option precedes TS, with no padding. If they are + // enabled individually, then we see padding before the option. + if opts.TS && opts.SACKPermitted { + offset += header.EncodeSACKPermittedOption(options[offset:]) + offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) + } else if opts.TS { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) + } else if opts.SACKPermitted { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeSACKPermittedOption(options[offset:]) + } + + // Initialize the WS option. + if opts.WS >= 0 { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeWSOption(opts.WS, options[offset:]) + } + + // Padding to the end; note that this never apply unless we add a + // fastopen option, we always expect the offset to remain the same. + if delta := header.AddTCPOptionPadding(options, offset); delta != 0 { + panic("unexpected option encoding") + } + + return options[:offset] +} + +// tcpFields is a struct to carry different parameters required by the +// send*TCP variant functions below. +type tcpFields struct { + id stack.TransportEndpointID + ttl uint8 + tos uint8 + flags byte + seq seqnum.Value + ack seqnum.Value + rcvWnd seqnum.Size + opts []byte + txHash uint32 +} + +func (e *endpoint) sendSynTCP(r *stack.Route, tf tcpFields, opts header.TCPSynOptions) *tcpip.Error { + tf.opts = makeSynOptions(opts) + // We ignore SYN send errors and let the callers re-attempt send. + if err := e.sendTCP(r, tf, buffer.VectorisedView{}, nil); err != nil { + e.stats.SendErrors.SynSendToNetworkFailed.Increment() + } + putOptions(tf.opts) + return nil +} + +func (e *endpoint) sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO) *tcpip.Error { + tf.txHash = e.txHash + if err := sendTCP(r, tf, data, gso, e.owner); err != nil { + e.stats.SendErrors.SegmentSendToNetworkFailed.Increment() + return err + } + e.stats.SegmentsSent.Increment() + return nil +} + +func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *stack.GSO) { + optLen := len(tf.opts) + hdr := &pkt.Header + packetSize := pkt.Data.Size() + // Initialize the header. + tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) + pkt.TransportHeader = buffer.View(tcp) + tcp.Encode(&header.TCPFields{ + SrcPort: tf.id.LocalPort, + DstPort: tf.id.RemotePort, + SeqNum: uint32(tf.seq), + AckNum: uint32(tf.ack), + DataOffset: uint8(header.TCPMinimumSize + optLen), + Flags: tf.flags, + WindowSize: uint16(tf.rcvWnd), + }) + copy(tcp[header.TCPMinimumSize:], tf.opts) + + length := uint16(hdr.UsedLength() + packetSize) + xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) + // Only calculate the checksum if offloading isn't supported. + if gso != nil && gso.NeedsCsum { + // This is called CHECKSUM_PARTIAL in the Linux kernel. We + // calculate a checksum of the pseudo-header and save it in the + // TCP header, then the kernel calculate a checksum of the + // header and data and get the right sum of the TCP packet. + tcp.SetChecksum(xsum) + } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { + xsum = header.ChecksumVV(pkt.Data, xsum) + tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) + } +} + +func sendTCPBatch(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { + // We need to shallow clone the VectorisedView here as ReadToView will + // split the VectorisedView and Trim underlying views as it splits. Not + // doing the clone here will cause the underlying views of data itself + // to be altered. + data = data.Clone(nil) + + optLen := len(tf.opts) + if tf.rcvWnd > 0xffff { + tf.rcvWnd = 0xffff + } + + mss := int(gso.MSS) + n := (data.Size() + mss - 1) / mss + + size := data.Size() + hdrSize := header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen + var pkts stack.PacketBufferList + for i := 0; i < n; i++ { + packetSize := mss + if packetSize > size { + packetSize = size + } + size -= packetSize + var pkt stack.PacketBuffer + pkt.Header = buffer.NewPrependable(hdrSize) + pkt.Hash = tf.txHash + pkt.Owner = owner + pkt.EgressRoute = r + pkt.GSOOptions = gso + pkt.NetworkProtocolNumber = r.NetworkProtocolNumber() + data.ReadToVV(&pkt.Data, packetSize) + buildTCPHdr(r, tf, &pkt, gso) + tf.seq = tf.seq.Add(seqnum.Size(packetSize)) + pkts.PushBack(&pkt) + } + + if tf.ttl == 0 { + tf.ttl = r.DefaultTTL() + } + sent, err := r.WritePackets(gso, pkts, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}) + if err != nil { + r.Stats().TCP.SegmentSendErrors.IncrementBy(uint64(n - sent)) + } + r.Stats().TCP.SegmentsSent.IncrementBy(uint64(sent)) + return err +} + +// sendTCP sends a TCP segment with the provided options via the provided +// network endpoint and under the provided identity. +func sendTCP(r *stack.Route, tf tcpFields, data buffer.VectorisedView, gso *stack.GSO, owner tcpip.PacketOwner) *tcpip.Error { + optLen := len(tf.opts) + if tf.rcvWnd > 0xffff { + tf.rcvWnd = 0xffff + } + + if r.Loop&stack.PacketLoop == 0 && gso != nil && gso.Type == stack.GSOSW && int(gso.MSS) < data.Size() { + return sendTCPBatch(r, tf, data, gso, owner) + } + + pkt := &stack.PacketBuffer{ + Header: buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen), + Data: data, + Hash: tf.txHash, + Owner: owner, + } + buildTCPHdr(r, tf, pkt, gso) + + if tf.ttl == 0 { + tf.ttl = r.DefaultTTL() + } + if err := r.WritePacket(gso, stack.NetworkHeaderParams{Protocol: ProtocolNumber, TTL: tf.ttl, TOS: tf.tos}, pkt); err != nil { + r.Stats().TCP.SegmentSendErrors.Increment() + return err + } + r.Stats().TCP.SegmentsSent.Increment() + if (tf.flags & header.TCPFlagRst) != 0 { + r.Stats().TCP.ResetsSent.Increment() + } + return nil +} + +// makeOptions makes an options slice. +func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { + options := getOptions() + offset := 0 + + // N.B. the ordering here matches the ordering used by Linux internally + // and described in the raw makeOptions function. We don't include + // unnecessary cases here (post connection.) + if e.sendTSOk { + // Embed the timestamp if timestamp has been enabled. + // + // We only use the lower 32 bits of the unix time in + // milliseconds. This is similar to what Linux does where it + // uses the lower 32 bits of the jiffies value in the tsVal + // field of the timestamp option. + // + // Further, RFC7323 section-5.4 recommends millisecond + // resolution as the lowest recommended resolution for the + // timestamp clock. + // + // Ref: https://tools.ietf.org/html/rfc7323#section-5.4. + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeTSOption(e.timestamp(), e.recentTimestamp(), options[offset:]) + } + if e.sackPermitted && len(sackBlocks) > 0 { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) + } + + // We expect the above to produce an aligned offset. + if delta := header.AddTCPOptionPadding(options, offset); delta != 0 { + panic("unexpected option encoding") + } + + return options[:offset] +} + +// sendRaw sends a TCP segment to the endpoint's peer. +func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { + var sackBlocks []header.SACKBlock + if e.EndpointState() == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] + } + options := e.makeOptions(sackBlocks) + err := e.sendTCP(&e.route, tcpFields{ + id: e.ID, + ttl: e.ttl, + tos: e.sendTOS, + flags: flags, + seq: seq, + ack: ack, + rcvWnd: rcvWnd, + opts: options, + }, data, e.gso) + putOptions(options) + return err +} + +func (e *endpoint) handleWrite() *tcpip.Error { + // Move packets from send queue to send list. The queue is accessible + // from other goroutines and protected by the send mutex, while the send + // list is only accessible from the handler goroutine, so it needs no + // mutexes. + e.sndBufMu.Lock() + + first := e.sndQueue.Front() + if first != nil { + e.snd.writeList.PushBackList(&e.sndQueue) + e.sndBufInQueue = 0 + } + + e.sndBufMu.Unlock() + + // Initialize the next segment to write if it's currently nil. + if e.snd.writeNext == nil { + e.snd.writeNext = first + } + + // Push out any new packets. + e.snd.sendData() + + return nil +} + +func (e *endpoint) handleClose() *tcpip.Error { + if !e.EndpointState().connected() { + return nil + } + // Drain the send queue. + e.handleWrite() + + // Mark send side as closed. + e.snd.closed = true + + return nil +} + +// resetConnectionLocked puts the endpoint in an error state with the given +// error code and sends a RST if and only if the error is not ErrConnectionReset +// indicating that the connection is being reset due to receiving a RST. This +// method must only be called from the protocol goroutine. +func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { + // Only send a reset if the connection is being aborted for a reason + // other than receiving a reset. + e.setEndpointState(StateError) + e.HardError = err + if err != tcpip.ErrConnectionReset && err != tcpip.ErrTimeout { + // The exact sequence number to be used for the RST is the same as the + // one used by Linux. We need to handle the case of window being shrunk + // which can cause sndNxt to be outside the acceptable window on the + // receiver. + // + // See: https://www.snellman.net/blog/archive/2016-02-01-tcp-rst/ for more + // information. + sndWndEnd := e.snd.sndUna.Add(e.snd.sndWnd) + resetSeqNum := sndWndEnd + if !sndWndEnd.LessThan(e.snd.sndNxt) || e.snd.sndNxt.Size(sndWndEnd) < (1<<e.snd.sndWndScale) { + resetSeqNum = e.snd.sndNxt + } + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, resetSeqNum, e.rcv.rcvNxt, 0) + } +} + +// completeWorkerLocked is called by the worker goroutine when it's about to +// exit. +func (e *endpoint) completeWorkerLocked() { + // Worker is terminating(either due to moving to + // CLOSED or ERROR state, ensure we release all + // registrations port reservations even if the socket + // itself is not yet closed by the application. + e.workerRunning = false + if e.workerCleanup { + e.cleanupLocked() + } +} + +// transitionToStateEstablisedLocked transitions a given endpoint +// to an established state using the handshake parameters provided. +// It also initializes sender/receiver. +func (e *endpoint) transitionToStateEstablishedLocked(h *handshake) { + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + + rcvBufSize := seqnum.Size(e.receiveBufferSize()) + e.rcvListMu.Lock() + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale(), rcvBufSize) + // Bootstrap the auto tuning algorithm. Starting at zero will + // result in a really large receive window after the first auto + // tuning adjustment. + e.rcvAutoParams.prevCopied = int(h.rcvWnd) + e.rcvListMu.Unlock() + + e.setEndpointState(StateEstablished) +} + +// transitionToStateCloseLocked ensures that the endpoint is +// cleaned up from the transport demuxer, "before" moving to +// StateClose. This will ensure that no packet will be +// delivered to this endpoint from the demuxer when the endpoint +// is transitioned to StateClose. +func (e *endpoint) transitionToStateCloseLocked() { + if e.EndpointState() == StateClose { + return + } + // Mark the endpoint as fully closed for reads/writes. + e.cleanupLocked() + e.setEndpointState(StateClose) + e.stack.Stats().TCP.CurrentConnected.Decrement() + e.stack.Stats().TCP.EstablishedClosed.Increment() +} + +// tryDeliverSegmentFromClosedEndpoint attempts to deliver the parsed +// segment to any other endpoint other than the current one. This is called +// only when the endpoint is in StateClose and we want to deliver the segment +// to any other listening endpoint. We reply with RST if we cannot find one. +func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" { + // Dual-stack socket, try IPv4. + ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route) + } + if ep == nil { + replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) + s.decRef() + return + } + + if e == ep { + panic("current endpoint not removed from demuxer, enqueing segments to itself") + } + + if ep := ep.(*endpoint); ep.enqueueSegment(s) { + ep.newSegmentWaker.Assert() + } +} + +// Drain segment queue from the endpoint and try to re-match the segment to a +// different endpoint. This is used when the current endpoint is transitioned to +// StateClose and has been unregistered from the transport demuxer. +func (e *endpoint) drainClosingSegmentQueue() { + for { + s := e.segmentQueue.dequeue() + if s == nil { + break + } + + e.tryDeliverSegmentFromClosedEndpoint(s) + } +} + +func (e *endpoint) handleReset(s *segment) (ok bool, err *tcpip.Error) { + if e.rcv.acceptable(s.sequenceNumber, 0) { + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + switch e.EndpointState() { + // In case of a RST in CLOSE-WAIT linux moves + // the socket to closed state with an error set + // to indicate EPIPE. + // + // Technically this seems to be at odds w/ RFC. + // As per https://tools.ietf.org/html/rfc793#section-2.7 + // page 69 the behavior for a segment arriving + // w/ RST bit set in CLOSE-WAIT is inlined below. + // + // ESTABLISHED + // FIN-WAIT-1 + // FIN-WAIT-2 + // CLOSE-WAIT + + // If the RST bit is set then, any outstanding RECEIVEs and + // SEND should receive "reset" responses. All segment queues + // should be flushed. Users should also receive an unsolicited + // general "connection reset" signal. Enter the CLOSED state, + // delete the TCB, and return. + case StateCloseWait: + e.transitionToStateCloseLocked() + e.HardError = tcpip.ErrAborted + e.notifyProtocolGoroutine(notifyTickleWorker) + return false, nil + default: + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + + // Notify protocol goroutine. This is required when + // handleSegment is invoked from the processor goroutine + // rather than the worker goroutine. + e.notifyProtocolGoroutine(notifyResetByPeer) + return false, tcpip.ErrConnectionReset + } + } + return true, nil +} + +// handleSegments processes all inbound segments. +func (e *endpoint) handleSegments(fastPath bool) *tcpip.Error { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + if e.EndpointState().closed() { + return nil + } + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + + cont, err := e.handleSegment(s) + if err != nil { + s.decRef() + return err + } + if !cont { + s.decRef() + return nil + } + } + + // When fastPath is true we don't want to wake up the worker + // goroutine. If the endpoint has more segments to process the + // dispatcher will call handleSegments again anyway. + if !fastPath && checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + + // Send an ACK for all processed packets if needed. + if e.rcv.rcvNxt != e.snd.maxSentAck { + e.snd.sendAck() + } + + e.resetKeepaliveTimer(true /* receivedData */) + + return nil +} + +// handleSegment handles a given segment and notifies the worker goroutine if +// if the connection should be terminated. +func (e *endpoint) handleSegment(s *segment) (cont bool, err *tcpip.Error) { + // Invoke the tcp probe if installed. + if e.probe != nil { + e.probe(e.completeState()) + } + + if s.flagIsSet(header.TCPFlagRst) { + if ok, err := e.handleReset(s); !ok { + return false, err + } + } else if s.flagIsSet(header.TCPFlagSyn) { + // See: https://tools.ietf.org/html/rfc5961#section-4.1 + // 1) If the SYN bit is set, irrespective of the sequence number, TCP + // MUST send an ACK (also referred to as challenge ACK) to the remote + // peer: + // + // <SEQ=SND.NXT><ACK=RCV.NXT><CTL=ACK> + // + // After sending the acknowledgment, TCP MUST drop the unacceptable + // segment and stop processing further. + // + // By sending an ACK, the remote peer is challenged to confirm the loss + // of the previous connection and the request to start a new connection. + // A legitimate peer, after restart, would not have a TCB in the + // synchronized state. Thus, when the ACK arrives, the peer should send + // a RST segment back with the sequence number derived from the ACK + // field that caused the RST. + + // This RST will confirm that the remote peer has indeed closed the + // previous connection. Upon receipt of a valid RST, the local TCP + // endpoint MUST terminate its connection. The local TCP endpoint + // should then rely on SYN retransmission from the remote end to + // re-establish the connection. + + e.snd.sendAck() + } else if s.flagIsSet(header.TCPFlagAck) { + // Patch the window size in the segment according to the + // send window scale. + s.window <<= e.snd.sndWndScale + + // RFC 793, page 41 states that "once in the ESTABLISHED + // state all segments must carry current acknowledgment + // information." + drop, err := e.rcv.handleRcvdSegment(s) + if err != nil { + return false, err + } + if drop { + return true, nil + } + + // Now check if the received segment has caused us to transition + // to a CLOSED state, if yes then terminate processing and do + // not invoke the sender. + state := e.state + if state == StateClose { + // When we get into StateClose while processing from the queue, + // return immediately and let the protocolMainloop handle it. + // + // We can reach StateClose only while processing a previous segment + // or a notification from the protocolMainLoop (caller goroutine). + // This means that with this return, the segment dequeue below can + // never occur on a closed endpoint. + s.decRef() + return false, nil + } + + e.snd.handleRcvdSegment(s) + } + + return true, nil +} + +// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP +// keepalive packets periodically when the connection is idle. If we don't hear +// from the other side after a number of tries, we terminate the connection. +func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { + userTimeout := e.userTimeout + + e.keepalive.Lock() + if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { + e.keepalive.Unlock() + return nil + } + + // If a userTimeout is set then abort the connection if it is + // exceeded. + if userTimeout != 0 && time.Since(e.rcv.lastRcvdAckTime) >= userTimeout && e.keepalive.unacked > 0 { + e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() + return tcpip.ErrTimeout + } + + if e.keepalive.unacked >= e.keepalive.count { + e.keepalive.Unlock() + e.stack.Stats().TCP.EstablishedTimedout.Increment() + return tcpip.ErrTimeout + } + + // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with + // seg.seq = snd.nxt-1. + e.keepalive.unacked++ + e.keepalive.Unlock() + e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1) + e.resetKeepaliveTimer(false) + return nil +} + +// resetKeepaliveTimer restarts or stops the keepalive timer, depending on +// whether it is enabled for this endpoint. +func (e *endpoint) resetKeepaliveTimer(receivedData bool) { + e.keepalive.Lock() + if receivedData { + e.keepalive.unacked = 0 + } + // Start the keepalive timer IFF it's enabled and there is no pending + // data to send. + if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { + e.keepalive.timer.disable() + e.keepalive.Unlock() + return + } + if e.keepalive.unacked > 0 { + e.keepalive.timer.enable(e.keepalive.interval) + } else { + e.keepalive.timer.enable(e.keepalive.idle) + } + e.keepalive.Unlock() +} + +// disableKeepaliveTimer stops the keepalive timer. +func (e *endpoint) disableKeepaliveTimer() { + e.keepalive.Lock() + e.keepalive.timer.disable() + e.keepalive.Unlock() +} + +// protocolMainLoop is the main loop of the TCP protocol. It runs in its own +// goroutine and is responsible for sending segments and handling received +// segments. +func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{}) *tcpip.Error { + e.mu.Lock() + var closeTimer *time.Timer + var closeWaker sleep.Waker + + epilogue := func() { + // e.mu is expected to be hold upon entering this section. + + if e.snd != nil { + e.snd.resendTimer.cleanup() + } + + if closeTimer != nil { + closeTimer.Stop() + } + + e.completeWorkerLocked() + + if e.drainDone != nil { + close(e.drainDone) + } + + e.mu.Unlock() + + e.drainClosingSegmentQueue() + + // When the protocol loop exits we should wake up our waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + } + + if handshake { + // This is an active connection, so we must initiate the 3-way + // handshake, and then inform potential waiters about its + // completion. + initialRcvWnd := e.initialReceiveWindow() + h := newHandshake(e, seqnum.Size(initialRcvWnd)) + h.ep.setEndpointState(StateSynSent) + + if err := h.execute(); err != nil { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + e.setEndpointState(StateError) + e.HardError = err + + e.workerCleanup = true + // Lock released below. + epilogue() + return err + } + } + + e.keepalive.timer.init(&e.keepalive.waker) + defer e.keepalive.timer.cleanup() + + drained := e.drainDone != nil + if drained { + close(e.drainDone) + <-e.undrain + } + + // Set up the functions that will be called when the main protocol loop + // wakes up. + funcs := []struct { + w *sleep.Waker + f func() *tcpip.Error + }{ + { + w: &e.sndWaker, + f: e.handleWrite, + }, + { + w: &e.sndCloseWaker, + f: e.handleClose, + }, + { + w: &closeWaker, + f: func() *tcpip.Error { + // This means the socket is being closed due + // to the TCP-FIN-WAIT2 timeout was hit. Just + // mark the socket as closed. + e.transitionToStateCloseLocked() + e.workerCleanup = true + return nil + }, + }, + { + w: &e.snd.resendWaker, + f: func() *tcpip.Error { + if !e.snd.retransmitTimerExpired() { + e.stack.Stats().TCP.EstablishedTimedout.Increment() + return tcpip.ErrTimeout + } + return nil + }, + }, + { + w: &e.newSegmentWaker, + f: func() *tcpip.Error { + return e.handleSegments(false /* fastPath */) + }, + }, + { + w: &e.keepalive.waker, + f: e.keepaliveTimerExpired, + }, + { + w: &e.notificationWaker, + f: func() *tcpip.Error { + n := e.fetchNotifications() + if n¬ifyNonZeroReceiveWindow != 0 { + e.rcv.nonZeroWindow() + } + + if n¬ifyReceiveWindowChanged != 0 { + e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize()) + } + + if n¬ifyMTUChanged != 0 { + e.sndBufMu.Lock() + count := e.packetTooBigCount + e.packetTooBigCount = 0 + mtu := e.sndMTU + e.sndBufMu.Unlock() + + e.snd.updateMaxPayloadSize(mtu, count) + } + + if n¬ifyReset != 0 || n¬ifyAbort != 0 { + return tcpip.ErrConnectionAborted + } + + if n¬ifyResetByPeer != 0 { + return tcpip.ErrConnectionReset + } + + if n¬ifyClose != 0 && closeTimer == nil { + if e.EndpointState() == StateFinWait2 && e.closed { + // The socket has been closed and we are in FIN_WAIT2 + // so start the FIN_WAIT2 timer. + closeTimer = time.AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + } + } + + if n¬ifyKeepaliveChanged != 0 { + // The timer could fire in background + // when the endpoint is drained. That's + // OK. See above. + e.resetKeepaliveTimer(true) + } + + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + if err := e.handleSegments(false /* fastPath */); err != nil { + return err + } + } + if !e.EndpointState().closed() { + // Only block the worker if the endpoint + // is not in closed state or error state. + close(e.drainDone) + e.mu.Unlock() + <-e.undrain + e.mu.Lock() + } + } + + if n¬ifyTickleWorker != 0 { + // Just a tickle notification. No need to do + // anything. + return nil + } + + return nil + }, + }, + } + + // Initialize the sleeper based on the wakers in funcs. + s := sleep.Sleeper{} + for i := range funcs { + s.AddWaker(funcs[i].w, i) + } + + // Notify the caller that the waker initialization is complete and the + // endpoint is ready. + if wakerInitDone != nil { + close(wakerInitDone) + } + + // Tell waiters that the endpoint is connected and writable. + e.waiterQueue.Notify(waiter.EventOut) + + // The following assertions and notifications are needed for restored + // endpoints. Fresh newly created endpoints have empty states and should + // not invoke any. + if !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + + e.rcvListMu.Lock() + if !e.rcvList.Empty() { + e.waiterQueue.Notify(waiter.EventIn) + } + e.rcvListMu.Unlock() + + if e.workerCleanup { + e.notifyProtocolGoroutine(notifyClose) + } + + // Main loop. Handle segments until both send and receive ends of the + // connection have completed. + cleanupOnError := func(err *tcpip.Error) { + e.stack.Stats().TCP.CurrentConnected.Decrement() + e.workerCleanup = true + if err != nil { + e.resetConnectionLocked(err) + } + // Lock released below. + epilogue() + } + +loop: + for { + switch e.EndpointState() { + case StateTimeWait, StateClose, StateError: + break loop + } + + e.mu.Unlock() + v, _ := s.Fetch(true) + e.mu.Lock() + + // We need to double check here because the notification may be + // stale by the time we got around to processing it. + switch e.EndpointState() { + case StateError: + // If the endpoint has already transitioned to an ERROR + // state just pass nil here as any reset that may need + // to be sent etc should already have been done and we + // just want to terminate the loop and cleanup the + // endpoint. + cleanupOnError(nil) + return nil + case StateTimeWait: + fallthrough + case StateClose: + break loop + default: + if err := funcs[v].f(); err != nil { + cleanupOnError(err) + return nil + } + } + } + + var reuseTW func() + if e.EndpointState() == StateTimeWait { + // Disable close timer as we now entering real TIME_WAIT. + if closeTimer != nil { + closeTimer.Stop() + } + // Mark the current sleeper done so as to free all associated + // wakers. + s.Done() + // Wake up any waiters before we enter TIME_WAIT. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + e.workerCleanup = true + reuseTW = e.doTimeWait() + } + + // Handle any StateError transition from StateTimeWait. + if e.EndpointState() == StateError { + cleanupOnError(nil) + return nil + } + + e.transitionToStateCloseLocked() + + // Lock released below. + epilogue() + + // A new SYN was received during TIME_WAIT and we need to abort + // the timewait and redirect the segment to the listener queue + if reuseTW != nil { + reuseTW() + } + + return nil +} + +// handleTimeWaitSegments processes segments received during TIME_WAIT +// state. +func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func()) { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + extTW, newSyn := e.rcv.handleTimeWaitSegment(s) + if newSyn { + info := e.EndpointInfo.TransportEndpointInfo + newID := info.ID + newID.RemoteAddress = "" + newID.RemotePort = 0 + netProtos := []tcpip.NetworkProtocolNumber{info.NetProto} + // If the local address is an IPv4 address then also + // look for IPv6 dual stack endpoints that might be + // listening on the local address. + if newID.LocalAddress.To4() != "" { + netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} + } + for _, netProto := range netProtos { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + tcpEP := listenEP.(*endpoint) + if EndpointState(tcpEP.State()) == StateListen { + reuseTW = func() { + if !tcpEP.enqueueSegment(s) { + s.decRef() + return + } + tcpEP.newSegmentWaker.Assert() + } + // We explicitly do not decRef + // the segment as it's still + // valid and being reflected to + // a listening endpoint. + return false, reuseTW + } + } + } + } + if extTW { + extendTimeWait = true + } + s.decRef() + } + if checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + return extendTimeWait, nil +} + +// doTimeWait is responsible for handling the TCP behaviour once a socket +// enters the TIME_WAIT state. Optionally it can return a closure that +// should be executed after releasing the endpoint registrations. This is +// done in cases where a new SYN is received during TIME_WAIT that carries +// a sequence number larger than one see on the connection. +func (e *endpoint) doTimeWait() (twReuse func()) { + // Trigger a 2 * MSL time wait state. During this period + // we will drop all incoming segments. + // NOTE: On Linux this is not configurable and is fixed at 60 seconds. + timeWaitDuration := DefaultTCPTimeWaitTimeout + + // Get the stack wide configuration. + var tcpTW tcpip.TCPTimeWaitTimeoutOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &tcpTW); err == nil { + timeWaitDuration = time.Duration(tcpTW) + } + + const newSegment = 1 + const notification = 2 + const timeWaitDone = 3 + + s := sleep.Sleeper{} + defer s.Done() + s.AddWaker(&e.newSegmentWaker, newSegment) + s.AddWaker(&e.notificationWaker, notification) + + var timeWaitWaker sleep.Waker + s.AddWaker(&timeWaitWaker, timeWaitDone) + timeWaitTimer := time.AfterFunc(timeWaitDuration, timeWaitWaker.Assert) + defer timeWaitTimer.Stop() + + for { + e.mu.Unlock() + v, _ := s.Fetch(true) + e.mu.Lock() + switch v { + case newSegment: + extendTimeWait, reuseTW := e.handleTimeWaitSegments() + if reuseTW != nil { + return reuseTW + } + if extendTimeWait { + timeWaitTimer.Reset(timeWaitDuration) + } + case notification: + n := e.fetchNotifications() + if n¬ifyClose != 0 || n¬ifyAbort != 0 { + return nil + } + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + // Ignore extending TIME_WAIT during a + // save. For sockets in TIME_WAIT we just + // terminate the TIME_WAIT early. + e.handleTimeWaitSegments() + } + close(e.drainDone) + e.mu.Unlock() + <-e.undrain + e.mu.Lock() + return nil + } + case timeWaitDone: + return nil + } + } +} diff --git a/pkg/tcpip/transport/tcp/connect_unsafe.go b/pkg/tcpip/transport/tcp/connect_unsafe.go new file mode 100644 index 000000000..cfc304616 --- /dev/null +++ b/pkg/tcpip/transport/tcp/connect_unsafe.go @@ -0,0 +1,30 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "reflect" + "unsafe" +) + +// optionsToArray converts a slice of capacity >-= maxOptionSize to an array. +// +// optionsToArray panics if the capacity of options is smaller than +// maxOptionSize. +func optionsToArray(options []byte) *[maxOptionSize]byte { + // Reslice to full capacity. + options = options[0:maxOptionSize] + return (*[maxOptionSize]byte)(unsafe.Pointer((*reflect.SliceHeader)(unsafe.Pointer(&options)).Data)) +} diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go new file mode 100644 index 000000000..7b1f5e763 --- /dev/null +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -0,0 +1,234 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "math" + "time" +) + +// cubicState stores the variables related to TCP CUBIC congestion +// control algorithm state. +// +// See: https://tools.ietf.org/html/rfc8312. +// +stateify savable +type cubicState struct { + // wLastMax is the previous wMax value. + wLastMax float64 + + // wMax is the value of the congestion window at the + // time of last congestion event. + wMax float64 + + // t denotes the time when the current congestion avoidance + // was entered. + t time.Time `state:".(unixTime)"` + + // numCongestionEvents tracks the number of congestion events since last + // RTO. + numCongestionEvents int + + // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as + // per RFC. + c float64 + + // k is the time period that the above function takes to increase the + // current window size to W_max if there are no further congestion + // events and is calculated using the following equation: + // + // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2) + k float64 + + // beta is the CUBIC multiplication decrease factor. that is, when a + // congestion event is detected, CUBIC reduces its cwnd to + // W_cubic(0)=W_max*beta_cubic. + beta float64 + + // wC is window computed by CUBIC at time t. It's calculated using the + // formula: + // + // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) + wC float64 + + // wEst is the window computed by CUBIC at time t+RTT i.e + // W_cubic(t+RTT). + wEst float64 + + s *sender +} + +// newCubicCC returns a partially initialized cubic state with the constants +// beta and c set and t set to current time. +func newCubicCC(s *sender) *cubicState { + return &cubicState{ + t: time.Now(), + beta: 0.7, + c: 0.4, + s: s, + } +} + +// enterCongestionAvoidance is used to initialize cubic in cases where we exit +// SlowStart without a real congestion event taking place. This can happen when +// a connection goes back to slow start due to a retransmit and we exceed the +// previously lowered ssThresh without experiencing packet loss. +// +// Refer: https://tools.ietf.org/html/rfc8312#section-4.8 +func (c *cubicState) enterCongestionAvoidance() { + // See: https://tools.ietf.org/html/rfc8312#section-4.7 & + // https://tools.ietf.org/html/rfc8312#section-4.8 + if c.numCongestionEvents == 0 { + c.k = 0 + c.t = time.Now() + c.wLastMax = c.wMax + c.wMax = float64(c.s.sndCwnd) + } +} + +// updateSlowStart will update the congestion window as per the slow-start +// algorithm used by NewReno. If after adjusting the congestion window we cross +// the ssThresh then it will return the number of packets that must be consumed +// in congestion avoidance mode. +func (c *cubicState) updateSlowStart(packetsAcked int) int { + // Don't let the congestion window cross into the congestion + // avoidance range. + newcwnd := c.s.sndCwnd + packetsAcked + enterCA := false + if newcwnd >= c.s.sndSsthresh { + newcwnd = c.s.sndSsthresh + c.s.sndCAAckCount = 0 + enterCA = true + } + + packetsAcked -= newcwnd - c.s.sndCwnd + c.s.sndCwnd = newcwnd + if enterCA { + c.enterCongestionAvoidance() + } + return packetsAcked +} + +// Update updates cubic's internal state variables. It must be called on every +// ACK received. +// Refer: https://tools.ietf.org/html/rfc8312#section-4 +func (c *cubicState) Update(packetsAcked int) { + if c.s.sndCwnd < c.s.sndSsthresh { + packetsAcked = c.updateSlowStart(packetsAcked) + if packetsAcked == 0 { + return + } + } else { + c.s.rtt.Lock() + srtt := c.s.rtt.srtt + c.s.rtt.Unlock() + c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt) + } +} + +// cubicCwnd computes the CUBIC congestion window after t seconds from last +// congestion event. +func (c *cubicState) cubicCwnd(t float64) float64 { + return c.c*math.Pow(t, 3.0) + c.wMax +} + +// getCwnd returns the current congestion window as computed by CUBIC. +// Refer: https://tools.ietf.org/html/rfc8312#section-4 +func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int { + elapsed := time.Since(c.t).Seconds() + + // Compute the window as per Cubic after 'elapsed' time + // since last congestion event. + c.wC = c.cubicCwnd(elapsed - c.k) + + // Compute the TCP friendly estimate of the congestion window. + c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds()) + + // Make sure in the TCP friendly region CUBIC performs at least + // as well as Reno. + if c.wC < c.wEst && float64(sndCwnd) < c.wEst { + // TCP Friendly region of cubic. + return int(c.wEst) + } + + // In Concave/Convex region of CUBIC, calculate what CUBIC window + // will be after 1 RTT and use that to grow congestion window + // for every ack. + tEst := (time.Since(c.t) + srtt).Seconds() + wtRtt := c.cubicCwnd(tEst - c.k) + // As per 4.3 for each received ACK cwnd must be incremented + // by (w_cubic(t+RTT) - cwnd/cwnd. + cwnd := float64(sndCwnd) + for i := 0; i < packetsAcked; i++ { + // Concave/Convex regions of cubic have the same formulas. + // See: https://tools.ietf.org/html/rfc8312#section-4.3 + cwnd += (wtRtt - cwnd) / cwnd + } + return int(cwnd) +} + +// HandleNDupAcks implements congestionControl.HandleNDupAcks. +func (c *cubicState) HandleNDupAcks() { + // See: https://tools.ietf.org/html/rfc8312#section-4.5 + c.numCongestionEvents++ + c.t = time.Now() + c.wLastMax = c.wMax + c.wMax = float64(c.s.sndCwnd) + + c.fastConvergence() + c.reduceSlowStartThreshold() +} + +// HandleRTOExpired implements congestionContrl.HandleRTOExpired. +func (c *cubicState) HandleRTOExpired() { + // See: https://tools.ietf.org/html/rfc8312#section-4.6 + c.t = time.Now() + c.numCongestionEvents = 0 + c.wLastMax = c.wMax + c.wMax = float64(c.s.sndCwnd) + + c.fastConvergence() + + // We lost a packet, so reduce ssthresh. + c.reduceSlowStartThreshold() + + // Reduce the congestion window to 1, i.e., enter slow-start. Per + // RFC 5681, page 7, we must use 1 regardless of the value of the + // initial congestion window. + c.s.sndCwnd = 1 +} + +// fastConvergence implements the logic for Fast Convergence algorithm as +// described in https://tools.ietf.org/html/rfc8312#section-4.6. +func (c *cubicState) fastConvergence() { + if c.wMax < c.wLastMax { + c.wLastMax = c.wMax + c.wMax = c.wMax * (1.0 + c.beta) / 2.0 + } else { + c.wLastMax = c.wMax + } + // Recompute k as wMax may have changed. + c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c) +} + +// PostRecovery implemements congestionControl.PostRecovery. +func (c *cubicState) PostRecovery() { + c.t = time.Now() +} + +// reduceSlowStartThreshold returns new SsThresh as described in +// https://tools.ietf.org/html/rfc8312#section-4.7. +func (c *cubicState) reduceSlowStartThreshold() { + c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0)) +} diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/tcpip/transport/tcp/cubic_state.go new file mode 100644 index 000000000..d0f58cfaf --- /dev/null +++ b/pkg/tcpip/transport/tcp/cubic_state.go @@ -0,0 +1,29 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveT is invoked by stateify. +func (c *cubicState) saveT() unixTime { + return unixTime{c.t.Unix(), c.t.UnixNano()} +} + +// loadT is invoked by stateify. +func (c *cubicState) loadT(unix unixTime) { + c.t = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go new file mode 100644 index 000000000..98aecab9e --- /dev/null +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -0,0 +1,234 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "encoding/binary" + + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// epQueue is a queue of endpoints. +type epQueue struct { + mu sync.Mutex + list endpointList +} + +// enqueue adds e to the queue if the endpoint is not already on the queue. +func (q *epQueue) enqueue(e *endpoint) { + q.mu.Lock() + if e.pendingProcessing { + q.mu.Unlock() + return + } + q.list.PushBack(e) + e.pendingProcessing = true + q.mu.Unlock() +} + +// dequeue removes and returns the first element from the queue if available, +// returns nil otherwise. +func (q *epQueue) dequeue() *endpoint { + q.mu.Lock() + if e := q.list.Front(); e != nil { + q.list.Remove(e) + e.pendingProcessing = false + q.mu.Unlock() + return e + } + q.mu.Unlock() + return nil +} + +// empty returns true if the queue is empty, false otherwise. +func (q *epQueue) empty() bool { + q.mu.Lock() + v := q.list.Empty() + q.mu.Unlock() + return v +} + +// processor is responsible for processing packets queued to a tcp endpoint. +type processor struct { + epQ epQueue + sleeper sleep.Sleeper + newEndpointWaker sleep.Waker + closeWaker sleep.Waker +} + +func (p *processor) close() { + p.closeWaker.Assert() +} + +func (p *processor) queueEndpoint(ep *endpoint) { + // Queue an endpoint for processing by the processor goroutine. + p.epQ.enqueue(ep) + p.newEndpointWaker.Assert() +} + +const ( + newEndpointWaker = 1 + closeWaker = 2 +) + +func (p *processor) start(wg *sync.WaitGroup) { + defer wg.Done() + defer p.sleeper.Done() + + for { + if id, _ := p.sleeper.Fetch(true); id == closeWaker { + break + } + for { + ep := p.epQ.dequeue() + if ep == nil { + break + } + if ep.segmentQueue.empty() { + continue + } + + // If socket has transitioned out of connected state then just let the + // worker handle the packet. + // + // NOTE: We read this outside of e.mu lock which means that by the time + // we get to handleSegments the endpoint may not be in ESTABLISHED. But + // this should be fine as all normal shutdown states are handled by + // handleSegments and if the endpoint moves to a CLOSED/ERROR state + // then handleSegments is a noop. + if ep.EndpointState() == StateEstablished && ep.mu.TryLock() { + // If the endpoint is in a connected state then we do direct delivery + // to ensure low latency and avoid scheduler interactions. + switch err := ep.handleSegments(true /* fastPath */); { + case err != nil: + // Send any active resets if required. + ep.resetConnectionLocked(err) + fallthrough + case ep.EndpointState() == StateClose: + ep.notifyProtocolGoroutine(notifyTickleWorker) + case !ep.segmentQueue.empty(): + p.epQ.enqueue(ep) + } + ep.mu.Unlock() + } else { + ep.newSegmentWaker.Assert() + } + } + } +} + +// dispatcher manages a pool of TCP endpoint processors which are responsible +// for the processing of inbound segments. This fixed pool of processor +// goroutines do full tcp processing. The processor is selected based on the +// hash of the endpoint id to ensure that delivery for the same endpoint happens +// in-order. +type dispatcher struct { + processors []processor + seed uint32 + wg sync.WaitGroup +} + +func (d *dispatcher) init(nProcessors int) { + d.close() + d.wait() + d.processors = make([]processor, nProcessors) + d.seed = generateRandUint32() + for i := range d.processors { + p := &d.processors[i] + p.sleeper.AddWaker(&p.newEndpointWaker, newEndpointWaker) + p.sleeper.AddWaker(&p.closeWaker, closeWaker) + d.wg.Add(1) + // NB: sleeper-waker registration must happen synchronously to avoid races + // with `close`. It's possible to pull all this logic into `start`, but + // that results in a heap-allocated function literal. + go p.start(&d.wg) + } +} + +func (d *dispatcher) close() { + for i := range d.processors { + d.processors[i].close() + } +} + +func (d *dispatcher) wait() { + d.wg.Wait() +} + +func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + ep := stackEP.(*endpoint) + s := newSegment(r, id, pkt) + if !s.parse() { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + s.decRef() + return + } + + if !s.csumValid { + ep.stack.Stats().MalformedRcvdPackets.Increment() + ep.stack.Stats().TCP.ChecksumErrors.Increment() + ep.stats.ReceiveErrors.ChecksumErrors.Increment() + s.decRef() + return + } + + ep.stack.Stats().TCP.ValidSegmentsReceived.Increment() + ep.stats.SegmentsReceived.Increment() + if (s.flags & header.TCPFlagRst) != 0 { + ep.stack.Stats().TCP.ResetsReceived.Increment() + } + + if !ep.enqueueSegment(s) { + s.decRef() + return + } + + // For sockets not in established state let the worker goroutine + // handle the packets. + if ep.EndpointState() != StateEstablished { + ep.newSegmentWaker.Assert() + return + } + + d.selectProcessor(id).queueEndpoint(ep) +} + +func generateRandUint32() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return binary.LittleEndian.Uint32(b) +} + +func (d *dispatcher) selectProcessor(id stack.TransportEndpointID) *processor { + var payload [4]byte + binary.LittleEndian.PutUint16(payload[0:], id.LocalPort) + binary.LittleEndian.PutUint16(payload[2:], id.RemotePort) + + h := jenkins.Sum32(d.seed) + h.Write(payload[:]) + h.Write([]byte(id.LocalAddress)) + h.Write([]byte(id.RemoteAddress)) + + return &d.processors[h.Sum32()%uint32(len(d.processors))] +} diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go new file mode 100644 index 000000000..804e95aea --- /dev/null +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -0,0 +1,651 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_test + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.dev/gvisor/pkg/waiter" +) + +func TestV4MappedConnectOnV6Only(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Start connection attempt, it must fail. + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) + if err != tcpip.ErrNoRoute { + t.Fatalf("Unexpected return value from Connect: %v", err) + } +} + +func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { + // Start connection attempt. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventOut) + defer c.WQ.EventUnregister(&we) + + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) + if err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + synCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + )) + checker.IPv4(t, b, synCheckers...) + + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: tcp.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK packet. + ackCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + )) + checker.IPv4(t, c.GetPacket(), ackCheckers...) + + // Wait for connection to be established. + select { + case <-ch: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } +} + +func TestV4MappedConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Test the connection request. + testV4Connect(t, c) +} + +func TestV4ConnectWhenBoundToWildcard(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV4Connect(t, c) +} + +func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to v4 mapped wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV4Connect(t, c) +} + +func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to v4 mapped address. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV4Connect(t, c) +} + +func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.NetworkChecker) { + // Start connection attempt to IPv6 address. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventOut) + defer c.WQ.EventUnregister(&we) + + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) + if err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetV6Packet() + synCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + )) + checker.IPv6(t, b, synCheckers...) + + tcp := header.TCP(header.IPv6(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + iss := seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: tcp.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK packet. + ackCheckers := append(checkers, checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + )) + checker.IPv6(t, c.GetV6Packet(), ackCheckers...) + + // Wait for connection to be established. + select { + case <-ch: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } +} + +func TestV6Connect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Test the connection request. + testV6Connect(t, c) +} + +func TestV6ConnectV6Only(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Test the connection request. + testV6Connect(t, c) +} + +func TestV6ConnectWhenBoundToWildcard(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV6Connect(t, c) +} + +func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to local address. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV6Connect(t, c) +} + +func TestV4RefuseOnV6Only(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the RST reply. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.AckNum(uint32(irs)+1), + ), + ) +} + +func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind and listen. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + irs := seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the RST reply. + checker.IPv6(t, c.GetV6Packet(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.AckNum(uint32(irs)+1), + ), + ) +} + +func testV4Accept(t *testing.T, c *context.Context) { + c.SetGSOEnabled(true) + defer c.SetGSOEnabled(false) + + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + checker.IPv4(t, b, + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1), + ), + ) + + // Send ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + nep, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + nep, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Make sure we get the same error when calling the original ep and the + // new one. This validates that v4-mapped endpoints are still able to + // query the V6Only flag, whereas pure v4 endpoints are not. + _, expected := c.EP.GetSockOptBool(tcpip.V6OnlyOption) + if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != expected { + t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) + } + + // Check the peer address. + addr, err := nep.GetRemoteAddress() + if err != nil { + t.Fatalf("GetRemoteAddress failed failed: %v", err) + } + + if addr.Addr != context.TestAddr { + t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) + } + + data := "Don't panic" + nep.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + b = c.GetPacket() + tcp = header.TCP(header.IPv4(b).Payload()) + if string(tcp.Payload()) != data { + t.Fatalf("Unexpected data: got %v, want %v", string(tcp.Payload()), data) + } +} + +func TestV4AcceptOnV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Accept(t, c) +} + +func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to v4 mapped wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Accept(t, c) +} + +func TestV4AcceptOnBoundToV4Mapped(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind and listen. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Accept(t, c) +} + +func TestV6AcceptOnV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind and listen. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + irs := seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetV6Packet() + tcp := header.TCP(header.IPv6(b).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + checker.IPv6(t, b, + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1), + ), + ) + + // Send ACK. + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + nep, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + nep, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Make sure we can still query the v6 only status of the new endpoint, + // that is, that it is in fact a v6 socket. + if _, err := nep.GetSockOptBool(tcpip.V6OnlyOption); err != nil { + t.Fatalf("GetSockOpt failed failed: %v", err) + } + + // Check the peer address. + addr, err := nep.GetRemoteAddress() + if err != nil { + t.Fatalf("GetRemoteAddress failed failed: %v", err) + } + + if addr.Addr != context.TestV6Addr { + t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr) + } +} + +func TestV4AcceptOnV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Accept(t, c) +} + +func testV4ListenClose(t *testing.T, c *context.Context) { + // Set the SynRcvd threshold to zero to force a syn cookie based accept + // to happen. + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption failed: %s", err) + } + + const n = uint16(32) + + // Start listening. + if err := c.EP.Listen(int(tcp.SynRcvdCountThreshold + 1)); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + irs := seqnum.Value(789) + for i := uint16(0); i < n; i++ { + // Send a SYN request. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + i, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + } + + // Each of these ACK's will cause a syn-cookie based connection to be + // accepted and delivered to the listening endpoint. + for i := uint16(0); i < n; i++ { + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + // Send ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + } + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + nep, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + nep, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(10 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + nep.Close() + c.EP.Close() +} + +func TestV4ListenCloseOnV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4ListenClose(t, c) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go new file mode 100644 index 000000000..caac6ef57 --- /dev/null +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -0,0 +1,2888 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "encoding/binary" + "fmt" + "math" + "runtime" + "strings" + "sync/atomic" + "time" + + "gvisor.dev/gvisor/pkg/rand" + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// EndpointState represents the state of a TCP endpoint. +type EndpointState uint32 + +// Endpoint states. Note that are represented in a netstack-specific manner and +// may not be meaningful externally. Specifically, they need to be translated to +// Linux's representation for these states if presented to userspace. +const ( + // Endpoint states internal to netstack. These map to the TCP state CLOSED. + StateInitial EndpointState = iota + StateBound + StateConnecting // Connect() called, but the initial SYN hasn't been sent. + StateError + + // TCP protocol states. + StateEstablished + StateSynSent + StateSynRecv + StateFinWait1 + StateFinWait2 + StateTimeWait + StateClose + StateCloseWait + StateLastAck + StateListen + StateClosing +) + +// connected returns true when s is one of the states representing an +// endpoint connected to a peer. +func (s EndpointState) connected() bool { + switch s { + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + return true + default: + return false + } +} + +// connecting returns true when s is one of the states representing a +// connection in progress, but not yet fully established. +func (s EndpointState) connecting() bool { + switch s { + case StateConnecting, StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// handshake returns true when s is one of the states representing an endpoint +// in the middle of a TCP handshake. +func (s EndpointState) handshake() bool { + switch s { + case StateSynSent, StateSynRecv: + return true + default: + return false + } +} + +// closed returns true when s is one of the states an endpoint transitions to +// when closed or when it encounters an error. This is distinct from a newly +// initialized endpoint that was never connected. +func (s EndpointState) closed() bool { + switch s { + case StateClose, StateError: + return true + default: + return false + } +} + +// String implements fmt.Stringer.String. +func (s EndpointState) String() string { + switch s { + case StateInitial: + return "INITIAL" + case StateBound: + return "BOUND" + case StateConnecting: + return "CONNECTING" + case StateError: + return "ERROR" + case StateEstablished: + return "ESTABLISHED" + case StateSynSent: + return "SYN-SENT" + case StateSynRecv: + return "SYN-RCVD" + case StateFinWait1: + return "FIN-WAIT1" + case StateFinWait2: + return "FIN-WAIT2" + case StateTimeWait: + return "TIME-WAIT" + case StateClose: + return "CLOSED" + case StateCloseWait: + return "CLOSE-WAIT" + case StateLastAck: + return "LAST-ACK" + case StateListen: + return "LISTEN" + case StateClosing: + return "CLOSING" + default: + panic("unreachable") + } +} + +// Reasons for notifying the protocol goroutine. +const ( + notifyNonZeroReceiveWindow = 1 << iota + notifyReceiveWindowChanged + notifyClose + notifyMTUChanged + notifyDrain + notifyReset + notifyResetByPeer + // notifyAbort is a request for an expedited teardown. + notifyAbort + notifyKeepaliveChanged + notifyMSSChanged + // notifyTickleWorker is used to tickle the protocol main loop during a + // restore after we update the endpoint state to the correct one. This + // ensures the loop terminates if the final state of the endpoint is + // say TIME_WAIT. + notifyTickleWorker + notifyError +) + +// SACKInfo holds TCP SACK related information for a given endpoint. +// +// +stateify savable +type SACKInfo struct { + // Blocks is the maximum number of SACK blocks we track + // per endpoint. + Blocks [MaxSACKBlocks]header.SACKBlock + + // NumBlocks is the number of valid SACK blocks stored in the + // blocks array above. + NumBlocks int +} + +// rcvBufAutoTuneParams are used to hold state variables to compute +// the auto tuned recv buffer size. +// +// +stateify savable +type rcvBufAutoTuneParams struct { + // measureTime is the time at which the current measurement + // was started. + measureTime time.Time `state:".(unixTime)"` + + // copied is the number of bytes copied out of the receive + // buffers since this measure began. + copied int + + // prevCopied is the number of bytes copied out of the receive + // buffers in the previous RTT period. + prevCopied int + + // rtt is the non-smoothed minimum RTT as measured by observing the time + // between when a byte is first acknowledged and the receipt of data + // that is at least one window beyond the sequence number that was + // acknowledged. + rtt time.Duration + + // rttMeasureSeqNumber is the highest acceptable sequence number at the + // time this RTT measurement period began. + rttMeasureSeqNumber seqnum.Value + + // rttMeasureTime is the absolute time at which the current rtt + // measurement period began. + rttMeasureTime time.Time `state:".(unixTime)"` + + // disabled is true if an explicit receive buffer is set for the + // endpoint. + disabled bool +} + +// ReceiveErrors collect segment receive errors within transport layer. +type ReceiveErrors struct { + tcpip.ReceiveErrors + + // SegmentQueueDropped is the number of segments dropped due to + // a full segment queue. + SegmentQueueDropped tcpip.StatCounter + + // ChecksumErrors is the number of segments dropped due to bad checksums. + ChecksumErrors tcpip.StatCounter + + // ListenOverflowSynDrop is the number of times the listen queue overflowed + // and a SYN was dropped. + ListenOverflowSynDrop tcpip.StatCounter + + // ListenOverflowAckDrop is the number of times the final ACK + // in the handshake was dropped due to overflow. + ListenOverflowAckDrop tcpip.StatCounter + + // ZeroRcvWindowState is the number of times we advertised + // a zero receive window when rcvList is full. + ZeroRcvWindowState tcpip.StatCounter +} + +// SendErrors collect segment send errors within the transport layer. +type SendErrors struct { + tcpip.SendErrors + + // SegmentSendToNetworkFailed is the number of TCP segments failed to be sent + // to the network endpoint. + SegmentSendToNetworkFailed tcpip.StatCounter + + // SynSendToNetworkFailed is the number of TCP SYNs failed to be sent + // to the network endpoint. + SynSendToNetworkFailed tcpip.StatCounter + + // Retransmits is the number of TCP segments retransmitted. + Retransmits tcpip.StatCounter + + // FastRetransmit is the number of segments retransmitted in fast + // recovery. + FastRetransmit tcpip.StatCounter + + // Timeouts is the number of times the RTO expired. + Timeouts tcpip.StatCounter +} + +// Stats holds statistics about the endpoint. +type Stats struct { + // SegmentsReceived is the number of TCP segments received that + // the transport layer successfully parsed. + SegmentsReceived tcpip.StatCounter + + // SegmentsSent is the number of TCP segments sent. + SegmentsSent tcpip.StatCounter + + // FailedConnectionAttempts is the number of times we saw Connect and + // Accept errors. + FailedConnectionAttempts tcpip.StatCounter + + // ReceiveErrors collects segment receive errors within the + // transport layer. + ReceiveErrors ReceiveErrors + + // ReadErrors collects segment read errors from an endpoint read call. + ReadErrors tcpip.ReadErrors + + // SendErrors collects segment send errors within the transport layer. + SendErrors SendErrors + + // WriteErrors collects segment write errors from an endpoint write call. + WriteErrors tcpip.WriteErrors +} + +// IsEndpointStats is an empty method to implement the tcpip.EndpointStats +// marker interface. +func (*Stats) IsEndpointStats() {} + +// EndpointInfo holds useful information about a transport endpoint which +// can be queried by monitoring tools. +// +// +stateify savable +type EndpointInfo struct { + stack.TransportEndpointInfo + + // HardError is meaningful only when state is stateError. It stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. HardError is protected by endpoint mu. + HardError *tcpip.Error `state:".(string)"` +} + +// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo +// marker interface. +func (*EndpointInfo) IsEndpointInfo() {} + +// endpoint represents a TCP endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. The protocol implementation, however, runs in a single +// goroutine. +// +// Each endpoint has a few mutexes: +// +// e.mu -> Primary mutex for an endpoint must be held for all operations except +// in e.Readiness where acquiring it will result in a deadlock in epoll +// implementation. +// +// The following three mutexes can be acquired independent of e.mu but if +// acquired with e.mu then e.mu must be acquired first. +// +// e.acceptMu -> protects acceptedChan. +// e.rcvListMu -> Protects the rcvList and associated fields. +// e.sndBufMu -> Protects the sndQueue and associated fields. +// e.lastErrorMu -> Protects the lastError field. +// +// LOCKING/UNLOCKING of the endpoint. The locking of an endpoint is different +// based on the context in which the lock is acquired. In the syscall context +// e.LockUser/e.UnlockUser should be used and when doing background processing +// e.mu.Lock/e.mu.Unlock should be used. The distinction is described below +// in brief. +// +// The reason for this locking behaviour is to avoid wakeups to handle packets. +// In cases where the endpoint is already locked the background processor can +// queue the packet up and go its merry way and the lock owner will eventually +// process the backlog when releasing the lock. Similarly when acquiring the +// lock from say a syscall goroutine we can implement a bit of spinning if we +// know that the lock is not held by another syscall goroutine. Background +// processors should never hold the lock for long and we can avoid an expensive +// sleep/wakeup by spinning for a shortwhile. +// +// For more details please see the detailed documentation on +// e.LockUser/e.UnlockUser methods. +// +// +stateify savable +type endpoint struct { + EndpointInfo + + // endpointEntry is used to queue endpoints for processing to the + // a given tcp processor goroutine. + // + // Precondition: epQueue.mu must be held to read/write this field.. + endpointEntry `state:"nosave"` + + // pendingProcessing is true if this endpoint is queued for processing + // to a TCP processor. + // + // Precondition: epQueue.mu must be held to read/write this field.. + pendingProcessing bool `state:"nosave"` + + // The following fields are initialized at creation time and do not + // change throughout the lifetime of the endpoint. + stack *stack.Stack `state:"manual"` + waiterQueue *waiter.Queue `state:"wait"` + uniqueID uint64 + + // lastError represents the last error that the endpoint reported; + // access to it is protected by the following mutex. + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` + + // The following fields are used to manage the receive queue. The + // protocol goroutine adds ready-for-delivery segments to rcvList, + // which are returned by Read() calls to users. + // + // Once the peer has closed its send side, rcvClosed is set to true + // to indicate to users that no more data is coming. + // + // rcvListMu can be taken after the endpoint mu below. + rcvListMu sync.Mutex `state:"nosave"` + rcvList segmentList `state:"wait"` + rcvClosed bool + rcvBufSize int + rcvBufUsed int + rcvAutoParams rcvBufAutoTuneParams + + // mu protects all endpoint fields unless documented otherwise. mu must + // be acquired before interacting with the endpoint fields. + mu sync.Mutex `state:"nosave"` + ownedByUser uint32 + + // state must be read/set using the EndpointState()/setEndpointState() + // methods. + state EndpointState `state:".(EndpointState)"` + + // origEndpointState is only used during a restore phase to save the + // endpoint state at restore time as the socket is moved to it's correct + // state. + origEndpointState EndpointState `state:"nosave"` + + isPortReserved bool `state:"manual"` + isRegistered bool `state:"manual"` + boundNICID tcpip.NICID + route stack.Route `state:"manual"` + ttl uint8 + v6only bool + isConnectNotified bool + // TCP should never broadcast but Linux nevertheless supports enabling/ + // disabling SO_BROADCAST, albeit as a NOOP. + broadcast bool + + // portFlags stores the current values of port related flags. + portFlags ports.Flags + + // Values used to reserve a port or register a transport endpoint + // (which ever happens first). + boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags + boundDest tcpip.FullAddress + + // effectiveNetProtos contains the network protocols actually in use. In + // most cases it will only contain "netProto", but in cases like IPv6 + // endpoints with v6only set to false, this could include multiple + // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., + // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped + // address). + effectiveNetProtos []tcpip.NetworkProtocolNumber + + // workerRunning specifies if a worker goroutine is running. + workerRunning bool + + // workerCleanup specifies if the worker goroutine must perform cleanup + // before exiting. This can only be set to true when workerRunning is + // also true, and they're both protected by the mutex. + workerCleanup bool + + // sendTSOk is used to indicate when the TS Option has been negotiated. + // When sendTSOk is true every non-RST segment should carry a TS as per + // RFC7323#section-1.1 + sendTSOk bool + + // recentTS is the timestamp that should be sent in the TSEcr field of + // the timestamp for future segments sent by the endpoint. This field is + // updated if required when a new segment is received by this endpoint. + // + // recentTS must be read/written atomically. + recentTS uint32 + + // tsOffset is a randomized offset added to the value of the + // TSVal field in the timestamp option. + tsOffset uint32 + + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + + // sackPermitted is set to true if the peer sends the TCPSACKPermitted + // option in the SYN/SYN-ACK. + sackPermitted bool + + // sack holds TCP SACK related information for this endpoint. + sack SACKInfo + + // bindToDevice is set to the NIC on which to bind or disabled if 0. + bindToDevice tcpip.NICID + + // delay enables Nagle's algorithm. + // + // delay is a boolean (0 is false) and must be accessed atomically. + delay uint32 + + // cork holds back segments until full. + // + // cork is a boolean (0 is false) and must be accessed atomically. + cork uint32 + + // scoreboard holds TCP SACK Scoreboard information for this endpoint. + scoreboard *SACKScoreboard + + // The options below aren't implemented, but we remember the user + // settings because applications expect to be able to set/query these + // options. + + // slowAck holds the negated state of quick ack. It is stubbed out and + // does nothing. + // + // slowAck is a boolean (0 is false) and must be accessed atomically. + slowAck uint32 + + // segmentQueue is used to hand received segments to the protocol + // goroutine. Segments are queued as long as the queue is not full, + // and dropped when it is. + segmentQueue segmentQueue `state:"wait"` + + // synRcvdCount is the number of connections for this endpoint that are + // in SYN-RCVD state. + synRcvdCount int + + // userMSS if non-zero is the MSS value explicitly set by the user + // for this endpoint using the TCP_MAXSEG setsockopt. + userMSS uint16 + + // maxSynRetries is the maximum number of SYN retransmits that TCP should + // send before aborting the attempt to connect. It cannot exceed 255. + // + // NOTE: This is currently a no-op and does not change the SYN + // retransmissions. + maxSynRetries uint8 + + // windowClamp is used to bound the size of the advertised window to + // this value. + windowClamp uint32 + + // The following fields are used to manage the send buffer. When + // segments are ready to be sent, they are added to sndQueue and the + // protocol goroutine is signaled via sndWaker. + // + // When the send side is closed, the protocol goroutine is notified via + // sndCloseWaker, and sndClosed is set to true. + sndBufMu sync.Mutex `state:"nosave"` + sndBufSize int + sndBufUsed int + sndClosed bool + sndBufInQueue seqnum.Size + sndQueue segmentList `state:"wait"` + sndWaker sleep.Waker `state:"manual"` + sndCloseWaker sleep.Waker `state:"manual"` + + // cc stores the name of the Congestion Control algorithm to use for + // this endpoint. + cc tcpip.CongestionControlOption + + // The following are used when a "packet too big" control packet is + // received. They are protected by sndBufMu. They are used to + // communicate to the main protocol goroutine how many such control + // messages have been received since the last notification was processed + // and what was the smallest MTU seen. + packetTooBigCount int + sndMTU int + + // newSegmentWaker is used to indicate to the protocol goroutine that + // it needs to wake up and handle new segments queued to it. + newSegmentWaker sleep.Waker `state:"manual"` + + // notificationWaker is used to indicate to the protocol goroutine that + // it needs to wake up and check for notifications. + notificationWaker sleep.Waker `state:"manual"` + + // notifyFlags is a bitmask of flags used to indicate to the protocol + // goroutine what it was notified; this is only accessed atomically. + notifyFlags uint32 `state:"nosave"` + + // keepalive manages TCP keepalive state. When the connection is idle + // (no data sent or received) for keepaliveIdle, we start sending + // keepalives every keepalive.interval. If we send keepalive.count + // without hearing a response, the connection is closed. + keepalive keepalive + + // userTimeout if non-zero specifies a user specified timeout for + // a connection w/ pending data to send. A connection that has pending + // unacked data will be forcibily aborted if the timeout is reached + // without any data being acked. + userTimeout time.Duration + + // deferAccept if non-zero specifies a user specified time during + // which the final ACK of a handshake will be dropped provided the + // ACK is a bare ACK and carries no data. If the timeout is crossed then + // the bare ACK is accepted and the connection is delivered to the + // listener. + deferAccept time.Duration + + // pendingAccepted is a synchronization primitive used to track number + // of connections that are queued up to be delivered to the accepted + // channel. We use this to ensure that all goroutines blocked on writing + // to the acceptedChan below terminate before we close acceptedChan. + pendingAccepted sync.WaitGroup `state:"nosave"` + + // acceptMu protects acceptedChan. + acceptMu sync.Mutex `state:"nosave"` + + // acceptCond is a condition variable that can be used to block on when + // acceptedChan is full and an endpoint is ready to be delivered. + // + // This condition variable is required because just blocking on sending + // to acceptedChan does not work in cases where endpoint.Listen is + // called twice with different backlog values. In such cases the channel + // is closed and a new one created. Any pending goroutines blocking on + // the write to the channel will panic. + // + // We use this condition variable to block/unblock goroutines which + // tried to deliver an endpoint but couldn't because accept backlog was + // full ( See: endpoint.deliverAccepted ). + acceptCond *sync.Cond `state:"nosave"` + + // acceptedChan is used by a listening endpoint protocol goroutine to + // send newly accepted connections to the endpoint so that they can be + // read by Accept() calls. + acceptedChan chan *endpoint `state:".([]*endpoint)"` + + // The following are only used from the protocol goroutine, and + // therefore don't need locks to protect them. + rcv *receiver `state:"wait"` + snd *sender `state:"wait"` + + // The goroutine drain completion notification channel. + drainDone chan struct{} `state:"nosave"` + + // The goroutine undrain notification channel. This is currently used as + // a way to block the worker goroutines. Today nothing closes/writes + // this channel and this causes any goroutines waiting on this to just + // block. This is used during save/restore to prevent worker goroutines + // from mutating state as it's being saved. + undrain chan struct{} `state:"nosave"` + + // probe if not nil is invoked on every received segment. It is passed + // a copy of the current state of the endpoint. + probe stack.TCPProbeFunc `state:"nosave"` + + // The following are only used to assist the restore run to re-connect. + connectingAddress tcpip.Address + + // amss is the advertised MSS to the peer by this endpoint. + amss uint16 + + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, + // applied while sending packets. Defaults to 0 as on Linux. + sendTOS uint8 + + gso *stack.GSO + + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats Stats `state:"nosave"` + + // tcpLingerTimeout is the maximum amount of a time a socket + // a socket stays in TIME_WAIT state before being marked + // closed. + tcpLingerTimeout time.Duration + + // closed indicates that the user has called closed on the + // endpoint and at this point the endpoint is only around + // to complete the TCP shutdown. + closed bool + + // txHash is the transport layer hash to be set on outbound packets + // emitted by this endpoint. + txHash uint32 + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner +} + +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +// calculateAdvertisedMSS calculates the MSS to advertise. +// +// If userMSS is non-zero and is not greater than the maximum possible MSS for +// r, it will be used; otherwise, the maximum possible MSS will be used. +func calculateAdvertisedMSS(userMSS uint16, r stack.Route) uint16 { + // The maximum possible MSS is dependent on the route. + maxMSS := mssForRoute(&r) + + if userMSS != 0 && userMSS < maxMSS { + return userMSS + } + + return maxMSS +} + +// LockUser tries to lock e.mu and if it fails it will check if the lock is held +// by another syscall goroutine. If yes, then it will goto sleep waiting for the +// lock to be released, if not then it will spin till it acquires the lock or +// another syscall goroutine acquires it in which case it will goto sleep as +// described above. +// +// The assumption behind spinning here being that background packet processing +// should not be holding the lock for long and spinning reduces latency as we +// avoid an expensive sleep/wakeup of of the syscall goroutine). +func (e *endpoint) LockUser() { + for { + // Try first if the sock is locked then check if it's owned + // by another user goroutine if not then we spin, otherwise + // we just goto sleep on the Lock() and wait. + if !e.mu.TryLock() { + // If socket is owned by the user then just goto sleep + // as the lock could be held for a reasonably long time. + if atomic.LoadUint32(&e.ownedByUser) == 1 { + e.mu.Lock() + atomic.StoreUint32(&e.ownedByUser, 1) + return + } + // Spin but yield the processor since the lower half + // should yield the lock soon. + runtime.Gosched() + continue + } + atomic.StoreUint32(&e.ownedByUser, 1) + return + } +} + +// UnlockUser will check if there are any segments already queued for processing +// and process any such segments before unlocking e.mu. This is required because +// we when packets arrive and endpoint lock is already held then such packets +// are queued up to be processed. If the lock is held by the endpoint goroutine +// then it will process these packets but if the lock is instead held by the +// syscall goroutine then we can have the syscall goroutine process the backlog +// before unlocking. +// +// This avoids an unnecessary wakeup of the endpoint protocol goroutine for the +// endpoint. It's also required eventually when we get rid of the endpoint +// protocol goroutine altogether. +// +// Precondition: e.LockUser() must have been called before calling e.UnlockUser() +func (e *endpoint) UnlockUser() { + // Lock segment queue before checking so that we avoid a race where + // segments can be queued between the time we check if queue is empty + // and actually unlock the endpoint mutex. + for { + e.segmentQueue.mu.Lock() + if e.segmentQueue.emptyLocked() { + if atomic.SwapUint32(&e.ownedByUser, 0) != 1 { + panic("e.UnlockUser() called without calling e.LockUser()") + } + e.mu.Unlock() + e.segmentQueue.mu.Unlock() + return + } + e.segmentQueue.mu.Unlock() + + switch e.EndpointState() { + case StateEstablished: + if err := e.handleSegments(true /* fastPath */); err != nil { + e.notifyProtocolGoroutine(notifyTickleWorker) + } + default: + // Since we are waking the endpoint goroutine here just unlock + // and let it process the queued segments. + e.newSegmentWaker.Assert() + if atomic.SwapUint32(&e.ownedByUser, 0) != 1 { + panic("e.UnlockUser() called without calling e.LockUser()") + } + e.mu.Unlock() + return + } + } +} + +// StopWork halts packet processing. Only to be used in tests. +func (e *endpoint) StopWork() { + e.mu.Lock() +} + +// ResumeWork resumes packet processing. Only to be used in tests. +func (e *endpoint) ResumeWork() { + e.mu.Unlock() +} + +// setEndpointState updates the state of the endpoint to state atomically. This +// method is unexported as the only place we should update the state is in this +// package but we allow the state to be read freely without holding e.mu. +// +// Precondition: e.mu must be held to call this method. +func (e *endpoint) setEndpointState(state EndpointState) { + oldstate := EndpointState(atomic.LoadUint32((*uint32)(&e.state))) + switch state { + case StateEstablished: + e.stack.Stats().TCP.CurrentEstablished.Increment() + e.stack.Stats().TCP.CurrentConnected.Increment() + case StateError: + fallthrough + case StateClose: + if oldstate == StateCloseWait || oldstate == StateEstablished { + e.stack.Stats().TCP.EstablishedResets.Increment() + } + fallthrough + default: + if oldstate == StateEstablished { + e.stack.Stats().TCP.CurrentEstablished.Decrement() + } + } + atomic.StoreUint32((*uint32)(&e.state), uint32(state)) +} + +// EndpointState returns the current state of the endpoint. +func (e *endpoint) EndpointState() EndpointState { + return EndpointState(atomic.LoadUint32((*uint32)(&e.state))) +} + +// setRecentTimestamp atomically sets the recentTS field to the +// provided value. +func (e *endpoint) setRecentTimestamp(recentTS uint32) { + atomic.StoreUint32(&e.recentTS, recentTS) +} + +// recentTimestamp atomically reads and returns the value of the recentTS field. +func (e *endpoint) recentTimestamp() uint32 { + return atomic.LoadUint32(&e.recentTS) +} + +// keepalive is a synchronization wrapper used to appease stateify. See the +// comment in endpoint, where it is used. +// +// +stateify savable +type keepalive struct { + sync.Mutex `state:"nosave"` + enabled bool + idle time.Duration + interval time.Duration + count int + unacked int + timer timer `state:"nosave"` + waker sleep.Waker `state:"nosave"` +} + +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { + e := &endpoint{ + stack: s, + EndpointInfo: EndpointInfo{ + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.TCPProtocolNumber, + }, + }, + waiterQueue: waiterQueue, + state: StateInitial, + rcvBufSize: DefaultReceiveBufferSize, + sndBufSize: DefaultSendBufferSize, + sndMTU: int(math.MaxInt32), + keepalive: keepalive{ + // Linux defaults. + idle: 2 * time.Hour, + interval: 75 * time.Second, + count: 9, + }, + uniqueID: s.UniqueID(), + txHash: s.Rand().Uint32(), + windowClamp: DefaultReceiveBufferSize, + maxSynRetries: DefaultSynRetries, + } + + var ss SendBufferSizeOption + if err := s.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + e.sndBufSize = ss.Default + } + + var rs ReceiveBufferSizeOption + if err := s.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + e.rcvBufSize = rs.Default + } + + var cs tcpip.CongestionControlOption + if err := s.TransportProtocolOption(ProtocolNumber, &cs); err == nil { + e.cc = cs + } + + var mrb tcpip.ModerateReceiveBufferOption + if err := s.TransportProtocolOption(ProtocolNumber, &mrb); err == nil { + e.rcvAutoParams.disabled = !bool(mrb) + } + + var de DelayEnabled + if err := s.TransportProtocolOption(ProtocolNumber, &de); err == nil && de { + e.SetSockOptBool(tcpip.DelayOption, true) + } + + var tcpLT tcpip.TCPLingerTimeoutOption + if err := s.TransportProtocolOption(ProtocolNumber, &tcpLT); err == nil { + e.tcpLingerTimeout = time.Duration(tcpLT) + } + + var synRetries tcpip.TCPSynRetriesOption + if err := s.TransportProtocolOption(ProtocolNumber, &synRetries); err == nil { + e.maxSynRetries = uint8(synRetries) + } + + if p := s.GetTCPProbe(); p != nil { + e.probe = p + } + + e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.tsOffset = timeStampOffset() + e.acceptCond = sync.NewCond(&e.acceptMu) + + return e +} + +// Readiness returns the current readiness of the endpoint. For example, if +// waiter.EventIn is set, the endpoint is immediately readable. +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + result := waiter.EventMask(0) + + switch e.EndpointState() { + case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: + // Ready for nothing. + + case StateClose, StateError: + // Ready for anything. + result = mask + + case StateListen: + // Check if there's anything in the accepted channel. + if (mask & waiter.EventIn) != 0 { + e.acceptMu.Lock() + if len(e.acceptedChan) > 0 { + result |= waiter.EventIn + } + e.acceptMu.Unlock() + } + } + if e.EndpointState().connected() { + // Determine if the endpoint is writable if requested. + if (mask & waiter.EventOut) != 0 { + e.sndBufMu.Lock() + if e.sndClosed || e.sndBufUsed < e.sndBufSize { + result |= waiter.EventOut + } + e.sndBufMu.Unlock() + } + + // Determine if the endpoint is readable if requested. + if (mask & waiter.EventIn) != 0 { + e.rcvListMu.Lock() + if e.rcvBufUsed > 0 || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvListMu.Unlock() + } + } + + return result +} + +func (e *endpoint) fetchNotifications() uint32 { + return atomic.SwapUint32(&e.notifyFlags, 0) +} + +func (e *endpoint) notifyProtocolGoroutine(n uint32) { + for { + v := atomic.LoadUint32(&e.notifyFlags) + if v&n == n { + // The flags are already set. + return + } + + if atomic.CompareAndSwapUint32(&e.notifyFlags, v, v|n) { + if v == 0 { + // We are causing a transition from no flags to + // at least one flag set, so we must cause the + // protocol goroutine to wake up. + e.notificationWaker.Assert() + } + return + } + } +} + +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + // The abort notification is not processed synchronously, so no + // synchronization is needed. + // + // If the endpoint becomes connected after this check, we still close + // the endpoint. This worst case results in a slower abort. + // + // If the endpoint disconnected after the check, nothing needs to be + // done, so sending a notification which will potentially be ignored is + // fine. + // + // If the endpoint connecting finishes after the check, the endpoint + // is either in a connected state (where we would notifyAbort anyway), + // SYN-RECV (where we would also notifyAbort anyway), or in an error + // state where nothing is required and the notification can be safely + // ignored. + // + // Endpoints where a Close during connecting or SYN-RECV state would be + // problematic are set to state connecting before being registered (and + // thus possible to be Aborted). They are never available in initial + // state. + // + // Endpoints transitioning from initial to connecting state may be + // safely either closed or sent notifyAbort. + if s := e.EndpointState(); s == StateConnecting || s == StateSynRecv || s.connected() { + e.notifyProtocolGoroutine(notifyAbort) + return + } + e.Close() +} + +// Close puts the endpoint in a closed state and frees all resources associated +// with it. It must be called only once and with no other concurrent calls to +// the endpoint. +func (e *endpoint) Close() { + e.LockUser() + defer e.UnlockUser() + if e.closed { + return + } + + // Issue a shutdown so that the peer knows we won't send any more data + // if we're connected, or stop accepting if we're listening. + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + e.closeNoShutdownLocked() +} + +// closeNoShutdown closes the endpoint without doing a full shutdown. This is +// used when a connection needs to be aborted with a RST and we want to skip +// a full 4 way TCP shutdown. +func (e *endpoint) closeNoShutdownLocked() { + // For listening sockets, we always release ports inline so that they + // are immediately available for reuse after Close() is called. If also + // registered, we unregister as well otherwise the next user would fail + // in Listen() when trying to register. + if e.EndpointState() == StateListen && e.isPortReserved { + if e.isRegistered { + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.isRegistered = false + } + + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) + e.isPortReserved = false + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} + } + + // Mark endpoint as closed. + e.closed = true + + switch e.EndpointState() { + case StateClose, StateError: + return + } + + // Either perform the local cleanup or kick the worker to make sure it + // knows it needs to cleanup. + if e.workerRunning { + e.workerCleanup = true + tcpip.AddDanglingEndpoint(e) + // Worker will remove the dangling endpoint when the endpoint + // goroutine terminates. + e.notifyProtocolGoroutine(notifyClose) + } else { + e.transitionToStateCloseLocked() + } +} + +// closePendingAcceptableConnections closes all connections that have completed +// handshake but not yet been delivered to the application. +func (e *endpoint) closePendingAcceptableConnectionsLocked() { + e.acceptMu.Lock() + if e.acceptedChan == nil { + e.acceptMu.Unlock() + return + } + close(e.acceptedChan) + ch := e.acceptedChan + e.acceptedChan = nil + e.acceptCond.Broadcast() + e.acceptMu.Unlock() + + // Reset all connections that are waiting to be accepted. + for n := range ch { + n.notifyProtocolGoroutine(notifyReset) + } + // Wait for reset of all endpoints that are still waiting to be delivered to + // the now closed acceptedChan. + e.pendingAccepted.Wait() +} + +// cleanupLocked frees all resources associated with the endpoint. It is called +// after Close() is called and the worker goroutine (if any) is done with its +// work. +func (e *endpoint) cleanupLocked() { + // Close all endpoints that might have been accepted by TCP but not by + // the client. + e.closePendingAcceptableConnectionsLocked() + + e.workerCleanup = false + + if e.isRegistered { + e.stack.StartTransportEndpointCleanup(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.isRegistered = false + } + + if e.isPortReserved { + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, e.boundDest) + e.isPortReserved = false + } + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} + e.boundDest = tcpip.FullAddress{} + + e.route.Release() + e.stack.CompleteTransportEndpointCleanup(e) + tcpip.DeleteDanglingEndpoint(e) +} + +// initialReceiveWindow returns the initial receive window to advertise in the +// SYN/SYN-ACK. +func (e *endpoint) initialReceiveWindow() int { + rcvWnd := e.receiveBufferAvailable() + if rcvWnd > math.MaxUint16 { + rcvWnd = math.MaxUint16 + } + + // Use the user supplied MSS, if available. + routeWnd := InitialCwnd * int(calculateAdvertisedMSS(e.userMSS, e.route)) * 2 + if rcvWnd > routeWnd { + rcvWnd = routeWnd + } + rcvWndScale := e.rcvWndScaleForHandshake() + + // Round-down the rcvWnd to a multiple of wndScale. This ensures that the + // window offered in SYN won't be reduced due to the loss of precision if + // window scaling is enabled after the handshake. + rcvWnd = (rcvWnd >> uint8(rcvWndScale)) << uint8(rcvWndScale) + + // Ensure we can always accept at least 1 byte if the scale specified + // was too high for the provided rcvWnd. + if rcvWnd == 0 { + rcvWnd = 1 + } + + return rcvWnd +} + +// ModerateRecvBuf adjusts the receive buffer and the advertised window +// based on the number of bytes copied to userspace. +func (e *endpoint) ModerateRecvBuf(copied int) { + e.LockUser() + defer e.UnlockUser() + + e.rcvListMu.Lock() + if e.rcvAutoParams.disabled { + e.rcvListMu.Unlock() + return + } + now := time.Now() + if rtt := e.rcvAutoParams.rtt; rtt == 0 || now.Sub(e.rcvAutoParams.measureTime) < rtt { + e.rcvAutoParams.copied += copied + e.rcvListMu.Unlock() + return + } + prevRTTCopied := e.rcvAutoParams.copied + copied + prevCopied := e.rcvAutoParams.prevCopied + rcvWnd := 0 + if prevRTTCopied > prevCopied { + // The minimal receive window based on what was copied by the app + // in the immediate preceding RTT and some extra buffer for 16 + // segments to account for variations. + // We multiply by 2 to account for packet losses. + rcvWnd = prevRTTCopied*2 + 16*int(e.amss) + + // Scale for slow start based on bytes copied in this RTT vs previous. + grow := (rcvWnd * (prevRTTCopied - prevCopied)) / prevCopied + + // Multiply growth factor by 2 again to account for sender being + // in slow-start where the sender grows it's congestion window + // by 100% per RTT. + rcvWnd += grow * 2 + + // Make sure auto tuned buffer size can always receive upto 2x + // the initial window of 10 segments. + if minRcvWnd := int(e.amss) * InitialCwnd * 2; rcvWnd < minRcvWnd { + rcvWnd = minRcvWnd + } + + // Cap the auto tuned buffer size by the maximum permissible + // receive buffer size. + if max := e.maxReceiveBufferSize(); rcvWnd > max { + rcvWnd = max + } + + // We do not adjust downwards as that can cause the receiver to + // reject valid data that might already be in flight as the + // acceptable window will shrink. + if rcvWnd > e.rcvBufSize { + availBefore := e.receiveBufferAvailableLocked() + e.rcvBufSize = rcvWnd + availAfter := e.receiveBufferAvailableLocked() + mask := uint32(notifyReceiveWindowChanged) + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { + mask |= notifyNonZeroReceiveWindow + } + e.notifyProtocolGoroutine(mask) + } + + // We only update prevCopied when we grow the buffer because in cases + // where prevCopied > prevRTTCopied the existing buffer is already big + // enough to handle the current rate and we don't need to do any + // adjustments. + e.rcvAutoParams.prevCopied = prevRTTCopied + } + e.rcvAutoParams.measureTime = now + e.rcvAutoParams.copied = 0 + e.rcvListMu.Unlock() +} + +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} + +// Read reads data from the endpoint. +func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + e.LockUser() + defer e.UnlockUser() + + // When in SYN-SENT state, let the caller block on the receive. + // An application can initiate a non-blocking connect and then block + // on a receive. It can expect to read any data after the handshake + // is complete. RFC793, section 3.9, p58. + if e.EndpointState() == StateSynSent { + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + } + + // The endpoint can be read if it's connected, or if it's already closed + // but has some pending unread data. Also note that a RST being received + // would cause the state to become StateError so we should allow the + // reads to proceed before returning a ECONNRESET. + e.rcvListMu.Lock() + bufUsed := e.rcvBufUsed + if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { + e.rcvListMu.Unlock() + he := e.HardError + if s == StateError { + return buffer.View{}, tcpip.ControlMessages{}, he + } + e.stats.ReadErrors.NotConnected.Increment() + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected + } + + v, err := e.readLocked() + e.rcvListMu.Unlock() + + if err == tcpip.ErrClosedForReceive { + e.stats.ReadErrors.ReadClosed.Increment() + } + return v, tcpip.ControlMessages{}, err +} + +func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { + if e.rcvBufUsed == 0 { + if e.rcvClosed || !e.EndpointState().connected() { + return buffer.View{}, tcpip.ErrClosedForReceive + } + return buffer.View{}, tcpip.ErrWouldBlock + } + + s := e.rcvList.Front() + views := s.data.Views() + v := views[s.viewToDeliver] + s.viewToDeliver++ + + if s.viewToDeliver >= len(views) { + e.rcvList.Remove(s) + s.decRef() + } + + e.rcvBufUsed -= len(v) + + // If the window was small before this read and if the read freed up + // enough buffer space, to either fit an aMSS or half a receive buffer + // (whichever smaller), then notify the protocol goroutine to send a + // window update. + if crossed, above := e.windowCrossedACKThresholdLocked(len(v)); crossed && above { + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + } + + return v, nil +} + +// isEndpointWritableLocked checks if a given endpoint is writable +// and also returns the number of bytes that can be written at this +// moment. If the endpoint is not writable then it returns an error +// indicating the reason why it's not writable. +// Caller must hold e.mu and e.sndBufMu +func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { + // The endpoint cannot be written to if it's not connected. + if !e.EndpointState().connected() { + switch e.EndpointState() { + case StateError: + return 0, e.HardError + default: + return 0, tcpip.ErrClosedForSend + } + } + + // Check if the connection has already been closed for sends. + if e.sndClosed { + return 0, tcpip.ErrClosedForSend + } + + avail := e.sndBufSize - e.sndBufUsed + if avail <= 0 { + return 0, tcpip.ErrWouldBlock + } + return avail, nil +} + +// Write writes data to the endpoint's peer. +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + // Linux completely ignores any address passed to sendto(2) for TCP sockets + // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More + // and opts.EndOfRecord are also ignored. + + e.LockUser() + e.sndBufMu.Lock() + + avail, err := e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.UnlockUser() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } + + // We can release locks while copying data. + // + // This is not possible if atomic is set, because we can't allow the + // available buffer space to be consumed by some other caller while we + // are copying data in. + if !opts.Atomic { + e.sndBufMu.Unlock() + e.UnlockUser() + } + + // Fetch data. + v, perr := p.Payload(avail) + if perr != nil || len(v) == 0 { + // Note that perr may be nil if len(v) == 0. + if opts.Atomic { + e.sndBufMu.Unlock() + e.UnlockUser() + } + return 0, nil, perr + } + + queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) { + // Add data to the send queue. + s := newSegmentFromView(&e.route, e.ID, v) + e.sndBufUsed += len(v) + e.sndBufInQueue += seqnum.Size(len(v)) + e.sndQueue.PushBack(s) + e.sndBufMu.Unlock() + + // Do the work inline. + e.handleWrite() + e.UnlockUser() + return int64(len(v)), nil, nil + } + + if opts.Atomic { + // Locks released in queueAndSend() + return queueAndSend() + } + + // Since we released locks in between it's possible that the + // endpoint transitioned to a CLOSED/ERROR states so make + // sure endpoint is still writable before trying to write. + e.LockUser() + e.sndBufMu.Lock() + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.UnlockUser() + e.stats.WriteErrors.WriteClosed.Increment() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due + // to a simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } + + // Locks released in queueAndSend() + return queueAndSend() +} + +// Peek reads data without consuming it from the endpoint. +// +// This method does not block if there is no data pending. +func (e *endpoint) Peek(vec [][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { + e.LockUser() + defer e.UnlockUser() + + // The endpoint can be read if it's connected, or if it's already closed + // but has some pending unread data. + if s := e.EndpointState(); !s.connected() && s != StateClose { + if s == StateError { + return 0, tcpip.ControlMessages{}, e.HardError + } + e.stats.ReadErrors.InvalidEndpointState.Increment() + return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + } + + e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() + + if e.rcvBufUsed == 0 { + if e.rcvClosed || !e.EndpointState().connected() { + e.stats.ReadErrors.ReadClosed.Increment() + return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive + } + return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + } + + // Make a copy of vec so we can modify the slide headers. + vec = append([][]byte(nil), vec...) + + var num int64 + for s := e.rcvList.Front(); s != nil; s = s.Next() { + views := s.data.Views() + + for i := s.viewToDeliver; i < len(views); i++ { + v := views[i] + + for len(v) > 0 { + if len(vec) == 0 { + return num, tcpip.ControlMessages{}, nil + } + if len(vec[0]) == 0 { + vec = vec[1:] + continue + } + + n := copy(vec[0], v) + v = v[n:] + vec[0] = vec[0][n:] + num += int64(n) + } + } + } + + return num, tcpip.ControlMessages{}, nil +} + +// windowCrossedACKThresholdLocked checks if the receive window to be announced +// now would be under aMSS or under half receive buffer, whichever smaller. This +// is useful as a receive side silly window syndrome prevention mechanism. If +// window grows to reasonable value, we should send ACK to the sender to inform +// the rx space is now large. We also want ensure a series of small read()'s +// won't trigger a flood of spurious tiny ACK's. +// +// For large receive buffers, the threshold is aMSS - once reader reads more +// than aMSS we'll send ACK. For tiny receive buffers, the threshold is half of +// receive buffer size. This is chosen arbitrairly. +// crossed will be true if the window size crossed the ACK threshold. +// above will be true if the new window is >= ACK threshold and false +// otherwise. +// +// Precondition: e.mu and e.rcvListMu must be held. +func (e *endpoint) windowCrossedACKThresholdLocked(deltaBefore int) (crossed bool, above bool) { + newAvail := e.receiveBufferAvailableLocked() + oldAvail := newAvail - deltaBefore + if oldAvail < 0 { + oldAvail = 0 + } + + threshold := int(e.amss) + if threshold > e.rcvBufSize/2 { + threshold = e.rcvBufSize / 2 + } + + switch { + case oldAvail < threshold && newAvail >= threshold: + return true, true + case oldAvail >= threshold && newAvail < threshold: + return true, false + } + return false, false +} + +// SetSockOptBool sets a socket option. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + + case tcpip.BroadcastOption: + e.LockUser() + e.broadcast = v + e.UnlockUser() + + case tcpip.CorkOption: + e.LockUser() + if !v { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + e.UnlockUser() + + case tcpip.DelayOption: + if v { + atomic.StoreUint32(&e.delay, 1) + } else { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + e.keepalive.enabled = v + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.QuickAckOption: + o := uint32(1) + if v { + o = 0 + } + atomic.StoreUint32(&e.slowAck, o) + + case tcpip.ReuseAddressOption: + e.LockUser() + e.portFlags.TupleOnly = v + e.UnlockUser() + + case tcpip.ReusePortOption: + e.LockUser() + e.portFlags.LoadBalanced = v + e.UnlockUser() + + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + // We only allow this to be set when we're in the initial state. + if e.EndpointState() != StateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.LockUser() + e.v6only = v + e.UnlockUser() + } + + return nil +} + +// SetSockOptInt sets a socket option. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + // Lower 2 bits represents ECN bits. RFC 3168, section 23.1 + const inetECNMask = 3 + + switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + e.keepalive.count = v + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.IPv4TOSOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + // TODO(gvisor.dev/issue/995): ECN is not currently supported, + // ignore the bits for now. + e.sendTOS = uint8(v) & ^uint8(inetECNMask) + e.UnlockUser() + + case tcpip.MaxSegOption: + userMSS := v + if userMSS < header.TCPMinimumMSS || userMSS > header.TCPMaximumMSS { + return tcpip.ErrInvalidOptionValue + } + e.LockUser() + e.userMSS = uint16(userMSS) + e.UnlockUser() + e.notifyProtocolGoroutine(notifyMSSChanged) + + case tcpip.MTUDiscoverOption: + // Return not supported if attempting to set this option to + // anything other than path MTU discovery disabled. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if v < rs.Min { + v = rs.Min + } + if v > rs.Max { + v = rs.Max + } + } + + mask := uint32(notifyReceiveWindowChanged) + + e.LockUser() + e.rcvListMu.Lock() + + // Make sure the receive buffer size allows us to send a + // non-zero window size. + scale := uint8(0) + if e.rcv != nil { + scale = e.rcv.rcvWndScale + } + if v>>scale == 0 { + v = 1 << scale + } + + // Make sure 2*size doesn't overflow. + if v > math.MaxInt32/2 { + v = math.MaxInt32 / 2 + } + + availBefore := e.receiveBufferAvailableLocked() + e.rcvBufSize = v + availAfter := e.receiveBufferAvailableLocked() + + e.rcvAutoParams.disabled = true + + // Immediately send an ACK to uncork the sender silly window + // syndrome prevetion, when our available space grows above aMSS + // or half receive buffer, whichever smaller. + if crossed, above := e.windowCrossedACKThresholdLocked(availAfter - availBefore); crossed && above { + mask |= notifyNonZeroReceiveWindow + } + + e.rcvListMu.Unlock() + e.UnlockUser() + e.notifyProtocolGoroutine(mask) + + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss SendBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if v < ss.Min { + v = ss.Min + } + if v > ss.Max { + v = ss.Max + } + } + + e.sndBufMu.Lock() + e.sndBufSize = v + e.sndBufMu.Unlock() + + case tcpip.TTLOption: + e.LockUser() + e.ttl = uint8(v) + e.UnlockUser() + + case tcpip.TCPSynCountOption: + if v < 1 || v > 255 { + return tcpip.ErrInvalidOptionValue + } + e.LockUser() + e.maxSynRetries = uint8(v) + e.UnlockUser() + + case tcpip.TCPWindowClampOption: + if v == 0 { + e.LockUser() + switch e.EndpointState() { + case StateClose, StateInitial: + e.windowClamp = 0 + e.UnlockUser() + return nil + default: + e.UnlockUser() + return tcpip.ErrInvalidOptionValue + } + } + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if v < rs.Min/2 { + v = rs.Min / 2 + } + } + e.LockUser() + e.windowClamp = uint32(v) + e.UnlockUser() + } + return nil +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.BindToDeviceOption: + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice + } + e.LockUser() + e.bindToDevice = id + e.UnlockUser() + + case tcpip.KeepaliveIdleOption: + e.keepalive.Lock() + e.keepalive.idle = time.Duration(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.KeepaliveIntervalOption: + e.keepalive.Lock() + e.keepalive.interval = time.Duration(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + + case tcpip.OutOfBandInlineOption: + // We don't currently support disabling this option. + + case tcpip.TCPUserTimeoutOption: + e.LockUser() + e.userTimeout = time.Duration(v) + e.UnlockUser() + + case tcpip.CongestionControlOption: + // Query the available cc algorithms in the stack and + // validate that the specified algorithm is actually + // supported in the stack. + var avail tcpip.AvailableCongestionControlOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil { + return err + } + availCC := strings.Split(string(avail), " ") + for _, cc := range availCC { + if v == tcpip.CongestionControlOption(cc) { + e.LockUser() + state := e.EndpointState() + e.cc = v + switch state { + case StateEstablished: + if e.EndpointState() == state { + e.snd.cc = e.snd.initCongestionControl(e.cc) + } + } + e.UnlockUser() + return nil + } + } + + // Linux returns ENOENT when an invalid congestion + // control algorithm is specified. + return tcpip.ErrNoSuchFile + + case tcpip.TCPLingerTimeoutOption: + e.LockUser() + if v < 0 { + // Same as effectively disabling TCPLinger timeout. + v = 0 + } + var stkTCPLingerTimeout tcpip.TCPLingerTimeoutOption + if err := e.stack.TransportProtocolOption(header.TCPProtocolNumber, &stkTCPLingerTimeout); err != nil { + // We were unable to retrieve a stack config, just use + // the DefaultTCPLingerTimeout. + if v > tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) { + stkTCPLingerTimeout = tcpip.TCPLingerTimeoutOption(DefaultTCPLingerTimeout) + } + } + // Cap it to the stack wide TCPLinger timeout. + if v > stkTCPLingerTimeout { + v = stkTCPLingerTimeout + } + e.tcpLingerTimeout = time.Duration(v) + e.UnlockUser() + + case tcpip.TCPDeferAcceptOption: + e.LockUser() + if time.Duration(v) > MaxRTO { + v = tcpip.TCPDeferAcceptOption(MaxRTO) + } + e.deferAccept = time.Duration(v) + e.UnlockUser() + + default: + return nil + } + return nil +} + +// readyReceiveSize returns the number of bytes ready to be received. +func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { + e.LockUser() + defer e.UnlockUser() + + // The endpoint cannot be in listen state. + if e.EndpointState() == StateListen { + return 0, tcpip.ErrInvalidEndpointState + } + + e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() + + return e.rcvBufUsed, nil +} + +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.BroadcastOption: + e.LockUser() + v := e.broadcast + e.UnlockUser() + return v, nil + + case tcpip.CorkOption: + return atomic.LoadUint32(&e.cork) != 0, nil + + case tcpip.DelayOption: + return atomic.LoadUint32(&e.delay) != 0, nil + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + v := e.keepalive.enabled + e.keepalive.Unlock() + + return v, nil + + case tcpip.QuickAckOption: + v := atomic.LoadUint32(&e.slowAck) == 0 + return v, nil + + case tcpip.ReuseAddressOption: + e.LockUser() + v := e.portFlags.TupleOnly + e.UnlockUser() + + return v, nil + + case tcpip.ReusePortOption: + e.LockUser() + v := e.portFlags.LoadBalanced + e.UnlockUser() + + return v, nil + + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrUnknownProtocolOption + } + + e.LockUser() + v := e.v6only + e.UnlockUser() + + return v, nil + + case tcpip.MulticastLoopOption: + return true, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { + switch opt { + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + v := e.keepalive.count + e.keepalive.Unlock() + return v, nil + + case tcpip.IPv4TOSOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.LockUser() + v := int(e.sendTOS) + e.UnlockUser() + return v, nil + + case tcpip.MaxSegOption: + // This is just stubbed out. Linux never returns the user_mss + // value as it either returns the defaultMSS or returns the + // actual current MSS. Netstack just returns the defaultMSS + // always for now. + v := header.TCPDefaultMSS + return v, nil + + case tcpip.MTUDiscoverOption: + // Always return the path MTU discovery disabled setting since + // it's the only one supported. + return tcpip.PMTUDiscoveryDont, nil + + case tcpip.ReceiveQueueSizeOption: + return e.readyReceiveSize() + + case tcpip.SendBufferSizeOption: + e.sndBufMu.Lock() + v := e.sndBufSize + e.sndBufMu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvListMu.Lock() + v := e.rcvBufSize + e.rcvListMu.Unlock() + return v, nil + + case tcpip.TTLOption: + e.LockUser() + v := int(e.ttl) + e.UnlockUser() + return v, nil + + case tcpip.TCPSynCountOption: + e.LockUser() + v := int(e.maxSynRetries) + e.UnlockUser() + return v, nil + + case tcpip.TCPWindowClampOption: + e.LockUser() + v := int(e.windowClamp) + e.UnlockUser() + return v, nil + + case tcpip.MulticastTTLOption: + return 1, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + e.lastErrorMu.Lock() + err := e.lastError + e.lastError = nil + e.lastErrorMu.Unlock() + return err + + case *tcpip.BindToDeviceOption: + e.LockUser() + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.UnlockUser() + + case *tcpip.TCPInfoOption: + *o = tcpip.TCPInfoOption{} + e.LockUser() + snd := e.snd + e.UnlockUser() + if snd != nil { + snd.rtt.Lock() + o.RTT = snd.rtt.srtt + o.RTTVar = snd.rtt.rttvar + snd.rtt.Unlock() + } + + case *tcpip.KeepaliveIdleOption: + e.keepalive.Lock() + *o = tcpip.KeepaliveIdleOption(e.keepalive.idle) + e.keepalive.Unlock() + + case *tcpip.KeepaliveIntervalOption: + e.keepalive.Lock() + *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval) + e.keepalive.Unlock() + + case *tcpip.TCPUserTimeoutOption: + e.LockUser() + *o = tcpip.TCPUserTimeoutOption(e.userTimeout) + e.UnlockUser() + + case *tcpip.OutOfBandInlineOption: + // We don't currently support disabling this option. + *o = 1 + + case *tcpip.CongestionControlOption: + e.LockUser() + *o = e.cc + e.UnlockUser() + + case *tcpip.TCPLingerTimeoutOption: + e.LockUser() + *o = tcpip.TCPLingerTimeoutOption(e.tcpLingerTimeout) + e.UnlockUser() + + case *tcpip.TCPDeferAcceptOption: + e.LockUser() + *o = tcpip.TCPDeferAcceptOption(e.deferAccept) + e.UnlockUser() + + default: + return tcpip.ErrUnknownProtocolOption + } + return nil +} + +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + if err != nil { + return tcpip.FullAddress{}, 0, err + } + return unwrapped, netProto, nil +} + +// Disconnect implements tcpip.Endpoint.Disconnect. +func (*endpoint) Disconnect() *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Connect connects the endpoint to its peer. +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + err := e.connect(addr, true, true) + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } + return err +} + +// connect connects the endpoint to its peer. In the normal non-S/R case, the +// new connection is expected to run the main goroutine and perform handshake. +// In restore of previously connected endpoints, both ends will be passively +// created (so no new handshaking is done); for stack-accepted connections not +// yet accepted by the app, they are restored without running the main goroutine +// here. +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tcpip.Error { + e.LockUser() + defer e.UnlockUser() + + connectingAddr := addr.Addr + + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + + if e.EndpointState().connected() { + // The endpoint is already connected. If caller hasn't been + // notified yet, return success. + if !e.isConnectNotified { + e.isConnectNotified = true + return nil + } + // Otherwise return that it's already connected. + return tcpip.ErrAlreadyConnected + } + + nicID := addr.NIC + switch e.EndpointState() { + case StateBound: + // If we're already bound to a NIC but the caller is requesting + // that we use a different one now, we cannot proceed. + if e.boundNICID == 0 { + break + } + + if nicID != 0 && nicID != e.boundNICID { + return tcpip.ErrNoRoute + } + + nicID = e.boundNICID + + case StateInitial: + // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) + // when we find a route. + + case StateConnecting, StateSynSent, StateSynRecv: + // A connection request has already been issued but hasn't completed + // yet. + return tcpip.ErrAlreadyConnecting + + case StateError: + return e.HardError + + default: + return tcpip.ErrInvalidEndpointState + } + + // Find a route to the desired destination. + r, err := e.stack.FindRoute(nicID, e.ID.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + netProtos := []tcpip.NetworkProtocolNumber{netProto} + e.ID.LocalAddress = r.LocalAddress + e.ID.RemoteAddress = r.RemoteAddress + e.ID.RemotePort = addr.Port + + if e.ID.LocalPort != 0 { + // The endpoint is bound to a port, attempt to register it. + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + if err != nil { + return err + } + } else { + // The endpoint doesn't have a local port yet, so try to get + // one. Make sure that it isn't one that will result in the same + // address/port for both local and remote (otherwise this + // endpoint would be trying to connect to itself). + sameAddr := e.ID.LocalAddress == e.ID.RemoteAddress + + // Calculate a port offset based on the destination IP/port and + // src IP to ensure that for a given tuple (srcIP, destIP, + // destPort) the offset used as a starting point is the same to + // ensure that we can cycle through the port space effectively. + h := jenkins.Sum32(e.stack.Seed()) + h.Write([]byte(e.ID.LocalAddress)) + h.Write([]byte(e.ID.RemoteAddress)) + portBuf := make([]byte, 2) + binary.LittleEndian.PutUint16(portBuf, e.ID.RemotePort) + h.Write(portBuf) + portOffset := h.Sum32() + + if _, err := e.stack.PickEphemeralPortStable(portOffset, func(p uint16) (bool, *tcpip.Error) { + if sameAddr && p == e.ID.RemotePort { + return false, nil + } + if _, err := e.stack.ReservePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr); err != nil { + return false, nil + } + + id := e.ID + id.LocalPort = p + if err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.portFlags, e.bindToDevice); err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, e.ID.LocalAddress, p, e.portFlags, e.bindToDevice, addr) + if err == tcpip.ErrPortInUse { + return false, nil + } + return false, err + } + + // Port picking successful. Save the details of + // the selected port. + e.ID = id + e.isPortReserved = true + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = e.portFlags + e.boundDest = addr + return true, nil + }); err != nil { + return err + } + } + + e.isRegistered = true + e.setEndpointState(StateConnecting) + e.route = r.Clone() + e.boundNICID = nicID + e.effectiveNetProtos = netProtos + e.connectingAddress = connectingAddr + + e.initGSO() + + // Connect in the restore phase does not perform handshake. Restore its + // connection setting here. + if !handshake { + e.segmentQueue.mu.Lock() + for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { + for s := l.Front(); s != nil; s = s.Next() { + s.id = e.ID + s.route = r.Clone() + e.sndWaker.Assert() + } + } + e.segmentQueue.mu.Unlock() + e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) + e.setEndpointState(StateEstablished) + } + + if run { + e.workerRunning = true + e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() + go e.protocolMainLoop(handshake, nil) // S/R-SAFE: will be drained before save. + } + + return tcpip.ErrConnectStarted +} + +// ConnectEndpoint is not supported. +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + +// Shutdown closes the read and/or write end of the endpoint connection to its +// peer. +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.LockUser() + defer e.UnlockUser() + return e.shutdownLocked(flags) +} + +func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { + e.shutdownFlags |= flags + switch { + case e.EndpointState().connected(): + // Close for read. + if e.shutdownFlags&tcpip.ShutdownRead != 0 { + // Mark read side as closed. + e.rcvListMu.Lock() + e.rcvClosed = true + rcvBufUsed := e.rcvBufUsed + e.rcvListMu.Unlock() + + // If we're fully closed and we have unread data we need to abort + // the connection with a RST. + if e.shutdownFlags&tcpip.ShutdownWrite != 0 && rcvBufUsed > 0 { + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + // Wake up worker to terminate loop. + e.notifyProtocolGoroutine(notifyTickleWorker) + return nil + } + } + + // Close for write. + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { + e.sndBufMu.Lock() + if e.sndClosed { + // Already closed. + e.sndBufMu.Unlock() + if e.EndpointState() == StateTimeWait { + return tcpip.ErrNotConnected + } + return nil + } + + // Queue fin segment. + s := newSegmentFromView(&e.route, e.ID, nil) + e.sndQueue.PushBack(s) + e.sndBufInQueue++ + // Mark endpoint as closed. + e.sndClosed = true + e.sndBufMu.Unlock() + e.handleClose() + } + + return nil + case e.EndpointState() == StateListen: + if e.shutdownFlags&tcpip.ShutdownRead != 0 { + // Reset all connections from the accept queue and keep the + // worker running so that it can continue handling incoming + // segments by replying with RST. + // + // By not removing this endpoint from the demuxer mapping, we + // ensure that any other bind to the same port fails, as on Linux. + e.rcvListMu.Lock() + e.rcvClosed = true + e.rcvListMu.Unlock() + e.closePendingAcceptableConnectionsLocked() + // Notify waiters that the endpoint is shutdown. + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut | waiter.EventHUp | waiter.EventErr) + } + return nil + default: + return tcpip.ErrNotConnected + } +} + +// Listen puts the endpoint in "listen" mode, which allows it to accept +// new connections. +func (e *endpoint) Listen(backlog int) *tcpip.Error { + err := e.listen(backlog) + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + e.stats.FailedConnectionAttempts.Increment() + } + return err +} + +func (e *endpoint) listen(backlog int) *tcpip.Error { + e.LockUser() + defer e.UnlockUser() + + if e.EndpointState() == StateListen && !e.closed { + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + if e.acceptedChan == nil { + // listen is called after shutdown. + e.acceptedChan = make(chan *endpoint, backlog) + e.shutdownFlags = 0 + e.rcvListMu.Lock() + e.rcvClosed = false + e.rcvListMu.Unlock() + } else { + // Adjust the size of the channel iff we can fix + // existing pending connections into the new one. + if len(e.acceptedChan) > backlog { + return tcpip.ErrInvalidEndpointState + } + if cap(e.acceptedChan) == backlog { + return nil + } + origChan := e.acceptedChan + e.acceptedChan = make(chan *endpoint, backlog) + close(origChan) + for ep := range origChan { + e.acceptedChan <- ep + } + } + + // Notify any blocked goroutines that they can attempt to + // deliver endpoints again. + e.acceptCond.Broadcast() + + return nil + } + + if e.EndpointState() == StateInitial { + // The listen is called on an unbound socket, the socket is + // automatically bound to a random free port with the local + // address set to INADDR_ANY. + if err := e.bindLocked(tcpip.FullAddress{}); err != nil { + return err + } + } + + // Endpoint must be bound before it can transition to listen mode. + if e.EndpointState() != StateBound { + e.stats.ReadErrors.InvalidEndpointState.Increment() + return tcpip.ErrInvalidEndpointState + } + + // Register the endpoint. + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice); err != nil { + return err + } + + e.isRegistered = true + e.setEndpointState(StateListen) + + // The channel may be non-nil when we're restoring the endpoint, and it + // may be pre-populated with some previously accepted (but not Accepted) + // endpoints. + e.acceptMu.Lock() + if e.acceptedChan == nil { + e.acceptedChan = make(chan *endpoint, backlog) + } + e.acceptMu.Unlock() + + e.workerRunning = true + go e.protocolListenLoop( // S/R-SAFE: drained on save. + seqnum.Size(e.receiveBufferAvailable())) + return nil +} + +// startAcceptedLoop sets up required state and starts a goroutine with the +// main loop for accepted connections. +func (e *endpoint) startAcceptedLoop() { + e.workerRunning = true + e.mu.Unlock() + wakerInitDone := make(chan struct{}) + go e.protocolMainLoop(false, wakerInitDone) // S/R-SAFE: drained on save. + <-wakerInitDone +} + +// Accept returns a new endpoint if a peer has established a connection +// to an endpoint previously set to listen mode. +func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + e.LockUser() + defer e.UnlockUser() + + e.rcvListMu.Lock() + rcvClosed := e.rcvClosed + e.rcvListMu.Unlock() + // Endpoint must be in listen state before it can accept connections. + if rcvClosed || e.EndpointState() != StateListen { + return nil, nil, tcpip.ErrInvalidEndpointState + } + + // Get the new accepted endpoint. + e.acceptMu.Lock() + defer e.acceptMu.Unlock() + var n *endpoint + select { + case n = <-e.acceptedChan: + e.acceptCond.Signal() + default: + return nil, nil, tcpip.ErrWouldBlock + } + return n, n.waiterQueue, nil +} + +// Bind binds the endpoint to a specific local port and optionally address. +func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { + e.LockUser() + defer e.UnlockUser() + + return e.bindLocked(addr) +} + +func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { + // Don't allow binding once endpoint is not in the initial state + // anymore. This is because once the endpoint goes into a connected or + // listen state, it is already bound. + if e.EndpointState() != StateInitial { + return tcpip.ErrAlreadyBound + } + + e.BindAddr = addr.Addr + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } + } + + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) + if err != nil { + return err + } + + e.boundBindToDevice = e.bindToDevice + e.boundPortFlags = e.portFlags + e.isPortReserved = true + e.effectiveNetProtos = netProtos + e.ID.LocalPort = port + + // Any failures beyond this point must remove the port registration. + defer func(portFlags ports.Flags, bindToDevice tcpip.NICID) { + if err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port, portFlags, bindToDevice, tcpip.FullAddress{}) + e.isPortReserved = false + e.effectiveNetProtos = nil + e.ID.LocalPort = 0 + e.ID.LocalAddress = "" + e.boundNICID = 0 + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} + } + }(e.boundPortFlags, e.boundBindToDevice) + + // If an address is specified, we must ensure that it's one of our + // local addresses. + if len(addr.Addr) != 0 { + nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nic == 0 { + return tcpip.ErrBadLocalAddress + } + + e.boundNICID = nic + e.ID.LocalAddress = addr.Addr + } + + if err := e.stack.CheckRegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e.boundPortFlags, e.boundBindToDevice); err != nil { + return err + } + + // Mark endpoint as bound. + e.setEndpointState(StateBound) + + return nil +} + +// GetLocalAddress returns the address to which the endpoint is bound. +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.LockUser() + defer e.UnlockUser() + + return tcpip.FullAddress{ + Addr: e.ID.LocalAddress, + Port: e.ID.LocalPort, + NIC: e.boundNICID, + }, nil +} + +// GetRemoteAddress returns the address to which the endpoint is connected. +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.LockUser() + defer e.UnlockUser() + + if !e.EndpointState().connected() { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, + NIC: e.boundNICID, + }, nil +} + +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + // TCP HandlePacket is not required anymore as inbound packets first + // land at the Dispatcher which then can either delivery using the + // worker go routine or directly do the invoke the tcp processing inline + // based on the state of the endpoint. +} + +func (e *endpoint) enqueueSegment(s *segment) bool { + // Send packet to worker goroutine. + if !e.segmentQueue.enqueue(s) { + // The queue is full, so we drop the segment. + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.SegmentQueueDropped.Increment() + return false + } + return true +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { + switch typ { + case stack.ControlPacketTooBig: + e.sndBufMu.Lock() + e.packetTooBigCount++ + if v := int(extra); v < e.sndMTU { + e.sndMTU = v + } + e.sndBufMu.Unlock() + + e.notifyProtocolGoroutine(notifyMTUChanged) + } +} + +// updateSndBufferUsage is called by the protocol goroutine when room opens up +// in the send buffer. The number of newly available bytes is v. +func (e *endpoint) updateSndBufferUsage(v int) { + e.sndBufMu.Lock() + notify := e.sndBufUsed >= e.sndBufSize>>1 + e.sndBufUsed -= v + // We only notify when there is half the sndBufSize available after + // a full buffer event occurs. This ensures that we don't wake up + // writers to queue just 1-2 segments and go back to sleep. + notify = notify && e.sndBufUsed < e.sndBufSize>>1 + e.sndBufMu.Unlock() + + if notify { + e.waiterQueue.Notify(waiter.EventOut) + } +} + +// readyToRead is called by the protocol goroutine when a new segment is ready +// to be read, or when the connection is closed for receiving (in which case +// s will be nil). +func (e *endpoint) readyToRead(s *segment) { + e.rcvListMu.Lock() + if s != nil { + s.incRef() + e.rcvBufUsed += s.data.Size() + // Increase counter if the receive window falls down below MSS + // or half receive buffer size, whichever smaller. + if crossed, above := e.windowCrossedACKThresholdLocked(-s.data.Size()); crossed && !above { + e.stats.ReceiveErrors.ZeroRcvWindowState.Increment() + } + e.rcvList.PushBack(s) + } else { + e.rcvClosed = true + } + e.rcvListMu.Unlock() + e.waiterQueue.Notify(waiter.EventIn) +} + +// receiveBufferAvailableLocked calculates how many bytes are still available +// in the receive buffer. +// rcvListMu must be held when this function is called. +func (e *endpoint) receiveBufferAvailableLocked() int { + // We may use more bytes than the buffer size when the receive buffer + // shrinks. + if e.rcvBufUsed >= e.rcvBufSize { + return 0 + } + + return e.rcvBufSize - e.rcvBufUsed +} + +// receiveBufferAvailable calculates how many bytes are still available in the +// receive buffer. +func (e *endpoint) receiveBufferAvailable() int { + e.rcvListMu.Lock() + available := e.receiveBufferAvailableLocked() + e.rcvListMu.Unlock() + return available +} + +func (e *endpoint) receiveBufferSize() int { + e.rcvListMu.Lock() + size := e.rcvBufSize + e.rcvListMu.Unlock() + + return size +} + +func (e *endpoint) maxReceiveBufferSize() int { + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err != nil { + // As a fallback return the hardcoded max buffer size. + return MaxBufferSize + } + return rs.Max +} + +// rcvWndScaleForHandshake computes the receive window scale to offer to the +// peer when window scaling is enabled (true by default). If auto-tuning is +// disabled then the window scaling factor is based on the size of the +// receiveBuffer otherwise we use the max permissible receive buffer size to +// compute the scale. +func (e *endpoint) rcvWndScaleForHandshake() int { + bufSizeForScale := e.receiveBufferSize() + + e.rcvListMu.Lock() + autoTuningDisabled := e.rcvAutoParams.disabled + e.rcvListMu.Unlock() + if autoTuningDisabled { + return FindWndScale(seqnum.Size(bufSizeForScale)) + } + + return FindWndScale(seqnum.Size(e.maxReceiveBufferSize())) +} + +// updateRecentTimestamp updates the recent timestamp using the algorithm +// described in https://tools.ietf.org/html/rfc7323#section-4.3 +func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { + if e.sendTSOk && seqnum.Value(e.recentTimestamp()).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + e.setRecentTimestamp(tsVal) + } +} + +// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if +// the SYN options indicate that timestamp option was negotiated. It also +// initializes the recentTS with the value provided in synOpts.TSval. +func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { + if synOpts.TS { + e.sendTSOk = true + e.setRecentTimestamp(synOpts.TSVal) + } +} + +// timestamp returns the timestamp value to be used in the TSVal field of the +// timestamp option for outgoing TCP segments for a given endpoint. +func (e *endpoint) timestamp() uint32 { + return tcpTimeStamp(e.tsOffset) +} + +// tcpTimeStamp returns a timestamp offset by the provided offset. This is +// not inlined above as it's used when SYN cookies are in use and endpoint +// is not created at the time when the SYN cookie is sent. +func tcpTimeStamp(offset uint32) uint32 { + now := time.Now() + return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset +} + +// timeStampOffset returns a randomized timestamp offset to be used when sending +// timestamp values in a timestamp option for a TCP segment. +func timeStampOffset() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + // Initialize a random tsOffset that will be added to the recentTS + // everytime the timestamp is sent when the Timestamp option is enabled. + // + // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on + // why this is required. + // + // NOTE: This is not completely to spec as normally this should be + // initialized in a manner analogous to how sequence numbers are + // randomized per connection basis. But for now this is sufficient. + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint +// if the SYN options indicate that the SACK option was negotiated and the TCP +// stack is configured to enable TCP SACK option. +func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { + var v SACKEnabled + if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { + // Stack doesn't support SACK. So just return. + return + } + if bool(v) && synOpts.SACKPermitted { + e.sackPermitted = true + } +} + +// maxOptionSize return the maximum size of TCP options. +func (e *endpoint) maxOptionSize() (size int) { + var maxSackBlocks [header.TCPMaxSACKBlocks]header.SACKBlock + options := e.makeOptions(maxSackBlocks[:]) + size = len(options) + putOptions(options) + + return size +} + +// completeState makes a full copy of the endpoint and returns it. This is used +// before invoking the probe. The state returned may not be fully consistent if +// there are intervening syscalls when the state is being copied. +func (e *endpoint) completeState() stack.TCPEndpointState { + var s stack.TCPEndpointState + s.SegTime = time.Now() + + // Copy EndpointID. + s.ID = stack.TCPEndpointID(e.ID) + + // Copy endpoint rcv state. + e.rcvListMu.Lock() + s.RcvBufSize = e.rcvBufSize + s.RcvBufUsed = e.rcvBufUsed + s.RcvClosed = e.rcvClosed + s.RcvAutoParams.MeasureTime = e.rcvAutoParams.measureTime + s.RcvAutoParams.CopiedBytes = e.rcvAutoParams.copied + s.RcvAutoParams.PrevCopiedBytes = e.rcvAutoParams.prevCopied + s.RcvAutoParams.RTT = e.rcvAutoParams.rtt + s.RcvAutoParams.RTTMeasureSeqNumber = e.rcvAutoParams.rttMeasureSeqNumber + s.RcvAutoParams.RTTMeasureTime = e.rcvAutoParams.rttMeasureTime + s.RcvAutoParams.Disabled = e.rcvAutoParams.disabled + e.rcvListMu.Unlock() + + // Endpoint TCP Option state. + s.SendTSOk = e.sendTSOk + s.RecentTS = e.recentTimestamp() + s.TSOffset = e.tsOffset + s.SACKPermitted = e.sackPermitted + s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) + copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks]) + s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy() + + // Copy endpoint send state. + e.sndBufMu.Lock() + s.SndBufSize = e.sndBufSize + s.SndBufUsed = e.sndBufUsed + s.SndClosed = e.sndClosed + s.SndBufInQueue = e.sndBufInQueue + s.PacketTooBigCount = e.packetTooBigCount + s.SndMTU = e.sndMTU + e.sndBufMu.Unlock() + + // Copy receiver state. + s.Receiver = stack.TCPReceiverState{ + RcvNxt: e.rcv.rcvNxt, + RcvAcc: e.rcv.rcvAcc, + RcvWndScale: e.rcv.rcvWndScale, + PendingBufUsed: e.rcv.pendingBufUsed, + PendingBufSize: e.rcv.pendingBufSize, + } + + // Copy sender state. + s.Sender = stack.TCPSenderState{ + LastSendTime: e.snd.lastSendTime, + DupAckCount: e.snd.dupAckCount, + FastRecovery: stack.TCPFastRecoveryState{ + Active: e.snd.fr.active, + First: e.snd.fr.first, + Last: e.snd.fr.last, + MaxCwnd: e.snd.fr.maxCwnd, + HighRxt: e.snd.fr.highRxt, + RescueRxt: e.snd.fr.rescueRxt, + }, + SndCwnd: e.snd.sndCwnd, + Ssthresh: e.snd.sndSsthresh, + SndCAAckCount: e.snd.sndCAAckCount, + Outstanding: e.snd.outstanding, + SndWnd: e.snd.sndWnd, + SndUna: e.snd.sndUna, + SndNxt: e.snd.sndNxt, + RTTMeasureSeqNum: e.snd.rttMeasureSeqNum, + RTTMeasureTime: e.snd.rttMeasureTime, + Closed: e.snd.closed, + RTO: e.snd.rto, + MaxPayloadSize: e.snd.maxPayloadSize, + SndWndScale: e.snd.sndWndScale, + MaxSentAck: e.snd.maxSentAck, + } + e.snd.rtt.Lock() + s.Sender.SRTT = e.snd.rtt.srtt + s.Sender.SRTTInited = e.snd.rtt.srttInited + e.snd.rtt.Unlock() + + if cubic, ok := e.snd.cc.(*cubicState); ok { + s.Sender.Cubic = stack.TCPCubicState{ + WMax: cubic.wMax, + WLastMax: cubic.wLastMax, + T: cubic.t, + TimeSinceLastCongestion: time.Since(cubic.t), + C: cubic.c, + K: cubic.k, + Beta: cubic.beta, + WC: cubic.wC, + WEst: cubic.wEst, + } + } + return s +} + +func (e *endpoint) initHardwareGSO() { + gso := &stack.GSO{} + switch e.route.NetProto { + case header.IPv4ProtocolNumber: + gso.Type = stack.GSOTCPv4 + gso.L3HdrLen = header.IPv4MinimumSize + case header.IPv6ProtocolNumber: + gso.Type = stack.GSOTCPv6 + gso.L3HdrLen = header.IPv6MinimumSize + default: + panic(fmt.Sprintf("Unknown netProto: %v", e.NetProto)) + } + gso.NeedsCsum = true + gso.CsumOffset = header.TCPChecksumOffset + gso.MaxSize = e.route.GSOMaxSize() + e.gso = gso +} + +func (e *endpoint) initGSO() { + if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 { + e.initHardwareGSO() + } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 { + e.gso = &stack.GSO{ + MaxSize: e.route.GSOMaxSize(), + Type: stack.GSOSW, + NeedsCsum: false, + } + } +} + +// State implements tcpip.Endpoint.State. It exports the endpoint's protocol +// state for diagnostics. +func (e *endpoint) State() uint32 { + return uint32(e.EndpointState()) +} + +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.LockUser() + // Make a copy of the endpoint info. + ret := e.EndpointInfo + e.UnlockUser() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + +// Wait implements stack.TransportEndpoint.Wait. +func (e *endpoint) Wait() { + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + e.waiterQueue.EventRegister(&waitEntry, waiter.EventHUp) + defer e.waiterQueue.EventUnregister(&waitEntry) + for { + e.LockUser() + running := e.workerRunning + e.UnlockUser() + if !running { + break + } + <-notifyCh + } +} + +func mssForRoute(r *stack.Route) uint16 { + // TODO(b/143359391): Respect TCP Min and Max size. + return uint16(r.MTU() - header.TCPMinimumSize) +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go new file mode 100644 index 000000000..abf1ac5c9 --- /dev/null +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -0,0 +1,348 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "fmt" + "sync/atomic" + "time" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +func (e *endpoint) drainSegmentLocked() { + // Drain only up to once. + if e.drainDone != nil { + return + } + + e.drainDone = make(chan struct{}) + e.undrain = make(chan struct{}) + e.mu.Unlock() + + e.notifyProtocolGoroutine(notifyDrain) + <-e.drainDone + + e.mu.Lock() +} + +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + // Stop incoming packets. + e.segmentQueue.setLimit(0) + + e.mu.Lock() + defer e.mu.Unlock() + + epState := e.EndpointState() + switch { + case epState == StateInitial || epState == StateBound: + case epState.connected() || epState.handshake(): + if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { + if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { + panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) + } + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.mu.Unlock() + e.Close() + e.mu.Lock() + } + if !e.workerRunning { + // The endpoint must be in acceptedChan or has been just + // disconnected and closed. + break + } + fallthrough + case epState == StateListen || epState == StateConnecting: + e.drainSegmentLocked() + // Refresh epState, since drainSegmentLocked may have changed it. + epState = e.EndpointState() + if !epState.closed() { + if !e.workerRunning { + panic("endpoint has no worker running in listen, connecting, or connected state") + } + } + case epState.closed(): + for e.workerRunning { + e.mu.Unlock() + time.Sleep(100 * time.Millisecond) + e.mu.Lock() + } + if e.workerRunning { + panic(fmt.Sprintf("endpoint: %+v still has worker running in closed or error state", e.ID)) + } + default: + panic(fmt.Sprintf("endpoint in unknown state %v", e.EndpointState())) + } + + if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { + panic("endpoint still has waiters upon save") + } +} + +// saveAcceptedChan is invoked by stateify. +func (e *endpoint) saveAcceptedChan() []*endpoint { + if e.acceptedChan == nil { + return nil + } + acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan)) + for i := 0; i < len(acceptedEndpoints); i++ { + select { + case ep := <-e.acceptedChan: + acceptedEndpoints[i] = ep + default: + panic("endpoint acceptedChan buffer got consumed by background context") + } + } + for i := 0; i < len(acceptedEndpoints); i++ { + select { + case e.acceptedChan <- acceptedEndpoints[i]: + default: + panic("endpoint acceptedChan buffer got populated by background context") + } + } + return acceptedEndpoints +} + +// loadAcceptedChan is invoked by stateify. +func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { + if cap(acceptedEndpoints) > 0 { + e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints)) + for _, ep := range acceptedEndpoints { + e.acceptedChan <- ep + } + } +} + +// saveState is invoked by stateify. +func (e *endpoint) saveState() EndpointState { + return e.EndpointState() +} + +// Endpoint loading must be done in the following ordering by their state, to +// avoid dangling connecting w/o listening peer, and to avoid conflicts in port +// reservation. +var connectedLoading sync.WaitGroup +var listenLoading sync.WaitGroup +var connectingLoading sync.WaitGroup + +// Bound endpoint loading happens last. + +// loadState is invoked by stateify. +func (e *endpoint) loadState(epState EndpointState) { + // This is to ensure that the loading wait groups include all applicable + // endpoints before any asynchronous calls to the Wait() methods. + // For restore purposes we treat TimeWait like a connected endpoint. + if epState.connected() || epState == StateTimeWait { + connectedLoading.Add(1) + } + switch { + case epState == StateListen: + listenLoading.Add(1) + case epState.connecting(): + connectingLoading.Add(1) + } + // Directly update the state here rather than using e.setEndpointState + // as the endpoint is still being loaded and the stack reference is not + // yet initialized. + atomic.StoreUint32((*uint32)(&e.state), uint32(epState)) +} + +// afterLoad is invoked by stateify. +func (e *endpoint) afterLoad() { + e.origEndpointState = e.state + // Restore the endpoint to InitialState as it will be moved to + // its origEndpointState during Resume. + e.state = StateInitial + // Condition variables and mutexs are not S/R'ed so reinitialize + // acceptCond with e.acceptMu. + e.acceptCond = sync.NewCond(&e.acceptMu) + stack.StackFromEnv.RegisterRestoredEndpoint(e) +} + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s + e.segmentQueue.setLimit(MaxUnprocessedSegments) + epState := e.origEndpointState + switch epState { + case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: + var ss SendBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { + panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) + } + } + + var rs ReceiveBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if e.rcvBufSize < rs.Min || e.rcvBufSize > rs.Max { + panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, rs.Min, rs.Max)) + } + } + } + + bind := func() { + addr, _, err := e.checkV4MappedLocked(tcpip.FullAddress{Addr: e.BindAddr, Port: e.ID.LocalPort}) + if err != nil { + panic("unable to parse BindAddr: " + err.String()) + } + if ok := e.stack.ReserveTuple(e.effectiveNetProtos, ProtocolNumber, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest); !ok { + panic(fmt.Sprintf("unable to re-reserve tuple (%v, %q, %d, %+v, %d, %v)", e.effectiveNetProtos, addr.Addr, addr.Port, e.boundPortFlags, e.boundBindToDevice, e.boundDest)) + } + e.isPortReserved = true + + // Mark endpoint as bound. + e.setEndpointState(StateBound) + } + + switch { + case epState.connected(): + bind() + if len(e.connectingAddress) == 0 { + e.connectingAddress = e.ID.RemoteAddress + // This endpoint is accepted by netstack but not yet by + // the app. If the endpoint is IPv6 but the remote + // address is IPv4, we need to connect as IPv6 so that + // dual-stack mode can be properly activated. + if e.NetProto == header.IPv6ProtocolNumber && len(e.ID.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.ID.RemoteAddress + } + } + // Reset the scoreboard to reinitialize the sack information as + // we do not restore SACK information. + e.scoreboard.Reset() + if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + e.mu.Lock() + e.state = e.origEndpointState + closed := e.closed + e.mu.Unlock() + e.notifyProtocolGoroutine(notifyTickleWorker) + if epState == StateFinWait2 && closed { + // If the endpoint has been closed then make sure we notify so + // that the FIN_WAIT2 timer is started after a restore. + e.notifyProtocolGoroutine(notifyClose) + } + connectedLoading.Done() + case epState == StateListen: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + bind() + backlog := cap(e.acceptedChan) + if err := e.Listen(backlog); err != nil { + panic("endpoint listening failed: " + err.String()) + } + e.LockUser() + if e.shutdownFlags != 0 { + e.shutdownLocked(e.shutdownFlags) + } + e.UnlockUser() + listenLoading.Done() + tcpip.AsyncLoading.Done() + }() + case epState.connecting(): + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + bind() + if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.ID.RemotePort}); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + connectingLoading.Done() + tcpip.AsyncLoading.Done() + }() + case epState == StateBound: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + connectingLoading.Wait() + bind() + tcpip.AsyncLoading.Done() + }() + case epState == StateClose: + e.isPortReserved = false + e.state = StateClose + e.stack.CompleteTransportEndpointCleanup(e) + tcpip.DeleteDanglingEndpoint(e) + case epState == StateError: + e.state = StateError + e.stack.CompleteTransportEndpointCleanup(e) + tcpip.DeleteDanglingEndpoint(e) + } +} + +// saveLastError is invoked by stateify. +func (e *endpoint) saveLastError() string { + if e.lastError == nil { + return "" + } + + return e.lastError.String() +} + +// loadLastError is invoked by stateify. +func (e *endpoint) loadLastError(s string) { + if s == "" { + return + } + + e.lastError = tcpip.StringToError(s) +} + +// saveHardError is invoked by stateify. +func (e *EndpointInfo) saveHardError() string { + if e.HardError == nil { + return "" + } + + return e.HardError.String() +} + +// loadHardError is invoked by stateify. +func (e *EndpointInfo) loadHardError(s string) { + if s == "" { + return + } + + e.HardError = tcpip.StringToError(s) +} + +// saveMeasureTime is invoked by stateify. +func (r *rcvBufAutoTuneParams) saveMeasureTime() unixTime { + return unixTime{r.measureTime.Unix(), r.measureTime.UnixNano()} +} + +// loadMeasureTime is invoked by stateify. +func (r *rcvBufAutoTuneParams) loadMeasureTime(unix unixTime) { + r.measureTime = time.Unix(unix.second, unix.nano) +} + +// saveRttMeasureTime is invoked by stateify. +func (r *rcvBufAutoTuneParams) saveRttMeasureTime() unixTime { + return unixTime{r.rttMeasureTime.Unix(), r.rttMeasureTime.UnixNano()} +} + +// loadRttMeasureTime is invoked by stateify. +func (r *rcvBufAutoTuneParams) loadRttMeasureTime(unix unixTime) { + r.rttMeasureTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go new file mode 100644 index 000000000..070b634b4 --- /dev/null +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -0,0 +1,169 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// Forwarder is a connection request forwarder, which allows clients to decide +// what to do with a connection request, for example: ignore it, send a RST, or +// attempt to complete the 3-way handshake. +// +// The canonical way of using it is to pass the Forwarder.HandlePacket function +// to stack.SetTransportProtocolHandler. +type Forwarder struct { + maxInFlight int + handler func(*ForwarderRequest) + + mu sync.Mutex + inFlight map[stack.TransportEndpointID]struct{} + listen *listenContext +} + +// NewForwarder allocates and initializes a new forwarder with the given +// maximum number of in-flight connection attempts. Once the maximum is reached +// new incoming connection requests will be ignored. +// +// If rcvWnd is set to zero, the default buffer size is used instead. +func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder { + if rcvWnd == 0 { + rcvWnd = DefaultReceiveBufferSize + } + return &Forwarder{ + maxInFlight: maxInFlight, + handler: handler, + inFlight: make(map[stack.TransportEndpointID]struct{}), + listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0), + } +} + +// HandlePacket handles a packet if it is of interest to the forwarder (i.e., if +// it's a SYN packet), returning true if it's the case. Otherwise the packet +// is not handled and false is returned. +// +// This function is expected to be passed as an argument to the +// stack.SetTransportProtocolHandler function. +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + s := newSegment(r, id, pkt) + defer s.decRef() + + // We only care about well-formed SYN packets. + if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn { + return false + } + + opts := parseSynSegmentOptions(s) + + f.mu.Lock() + defer f.mu.Unlock() + + // We have an inflight request for this id, ignore this one for now. + if _, ok := f.inFlight[id]; ok { + return true + } + + // Ignore the segment if we're beyond the limit. + if len(f.inFlight) >= f.maxInFlight { + return true + } + + // Launch a new goroutine to handle the request. + f.inFlight[id] = struct{}{} + s.incRef() + go f.handler(&ForwarderRequest{ // S/R-SAFE: not used by Sentry. + forwarder: f, + segment: s, + synOptions: opts, + }) + + return true +} + +// ForwarderRequest represents a connection request received by the forwarder +// and passed to the client. Clients must eventually call Complete() on it, and +// may optionally create an endpoint to represent it via CreateEndpoint. +type ForwarderRequest struct { + mu sync.Mutex + forwarder *Forwarder + segment *segment + synOptions header.TCPSynOptions +} + +// ID returns the 4-tuple (src address, src port, dst address, dst port) that +// represents the connection request. +func (r *ForwarderRequest) ID() stack.TransportEndpointID { + return r.segment.id +} + +// Complete completes the request, and optionally sends a RST segment back to the +// sender. +func (r *ForwarderRequest) Complete(sendReset bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.segment == nil { + panic("Completing already completed forwarder request") + } + + // Remove request from the forwarder. + r.forwarder.mu.Lock() + delete(r.forwarder.inFlight, r.segment.id) + r.forwarder.mu.Unlock() + + // If the caller requested, send a reset. + if sendReset { + replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL()) + } + + // Release all resources. + r.segment.decRef() + r.segment = nil + r.forwarder = nil +} + +// CreateEndpoint creates a TCP endpoint for the connection request, performing +// the 3-way handshake in the process. +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.segment == nil { + return nil, tcpip.ErrInvalidEndpointState + } + + f := r.forwarder + ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{ + MSS: r.synOptions.MSS, + WS: r.synOptions.WS, + TS: r.synOptions.TS, + TSVal: r.synOptions.TSVal, + TSEcr: r.synOptions.TSEcr, + SACKPermitted: r.synOptions.SACKPermitted, + }, queue, nil) + if err != nil { + return nil, err + } + + // Start the protocol goroutine. + ep.startAcceptedLoop() + + return ep, nil +} diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go new file mode 100644 index 000000000..b34e47bbd --- /dev/null +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -0,0 +1,541 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tcp contains the implementation of the TCP transport protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing tcp.NewProtocol() as one of the +// transport protocols when calling stack.New(). Then endpoints can be created +// by passing tcp.ProtocolNumber as the transport protocol number when calling +// Stack.NewEndpoint(). +package tcp + +import ( + "fmt" + "runtime" + "strings" + "time" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/raw" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // ProtocolNumber is the tcp protocol number. + ProtocolNumber = header.TCPProtocolNumber + + // MinBufferSize is the smallest size of a receive or send buffer. + MinBufferSize = 4 << 10 // 4096 bytes. + + // DefaultSendBufferSize is the default size of the send buffer for + // an endpoint. + DefaultSendBufferSize = 1 << 20 // 1MB + + // DefaultReceiveBufferSize is the default size of the receive buffer + // for an endpoint. + DefaultReceiveBufferSize = 1 << 20 // 1MB + + // MaxBufferSize is the largest size a receive/send buffer can grow to. + MaxBufferSize = 4 << 20 // 4MB + + // MaxUnprocessedSegments is the maximum number of unprocessed segments + // that can be queued for a given endpoint. + MaxUnprocessedSegments = 300 + + // DefaultTCPLingerTimeout is the amount of time that sockets linger in + // FIN_WAIT_2 state before being marked closed. + DefaultTCPLingerTimeout = 60 * time.Second + + // DefaultTCPTimeWaitTimeout is the amount of time that sockets linger + // in TIME_WAIT state before being marked closed. + DefaultTCPTimeWaitTimeout = 60 * time.Second + + // DefaultSynRetries is the default value for the number of SYN retransmits + // before a connect is aborted. + DefaultSynRetries = 6 +) + +const ( + ccReno = "reno" + ccCubic = "cubic" +) + +// SACKEnabled is used by stack.(*Stack).TransportProtocolOption to +// enable/disable SACK support in TCP. See: https://tools.ietf.org/html/rfc2018. +type SACKEnabled bool + +// DelayEnabled is used by stack.(Stack*).TransportProtocolOption to +// enable/disable Nagle's algorithm in TCP. +type DelayEnabled bool + +// SendBufferSizeOption is used by stack.(Stack*).TransportProtocolOption +// to get/set the default, min and max TCP send buffer sizes. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption is used by +// stack.(Stack*).TransportProtocolOption to get/set the default, min and max +// TCP receive buffer sizes. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + +// syncRcvdCounter tracks the number of endpoints in the SYN-RCVD state. The +// value is protected by a mutex so that we can increment only when it's +// guaranteed not to go above a threshold. +type synRcvdCounter struct { + sync.Mutex + value uint64 + pending sync.WaitGroup + threshold uint64 +} + +// inc tries to increment the global number of endpoints in SYN-RCVD state. It +// succeeds if the increment doesn't make the count go beyond the threshold, and +// fails otherwise. +func (s *synRcvdCounter) inc() bool { + s.Lock() + defer s.Unlock() + if s.value >= s.threshold { + return false + } + + s.pending.Add(1) + s.value++ + + return true +} + +// dec atomically decrements the global number of endpoints in SYN-RCVD +// state. It must only be called if a previous call to inc succeeded. +func (s *synRcvdCounter) dec() { + s.Lock() + defer s.Unlock() + s.value-- + s.pending.Done() +} + +// synCookiesInUse returns true if the synRcvdCount is greater than +// SynRcvdCountThreshold. +func (s *synRcvdCounter) synCookiesInUse() bool { + s.Lock() + defer s.Unlock() + return s.value >= s.threshold +} + +// SetThreshold sets synRcvdCounter.Threshold to ths new threshold. +func (s *synRcvdCounter) SetThreshold(threshold uint64) { + s.Lock() + defer s.Unlock() + s.threshold = threshold +} + +// Threshold returns the current value of synRcvdCounter.Threhsold. +func (s *synRcvdCounter) Threshold() uint64 { + s.Lock() + defer s.Unlock() + return s.threshold +} + +type protocol struct { + mu sync.RWMutex + sackEnabled bool + delayEnabled bool + sendBufferSize SendBufferSizeOption + recvBufferSize ReceiveBufferSizeOption + congestionControl string + availableCongestionControl []string + moderateReceiveBuffer bool + tcpLingerTimeout time.Duration + tcpTimeWaitTimeout time.Duration + minRTO time.Duration + maxRTO time.Duration + maxRetries uint32 + synRcvdCount synRcvdCounter + synRetries uint8 + dispatcher dispatcher +} + +// Number returns the tcp protocol number. +func (*protocol) Number() tcpip.TransportProtocolNumber { + return ProtocolNumber +} + +// NewEndpoint creates a new tcp endpoint. +func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, waiterQueue), nil +} + +// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently +// unsupported. It implements stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue) +} + +// MinimumPacketSize returns the minimum valid tcp packet size. +func (*protocol) MinimumPacketSize() int { + return header.TCPMinimumSize +} + +// ParsePorts returns the source and destination ports stored in the given tcp +// packet. +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + h := header.TCP(v) + return h.SourcePort(), h.DestinationPort(), nil +} + +// QueuePacket queues packets targeted at an endpoint after hashing the packet +// to a specific processing queue. Each queue is serviced by its own processor +// goroutine which is responsible for dequeuing and doing full TCP dispatch of +// the packet. +func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + p.dispatcher.queuePacket(r, ep, id, pkt) +} + +// HandleUnknownDestinationPacket handles packets targeted at this protocol but +// that don't match any existing endpoint. +// +// RFC 793, page 36, states that "If the connection does not exist (CLOSED) then +// a reset is sent in response to any incoming segment except another reset. In +// particular, SYNs addressed to a non-existent connection are rejected by this +// means." +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + s := newSegment(r, id, pkt) + defer s.decRef() + + if !s.parse() || !s.csumValid { + return false + } + + // There's nothing to do if this is already a reset packet. + if s.flagIsSet(header.TCPFlagRst) { + return true + } + + replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) + return true +} + +// replyWithReset replies to the given segment with a reset segment. +func replyWithReset(s *segment, tos, ttl uint8) { + // Get the seqnum from the packet if the ack flag is set. + seq := seqnum.Value(0) + ack := seqnum.Value(0) + flags := byte(header.TCPFlagRst) + // As per RFC 793 page 35 (Reset Generation) + // 1. If the connection does not exist (CLOSED) then a reset is sent + // in response to any incoming segment except another reset. In + // particular, SYNs addressed to a non-existent connection are rejected + // by this means. + + // If the incoming segment has an ACK field, the reset takes its + // sequence number from the ACK field of the segment, otherwise the + // reset has sequence number zero and the ACK field is set to the sum + // of the sequence number and segment length of the incoming segment. + // The connection remains in the CLOSED state. + if s.flagIsSet(header.TCPFlagAck) { + seq = s.ackNumber + } else { + flags |= header.TCPFlagAck + ack = s.sequenceNumber.Add(s.logicalLen()) + } + sendTCP(&s.route, tcpFields{ + id: s.id, + ttl: ttl, + tos: tos, + flags: flags, + seq: seq, + ack: ack, + rcvWnd: 0, + }, buffer.VectorisedView{}, nil /* gso */, nil /* PacketOwner */) +} + +// SetOption implements stack.TransportProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case SACKEnabled: + p.mu.Lock() + p.sackEnabled = bool(v) + p.mu.Unlock() + return nil + + case DelayEnabled: + p.mu.Lock() + p.delayEnabled = bool(v) + p.mu.Unlock() + return nil + + case SendBufferSizeOption: + if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.sendBufferSize = v + p.mu.Unlock() + return nil + + case ReceiveBufferSizeOption: + if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.recvBufferSize = v + p.mu.Unlock() + return nil + + case tcpip.CongestionControlOption: + for _, c := range p.availableCongestionControl { + if string(v) == c { + p.mu.Lock() + p.congestionControl = string(v) + p.mu.Unlock() + return nil + } + } + // linux returns ENOENT when an invalid congestion control + // is specified. + return tcpip.ErrNoSuchFile + + case tcpip.ModerateReceiveBufferOption: + p.mu.Lock() + p.moderateReceiveBuffer = bool(v) + p.mu.Unlock() + return nil + + case tcpip.TCPLingerTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpLingerTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPTimeWaitTimeoutOption: + if v < 0 { + v = 0 + } + p.mu.Lock() + p.tcpTimeWaitTimeout = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPMinRTOOption: + if v < 0 { + v = tcpip.TCPMinRTOOption(MinRTO) + } + p.mu.Lock() + p.minRTO = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPMaxRTOOption: + if v < 0 { + v = tcpip.TCPMaxRTOOption(MaxRTO) + } + p.mu.Lock() + p.maxRTO = time.Duration(v) + p.mu.Unlock() + return nil + + case tcpip.TCPMaxRetriesOption: + p.mu.Lock() + p.maxRetries = uint32(v) + p.mu.Unlock() + return nil + + case tcpip.TCPSynRcvdCountThresholdOption: + p.mu.Lock() + p.synRcvdCount.SetThreshold(uint64(v)) + p.mu.Unlock() + return nil + + case tcpip.TCPSynRetriesOption: + if v < 1 || v > 255 { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.synRetries = uint8(v) + p.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// Option implements stack.TransportProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *SACKEnabled: + p.mu.RLock() + *v = SACKEnabled(p.sackEnabled) + p.mu.RUnlock() + return nil + + case *DelayEnabled: + p.mu.RLock() + *v = DelayEnabled(p.delayEnabled) + p.mu.RUnlock() + return nil + + case *SendBufferSizeOption: + p.mu.RLock() + *v = p.sendBufferSize + p.mu.RUnlock() + return nil + + case *ReceiveBufferSizeOption: + p.mu.RLock() + *v = p.recvBufferSize + p.mu.RUnlock() + return nil + + case *tcpip.CongestionControlOption: + p.mu.RLock() + *v = tcpip.CongestionControlOption(p.congestionControl) + p.mu.RUnlock() + return nil + + case *tcpip.AvailableCongestionControlOption: + p.mu.RLock() + *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + p.mu.RUnlock() + return nil + + case *tcpip.ModerateReceiveBufferOption: + p.mu.RLock() + *v = tcpip.ModerateReceiveBufferOption(p.moderateReceiveBuffer) + p.mu.RUnlock() + return nil + + case *tcpip.TCPLingerTimeoutOption: + p.mu.RLock() + *v = tcpip.TCPLingerTimeoutOption(p.tcpLingerTimeout) + p.mu.RUnlock() + return nil + + case *tcpip.TCPTimeWaitTimeoutOption: + p.mu.RLock() + *v = tcpip.TCPTimeWaitTimeoutOption(p.tcpTimeWaitTimeout) + p.mu.RUnlock() + return nil + + case *tcpip.TCPMinRTOOption: + p.mu.RLock() + *v = tcpip.TCPMinRTOOption(p.minRTO) + p.mu.RUnlock() + return nil + + case *tcpip.TCPMaxRTOOption: + p.mu.RLock() + *v = tcpip.TCPMaxRTOOption(p.maxRTO) + p.mu.RUnlock() + return nil + + case *tcpip.TCPMaxRetriesOption: + p.mu.RLock() + *v = tcpip.TCPMaxRetriesOption(p.maxRetries) + p.mu.RUnlock() + return nil + + case *tcpip.TCPSynRcvdCountThresholdOption: + p.mu.RLock() + *v = tcpip.TCPSynRcvdCountThresholdOption(p.synRcvdCount.Threshold()) + p.mu.RUnlock() + return nil + + case *tcpip.TCPSynRetriesOption: + p.mu.RLock() + *v = tcpip.TCPSynRetriesOption(p.synRetries) + p.mu.RUnlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// Close implements stack.TransportProtocol.Close. +func (p *protocol) Close() { + p.dispatcher.close() +} + +// Wait implements stack.TransportProtocol.Wait. +func (p *protocol) Wait() { + p.dispatcher.wait() +} + +// SynRcvdCounter returns a reference to the synRcvdCount for this protocol +// instance. +func (p *protocol) SynRcvdCounter() *synRcvdCounter { + return &p.synRcvdCount +} + +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + hdr, ok := pkt.Data.PullUp(header.TCPMinimumSize) + if !ok { + return false + } + + // If the header has options, pull those up as well. + if offset := int(header.TCP(hdr).DataOffset()); offset > header.TCPMinimumSize && offset <= pkt.Data.Size() { + hdr, ok = pkt.Data.PullUp(offset) + if !ok { + panic(fmt.Sprintf("There should be at least %d bytes in pkt.Data.", offset)) + } + } + + pkt.TransportHeader = hdr + pkt.Data.TrimFront(len(hdr)) + return true +} + +// NewProtocol returns a TCP transport protocol. +func NewProtocol() stack.TransportProtocol { + p := protocol{ + sendBufferSize: SendBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultSendBufferSize, + Max: MaxBufferSize, + }, + recvBufferSize: ReceiveBufferSizeOption{ + Min: MinBufferSize, + Default: DefaultReceiveBufferSize, + Max: MaxBufferSize, + }, + congestionControl: ccReno, + availableCongestionControl: []string{ccReno, ccCubic}, + tcpLingerTimeout: DefaultTCPLingerTimeout, + tcpTimeWaitTimeout: DefaultTCPTimeWaitTimeout, + synRcvdCount: synRcvdCounter{threshold: SynRcvdCountThreshold}, + synRetries: DefaultSynRetries, + minRTO: MinRTO, + maxRTO: MaxRTO, + maxRetries: MaxRetries, + } + p.dispatcher.init(runtime.GOMAXPROCS(0)) + return &p +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go new file mode 100644 index 000000000..dd89a292a --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -0,0 +1,475 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "container/heap" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +// receiver holds the state necessary to receive TCP segments and turn them +// into a stream of bytes. +// +// +stateify savable +type receiver struct { + ep *endpoint + + rcvNxt seqnum.Value + + // rcvAcc is one beyond the last acceptable sequence number. That is, + // the "largest" sequence value that the receiver has announced to the + // its peer that it's willing to accept. This may be different than + // rcvNxt + rcvWnd if the receive window is reduced; in that case we + // have to reduce the window as we receive more data instead of + // shrinking it. + rcvAcc seqnum.Value + + // rcvWnd is the non-scaled receive window last advertised to the peer. + rcvWnd seqnum.Size + + rcvWndScale uint8 + + closed bool + + pendingRcvdSegments segmentHeap + pendingBufUsed seqnum.Size + pendingBufSize seqnum.Size + + // Time when the last ack was received. + lastRcvdAckTime time.Time `state:".(unixTime)"` +} + +func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8, pendingBufSize seqnum.Size) *receiver { + return &receiver{ + ep: ep, + rcvNxt: irs + 1, + rcvAcc: irs.Add(rcvWnd + 1), + rcvWnd: rcvWnd, + rcvWndScale: rcvWndScale, + pendingBufSize: pendingBufSize, + lastRcvdAckTime: time.Now(), + } +} + +// acceptable checks if the segment sequence number range is acceptable +// according to the table on page 26 of RFC 793. +func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { + // r.rcvWnd could be much larger than the window size we advertised in our + // outgoing packets, we should use what we have advertised for acceptability + // test. + scaledWindowSize := r.rcvWnd >> r.rcvWndScale + if scaledWindowSize > 0xffff { + // This is what we actually put in the Window field. + scaledWindowSize = 0xffff + } + advertisedWindowSize := scaledWindowSize << r.rcvWndScale + return header.Acceptable(segSeq, segLen, r.rcvNxt, r.rcvNxt.Add(advertisedWindowSize)) +} + +// getSendParams returns the parameters needed by the sender when building +// segments to send. +func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { + // Calculate the window size based on the available buffer space. + receiveBufferAvailable := r.ep.receiveBufferAvailable() + acc := r.rcvNxt.Add(seqnum.Size(receiveBufferAvailable)) + if r.rcvAcc.LessThan(acc) { + r.rcvAcc = acc + } + // Stash away the non-scaled receive window as we use it for measuring + // receiver's estimated RTT. + r.rcvWnd = r.rcvNxt.Size(r.rcvAcc) + return r.rcvNxt, r.rcvWnd >> r.rcvWndScale +} + +// nonZeroWindow is called when the receive window grows from zero to nonzero; +// in such cases we may need to send an ack to indicate to our peer that it can +// resume sending data. +func (r *receiver) nonZeroWindow() { + // Immediately send an ack. + r.ep.snd.sendAck() +} + +// consumeSegment attempts to consume a segment that was received by r. The +// segment may have just been received or may have been received earlier but +// wasn't ready to be consumed then. +// +// Returns true if the segment was consumed, false if it cannot be consumed +// yet because of a missing segment. +func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum.Size) bool { + if segLen > 0 { + // If the segment doesn't include the seqnum we're expecting to + // consume now, we're missing a segment. We cannot proceed until + // we receive that segment though. + if !r.rcvNxt.InWindow(segSeq, segLen) { + return false + } + + // Trim segment to eliminate already acknowledged data. + if segSeq.LessThan(r.rcvNxt) { + diff := segSeq.Size(r.rcvNxt) + segLen -= diff + segSeq.UpdateForward(diff) + s.sequenceNumber.UpdateForward(diff) + s.data.TrimFront(int(diff)) + } + + // Move segment to ready-to-deliver list. Wakeup any waiters. + r.ep.readyToRead(s) + + } else if segSeq != r.rcvNxt { + return false + } + + // Update the segment that we're expecting to consume. + r.rcvNxt = segSeq.Add(segLen) + + // In cases of a misbehaving sender which could send more than the + // advertised window, we could end up in a situation where we get a + // segment that exceeds the window advertised. Instead of partially + // accepting the segment and discarding bytes beyond the advertised + // window, we accept the whole segment and make sure r.rcvAcc is moved + // forward to match r.rcvNxt to indicate that the window is now closed. + // + // In absence of this check the r.acceptable() check fails and accepts + // segments that should be dropped because rcvWnd is calculated as + // the size of the interval (rcvNxt, rcvAcc] which becomes extremely + // large if rcvAcc is ever less than rcvNxt. + if r.rcvAcc.LessThan(r.rcvNxt) { + r.rcvAcc = r.rcvNxt + } + + // Trim SACK Blocks to remove any SACK information that covers + // sequence numbers that have been consumed. + TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + + // Handle FIN or FIN-ACK. + if s.flagIsSet(header.TCPFlagFin) { + r.rcvNxt++ + + // Send ACK immediately. + r.ep.snd.sendAck() + + // Tell any readers that no more data will come. + r.closed = true + r.ep.readyToRead(nil) + + // We just received a FIN, our next state depends on whether we sent a + // FIN already or not. + switch r.ep.EndpointState() { + case StateEstablished: + r.ep.setEndpointState(StateCloseWait) + case StateFinWait1: + if s.flagIsSet(header.TCPFlagAck) { + // FIN-ACK, transition to TIME-WAIT. + r.ep.setEndpointState(StateTimeWait) + } else { + // Simultaneous close, expecting a final ACK. + r.ep.setEndpointState(StateClosing) + } + case StateFinWait2: + r.ep.setEndpointState(StateTimeWait) + } + + // Flush out any pending segments, except the very first one if + // it happens to be the one we're handling now because the + // caller is using it. + first := 0 + if len(r.pendingRcvdSegments) != 0 && r.pendingRcvdSegments[0] == s { + first = 1 + } + + for i := first; i < len(r.pendingRcvdSegments); i++ { + r.pendingRcvdSegments[i].decRef() + // Note that slice truncation does not allow garbage collection of + // truncated items, thus truncated items must be set to nil to avoid + // memory leaks. + r.pendingRcvdSegments[i] = nil + } + r.pendingRcvdSegments = r.pendingRcvdSegments[:first] + + return true + } + + // Handle ACK (not FIN-ACK, which we handled above) during one of the + // shutdown states. + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == r.ep.snd.sndNxt { + switch r.ep.EndpointState() { + case StateFinWait1: + r.ep.setEndpointState(StateFinWait2) + // Notify protocol goroutine that we have received an + // ACK to our FIN so that it can start the FIN_WAIT2 + // timer to abort connection if the other side does + // not close within 2MSL. + r.ep.notifyProtocolGoroutine(notifyClose) + case StateClosing: + r.ep.setEndpointState(StateTimeWait) + case StateLastAck: + r.ep.transitionToStateCloseLocked() + } + } + + return true +} + +// updateRTT updates the receiver RTT measurement based on the sequence number +// of the received segment. +func (r *receiver) updateRTT() { + // From: https://public.lanl.gov/radiant/pubs/drs/sc2001-poster.pdf + // + // A system that is only transmitting acknowledgements can still + // estimate the round-trip time by observing the time between when a byte + // is first acknowledged and the receipt of data that is at least one + // window beyond the sequence number that was acknowledged. + r.ep.rcvListMu.Lock() + if r.ep.rcvAutoParams.rttMeasureTime.IsZero() { + // New measurement. + r.ep.rcvAutoParams.rttMeasureTime = time.Now() + r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) + r.ep.rcvListMu.Unlock() + return + } + if r.rcvNxt.LessThan(r.ep.rcvAutoParams.rttMeasureSeqNumber) { + r.ep.rcvListMu.Unlock() + return + } + rtt := time.Since(r.ep.rcvAutoParams.rttMeasureTime) + // We only store the minimum observed RTT here as this is only used in + // absence of a SRTT available from either timestamps or a sender + // measurement of RTT. + if r.ep.rcvAutoParams.rtt == 0 || rtt < r.ep.rcvAutoParams.rtt { + r.ep.rcvAutoParams.rtt = rtt + } + r.ep.rcvAutoParams.rttMeasureTime = time.Now() + r.ep.rcvAutoParams.rttMeasureSeqNumber = r.rcvNxt.Add(r.rcvWnd) + r.ep.rcvListMu.Unlock() +} + +func (r *receiver) handleRcvdSegmentClosing(s *segment, state EndpointState, closed bool) (drop bool, err *tcpip.Error) { + r.ep.rcvListMu.Lock() + rcvClosed := r.ep.rcvClosed || r.closed + r.ep.rcvListMu.Unlock() + + // If we are in one of the shutdown states then we need to do + // additional checks before we try and process the segment. + switch state { + case StateCloseWait: + // If the ACK acks something not yet sent then we send an ACK. + if r.ep.snd.sndNxt.LessThan(s.ackNumber) { + r.ep.snd.sendAck() + return true, nil + } + fallthrough + case StateClosing, StateLastAck: + if !s.sequenceNumber.LessThanEq(r.rcvNxt) { + // Just drop the segment as we have + // already received a FIN and this + // segment is after the sequence number + // for the FIN. + return true, nil + } + fallthrough + case StateFinWait1: + fallthrough + case StateFinWait2: + // If we are closed for reads (either due to an + // incoming FIN or the user calling shutdown(.., + // SHUT_RD) then any data past the rcvNxt should + // trigger a RST. + endDataSeq := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if state != StateCloseWait && rcvClosed && r.rcvNxt.LessThan(endDataSeq) { + return true, tcpip.ErrConnectionAborted + } + if state == StateFinWait1 { + break + } + + // If it's a retransmission of an old data segment + // or a pure ACK then allow it. + if s.sequenceNumber.Add(s.logicalLen()).LessThanEq(r.rcvNxt) || + s.logicalLen() == 0 { + break + } + + // In FIN-WAIT2 if the socket is fully + // closed(not owned by application on our end + // then the only acceptable segment is a + // FIN. Since FIN can technically also carry + // data we verify that the segment carrying a + // FIN ends at exactly e.rcvNxt+1. + // + // From RFC793 page 25. + // + // For sequence number purposes, the SYN is + // considered to occur before the first actual + // data octet of the segment in which it occurs, + // while the FIN is considered to occur after + // the last actual data octet in a segment in + // which it occurs. + if closed && (!s.flagIsSet(header.TCPFlagFin) || s.sequenceNumber.Add(s.logicalLen()) != r.rcvNxt+1) { + return true, tcpip.ErrConnectionAborted + } + } + + // We don't care about receive processing anymore if the receive side + // is closed. + // + // NOTE: We still want to permit a FIN as it's possible only our + // end has closed and the peer is yet to send a FIN. Hence we + // compare only the payload. + segEnd := s.sequenceNumber.Add(seqnum.Size(s.data.Size())) + if rcvClosed && !segEnd.LessThanEq(r.rcvNxt) { + return true, nil + } + return false, nil +} + +// handleRcvdSegment handles TCP segments directed at the connection managed by +// r as they arrive. It is called by the protocol main loop. +func (r *receiver) handleRcvdSegment(s *segment) (drop bool, err *tcpip.Error) { + state := r.ep.EndpointState() + closed := r.ep.closed + + segLen := seqnum.Size(s.data.Size()) + segSeq := s.sequenceNumber + + // If the sequence number range is outside the acceptable range, just + // send an ACK and stop further processing of the segment. + // This is according to RFC 793, page 68. + if !r.acceptable(segSeq, segLen) { + r.ep.snd.sendAck() + return true, nil + } + + if state != StateEstablished { + drop, err := r.handleRcvdSegmentClosing(s, state, closed) + if drop || err != nil { + return drop, err + } + } + + // Store the time of the last ack. + r.lastRcvdAckTime = time.Now() + + // Defer segment processing if it can't be consumed now. + if !r.consumeSegment(s, segSeq, segLen) { + if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { + // We only store the segment if it's within our buffer + // size limit. + if r.pendingBufUsed < r.pendingBufSize { + r.pendingBufUsed += s.logicalLen() + s.incRef() + heap.Push(&r.pendingRcvdSegments, s) + UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) + } + + // Immediately send an ack so that the peer knows it may + // have to retransmit. + r.ep.snd.sendAck() + } + return false, nil + } + + // Since we consumed a segment update the receiver's RTT estimate + // if required. + if segLen > 0 { + r.updateRTT() + } + + // By consuming the current segment, we may have filled a gap in the + // sequence number domain that allows pending segments to be consumed + // now. So try to do it. + for !r.closed && r.pendingRcvdSegments.Len() > 0 { + s := r.pendingRcvdSegments[0] + segLen := seqnum.Size(s.data.Size()) + segSeq := s.sequenceNumber + + // Skip segment altogether if it has already been acknowledged. + if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) && + !r.consumeSegment(s, segSeq, segLen) { + break + } + + heap.Pop(&r.pendingRcvdSegments) + r.pendingBufUsed -= s.logicalLen() + s.decRef() + } + return false, nil +} + +// handleTimeWaitSegment handles inbound segments received when the endpoint +// has entered the TIME_WAIT state. +func (r *receiver) handleTimeWaitSegment(s *segment) (resetTimeWait bool, newSyn bool) { + segSeq := s.sequenceNumber + segLen := seqnum.Size(s.data.Size()) + + // Just silently drop any RST packets in TIME_WAIT. We do not support + // TIME_WAIT assasination as a result we confirm w/ fix 1 as described + // in https://tools.ietf.org/html/rfc1337#section-3. + if s.flagIsSet(header.TCPFlagRst) { + return false, false + } + + // If it's a SYN and the sequence number is higher than any seen before + // for this connection then try and redirect it to a listening endpoint + // if available. + // + // RFC 1122: + // "When a connection is [...] on TIME-WAIT state [...] + // [a TCP] MAY accept a new SYN from the remote TCP to + // reopen the connection directly, if it: + + // (1) assigns its initial sequence number for the new + // connection to be larger than the largest sequence + // number it used on the previous connection incarnation, + // and + + // (2) returns to TIME-WAIT state if the SYN turns out + // to be an old duplicate". + if s.flagIsSet(header.TCPFlagSyn) && r.rcvNxt.LessThan(segSeq) { + + return false, true + } + + // Drop the segment if it does not contain an ACK. + if !s.flagIsSet(header.TCPFlagAck) { + return false, false + } + + // Update Timestamp if required. See RFC7323, section-4.3. + if r.ep.sendTSOk && s.parsedOptions.TS { + r.ep.updateRecentTimestamp(s.parsedOptions.TSVal, r.ep.snd.maxSentAck, segSeq) + } + + if segSeq.Add(1) == r.rcvNxt && s.flagIsSet(header.TCPFlagFin) { + // If it's a FIN-ACK then resetTimeWait and send an ACK, as it + // indicates our final ACK could have been lost. + r.ep.snd.sendAck() + return true, false + } + + // If the sequence number range is outside the acceptable range or + // carries data then just send an ACK. This is according to RFC 793, + // page 37. + // + // NOTE: In TIME_WAIT the only acceptable sequence number is rcvNxt. + if segSeq != r.rcvNxt || segLen != 0 { + r.ep.snd.sendAck() + } + return false, false +} diff --git a/pkg/tcpip/transport/tcp/rcv_state.go b/pkg/tcpip/transport/tcp/rcv_state.go new file mode 100644 index 000000000..2bf21a2e7 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv_state.go @@ -0,0 +1,29 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveLastRcvdAckTime is invoked by stateify. +func (r *receiver) saveLastRcvdAckTime() unixTime { + return unixTime{r.lastRcvdAckTime.Unix(), r.lastRcvdAckTime.UnixNano()} +} + +// loadLastRcvdAckTime is invoked by stateify. +func (r *receiver) loadLastRcvdAckTime(unix unixTime) { + r.lastRcvdAckTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/rcv_test.go b/pkg/tcpip/transport/tcp/rcv_test.go new file mode 100644 index 000000000..8a026ec46 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv_test.go @@ -0,0 +1,74 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package rcv_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +func TestAcceptable(t *testing.T) { + for _, tt := range []struct { + segSeq seqnum.Value + segLen seqnum.Size + rcvNxt, rcvAcc seqnum.Value + want bool + }{ + // The segment is smaller than the window. + {105, 2, 100, 104, false}, + {105, 2, 101, 105, true}, + {105, 2, 102, 106, true}, + {105, 2, 103, 107, true}, + {105, 2, 104, 108, true}, + {105, 2, 105, 109, true}, + {105, 2, 106, 110, true}, + {105, 2, 107, 111, false}, + + // The segment is larger than the window. + {105, 4, 103, 105, true}, + {105, 4, 104, 106, true}, + {105, 4, 105, 107, true}, + {105, 4, 106, 108, true}, + {105, 4, 107, 109, true}, + {105, 4, 108, 110, true}, + {105, 4, 109, 111, false}, + {105, 4, 110, 112, false}, + + // The segment has no width. + {105, 0, 100, 102, false}, + {105, 0, 101, 103, false}, + {105, 0, 102, 104, false}, + {105, 0, 103, 105, true}, + {105, 0, 104, 106, true}, + {105, 0, 105, 107, true}, + {105, 0, 106, 108, false}, + {105, 0, 107, 109, false}, + + // The receive window has no width. + {105, 2, 103, 103, false}, + {105, 2, 104, 104, false}, + {105, 2, 105, 105, false}, + {105, 2, 106, 106, false}, + {105, 2, 107, 107, false}, + {105, 2, 108, 108, false}, + {105, 2, 109, 109, false}, + } { + if got := header.Acceptable(tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc); got != tt.want { + t.Errorf("header.Acceptable(%d, %d, %d, %d) = %t, want %t", tt.segSeq, tt.segLen, tt.rcvNxt, tt.rcvAcc, got, tt.want) + } + } +} diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go new file mode 100644 index 000000000..f83ebc717 --- /dev/null +++ b/pkg/tcpip/transport/tcp/reno.go @@ -0,0 +1,103 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +// renoState stores the variables related to TCP New Reno congestion +// control algorithm. +// +// +stateify savable +type renoState struct { + s *sender +} + +// newRenoCC initializes the state for the NewReno congestion control algorithm. +func newRenoCC(s *sender) *renoState { + return &renoState{s: s} +} + +// updateSlowStart will update the congestion window as per the slow-start +// algorithm used by NewReno. If after adjusting the congestion window +// we cross the SSthreshold then it will return the number of packets that +// must be consumed in congestion avoidance mode. +func (r *renoState) updateSlowStart(packetsAcked int) int { + // Don't let the congestion window cross into the congestion + // avoidance range. + newcwnd := r.s.sndCwnd + packetsAcked + if newcwnd >= r.s.sndSsthresh { + newcwnd = r.s.sndSsthresh + r.s.sndCAAckCount = 0 + } + + packetsAcked -= newcwnd - r.s.sndCwnd + r.s.sndCwnd = newcwnd + return packetsAcked +} + +// updateCongestionAvoidance will update congestion window in congestion +// avoidance mode as described in RFC5681 section 3.1 +func (r *renoState) updateCongestionAvoidance(packetsAcked int) { + // Consume the packets in congestion avoidance mode. + r.s.sndCAAckCount += packetsAcked + if r.s.sndCAAckCount >= r.s.sndCwnd { + r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd + r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd + } +} + +// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681, +// page 6, eq. 4. It is called when we detect congestion in the network. +func (r *renoState) reduceSlowStartThreshold() { + r.s.sndSsthresh = r.s.outstanding / 2 + if r.s.sndSsthresh < 2 { + r.s.sndSsthresh = 2 + } + +} + +// Update updates the congestion state based on the number of packets that +// were acknowledged. +// Update implements congestionControl.Update. +func (r *renoState) Update(packetsAcked int) { + if r.s.sndCwnd < r.s.sndSsthresh { + packetsAcked = r.updateSlowStart(packetsAcked) + if packetsAcked == 0 { + return + } + } + r.updateCongestionAvoidance(packetsAcked) +} + +// HandleNDupAcks implements congestionControl.HandleNDupAcks. +func (r *renoState) HandleNDupAcks() { + // A retransmit was triggered due to nDupAckThreshold + // being hit. Reduce our slow start threshold. + r.reduceSlowStartThreshold() +} + +// HandleRTOExpired implements congestionControl.HandleRTOExpired. +func (r *renoState) HandleRTOExpired() { + // We lost a packet, so reduce ssthresh. + r.reduceSlowStartThreshold() + + // Reduce the congestion window to 1, i.e., enter slow-start. Per + // RFC 5681, page 7, we must use 1 regardless of the value of the + // initial congestion window. + r.s.sndCwnd = 1 +} + +// PostRecovery implements congestionControl.PostRecovery. +func (r *renoState) PostRecovery() { + // noop. +} diff --git a/pkg/tcpip/transport/tcp/sack.go b/pkg/tcpip/transport/tcp/sack.go new file mode 100644 index 000000000..7be86d68e --- /dev/null +++ b/pkg/tcpip/transport/tcp/sack.go @@ -0,0 +1,105 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +const ( + // MaxSACKBlocks is the maximum number of SACK blocks stored + // at receiver side. + MaxSACKBlocks = 6 +) + +// UpdateSACKBlocks updates the list of SACK blocks to include the segment +// specified by segStart->segEnd. If the segment happens to be an out of order +// delivery then the first block in the sack.blocks always includes the +// segment identified by segStart->segEnd. +func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) { + newSB := header.SACKBlock{Start: segStart, End: segEnd} + + // Ignore any invalid SACK blocks or blocks that are before rcvNxt as + // those bytes have already been acked. + if newSB.End.LessThanEq(newSB.Start) || newSB.End.LessThan(rcvNxt) { + return + } + + if sack.NumBlocks == 0 { + sack.Blocks[0] = newSB + sack.NumBlocks = 1 + return + } + var n = 0 + for i := 0; i < sack.NumBlocks; i++ { + start, end := sack.Blocks[i].Start, sack.Blocks[i].End + if end.LessThanEq(rcvNxt) { + // Discard any sack blocks that are before rcvNxt as + // those have already been acked. + continue + } + if newSB.Start.LessThanEq(end) && start.LessThanEq(newSB.End) { + // Merge this SACK block into newSB and discard this SACK + // block. + if start.LessThan(newSB.Start) { + newSB.Start = start + } + if newSB.End.LessThan(end) { + newSB.End = end + } + } else { + // Save this block. + sack.Blocks[n] = sack.Blocks[i] + n++ + } + } + if rcvNxt.LessThan(newSB.Start) { + // If this was an out of order segment then make sure that the + // first SACK block is the one that includes the segment. + // + // See the first bullet point in + // https://tools.ietf.org/html/rfc2018#section-4 + if n == MaxSACKBlocks { + // If the number of SACK blocks is equal to + // MaxSACKBlocks then discard the last SACK block. + n-- + } + for i := n - 1; i >= 0; i-- { + sack.Blocks[i+1] = sack.Blocks[i] + } + sack.Blocks[0] = newSB + n++ + } + sack.NumBlocks = n +} + +// TrimSACKBlockList updates the sack block list by removing/modifying any block +// where start is < rcvNxt. +func TrimSACKBlockList(sack *SACKInfo, rcvNxt seqnum.Value) { + n := 0 + for i := 0; i < sack.NumBlocks; i++ { + if sack.Blocks[i].End.LessThanEq(rcvNxt) { + continue + } + if sack.Blocks[i].Start.LessThan(rcvNxt) { + // Shrink this SACK block. + sack.Blocks[i].Start = rcvNxt + } + sack.Blocks[n] = sack.Blocks[i] + n++ + } + sack.NumBlocks = n +} diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go new file mode 100644 index 000000000..7ef2df377 --- /dev/null +++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go @@ -0,0 +1,306 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "fmt" + "strings" + + "github.com/google/btree" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +const ( + // maxSACKBlocks is the maximum number of distinct SACKBlocks the + // scoreboard will track. Once there are 100 distinct blocks, new + // insertions will fail. + maxSACKBlocks = 100 + + // defaultBtreeDegree is set to 2 as btree.New(2) results in a 2-3-4 + // tree. + defaultBtreeDegree = 2 +) + +// SACKScoreboard stores a set of disjoint SACK ranges. +// +// +stateify savable +type SACKScoreboard struct { + // smss is defined in RFC5681 as following: + // + // The SMSS is the size of the largest segment that the sender can + // transmit. This value can be based on the maximum transmission unit + // of the network, the path MTU discovery [RFC1191, RFC4821] algorithm, + // RMSS (see next item), or other factors. The size does not include + // the TCP/IP headers and options. + smss uint16 + maxSACKED seqnum.Value + sacked seqnum.Size `state:"nosave"` + ranges *btree.BTree `state:"nosave"` +} + +// NewSACKScoreboard returns a new SACK Scoreboard. +func NewSACKScoreboard(smss uint16, iss seqnum.Value) *SACKScoreboard { + return &SACKScoreboard{ + smss: smss, + ranges: btree.New(defaultBtreeDegree), + maxSACKED: iss, + } +} + +// Reset erases all known range information from the SACK scoreboard. +func (s *SACKScoreboard) Reset() { + s.ranges = btree.New(defaultBtreeDegree) + s.sacked = 0 +} + +// Insert inserts/merges the provided SACKBlock into the scoreboard. +func (s *SACKScoreboard) Insert(r header.SACKBlock) { + if s.ranges.Len() >= maxSACKBlocks { + return + } + + // Check if we can merge the new range with a range before or after it. + var toDelete []btree.Item + if s.maxSACKED.LessThan(r.End - 1) { + s.maxSACKED = r.End - 1 + } + s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool { + if i == r { + return true + } + sacked := i.(header.SACKBlock) + // There is a hole between these two SACK blocks, so we can't + // merge anymore. + if r.End.LessThan(sacked.Start) { + return false + } + // There is some overlap at this point, merge the blocks and + // delete the other one. + // + // ----sS--------sE + // r.S---------------rE + // -------sE + if sacked.End.LessThan(r.End) { + // sacked is contained in the newly inserted range. + // Delete this block. + toDelete = append(toDelete, i) + return true + } + // sacked covers a range past end of the newly inserted + // block. + r.End = sacked.End + toDelete = append(toDelete, i) + return true + }) + + s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool { + if i == r { + return true + } + sacked := i.(header.SACKBlock) + // sA------sE + // rA----rE + if sacked.End.LessThan(r.Start) { + return false + } + // The previous range extends into the current block. Merge it + // into the newly inserted range and delete the other one. + // + // <-rA---rE----<---rE---> + // sA--------------sE + r.Start = sacked.Start + // Extend r to cover sacked if sacked extends past r. + if r.End.LessThan(sacked.End) { + r.End = sacked.End + } + toDelete = append(toDelete, i) + return true + }) + for _, i := range toDelete { + if sb := s.ranges.Delete(i); sb != nil { + sb := i.(header.SACKBlock) + s.sacked -= sb.Start.Size(sb.End) + } + } + + replaced := s.ranges.ReplaceOrInsert(r) + if replaced == nil { + s.sacked += r.Start.Size(r.End) + } +} + +// IsSACKED returns true if the a given range of sequence numbers denoted by r +// are already covered by SACK information in the scoreboard. +func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool { + if s.Empty() { + return false + } + + found := false + s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool { + sacked := i.(header.SACKBlock) + if sacked.End.LessThan(r.Start) { + return false + } + if sacked.Contains(r) { + found = true + return false + } + return true + }) + return found +} + +// Dump prints the state of the scoreboard structure. +func (s *SACKScoreboard) String() string { + var str strings.Builder + str.WriteString("SACKScoreboard: {") + s.ranges.Ascend(func(i btree.Item) bool { + str.WriteString(fmt.Sprintf("%v,", i)) + return true + }) + str.WriteString("}\n") + return str.String() +} + +// Delete removes all SACK information prior to seq. +func (s *SACKScoreboard) Delete(seq seqnum.Value) { + if s.Empty() { + return + } + toDelete := []btree.Item{} + toInsert := []btree.Item{} + r := header.SACKBlock{seq, seq.Add(1)} + s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool { + if i == r { + return true + } + sb := i.(header.SACKBlock) + toDelete = append(toDelete, i) + if sb.End.LessThanEq(seq) { + s.sacked -= sb.Start.Size(sb.End) + } else { + newSB := header.SACKBlock{seq, sb.End} + toInsert = append(toInsert, newSB) + s.sacked -= sb.Start.Size(seq) + } + return true + }) + for _, sb := range toDelete { + s.ranges.Delete(sb) + } + for _, sb := range toInsert { + s.ranges.ReplaceOrInsert(sb) + } +} + +// Copy provides a copy of the SACK scoreboard. +func (s *SACKScoreboard) Copy() (sackBlocks []header.SACKBlock, maxSACKED seqnum.Value) { + s.ranges.Ascend(func(i btree.Item) bool { + sackBlocks = append(sackBlocks, i.(header.SACKBlock)) + return true + }) + return sackBlocks, s.maxSACKED +} + +// IsRangeLost implements the IsLost(SeqNum) operation defined in RFC 6675 +// section 4 but operates on a range of sequence numbers and returns true if +// there are at least nDupAckThreshold SACK blocks greater than the range being +// checked or if at least (nDupAckThreshold-1)*s.smss bytes have been SACKED +// with sequence numbers greater than the block being checked. +func (s *SACKScoreboard) IsRangeLost(r header.SACKBlock) bool { + if s.Empty() { + return false + } + nDupSACK := 0 + nDupSACKBytes := seqnum.Size(0) + isLost := false + + // We need to check if the immediate lower (if any) sacked + // range contains or partially overlaps with r. + searchMore := true + s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool { + sacked := i.(header.SACKBlock) + if sacked.Contains(r) { + searchMore = false + return false + } + if sacked.End.LessThanEq(r.Start) { + // all sequence numbers covered by sacked are below + // r so we continue searching. + return false + } + // There is a partial overlap. In this case we r.Start is + // between sacked.Start & sacked.End and r.End extends beyond + // sacked.End. + // Move r.Start to sacked.End and continuing searching blocks + // above r.Start. + r.Start = sacked.End + return false + }) + + if !searchMore { + return isLost + } + + s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool { + sacked := i.(header.SACKBlock) + if sacked.Contains(r) { + return false + } + nDupSACKBytes += sacked.Start.Size(sacked.End) + nDupSACK++ + if nDupSACK >= nDupAckThreshold || nDupSACKBytes >= seqnum.Size((nDupAckThreshold-1)*s.smss) { + isLost = true + return false + } + return true + }) + return isLost +} + +// IsLost implements the IsLost(SeqNum) operation defined in RFC3517 section +// 4. +// +// This routine returns whether the given sequence number is considered to be +// lost. The routine returns true when either nDupAckThreshold discontiguous +// SACKed sequences have arrived above 'SeqNum' or (nDupAckThreshold * SMSS) +// bytes with sequence numbers greater than 'SeqNum' have been SACKed. +// Otherwise, the routine returns false. +func (s *SACKScoreboard) IsLost(seq seqnum.Value) bool { + return s.IsRangeLost(header.SACKBlock{seq, seq.Add(1)}) +} + +// Empty returns true if the SACK scoreboard has no entries, false otherwise. +func (s *SACKScoreboard) Empty() bool { + return s.ranges.Len() == 0 +} + +// Sacked returns the current number of bytes held in the SACK scoreboard. +func (s *SACKScoreboard) Sacked() seqnum.Size { + return s.sacked +} + +// MaxSACKED returns the highest sequence number ever inserted in the SACK +// scoreboard. +func (s *SACKScoreboard) MaxSACKED() seqnum.Value { + return s.maxSACKED +} + +// SMSS returns the sender's MSS as held by the SACK scoreboard. +func (s *SACKScoreboard) SMSS() uint16 { + return s.smss +} diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard_test.go b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go new file mode 100644 index 000000000..b4e5ba0df --- /dev/null +++ b/pkg/tcpip/transport/tcp/sack_scoreboard_test.go @@ -0,0 +1,249 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" +) + +const smss = 1500 + +func initScoreboard(blocks []header.SACKBlock, iss seqnum.Value) *tcp.SACKScoreboard { + s := tcp.NewSACKScoreboard(smss, iss) + for _, blk := range blocks { + s.Insert(blk) + } + return s +} + +func TestSACKScoreboardIsSACKED(t *testing.T) { + type blockTest struct { + block header.SACKBlock + sacked bool + } + testCases := []struct { + comment string + scoreboardBlocks []header.SACKBlock + blockTests []blockTest + iss seqnum.Value + }{ + { + "Test holes and unsacked SACK blocks in SACKed ranges and insertion of overlapping SACK blocks", + []header.SACKBlock{{10, 20}, {10, 30}, {30, 40}, {41, 50}, {5, 10}, {1, 50}, {111, 120}, {101, 110}, {52, 120}}, + []blockTest{ + {header.SACKBlock{15, 21}, true}, + {header.SACKBlock{200, 201}, false}, + {header.SACKBlock{50, 51}, false}, + {header.SACKBlock{53, 120}, true}, + }, + 0, + }, + { + "Test disjoint SACKBlocks", + []header.SACKBlock{{2288624809, 2288810057}, {2288811477, 2288838565}}, + []blockTest{ + {header.SACKBlock{2288624809, 2288810057}, true}, + {header.SACKBlock{2288811477, 2288838565}, true}, + {header.SACKBlock{2288810057, 2288811477}, false}, + }, + 2288624809, + }, + { + "Test sequence number wrap around", + []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}}, + []blockTest{ + {header.SACKBlock{4294254144, 4294254145}, true}, + {header.SACKBlock{4294254143, 4294254144}, false}, + {header.SACKBlock{4294254144, 1}, true}, + {header.SACKBlock{225652, 5350509}, false}, + {header.SACKBlock{5340409, 5350509}, true}, + {header.SACKBlock{5350509, 5350609}, false}, + }, + 4294254144, + }, + { + "Test disjoint SACKBlocks out of order", + []header.SACKBlock{{827450276, 827454536}, {827426028, 827428868}}, + []blockTest{ + {header.SACKBlock{827426028, 827428867}, true}, + {header.SACKBlock{827450168, 827450275}, false}, + }, + 827426000, + }, + } + for _, tc := range testCases { + sb := initScoreboard(tc.scoreboardBlocks, tc.iss) + for _, blkTest := range tc.blockTests { + if want, got := blkTest.sacked, sb.IsSACKED(blkTest.block); got != want { + t.Errorf("%s: s.IsSACKED(%v) = %v, want %v", tc.comment, blkTest.block, got, want) + } + } + } +} + +func TestSACKScoreboardIsRangeLost(t *testing.T) { + s := tcp.NewSACKScoreboard(10, 0) + s.Insert(header.SACKBlock{1, 25}) + s.Insert(header.SACKBlock{25, 50}) + s.Insert(header.SACKBlock{51, 100}) + s.Insert(header.SACKBlock{111, 120}) + s.Insert(header.SACKBlock{101, 110}) + s.Insert(header.SACKBlock{121, 141}) + s.Insert(header.SACKBlock{145, 146}) + s.Insert(header.SACKBlock{147, 148}) + s.Insert(header.SACKBlock{149, 150}) + s.Insert(header.SACKBlock{153, 154}) + s.Insert(header.SACKBlock{155, 156}) + testCases := []struct { + block header.SACKBlock + lost bool + }{ + // Block not covered by SACK block and has more than + // nDupAckThreshold discontiguous SACK blocks after it as well + // as (nDupAckThreshold -1) * 10 (smss) bytes that have been + // SACKED above the sequence number covered by this block. + {block: header.SACKBlock{0, 1}, lost: true}, + + // These blocks have all been SACKed and should not be + // considered lost. + {block: header.SACKBlock{1, 2}, lost: false}, + {block: header.SACKBlock{25, 26}, lost: false}, + {block: header.SACKBlock{1, 45}, lost: false}, + + // Same as the first case above. + {block: header.SACKBlock{50, 51}, lost: true}, + + // This block has been SACKed and should not be considered lost. + {block: header.SACKBlock{119, 120}, lost: false}, + + // This one should return true because there are > + // (nDupAckThreshold - 1) * 10 (smss) bytes that have been + // sacked above this sequence number. + {block: header.SACKBlock{120, 121}, lost: true}, + + // This block has been SACKed and should not be considered lost. + {block: header.SACKBlock{125, 126}, lost: false}, + + // This block has not been SACKed and there are nDupAckThreshold + // number of SACKed blocks after it. + {block: header.SACKBlock{141, 145}, lost: true}, + + // This block has not been SACKed and there are less than + // nDupAckThreshold SACKed sequences after it. + {block: header.SACKBlock{151, 152}, lost: false}, + } + for _, tc := range testCases { + if want, got := tc.lost, s.IsRangeLost(tc.block); got != want { + t.Errorf("s.IsRangeLost(%v) = %v, want %v", tc.block, got, want) + } + } +} + +func TestSACKScoreboardIsLost(t *testing.T) { + s := tcp.NewSACKScoreboard(10, 0) + s.Insert(header.SACKBlock{1, 25}) + s.Insert(header.SACKBlock{25, 50}) + s.Insert(header.SACKBlock{51, 100}) + s.Insert(header.SACKBlock{111, 120}) + s.Insert(header.SACKBlock{101, 110}) + s.Insert(header.SACKBlock{121, 141}) + s.Insert(header.SACKBlock{121, 141}) + s.Insert(header.SACKBlock{145, 146}) + s.Insert(header.SACKBlock{147, 148}) + s.Insert(header.SACKBlock{149, 150}) + s.Insert(header.SACKBlock{153, 154}) + s.Insert(header.SACKBlock{155, 156}) + testCases := []struct { + seq seqnum.Value + lost bool + }{ + // Sequence number not covered by SACK block and has more than + // nDupAckThreshold discontiguous SACK blocks after it as well + // as (nDupAckThreshold -1) * 10 (smss) bytes that have been + // SACKED above the sequence number. + {seq: 0, lost: true}, + + // These sequence numbers have all been SACKed and should not be + // considered lost. + {seq: 1, lost: false}, + {seq: 25, lost: false}, + {seq: 45, lost: false}, + + // Same as first case above. + {seq: 50, lost: true}, + + // This block has been SACKed and should not be considered lost. + {seq: 119, lost: false}, + + // This one should return true because there are > + // (nDupAckThreshold - 1) * 10 (smss) bytes that have been + // sacked above this sequence number. + {seq: 120, lost: true}, + + // This sequence number has been SACKed and should not be + // considered lost. + {seq: 125, lost: false}, + + // This sequence number has not been SACKed and there are + // nDupAckThreshold number of SACKed blocks after it. + {seq: 141, lost: true}, + + // This sequence number has not been SACKed and there are less + // than nDupAckThreshold SACKed sequences after it. + {seq: 151, lost: false}, + } + for _, tc := range testCases { + if want, got := tc.lost, s.IsLost(tc.seq); got != want { + t.Errorf("s.IsLost(%v) = %v, want %v", tc.seq, got, want) + } + } +} + +func TestSACKScoreboardDelete(t *testing.T) { + blocks := []header.SACKBlock{{4294254144, 225652}, {5340409, 5350509}} + s := initScoreboard(blocks, 4294254143) + s.Delete(5340408) + if s.Empty() { + t.Fatalf("s.Empty() = true, want false") + } + if got, want := s.Sacked(), blocks[1].Start.Size(blocks[1].End); got != want { + t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) + } + s.Delete(5340410) + if s.Empty() { + t.Fatal("s.Empty() = true, want false") + } + newSB := header.SACKBlock{5340410, 5350509} + if !s.IsSACKED(newSB) { + t.Fatalf("s.IsSACKED(%v) = false, want true, scoreboard: %v", newSB, s) + } + s.Delete(5350509) + lastOctet := header.SACKBlock{5350508, 5350509} + if s.IsSACKED(lastOctet) { + t.Fatalf("s.IsSACKED(%v) = false, want true", lastOctet) + } + + s.Delete(5350510) + if !s.Empty() { + t.Fatal("s.Empty() = false, want true") + } + if got, want := s.Sacked(), seqnum.Size(0); got != want { + t.Fatalf("incorrect sacked bytes in scoreboard got: %v, want: %v", got, want) + } +} diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go new file mode 100644 index 000000000..0280892a8 --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment.go @@ -0,0 +1,194 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "sync/atomic" + "time" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// segment represents a TCP segment. It holds the payload and parsed TCP segment +// information, and can be added to intrusive lists. +// segment is mostly immutable, the only field allowed to change is viewToDeliver. +// +// +stateify savable +type segment struct { + segmentEntry + refCnt int32 + id stack.TransportEndpointID `state:"manual"` + route stack.Route `state:"manual"` + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + hdr header.TCP + // views is used as buffer for data when its length is large + // enough to store a VectorisedView. + views [8]buffer.View `state:"nosave"` + // viewToDeliver keeps track of the next View that should be + // delivered by the Read endpoint. + viewToDeliver int + sequenceNumber seqnum.Value + ackNumber seqnum.Value + flags uint8 + window seqnum.Size + // csum is only populated for received segments. + csum uint16 + // csumValid is true if the csum in the received segment is valid. + csumValid bool + + // parsedOptions stores the parsed values from the options in the segment. + parsedOptions header.TCPOptions + options []byte `state:".([]byte)"` + hasNewSACKInfo bool + rcvdTime time.Time `state:".(unixTime)"` + // xmitTime is the last transmit time of this segment. + xmitTime time.Time `state:".(unixTime)"` + xmitCount uint32 +} + +func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { + s := &segment{ + refCnt: 1, + id: id, + route: r.Clone(), + } + s.data = pkt.Data.Clone(s.views[:]) + s.hdr = header.TCP(pkt.TransportHeader) + s.rcvdTime = time.Now() + return s +} + +func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment { + s := &segment{ + refCnt: 1, + id: id, + route: r.Clone(), + } + s.rcvdTime = time.Now() + if len(v) != 0 { + s.views[0] = v + s.data = buffer.NewVectorisedView(len(v), s.views[:1]) + } + return s +} + +func (s *segment) clone() *segment { + t := &segment{ + refCnt: 1, + id: s.id, + sequenceNumber: s.sequenceNumber, + ackNumber: s.ackNumber, + flags: s.flags, + window: s.window, + route: s.route.Clone(), + viewToDeliver: s.viewToDeliver, + rcvdTime: s.rcvdTime, + xmitTime: s.xmitTime, + xmitCount: s.xmitCount, + } + t.data = s.data.Clone(t.views[:]) + return t +} + +// flagIsSet checks if at least one flag in flags is set in s.flags. +func (s *segment) flagIsSet(flags uint8) bool { + return s.flags&flags != 0 +} + +// flagsAreSet checks if all flags in flags are set in s.flags. +func (s *segment) flagsAreSet(flags uint8) bool { + return s.flags&flags == flags +} + +func (s *segment) decRef() { + if atomic.AddInt32(&s.refCnt, -1) == 0 { + s.route.Release() + } +} + +func (s *segment) incRef() { + atomic.AddInt32(&s.refCnt, 1) +} + +// logicalLen is the segment length in the sequence number space. It's defined +// as the data length plus one for each of the SYN and FIN bits set. +func (s *segment) logicalLen() seqnum.Size { + l := seqnum.Size(s.data.Size()) + if s.flagIsSet(header.TCPFlagSyn) { + l++ + } + if s.flagIsSet(header.TCPFlagFin) { + l++ + } + return l +} + +// parse populates the sequence & ack numbers, flags, and window fields of the +// segment from the TCP header stored in the data. It then updates the view to +// skip the header. +// +// Returns boolean indicating if the parsing was successful. +// +// If checksum verification is not offloaded then parse also verifies the +// TCP checksum and stores the checksum and result of checksum verification in +// the csum and csumValid fields of the segment. +func (s *segment) parse() bool { + // h is the header followed by the payload. We check that the offset to + // the data respects the following constraints: + // 1. That it's at least the minimum header size; if we don't do this + // then part of the header would be delivered to user. + // 2. That the header fits within the buffer; if we don't do this, we + // would panic when we tried to access data beyond the buffer. + // + // N.B. The segment has already been validated as having at least the + // minimum TCP size before reaching here, so it's safe to read the + // fields. + offset := int(s.hdr.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(s.hdr) { + return false + } + + s.options = []byte(s.hdr[header.TCPMinimumSize:]) + s.parsedOptions = header.ParseTCPOptions(s.options) + + // Query the link capabilities to decide if checksum validation is + // required. + verifyChecksum := true + if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 { + s.csumValid = true + verifyChecksum = false + } + if verifyChecksum { + s.csum = s.hdr.Checksum() + xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr))) + xsum = s.hdr.CalculateChecksum(xsum) + xsum = header.ChecksumVV(s.data, xsum) + s.csumValid = xsum == 0xffff + } + + s.sequenceNumber = seqnum.Value(s.hdr.SequenceNumber()) + s.ackNumber = seqnum.Value(s.hdr.AckNumber()) + s.flags = s.hdr.Flags() + s.window = seqnum.Size(s.hdr.WindowSize()) + return true +} + +// sackBlock returns a header.SACKBlock that represents this segment. +func (s *segment) sackBlock() header.SACKBlock { + return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())} +} diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go new file mode 100644 index 000000000..8d3ddce4b --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_heap.go @@ -0,0 +1,51 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import "container/heap" + +type segmentHeap []*segment + +var _ heap.Interface = (*segmentHeap)(nil) + +// Len returns the length of h. +func (h *segmentHeap) Len() int { + return len(*h) +} + +// Less determines whether the i-th element of h is less than the j-th element. +func (h *segmentHeap) Less(i, j int) bool { + return (*h)[i].sequenceNumber.LessThan((*h)[j].sequenceNumber) +} + +// Swap swaps the i-th and j-th elements of h. +func (h *segmentHeap) Swap(i, j int) { + (*h)[i], (*h)[j] = (*h)[j], (*h)[i] +} + +// Push adds x as the last element of h. +func (h *segmentHeap) Push(x interface{}) { + *h = append(*h, x.(*segment)) +} + +// Pop removes the last element of h and returns it. +func (h *segmentHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + old[n-1] = nil + *h = old[:n-1] + return x +} diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go new file mode 100644 index 000000000..48a257137 --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -0,0 +1,85 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "gvisor.dev/gvisor/pkg/sync" +) + +// segmentQueue is a bounded, thread-safe queue of TCP segments. +// +// +stateify savable +type segmentQueue struct { + mu sync.Mutex `state:"nosave"` + list segmentList `state:"wait"` + limit int + used int +} + +// emptyLocked determines if the queue is empty. +// Preconditions: q.mu must be held. +func (q *segmentQueue) emptyLocked() bool { + return q.used == 0 +} + +// empty determines if the queue is empty. +func (q *segmentQueue) empty() bool { + q.mu.Lock() + r := q.emptyLocked() + q.mu.Unlock() + + return r +} + +// setLimit updates the limit. No segments are immediately dropped in case the +// queue becomes full due to the new limit. +func (q *segmentQueue) setLimit(limit int) { + q.mu.Lock() + q.limit = limit + q.mu.Unlock() +} + +// enqueue adds the given segment to the queue. +// +// Returns true when the segment is successfully added to the queue, in which +// case ownership of the reference is transferred to the queue. And returns +// false if the queue is full, in which case ownership is retained by the +// caller. +func (q *segmentQueue) enqueue(s *segment) bool { + q.mu.Lock() + r := q.used < q.limit + if r { + q.list.PushBack(s) + q.used++ + } + q.mu.Unlock() + + return r +} + +// dequeue removes and returns the next segment from queue, if one exists. +// Ownership is transferred to the caller, who is responsible for decrementing +// the ref count when done. +func (q *segmentQueue) dequeue() *segment { + q.mu.Lock() + s := q.list.Front() + if s != nil { + q.list.Remove(s) + q.used-- + } + q.mu.Unlock() + + return s +} diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go new file mode 100644 index 000000000..7dc2741a6 --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_state.go @@ -0,0 +1,82 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" + + "gvisor.dev/gvisor/pkg/tcpip/buffer" +) + +// saveData is invoked by stateify. +func (s *segment) saveData() buffer.VectorisedView { + // We cannot save s.data directly as s.data.views may alias to s.views, + // which is not allowed by state framework (in-struct pointer). + v := make([]buffer.View, len(s.data.Views())) + // For views already delivered, we cannot save them directly as they may + // have already been sliced and saved elsewhere (e.g., readViews). + for i := 0; i < s.viewToDeliver; i++ { + v[i] = append([]byte(nil), s.data.Views()[i]...) + } + for i := s.viewToDeliver; i < len(v); i++ { + v[i] = s.data.Views()[i] + } + return buffer.NewVectorisedView(s.data.Size(), v) +} + +// loadData is invoked by stateify. +func (s *segment) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the s.data = data.Clone(s.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing s.views for data.views. + s.data = data +} + +// saveOptions is invoked by stateify. +func (s *segment) saveOptions() []byte { + // We cannot save s.options directly as it may point to s.data's trimmed + // tail, which is not allowed by state framework (in-struct pointer). + b := make([]byte, 0, cap(s.options)) + return append(b, s.options...) +} + +// loadOptions is invoked by stateify. +func (s *segment) loadOptions(options []byte) { + // NOTE: We cannot point s.options back into s.data's trimmed tail. But + // it is OK as they do not need to aliased. Plus, options is already + // allocated so there is no cost here. + s.options = options +} + +// saveRcvdTime is invoked by stateify. +func (s *segment) saveRcvdTime() unixTime { + return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()} +} + +// loadRcvdTime is invoked by stateify. +func (s *segment) loadRcvdTime(unix unixTime) { + s.rcvdTime = time.Unix(unix.second, unix.nano) +} + +// saveXmitTime is invoked by stateify. +func (s *segment) saveXmitTime() unixTime { + return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()} +} + +// loadXmitTime is invoked by stateify. +func (s *segment) loadXmitTime(unix unixTime) { + s.rcvdTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go new file mode 100644 index 000000000..5862c32f2 --- /dev/null +++ b/pkg/tcpip/transport/tcp/snd.go @@ -0,0 +1,1487 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "fmt" + "math" + "sync/atomic" + "time" + + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +const ( + // MinRTO is the minimum allowed value for the retransmit timeout. + MinRTO = 200 * time.Millisecond + + // MaxRTO is the maximum allowed value for the retransmit timeout. + MaxRTO = 120 * time.Second + + // InitialCwnd is the initial congestion window. + InitialCwnd = 10 + + // nDupAckThreshold is the number of duplicate ACK's required + // before fast-retransmit is entered. + nDupAckThreshold = 3 + + // MaxRetries is the maximum number of probe retries sender does + // before timing out the connection. + // Linux default TCP_RETR2, net.ipv4.tcp_retries2. + MaxRetries = 15 +) + +// ccState indicates the current congestion control state for this sender. +type ccState int + +const ( + // Open indicates that the sender is receiving acks in order and + // no loss or dupACK's etc have been detected. + Open ccState = iota + // RTORecovery indicates that an RTO has occurred and the sender + // has entered an RTO based recovery phase. + RTORecovery + // FastRecovery indicates that the sender has entered FastRecovery + // based on receiving nDupAck's. This state is entered only when + // SACK is not in use. + FastRecovery + // SACKRecovery indicates that the sender has entered SACK based + // recovery. + SACKRecovery + // Disorder indicates the sender either received some SACK blocks + // or dupACK's. + Disorder +) + +// congestionControl is an interface that must be implemented by any supported +// congestion control algorithm. +type congestionControl interface { + // HandleNDupAcks is invoked when sender.dupAckCount >= nDupAckThreshold + // just before entering fast retransmit. + HandleNDupAcks() + + // HandleRTOExpired is invoked when the retransmit timer expires. + HandleRTOExpired() + + // Update is invoked when processing inbound acks. It's passed the + // number of packet's that were acked by the most recent cumulative + // acknowledgement. + Update(packetsAcked int) + + // PostRecovery is invoked when the sender is exiting a fast retransmit/ + // recovery phase. This provides congestion control algorithms a way + // to adjust their state when exiting recovery. + PostRecovery() +} + +// sender holds the state necessary to send TCP segments. +// +// +stateify savable +type sender struct { + ep *endpoint + + // lastSendTime is the timestamp when the last packet was sent. + lastSendTime time.Time `state:".(unixTime)"` + + // dupAckCount is the number of duplicated acks received. It is used for + // fast retransmit. + dupAckCount int + + // fr holds state related to fast recovery. + fr fastRecovery + + // sndCwnd is the congestion window, in packets. + sndCwnd int + + // sndSsthresh is the threshold between slow start and congestion + // avoidance. + sndSsthresh int + + // sndCAAckCount is the number of packets acknowledged during congestion + // avoidance. When enough packets have been ack'd (typically cwnd + // packets), the congestion window is incremented by one. + sndCAAckCount int + + // outstanding is the number of outstanding packets, that is, packets + // that have been sent but not yet acknowledged. + outstanding int + + // sndWnd is the send window size. + sndWnd seqnum.Size + + // sndUna is the next unacknowledged sequence number. + sndUna seqnum.Value + + // sndNxt is the sequence number of the next segment to be sent. + sndNxt seqnum.Value + + // rttMeasureSeqNum is the sequence number being used for the latest RTT + // measurement. + rttMeasureSeqNum seqnum.Value + + // rttMeasureTime is the time when the rttMeasureSeqNum was sent. + rttMeasureTime time.Time `state:".(unixTime)"` + + // firstRetransmittedSegXmitTime is the original transmit time of + // the first segment that was retransmitted due to RTO expiration. + firstRetransmittedSegXmitTime time.Time `state:".(unixTime)"` + + // zeroWindowProbing is set if the sender is currently probing + // for zero receive window. + zeroWindowProbing bool `state:"nosave"` + + // unackZeroWindowProbes is the number of unacknowledged zero + // window probes. + unackZeroWindowProbes uint32 `state:"nosave"` + + closed bool + writeNext *segment + writeList segmentList + resendTimer timer `state:"nosave"` + resendWaker sleep.Waker `state:"nosave"` + + // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time", + // "round-trip time variation" and "retransmit timeout", as defined in + // section 2 of RFC 6298. + rtt rtt + rto time.Duration + + // minRTO is the minimum permitted value for sender.rto. + minRTO time.Duration + + // maxRTO is the maximum permitted value for sender.rto. + maxRTO time.Duration + + // maxRetries is the maximum permitted retransmissions. + maxRetries uint32 + + // maxPayloadSize is the maximum size of the payload of a given segment. + // It is initialized on demand. + maxPayloadSize int + + // gso is set if generic segmentation offload is enabled. + gso bool + + // sndWndScale is the number of bits to shift left when reading the send + // window size from a segment. + sndWndScale uint8 + + // maxSentAck is the maxium acknowledgement actually sent. + maxSentAck seqnum.Value + + // state is the current state of congestion control for this endpoint. + state ccState + + // cc is the congestion control algorithm in use for this sender. + cc congestionControl +} + +// rtt is a synchronization wrapper used to appease stateify. See the comment +// in sender, where it is used. +// +// +stateify savable +type rtt struct { + sync.Mutex `state:"nosave"` + + srtt time.Duration + rttvar time.Duration + srttInited bool +} + +// fastRecovery holds information related to fast recovery from a packet loss. +// +// +stateify savable +type fastRecovery struct { + // active whether the endpoint is in fast recovery. The following fields + // are only meaningful when active is true. + active bool + + // first and last represent the inclusive sequence number range being + // recovered. + first seqnum.Value + last seqnum.Value + + // maxCwnd is the maximum value the congestion window may be inflated to + // due to duplicate acks. This exists to avoid attacks where the + // receiver intentionally sends duplicate acks to artificially inflate + // the sender's cwnd. + maxCwnd int + + // highRxt is the highest sequence number which has been retransmitted + // during the current loss recovery phase. + // See: RFC 6675 Section 2 for details. + highRxt seqnum.Value + + // rescueRxt is the highest sequence number which has been + // optimistically retransmitted to prevent stalling of the ACK clock + // when there is loss at the end of the window and no new data is + // available for transmission. + // See: RFC 6675 Section 2 for details. + rescueRxt seqnum.Value +} + +func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender { + // The sender MUST reduce the TCP data length to account for any IP or + // TCP options that it is including in the packets that it sends. + // See: https://tools.ietf.org/html/rfc6691#section-2 + maxPayloadSize := int(mss) - ep.maxOptionSize() + + s := &sender{ + ep: ep, + sndWnd: sndWnd, + sndUna: iss + 1, + sndNxt: iss + 1, + rto: 1 * time.Second, + rttMeasureSeqNum: iss + 1, + lastSendTime: time.Now(), + maxPayloadSize: maxPayloadSize, + maxSentAck: irs + 1, + fr: fastRecovery{ + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. + last: iss, + highRxt: iss, + rescueRxt: iss, + }, + gso: ep.gso != nil, + } + + if s.gso { + s.ep.gso.MSS = uint16(maxPayloadSize) + } + + s.cc = s.initCongestionControl(ep.cc) + + // A negative sndWndScale means that no scaling is in use, otherwise we + // store the scaling value. + if sndWndScale > 0 { + s.sndWndScale = uint8(sndWndScale) + } + + s.resendTimer.init(&s.resendWaker) + + s.updateMaxPayloadSize(int(ep.route.MTU()), 0) + + // Initialize SACK Scoreboard after updating max payload size as we use + // the maxPayloadSize as the smss when determining if a segment is lost + // etc. + s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss) + + // Get Stack wide config. + var minRTO tcpip.TCPMinRTOOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &minRTO); err != nil { + panic(fmt.Sprintf("unable to get minRTO from stack: %s", err)) + } + s.minRTO = time.Duration(minRTO) + + var maxRTO tcpip.TCPMaxRTOOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &maxRTO); err != nil { + panic(fmt.Sprintf("unable to get maxRTO from stack: %s", err)) + } + s.maxRTO = time.Duration(maxRTO) + + var maxRetries tcpip.TCPMaxRetriesOption + if err := ep.stack.TransportProtocolOption(ProtocolNumber, &maxRetries); err != nil { + panic(fmt.Sprintf("unable to get maxRetries from stack: %s", err)) + } + s.maxRetries = uint32(maxRetries) + + return s +} + +// initCongestionControl initializes the specified congestion control module and +// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to +// their initial values. +func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl { + s.sndCwnd = InitialCwnd + s.sndSsthresh = math.MaxInt64 + + switch congestionControlName { + case ccCubic: + return newCubicCC(s) + case ccReno: + fallthrough + default: + return newRenoCC(s) + } +} + +// updateMaxPayloadSize updates the maximum payload size based on the given +// MTU. If this is in response to "packet too big" control packets (indicated +// by the count argument), it also reduces the number of outstanding packets and +// attempts to retransmit the first packet above the MTU size. +func (s *sender) updateMaxPayloadSize(mtu, count int) { + m := mtu - header.TCPMinimumSize + + m -= s.ep.maxOptionSize() + + // We don't adjust up for now. + if m >= s.maxPayloadSize { + return + } + + // Make sure we can transmit at least one byte. + if m <= 0 { + m = 1 + } + + s.maxPayloadSize = m + if s.gso { + s.ep.gso.MSS = uint16(m) + } + + if count == 0 { + // updateMaxPayloadSize is also called when the sender is created. + // and there is no data to send in such cases. Return immediately. + return + } + + // Update the scoreboard's smss to reflect the new lowered + // maxPayloadSize. + s.ep.scoreboard.smss = uint16(m) + + s.outstanding -= count + if s.outstanding < 0 { + s.outstanding = 0 + } + + // Rewind writeNext to the first segment exceeding the MTU. Do nothing + // if it is already before such a packet. + for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { + if seg == s.writeNext { + // We got to writeNext before we could find a segment + // exceeding the MTU. + break + } + + if seg.data.Size() > m { + // We found a segment exceeding the MTU. Rewind + // writeNext and try to retransmit it. + s.writeNext = seg + break + } + } + + // Since we likely reduced the number of outstanding packets, we may be + // ready to send some more. + s.sendData() +} + +// sendAck sends an ACK segment. +func (s *sender) sendAck() { + s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt) +} + +// updateRTO updates the retransmit timeout when a new roud-trip time is +// available. This is done in accordance with section 2 of RFC 6298. +func (s *sender) updateRTO(rtt time.Duration) { + s.rtt.Lock() + if !s.rtt.srttInited { + s.rtt.rttvar = rtt / 2 + s.rtt.srtt = rtt + s.rtt.srttInited = true + } else { + diff := s.rtt.srtt - rtt + if diff < 0 { + diff = -diff + } + // Use RFC6298 standard algorithm to update rttvar and srtt when + // no timestamps are available. + if !s.ep.sendTSOk { + s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4 + s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8 + } else { + // When we are taking RTT measurements of every ACK then + // we need to use a modified method as specified in + // https://tools.ietf.org/html/rfc7323#appendix-G + if s.outstanding == 0 { + s.rtt.Unlock() + return + } + // Netstack measures congestion window/inflight all in + // terms of packets and not bytes. This is similar to + // how linux also does cwnd and inflight. In practice + // this approximation works as expected. + expectedSamples := math.Ceil(float64(s.outstanding) / 2) + + // alpha & beta values are the original values as recommended in + // https://tools.ietf.org/html/rfc6298#section-2.3. + const alpha = 0.125 + const beta = 0.25 + + alphaPrime := alpha / expectedSamples + betaPrime := beta / expectedSamples + rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds() + srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds() + s.rtt.rttvar = time.Duration(rttVar * float64(time.Second)) + s.rtt.srtt = time.Duration(srtt * float64(time.Second)) + } + } + + s.rto = s.rtt.srtt + 4*s.rtt.rttvar + s.rtt.Unlock() + if s.rto < s.minRTO { + s.rto = s.minRTO + } +} + +// resendSegment resends the first unacknowledged segment. +func (s *sender) resendSegment() { + // Don't use any segments we already sent to measure RTT as they may + // have been affected by packets being lost. + s.rttMeasureSeqNum = s.sndNxt + + // Resend the segment. + if seg := s.writeList.Front(); seg != nil { + if seg.data.Size() > s.maxPayloadSize { + s.splitSeg(seg, s.maxPayloadSize) + } + + // See: RFC 6675 section 5 Step 4.3 + // + // To prevent retransmission, set both the HighRXT and RescueRXT + // to the highest sequence number in the retransmitted segment. + s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.sendSegment(seg) + s.ep.stack.Stats().TCP.FastRetransmit.Increment() + s.ep.stats.SendErrors.FastRetransmit.Increment() + + // Run SetPipe() as per RFC 6675 section 5 Step 4.4 + s.SetPipe() + } +} + +// retransmitTimerExpired is called when the retransmit timer expires, and +// unacknowledged segments are assumed lost, and thus need to be resent. +// Returns true if the connection is still usable, or false if the connection +// is deemed lost. +func (s *sender) retransmitTimerExpired() bool { + // Check if the timer actually expired or if it's a spurious wake due + // to a previously orphaned runtime timer. + if !s.resendTimer.checkExpiration() { + return true + } + + // TODO(b/147297758): Band-aid fix, retransmitTimer can fire in some edge cases + // when writeList is empty. Remove this once we have a proper fix for this + // issue. + if s.writeList.Front() == nil { + return true + } + + s.ep.stack.Stats().TCP.Timeouts.Increment() + s.ep.stats.SendErrors.Timeouts.Increment() + + // Give up if we've waited more than a minute since the last resend or + // if a user time out is set and we have exceeded the user specified + // timeout since the first retransmission. + uto := s.ep.userTimeout + + if s.firstRetransmittedSegXmitTime.IsZero() { + // We store the original xmitTime of the segment that we are + // about to retransmit as the retransmission time. This is + // required as by the time the retransmitTimer has expired the + // segment has already been sent and unacked for the RTO at the + // time the segment was sent. + s.firstRetransmittedSegXmitTime = s.writeList.Front().xmitTime + } + + elapsed := time.Since(s.firstRetransmittedSegXmitTime) + remaining := s.maxRTO + if uto != 0 { + // Cap to the user specified timeout if one is specified. + remaining = uto - elapsed + } + + // Always honor the user-timeout irrespective of whether the zero + // window probes were acknowledged. + // net/ipv4/tcp_timer.c::tcp_probe_timer() + if remaining <= 0 || s.unackZeroWindowProbes >= s.maxRetries { + return false + } + + // Set new timeout. The timer will be restarted by the call to sendData + // below. + s.rto *= 2 + // Cap the RTO as per RFC 1122 4.2.3.1, RFC 6298 5.5 + if s.rto > s.maxRTO { + s.rto = s.maxRTO + } + + // Cap RTO to remaining time. + if s.rto > remaining { + s.rto = remaining + } + + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. + // + // Retransmit timeouts: + // After a retransmit timeout, record the highest sequence number + // transmitted in the variable recover, and exit the fast recovery + // procedure if applicable. + s.fr.last = s.sndNxt - 1 + + if s.fr.active { + // We were attempting fast recovery but were not successful. + // Leave the state. We don't need to update ssthresh because it + // has already been updated when entered fast-recovery. + s.leaveFastRecovery() + } + + s.state = RTORecovery + s.cc.HandleRTOExpired() + + // Mark the next segment to be sent as the first unacknowledged one and + // start sending again. Set the number of outstanding packets to 0 so + // that we'll be able to retransmit. + // + // We'll keep on transmitting (or retransmitting) as we get acks for + // the data we transmit. + s.outstanding = 0 + + // Expunge all SACK information as per https://tools.ietf.org/html/rfc6675#section-5.1 + // + // In order to avoid memory deadlocks, the TCP receiver is allowed to + // discard data that has already been selectively acknowledged. As a + // result, [RFC2018] suggests that a TCP sender SHOULD expunge the SACK + // information gathered from a receiver upon a retransmission timeout + // (RTO) "since the timeout might indicate that the data receiver has + // reneged." Additionally, a TCP sender MUST "ignore prior SACK + // information in determining which data to retransmit." + // + // NOTE: We take the stricter interpretation and just expunge all + // information as we lack more rigorous checks to validate if the SACK + // information is usable after an RTO. + s.ep.scoreboard.Reset() + s.writeNext = s.writeList.Front() + + // RFC 1122 4.2.2.17: Start sending zero window probes when we still see a + // zero receive window after retransmission interval and we have data to + // send. + if s.zeroWindowProbing { + s.sendZeroWindowProbe() + // RFC 1122 4.2.2.17: A TCP MAY keep its offered receive window closed + // indefinitely. As long as the receiving TCP continues to send + // acknowledgments in response to the probe segments, the sending TCP + // MUST allow the connection to stay open. + return true + } + + seg := s.writeNext + // RFC 1122 4.2.3.5: Close the connection when the number of + // retransmissions for this segment is beyond a limit. + if seg != nil && seg.xmitCount > s.maxRetries { + return false + } + + s.sendData() + + return true +} + +// pCount returns the number of packets in the segment. Due to GSO, a segment +// can be composed of multiple packets. +func (s *sender) pCount(seg *segment) int { + size := seg.data.Size() + if size == 0 { + return 1 + } + + return (size-1)/s.maxPayloadSize + 1 +} + +// splitSeg splits a given segment at the size specified and inserts the +// remainder as a new segment after the current one in the write list. +func (s *sender) splitSeg(seg *segment, size int) { + if seg.data.Size() <= size { + return + } + // Split this segment up. + nSeg := seg.clone() + nSeg.data.TrimFront(size) + nSeg.sequenceNumber.UpdateForward(seqnum.Size(size)) + s.writeList.InsertAfter(seg, nSeg) + + // The segment being split does not carry PUSH flag because it is + // followed by the newly split segment. + // RFC1122 section 4.2.2.2: MUST set the PSH bit in the last buffered + // segment (i.e., when there is no more queued data to be sent). + // Linux removes PSH flag only when the segment is being split over MSS + // and retains it when we are splitting the segment over lack of sender + // window space. + // ref: net/ipv4/tcp_output.c::tcp_write_xmit(), tcp_mss_split_point() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup(), tcp_snd_wnd_test() + if seg.data.Size() > s.maxPayloadSize { + seg.flags ^= header.TCPFlagPsh + } + + seg.data.CapLength(size) +} + +// NextSeg implements the RFC6675 NextSeg() operation. +// +// NextSeg starts scanning the writeList starting from nextSegHint and returns +// the hint to be passed on the next call to NextSeg. This is required to avoid +// iterating the write list repeatedly when NextSeg is invoked in a loop during +// recovery. The returned hint will be nil if there are no more segments that +// can match rules defined by NextSeg operation in RFC6675. +// +// rescueRtx will be true only if nextSeg is a rescue retransmission as +// described by Step 4) of the NextSeg algorithm. +func (s *sender) NextSeg(nextSegHint *segment) (nextSeg, hint *segment, rescueRtx bool) { + var s3 *segment + var s4 *segment + // Step 1. + for seg := nextSegHint; seg != nil; seg = seg.Next() { + // Stop iteration if we hit a segment that has never been + // transmitted (i.e. either it has no assigned sequence number + // or if it does have one, it's >= the next sequence number + // to be sent [i.e. >= s.sndNxt]). + if !s.isAssignedSequenceNumber(seg) || s.sndNxt.LessThanEq(seg.sequenceNumber) { + hint = nil + break + } + segSeq := seg.sequenceNumber + if smss := s.ep.scoreboard.SMSS(); seg.data.Size() > int(smss) { + s.splitSeg(seg, int(smss)) + } + + // See RFC 6675 Section 4 + // + // 1. If there exists a smallest unSACKED sequence number + // 'S2' that meets the following 3 criteria for determinig + // loss, the sequence range of one segment of up to SMSS + // octects starting with S2 MUST be returned. + if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) { + // NextSeg(): + // + // (1.a) S2 is greater than HighRxt + // (1.b) S2 is less than highest octect covered by + // any received SACK. + if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) { + // NextSeg(): + // (1.c) IsLost(S2) returns true. + if s.ep.scoreboard.IsLost(segSeq) { + return seg, seg.Next(), false + } + + // NextSeg(): + // + // (3): If the conditions for rules (1) and (2) + // fail, but there exists an unSACKed sequence + // number S3 that meets the criteria for + // detecting loss given in steps 1.a and 1.b + // above (specifically excluding (1.c)) then one + // segment of upto SMSS octets starting with S3 + // SHOULD be returned. + if s3 == nil { + s3 = seg + hint = seg.Next() + } + } + // NextSeg(): + // + // (4) If the conditions for (1), (2) and (3) fail, + // but there exists outstanding unSACKED data, we + // provide the opportunity for a single "rescue" + // retransmission per entry into loss recovery. If + // HighACK is greater than RescueRxt (or RescueRxt + // is undefined), then one segment of upto SMSS + // octects that MUST include the highest outstanding + // unSACKed sequence number SHOULD be returned, and + // RescueRxt set to RecoveryPoint. HighRxt MUST NOT + // be updated. + if s.fr.rescueRxt.LessThan(s.sndUna - 1) { + if s4 != nil { + if s4.sequenceNumber.LessThan(segSeq) { + s4 = seg + } + } else { + s4 = seg + } + } + } + } + + // If we got here then no segment matched step (1). + // Step (2): "If no sequence number 'S2' per rule (1) + // exists but there exists available unsent data and the + // receiver's advertised window allows, the sequence + // range of one segment of up to SMSS octets of + // previously unsent data starting with sequence number + // HighData+1 MUST be returned." + for seg := s.writeNext; seg != nil; seg = seg.Next() { + if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) { + continue + } + // We do not split the segment here to <= smss as it has + // potentially not been assigned a sequence number yet. + return seg, nil, false + } + + if s3 != nil { + return s3, hint, false + } + + return s4, nil, true +} + +// maybeSendSegment tries to send the specified segment and either coalesces +// other segments into this one or splits the specified segment based on the +// lower of the specified limit value or the receivers window size specified by +// end. +func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (sent bool) { + // We abuse the flags field to determine if we have already + // assigned a sequence number to this segment. + if !s.isAssignedSequenceNumber(seg) { + // Merge segments if allowed. + if seg.data.Size() != 0 { + available := int(s.sndNxt.Size(end)) + if available > limit { + available = limit + } + + // nextTooBig indicates that the next segment was too + // large to entirely fit in the current segment. It + // would be possible to split the next segment and merge + // the portion that fits, but unexpectedly splitting + // segments can have user visible side-effects which can + // break applications. For example, RFC 7766 section 8 + // says that the length and data of a DNS response + // should be sent in the same TCP segment to avoid + // triggering bugs in poorly written DNS + // implementations. + var nextTooBig bool + for seg.Next() != nil && seg.Next().data.Size() != 0 { + if seg.data.Size()+seg.Next().data.Size() > available { + nextTooBig = true + break + } + seg.data.Append(seg.Next().data) + + // Consume the segment that we just merged in. + s.writeList.Remove(seg.Next()) + } + if !nextTooBig && seg.data.Size() < available { + // Segment is not full. + if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 { + // Nagle's algorithm. From Wikipedia: + // Nagle's algorithm works by + // combining a number of small + // outgoing messages and sending them + // all at once. Specifically, as long + // as there is a sent packet for which + // the sender has received no + // acknowledgment, the sender should + // keep buffering its output until it + // has a full packet's worth of + // output, thus allowing output to be + // sent all at once. + return false + } + // With TCP_CORK, hold back until minimum of the available + // send space and MSS. + // TODO(gvisor.dev/issue/2833): Drain the held segments after a + // timeout. + if seg.data.Size() < s.maxPayloadSize && atomic.LoadUint32(&s.ep.cork) != 0 { + return false + } + } + } + + // Assign flags. We don't do it above so that we can merge + // additional data if Nagle holds the segment. + seg.sequenceNumber = s.sndNxt + seg.flags = header.TCPFlagAck | header.TCPFlagPsh + } + + var segEnd seqnum.Value + if seg.data.Size() == 0 { + if s.writeList.Back() != seg { + panic("FIN segments must be the final segment in the write list.") + } + seg.flags = header.TCPFlagAck | header.TCPFlagFin + segEnd = seg.sequenceNumber.Add(1) + // Update the state to reflect that we have now + // queued a FIN. + switch s.ep.EndpointState() { + case StateCloseWait: + s.ep.setEndpointState(StateLastAck) + default: + s.ep.setEndpointState(StateFinWait1) + } + } else { + // We're sending a non-FIN segment. + if seg.flags&header.TCPFlagFin != 0 { + panic("Netstack queues FIN segments without data.") + } + + if !seg.sequenceNumber.LessThan(end) { + return false + } + + available := int(seg.sequenceNumber.Size(end)) + if available == 0 { + return false + } + + // If the whole segment or at least 1MSS sized segment cannot + // be accomodated in the receiver advertized window, skip + // splitting and sending of the segment. ref: + // net/ipv4/tcp_output.c::tcp_snd_wnd_test() + // + // Linux checks this for all segment transmits not triggered by + // a probe timer. On this condition, it defers the segment split + // and transmit to a short probe timer. + // + // ref: include/net/tcp.h::tcp_check_probe_timer() + // ref: net/ipv4/tcp_output.c::tcp_write_wakeup() + // + // Instead of defining a new transmit timer, we attempt to split + // the segment right here if there are no pending segments. If + // there are pending segments, segment transmits are deferred to + // the retransmit timer handler. + if s.sndUna != s.sndNxt { + switch { + case available >= seg.data.Size(): + // OK to send, the whole segments fits in the + // receiver's advertised window. + case available >= s.maxPayloadSize: + // OK to send, at least 1 MSS sized segment fits + // in the receiver's advertised window. + default: + return false + } + } + + // The segment size limit is computed as a function of sender + // congestion window and MSS. When sender congestion window is > + // 1, this limit can be larger than MSS. Ensure that the + // currently available send space is not greater than minimum of + // this limit and MSS. + if available > limit { + available = limit + } + + // If GSO is not in use then cap available to + // maxPayloadSize. When GSO is in use the gVisor GSO logic or + // the host GSO logic will cap the segment to the correct size. + if s.ep.gso == nil && available > s.maxPayloadSize { + available = s.maxPayloadSize + } + + if seg.data.Size() > available { + s.splitSeg(seg, available) + } + + segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + } + + s.sendSegment(seg) + + // Update sndNxt if we actually sent new data (as opposed to + // retransmitting some previously sent data). + if s.sndNxt.LessThan(segEnd) { + s.sndNxt = segEnd + } + + return true +} + +// handleSACKRecovery implements the loss recovery phase as described in RFC6675 +// section 5, step C. +func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) { + s.SetPipe() + + if smss := int(s.ep.scoreboard.SMSS()); limit > smss { + // Cap segment size limit to s.smss as SACK recovery requires + // that all retransmissions or new segments send during recovery + // be of <= SMSS. + limit = smss + } + + nextSegHint := s.writeList.Front() + for s.outstanding < s.sndCwnd { + var nextSeg *segment + var rescueRtx bool + nextSeg, nextSegHint, rescueRtx = s.NextSeg(nextSegHint) + if nextSeg == nil { + return dataSent + } + if !s.isAssignedSequenceNumber(nextSeg) || s.sndNxt.LessThanEq(nextSeg.sequenceNumber) { + // New data being sent. + + // Step C.3 described below is handled by + // maybeSendSegment which increments sndNxt when + // a segment is transmitted. + // + // Step C.3 "If any of the data octets sent in + // (C.1) are above HighData, HighData must be + // updated to reflect the transmission of + // previously unsent data." + // + // We pass s.smss as the limit as the Step 2) requires that + // new data sent should be of size s.smss or less. + if sent := s.maybeSendSegment(nextSeg, limit, end); !sent { + return dataSent + } + dataSent = true + s.outstanding++ + s.writeNext = nextSeg.Next() + continue + } + + // Now handle the retransmission case where we matched either step 1,3 or 4 + // of the NextSeg algorithm. + // RFC 6675, Step C.4. + // + // "The estimate of the amount of data outstanding in the network + // must be updated by incrementing pipe by the number of octets + // transmitted in (C.1)." + s.outstanding++ + dataSent = true + s.sendSegment(nextSeg) + + segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) + if rescueRtx { + // We do the last part of rule (4) of NextSeg here to update + // RescueRxt as until this point we don't know if we are going + // to use the rescue transmission. + s.fr.rescueRxt = s.fr.last + } else { + // RFC 6675, Step C.2 + // + // "If any of the data octets sent in (C.1) are below + // HighData, HighRxt MUST be set to the highest sequence + // number of the retransmitted segment unless NextSeg () + // rule (4) was invoked for this retransmission." + s.fr.highRxt = segEnd - 1 + } + } + return dataSent +} + +func (s *sender) sendZeroWindowProbe() { + ack, win := s.ep.rcv.getSendParams() + s.unackZeroWindowProbes++ + // Send a zero window probe with sequence number pointing to + // the last acknowledged byte. + s.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, s.sndUna-1, ack, win) + // Rearm the timer to continue probing. + s.resendTimer.enable(s.rto) +} + +func (s *sender) enableZeroWindowProbing() { + s.zeroWindowProbing = true + // We piggyback the probing on the retransmit timer with the + // current retranmission interval, as we may start probing while + // segment retransmissions. + if s.firstRetransmittedSegXmitTime.IsZero() { + s.firstRetransmittedSegXmitTime = time.Now() + } + s.resendTimer.enable(s.rto) +} + +func (s *sender) disableZeroWindowProbing() { + s.zeroWindowProbing = false + s.unackZeroWindowProbes = 0 + s.firstRetransmittedSegXmitTime = time.Time{} + s.resendTimer.disable() +} + +// sendData sends new data segments. It is called when data becomes available or +// when the send window opens up. +func (s *sender) sendData() { + limit := s.maxPayloadSize + if s.gso { + limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize) + } + end := s.sndUna.Add(s.sndWnd) + + // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10. + // "A TCP SHOULD set cwnd to no more than RW before beginning + // transmission if the TCP has not sent data in the interval exceeding + // the retrasmission timeout." + if !s.fr.active && s.state != RTORecovery && time.Now().Sub(s.lastSendTime) > s.rto { + if s.sndCwnd > InitialCwnd { + s.sndCwnd = InitialCwnd + } + } + + var dataSent bool + + // RFC 6675 recovery algorithm step C 1-5. + if s.fr.active && s.ep.sackPermitted { + dataSent = s.handleSACKRecovery(s.maxPayloadSize, end) + } else { + for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { + cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize + if cwndLimit < limit { + limit = cwndLimit + } + if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + // Move writeNext along so that we don't try and scan data that + // has already been SACKED. + s.writeNext = seg.Next() + continue + } + if sent := s.maybeSendSegment(seg, limit, end); !sent { + break + } + dataSent = true + s.outstanding += s.pCount(seg) + s.writeNext = seg.Next() + } + } + + if dataSent { + // We sent data, so we should stop the keepalive timer to ensure + // that no keepalives are sent while there is pending data. + s.ep.disableKeepaliveTimer() + } + + // If the sender has advertized zero receive window and we have + // data to be sent out, start zero window probing to query the + // the remote for it's receive window size. + if s.writeNext != nil && s.sndWnd == 0 { + s.enableZeroWindowProbing() + } + + // Enable the timer if we have pending data and it's not enabled yet. + if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { + s.resendTimer.enable(s.rto) + } + // If we have no more pending data, start the keepalive timer. + if s.sndUna == s.sndNxt { + s.ep.resetKeepaliveTimer(false) + } +} + +func (s *sender) enterFastRecovery() { + s.fr.active = true + // Save state to reflect we're now in fast recovery. + // + // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3. + // We inflate the cwnd by 3 to account for the 3 packets which triggered + // the 3 duplicate ACKs and are now not in flight. + s.sndCwnd = s.sndSsthresh + 3 + s.fr.first = s.sndUna + s.fr.last = s.sndNxt - 1 + s.fr.maxCwnd = s.sndCwnd + s.outstanding + s.fr.highRxt = s.sndUna + s.fr.rescueRxt = s.sndUna + if s.ep.sackPermitted { + s.state = SACKRecovery + s.ep.stack.Stats().TCP.SACKRecovery.Increment() + return + } + s.state = FastRecovery + s.ep.stack.Stats().TCP.FastRecovery.Increment() +} + +func (s *sender) leaveFastRecovery() { + s.fr.active = false + s.fr.maxCwnd = 0 + s.dupAckCount = 0 + + // Deflate cwnd. It had been artificially inflated when new dups arrived. + s.sndCwnd = s.sndSsthresh + + s.cc.PostRecovery() +} + +func (s *sender) handleFastRecovery(seg *segment) (rtx bool) { + ack := seg.ackNumber + // We are in fast recovery mode. Ignore the ack if it's out of + // range. + if !ack.InRange(s.sndUna, s.sndNxt+1) { + return false + } + + // Leave fast recovery if it acknowledges all the data covered by + // this fast recovery session. + if s.fr.last.LessThan(ack) { + s.leaveFastRecovery() + return false + } + + if s.ep.sackPermitted { + // When SACK is enabled we let retransmission be governed by + // the SACK logic. + return false + } + + // Don't count this as a duplicate if it is carrying data or + // updating the window. + if seg.logicalLen() != 0 || s.sndWnd != seg.window { + return false + } + + // Inflate the congestion window if we're getting duplicate acks + // for the packet we retransmitted. + if ack == s.fr.first { + // We received a dup, inflate the congestion window by 1 packet + // if we're not at the max yet. Only inflate the window if + // regular FastRecovery is in use, RFC6675 does not require + // inflating cwnd on duplicate ACKs. + if s.sndCwnd < s.fr.maxCwnd { + s.sndCwnd++ + } + return false + } + + // A partial ack was received. Retransmit this packet and + // remember it so that we don't retransmit it again. We don't + // inflate the window because we're putting the same packet back + // onto the wire. + // + // N.B. The retransmit timer will be reset by the caller. + s.fr.first = ack + s.dupAckCount = 0 + return true +} + +// isAssignedSequenceNumber relies on the fact that we only set flags once a +// sequencenumber is assigned and that is only done right before we send the +// segment. As a result any segment that has a non-zero flag has a valid +// sequence number assigned to it. +func (s *sender) isAssignedSequenceNumber(seg *segment) bool { + return seg.flags != 0 +} + +// SetPipe implements the SetPipe() function described in RFC6675. Netstack +// maintains the congestion window in number of packets and not bytes, so +// SetPipe() here measures number of outstanding packets rather than actual +// outstanding bytes in the network. +func (s *sender) SetPipe() { + // If SACK isn't permitted or it is permitted but recovery is not active + // then ignore pipe calculations. + if !s.ep.sackPermitted || !s.fr.active { + return + } + pipe := 0 + smss := seqnum.Size(s.ep.scoreboard.SMSS()) + for s1 := s.writeList.Front(); s1 != nil && s1.data.Size() != 0 && s.isAssignedSequenceNumber(s1); s1 = s1.Next() { + // With GSO each segment can be much larger than SMSS. So check the segment + // in SMSS sized ranges. + segEnd := s1.sequenceNumber.Add(seqnum.Size(s1.data.Size())) + for startSeq := s1.sequenceNumber; startSeq.LessThan(segEnd); startSeq = startSeq.Add(smss) { + endSeq := startSeq.Add(smss) + if segEnd.LessThan(endSeq) { + endSeq = segEnd + } + sb := header.SACKBlock{startSeq, endSeq} + // SetPipe(): + // + // After initializing pipe to zero, the following steps are + // taken for each octet 'S1' in the sequence space between + // HighACK and HighData that has not been SACKed: + if !s1.sequenceNumber.LessThan(s.sndNxt) { + break + } + if s.ep.scoreboard.IsSACKED(sb) { + continue + } + + // SetPipe(): + // + // (a) If IsLost(S1) returns false, Pipe is incremened by 1. + // + // NOTE: here we mark the whole segment as lost. We do not try + // and test every byte in our write buffer as we maintain our + // pipe in terms of oustanding packets and not bytes. + if !s.ep.scoreboard.IsRangeLost(sb) { + pipe++ + } + // SetPipe(): + // (b) If S1 <= HighRxt, Pipe is incremented by 1. + if s1.sequenceNumber.LessThanEq(s.fr.highRxt) { + pipe++ + } + } + } + s.outstanding = pipe +} + +// checkDuplicateAck is called when an ack is received. It manages the state +// related to duplicate acks and determines if a retransmit is needed according +// to the rules in RFC 6582 (NewReno). +func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { + ack := seg.ackNumber + if s.fr.active { + return s.handleFastRecovery(seg) + } + + // We're not in fast recovery yet. A segment is considered a duplicate + // only if it doesn't carry any data and doesn't update the send window, + // because if it does, it wasn't sent in response to an out-of-order + // segment. If SACK is enabled then we have an additional check to see + // if the segment carries new SACK information. If it does then it is + // considered a duplicate ACK as per RFC6675. + if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt { + if !s.ep.sackPermitted || !seg.hasNewSACKInfo { + s.dupAckCount = 0 + return false + } + } + + s.dupAckCount++ + + // Do not enter fast recovery until we reach nDupAckThreshold or the + // first unacknowledged byte is considered lost as per SACK scoreboard. + if s.dupAckCount < nDupAckThreshold || (s.ep.sackPermitted && !s.ep.scoreboard.IsLost(s.sndUna)) { + // RFC 6675 Step 3. + s.fr.highRxt = s.sndUna - 1 + // Do run SetPipe() to calculate the outstanding segments. + s.SetPipe() + s.state = Disorder + return false + } + + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2 + // + // We only do the check here, the incrementing of last to the highest + // sequence number transmitted till now is done when enterFastRecovery + // is invoked. + if !s.fr.last.LessThan(seg.ackNumber) { + s.dupAckCount = 0 + return false + } + s.cc.HandleNDupAcks() + s.enterFastRecovery() + s.dupAckCount = 0 + return true +} + +// handleRcvdSegment is called when a segment is received; it is responsible for +// updating the send-related state. +func (s *sender) handleRcvdSegment(seg *segment) { + // Check if we can extract an RTT measurement from this ack. + if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) { + s.updateRTO(time.Now().Sub(s.rttMeasureTime)) + s.rttMeasureSeqNum = s.sndNxt + } + + // Update Timestamp if required. See RFC7323, section-4.3. + if s.ep.sendTSOk && seg.parsedOptions.TS { + s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber) + } + + // Insert SACKBlock information into our scoreboard. + if s.ep.sackPermitted { + for _, sb := range seg.parsedOptions.SACKBlocks { + // Only insert the SACK block if the following holds + // true: + // * SACK block acks data after the ack number in the + // current segment. + // * SACK block represents a sequence + // between sndUna and sndNxt (i.e. data that is + // currently unacked and in-flight). + // * SACK block that has not been SACKed already. + // + // NOTE: This check specifically excludes DSACK blocks + // which have start/end before sndUna and are used to + // indicate spurious retransmissions. + if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { + s.ep.scoreboard.Insert(sb) + seg.hasNewSACKInfo = true + } + } + s.SetPipe() + } + + // Count the duplicates and do the fast retransmit if needed. + rtx := s.checkDuplicateAck(seg) + + // Stash away the current window size. + s.sndWnd = seg.window + + ack := seg.ackNumber + + // Disable zero window probing if remote advertizes a non-zero receive + // window. This can be with an ACK to the zero window probe (where the + // acknumber refers to the already acknowledged byte) OR to any previously + // unacknowledged segment. + if s.zeroWindowProbing && seg.window > 0 && + (ack == s.sndUna || (ack-1).InRange(s.sndUna, s.sndNxt)) { + s.disableZeroWindowProbing() + } + + // On receiving the ACK for the zero window probe, account for it and + // skip trying to send any segment as we are still probing for + // receive window to become non-zero. + if s.zeroWindowProbing && s.unackZeroWindowProbes > 0 && ack == s.sndUna { + s.unackZeroWindowProbes-- + return + } + + // Ignore ack if it doesn't acknowledge any new data. + if (ack - 1).InRange(s.sndUna, s.sndNxt) { + s.dupAckCount = 0 + + // See : https://tools.ietf.org/html/rfc1323#section-3.3. + // Specifically we should only update the RTO using TSEcr if the + // following condition holds: + // + // A TSecr value received in a segment is used to update the + // averaged RTT measurement only if the segment acknowledges + // some new data, i.e., only if it advances the left edge of + // the send window. + if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 { + // TSVal/Ecr values sent by Netstack are at a millisecond + // granularity. + elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond + s.updateRTO(elapsed) + } + + // When an ack is received we must rearm the timer. + // RFC 6298 5.3 + s.resendTimer.enable(s.rto) + + // Remove all acknowledged data from the write list. + acked := s.sndUna.Size(ack) + s.sndUna = ack + + ackLeft := acked + originalOutstanding := s.outstanding + for ackLeft > 0 { + // We use logicalLen here because we can have FIN + // segments (which are always at the end of list) that + // have no data, but do consume a sequence number. + seg := s.writeList.Front() + datalen := seg.logicalLen() + + if datalen > ackLeft { + prevCount := s.pCount(seg) + seg.data.TrimFront(int(ackLeft)) + seg.sequenceNumber.UpdateForward(ackLeft) + s.outstanding -= prevCount - s.pCount(seg) + break + } + + if s.writeNext == seg { + s.writeNext = seg.Next() + } + + s.writeList.Remove(seg) + + // if SACK is enabled then Only reduce outstanding if + // the segment was not previously SACKED as these have + // already been accounted for in SetPipe(). + if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + s.outstanding -= s.pCount(seg) + } + seg.decRef() + ackLeft -= datalen + } + + // Update the send buffer usage and notify potential waiters. + s.ep.updateSndBufferUsage(int(acked)) + + // Clear SACK information for all acked data. + s.ep.scoreboard.Delete(s.sndUna) + + // If we are not in fast recovery then update the congestion + // window based on the number of acknowledged packets. + if !s.fr.active { + s.cc.Update(originalOutstanding - s.outstanding) + if s.fr.last.LessThan(s.sndUna) { + s.state = Open + } + } + + // It is possible for s.outstanding to drop below zero if we get + // a retransmit timeout, reset outstanding to zero but later + // get an ack that cover previously sent data. + if s.outstanding < 0 { + s.outstanding = 0 + } + + s.SetPipe() + + // If all outstanding data was acknowledged the disable the timer. + // RFC 6298 Rule 5.3 + if s.sndUna == s.sndNxt { + s.outstanding = 0 + // Reset firstRetransmittedSegXmitTime to the zero value. + s.firstRetransmittedSegXmitTime = time.Time{} + s.resendTimer.disable() + } + } + // Now that we've popped all acknowledged data from the retransmit + // queue, retransmit if needed. + if rtx { + s.resendSegment() + } + + // Send more data now that some of the pending data has been ack'd, or + // that the window opened up, or the congestion window was inflated due + // to a duplicate ack during fast recovery. This will also re-enable + // the retransmit timer if needed. + if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo { + s.sendData() + } +} + +// sendSegment sends the specified segment. +func (s *sender) sendSegment(seg *segment) *tcpip.Error { + if seg.xmitCount > 0 { + s.ep.stack.Stats().TCP.Retransmits.Increment() + s.ep.stats.SendErrors.Retransmits.Increment() + if s.sndCwnd < s.sndSsthresh { + s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() + } + } + seg.xmitTime = time.Now() + seg.xmitCount++ + err := s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) + + // Every time a packet containing data is sent (including a + // retransmission), if SACK is enabled and we are retransmitting data + // then use the conservative timer described in RFC6675 Section 6.0, + // otherwise follow the standard time described in RFC6298 Section 5.1. + if err != nil && seg.data.Size() != 0 { + if s.fr.active && seg.xmitCount > 1 && s.ep.sackPermitted { + s.resendTimer.enable(s.rto) + } else { + if !s.resendTimer.enabled() { + s.resendTimer.enable(s.rto) + } + } + } + + return err +} + +// sendSegmentFromView sends a new segment containing the given payload, flags +// and sequence number. +func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error { + s.lastSendTime = time.Now() + if seq == s.rttMeasureSeqNum { + s.rttMeasureTime = s.lastSendTime + } + + rcvNxt, rcvWnd := s.ep.rcv.getSendParams() + + // Remember the max sent ack. + s.maxSentAck = rcvNxt + + return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd) +} diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go new file mode 100644 index 000000000..8b20c3455 --- /dev/null +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -0,0 +1,60 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// +stateify savable +type unixTime struct { + second int64 + nano int64 +} + +// saveLastSendTime is invoked by stateify. +func (s *sender) saveLastSendTime() unixTime { + return unixTime{s.lastSendTime.Unix(), s.lastSendTime.UnixNano()} +} + +// loadLastSendTime is invoked by stateify. +func (s *sender) loadLastSendTime(unix unixTime) { + s.lastSendTime = time.Unix(unix.second, unix.nano) +} + +// saveRttMeasureTime is invoked by stateify. +func (s *sender) saveRttMeasureTime() unixTime { + return unixTime{s.rttMeasureTime.Unix(), s.rttMeasureTime.UnixNano()} +} + +// loadRttMeasureTime is invoked by stateify. +func (s *sender) loadRttMeasureTime(unix unixTime) { + s.rttMeasureTime = time.Unix(unix.second, unix.nano) +} + +// afterLoad is invoked by stateify. +func (s *sender) afterLoad() { + s.resendTimer.init(&s.resendWaker) +} + +// saveFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) saveFirstRetransmittedSegXmitTime() unixTime { + return unixTime{s.firstRetransmittedSegXmitTime.Unix(), s.firstRetransmittedSegXmitTime.UnixNano()} +} + +// loadFirstRetransmittedSegXmitTime is invoked by stateify. +func (s *sender) loadFirstRetransmittedSegXmitTime(unix unixTime) { + s.firstRetransmittedSegXmitTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go new file mode 100644 index 000000000..b9993ce1a --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_noracedetector_test.go @@ -0,0 +1,550 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// +// These tests are flaky when run under the go race detector due to some +// iterations taking long enough that the retransmit timer can kick in causing +// the congestion window measurements to fail due to extra packets etc. +// +// +build !race + +package tcp_test + +import ( + "fmt" + "math" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +func TestFastRecovery(t *testing.T) { + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + const iterations = 3 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + } + + // Send 3 duplicate acks. This should force an immediate retransmit of + // the pending packet and put the sender into fast recovery. + rtxOffset := bytesRead - maxPayload*expected + for i := 0; i < 3; i++ { + c.SendAck(790, rtxOffset) + } + + // Receive the retransmitted packet. + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Wait before checking metrics. + metricPollFn := func() error { + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) + } + + if got, want := c.Stack().Stats().TCP.FastRecovery.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRecovery.Value = %d, want = %d", got, want) + } + return nil + } + + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + + // Now send 7 mode duplicate acks. Each of these should cause a window + // inflation by 1 and cause the sender to send an extra packet. + for i := 0; i < 7; i++ { + c.SendAck(790, rtxOffset) + } + + recover := bytesRead + + // Ensure no new packets arrive. + c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", + 50*time.Millisecond) + + // Acknowledge half of the pending data. + rtxOffset = bytesRead - expected*maxPayload/2 + c.SendAck(790, rtxOffset) + + // Receive the retransmit due to partial ack. + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Wait before checking metrics. + metricPollFn = func() error { + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(2); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) + } + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(2); got != want { + return fmt.Errorf("got stats.TCP.Retransmit.Value = %d, want = %d", got, want) + } + return nil + } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + + // Receive the 10 extra packets that should have been released due to + // the congestion window inflation in recovery. + for i := 0; i < 10; i++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // A partial ACK during recovery should reduce congestion window by the + // number acked. Since we had "expected" packets outstanding before sending + // partial ack and we acked expected/2 , the cwnd and outstanding should + // be expected/2 + 10 (7 dupAcks + 3 for the original 3 dupacks that triggered + // fast recovery). Which means the sender should not send any more packets + // till we ack this one. + c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", + 50*time.Millisecond) + + // Acknowledge all pending data to recover point. + c.SendAck(790, recover) + + // At this point, the cwnd should reset to expected/2 and there are 10 + // packets outstanding. + // + // NOTE: Technically netstack is incorrect in that we adjust the cwnd on + // the same segment that takes us out of recovery. But because of that + // the actual cwnd at exit of recovery will be expected/2 + 1 as we + // acked a cwnd worth of packets which will increase the cwnd further by + // 1 in congestion avoidance. + // + // Now in the first iteration since there are 10 packets outstanding. + // We would expect to get expected/2 +1 - 10 packets. But subsequent + // iterations will send us expected/2 + 1 + 1 (per iteration). + expected = expected/2 + 1 - 10 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // In cogestion avoidance, the packets trains increase by 1 in + // each iteration. + if i == 0 { + // After the first iteration we expect to get the full + // congestion window worth of packets in every + // iteration. + expected += 10 + } + expected++ + } +} + +func TestExponentialIncreaseDuringSlowStart(t *testing.T) { + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + const iterations = 3 + data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // Double the number of expected packets for the next iteration. + expected *= 2 + } +} + +func TestCongestionAvoidance(t *testing.T) { + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + const iterations = 3 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (slow start phase).", 50*time.Millisecond) + } + + // Don't acknowledge the first packet of the last packet train. Let's + // wait for them to time out, which will trigger a restart of slow + // start, and initialization of ssthresh to cwnd/2. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // This part is tricky: when the timeout happened, we had "expected" + // packets pending, cwnd reset to 1, and ssthresh set to expected/2. + // By acknowledging "expected" packets, the slow-start part will + // increase cwnd to expected/2 (which "consumes" expected/2-1 of the + // acknowledgements), then the congestion avoidance part will consume + // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack + // remains in the "ack count" (which will cause cwnd to be incremented + // once it reaches cwnd acks). + // + // So we're straight into congestion avoidance with cwnd set to + // expected/2 + 1. + // + // Check that packets trains of cwnd packets are sent, and that cwnd is + // incremented by 1 after we acknowledge each packet. + expected = expected/2 + 1 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (congestion avoidance phase).", 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // In cogestion avoidance, the packets trains increase by 1 in + // each iteration. + expected++ + } +} + +// cubicCwnd returns an estimate of a cubic window given the +// originalCwnd, wMax, last congestion event time and sRTT. +func cubicCwnd(origCwnd int, wMax int, congEventTime time.Time, sRTT time.Duration) int { + cwnd := float64(origCwnd) + // We wait 50ms between each iteration so sRTT as computed by cubic + // should be close to 50ms. + elapsed := (time.Since(congEventTime) + sRTT).Seconds() + k := math.Cbrt(float64(wMax) * 0.3 / 0.7) + wtRTT := 0.4*math.Pow(elapsed-k, 3) + float64(wMax) + cwnd += (wtRTT - cwnd) / cwnd + return int(cwnd) +} + +func TestCubicCongestionAvoidance(t *testing.T) { + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + enableCUBIC(t, c) + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + const iterations = 3 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd (during slow-start phase).", 50*time.Millisecond) + } + + // Don't acknowledge the first packet of the last packet train. Let's + // wait for them to time out, which will trigger a restart of slow + // start, and initialization of ssthresh to cwnd * 0.7. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Acknowledge all pending data. + c.SendAck(790, bytesRead) + + // Store away the time we sent the ACK and assuming a 200ms RTO + // we estimate that the sender will have an RTO 200ms from now + // and go back into slow start. + packetDropTime := time.Now().Add(200 * time.Millisecond) + + // This part is tricky: when the timeout happened, we had "expected" + // packets pending, cwnd reset to 1, and ssthresh set to expected * 0.7. + // By acknowledging "expected" packets, the slow-start part will + // increase cwnd to expected/2 essentially putting the connection + // straight into congestion avoidance. + wMax := expected + // Lower expected as per cubic spec after a congestion event. + expected = int(float64(expected) * 0.7) + cwnd := expected + for i := 0; i < iterations; i++ { + // Cubic grows window independent of ACKs. Cubic Window growth + // is a function of time elapsed since last congestion event. + // As a result the congestion window does not grow + // deterministically in response to ACKs. + // + // We need to roughly estimate what the cwnd of the sender is + // based on when we sent the dupacks. + cwnd := cubicCwnd(cwnd, wMax, packetDropTime, 50*time.Millisecond) + + packetsExpected := cwnd + for j := 0; j < packetsExpected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + t.Logf("expected packets received, next trying to receive any extra packets that may come") + + // If our estimate was correct there should be no more pending packets. + // We attempt to read a packet a few times with a short sleep in between + // to ensure that we don't see the sender send any unexpected packets. + unexpectedPackets := 0 + for { + gotPacket := c.ReceiveNonBlockingAndCheckPacket(data, bytesRead, maxPayload) + if !gotPacket { + break + } + bytesRead += maxPayload + unexpectedPackets++ + time.Sleep(1 * time.Millisecond) + } + if unexpectedPackets != 0 { + t.Fatalf("received %d unexpected packets for iteration %d", unexpectedPackets, i) + } + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd(congestion avoidance)", 5*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + } +} + +func TestRetransmit(t *testing.T) { + maxPayload := 32 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + const iterations = 3 + data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in two shots. Packets will only be written at the + // MTU size though. + half := data[:len(data)/2] + if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + half = data[len(data)/2:] + if _, _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + } + + // Wait for a timeout and retransmit. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + metricPollFn := func() error { + if got, want := c.Stack().Stats().TCP.Timeouts.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Timeouts.Value = %d, want = %d", got, want) + } + + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) + } + + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Timeouts.Value(), uint64(1); got != want { + return fmt.Errorf("got EP SendErrors.Timeouts.Value = %d, want = %d", got, want) + } + + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got EP stats SendErrors.Retransmits.Value = %d, want = %d", got, want) + } + + if got, want := c.Stack().Stats().TCP.SlowStartRetransmits.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.SlowStartRetransmits.Value = %d, want = %d", got, want) + } + + return nil + } + + // Poll when checking metrics. + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + + // Acknowledge half of the pending data. + rtxOffset = bytesRead - expected*maxPayload/2 + c.SendAck(790, rtxOffset) + + // Receive the remaining data, making sure that acknowledged data is not + // retransmitted. + for offset := rtxOffset; offset < len(data); offset += maxPayload { + c.ReceiveAndCheckPacket(data, offset, maxPayload) + c.SendAck(790, offset+maxPayload) + } + + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) +} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go new file mode 100644 index 000000000..99521f0c1 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -0,0 +1,589 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_test + +import ( + "fmt" + "log" + "reflect" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.dev/gvisor/pkg/test/testutil" +) + +// createConnectedWithSACKPermittedOption creates and connects c.ep with the +// SACKPermitted option enabled if the stack in the context has the SACK support +// enabled. +func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint { + return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) +} + +// createConnectedWithSACKAndTS creates and connects c.ep with the SACK & TS +// option enabled if the stack in the context has SACK and TS enabled. +func createConnectedWithSACKAndTS(c *context.Context) *context.RawEndpoint { + return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled(), TS: true}) +} + +func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { + t.Helper() + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%t) = %s", enable, err) + } +} + +// TestSackPermittedConnect establishes a connection with the SACK option +// enabled. +func TestSackPermittedConnect(t *testing.T) { + for _, sackEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + setStackSACKPermitted(t, c, sackEnabled) + rep := createConnectedWithSACKPermittedOption(c) + data := []byte{1, 2, 3} + + rep.SendPacket(data, nil) + savedSeqNum := rep.NextSeqNum + rep.VerifyACKNoSACK() + + // Make an out of order packet and send it. + rep.NextSeqNum += 3 + sackBlocks := []header.SACKBlock{ + {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, + } + rep.SendPacket(data, nil) + + // Restore the saved sequence number so that the + // VerifyXXX calls use the right sequence number for + // checking ACK numbers. + rep.NextSeqNum = savedSeqNum + if sackEnabled { + rep.VerifyACKHasSACK(sackBlocks) + } else { + rep.VerifyACKNoSACK() + } + + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative ACK for all 9 + // bytes sent and no SACK blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned in the ACK. + rep.VerifyACKNoSACK() + }) + } +} + +// TestSackDisabledConnect establishes a connection with the SACK option +// disabled and verifies that no SACKs are sent for out of order segments. +func TestSackDisabledConnect(t *testing.T) { + for _, sackEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + setStackSACKPermitted(t, c, sackEnabled) + + rep := c.CreateConnectedWithOptions(header.TCPSynOptions{}) + + data := []byte{1, 2, 3} + + rep.SendPacket(data, nil) + savedSeqNum := rep.NextSeqNum + rep.VerifyACKNoSACK() + + // Make an out of order packet and send it. + rep.NextSeqNum += 3 + rep.SendPacket(data, nil) + + // The ACK should contain the older sequence number and + // no SACK blocks. + rep.NextSeqNum = savedSeqNum + rep.VerifyACKNoSACK() + + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative ACK for all 9 + // bytes sent and no SACK blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned in the ACK. + rep.VerifyACKNoSACK() + }) + } +} + +// TestSackPermittedAccept accepts and establishes a connection with the +// SACKPermitted option enabled if the connection request specifies the +// SACKPermitted option. In case of SYN cookies SACK should be disabled as we +// don't encode the SACK information in the cookie. +func TestSackPermittedAccept(t *testing.T) { + type testCase struct { + cookieEnabled bool + sackPermitted bool + wndScale int + wndSize uint16 + } + + testCases := []testCase{ + // When cookie is used window scaling is disabled. + {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled. + {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { + for _, sackEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + if tc.cookieEnabled { + // Set the SynRcvd threshold to + // zero to force a syn cookie + // based accept to happen. + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + } + setStackSACKPermitted(t, c, sackEnabled) + + rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) + // Now verify no SACK blocks are + // received when sack is disabled. + data := []byte{1, 2, 3} + rep.SendPacket(data, nil) + rep.VerifyACKNoSACK() + + savedSeqNum := rep.NextSeqNum + + // Make an out of order packet and send + // it. + rep.NextSeqNum += 3 + sackBlocks := []header.SACKBlock{ + {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, + } + rep.SendPacket(data, nil) + + // The ACK should contain the older + // sequence number. + rep.NextSeqNum = savedSeqNum + if sackEnabled && tc.sackPermitted { + rep.VerifyACKHasSACK(sackBlocks) + } else { + rep.VerifyACKNoSACK() + } + + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative + // ACK for all 9 bytes sent and no SACK + // blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned + // in the ACK. + rep.VerifyACKNoSACK() + }) + } + }) + } +} + +// TestSackDisabledAccept accepts and establishes a connection with +// the SACKPermitted option disabled and verifies that no SACKs are +// sent for out of order packets. +func TestSackDisabledAccept(t *testing.T) { + type testCase struct { + cookieEnabled bool + wndScale int + wndSize uint16 + } + + testCases := []testCase{ + // When cookie is used window scaling is disabled. + {true, -1, 0xffff}, // When cookie is used window scaling is disabled. + {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { + for _, sackEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + if tc.cookieEnabled { + // Set the SynRcvd threshold to + // zero to force a syn cookie + // based accept to happen. + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + } + + setStackSACKPermitted(t, c, sackEnabled) + + rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Now verify no SACK blocks are + // received when sack is disabled. + data := []byte{1, 2, 3} + rep.SendPacket(data, nil) + rep.VerifyACKNoSACK() + savedSeqNum := rep.NextSeqNum + + // Make an out of order packet and send + // it. + rep.NextSeqNum += 3 + rep.SendPacket(data, nil) + + // The ACK should contain the older + // sequence number and no SACK blocks. + rep.NextSeqNum = savedSeqNum + rep.VerifyACKNoSACK() + + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative + // ACK for all 9 bytes sent and no SACK + // blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned + // in the ACK. + rep.VerifyACKNoSACK() + }) + } + }) + } +} + +func TestUpdateSACKBlocks(t *testing.T) { + testCases := []struct { + segStart seqnum.Value + segEnd seqnum.Value + rcvNxt seqnum.Value + sackBlocks []header.SACKBlock + updated []header.SACKBlock + }{ + // Trivial cases where current SACK block list is empty and we + // have an out of order delivery. + {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}}, + {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}}, + {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}}, + + // Cases where current SACK block list is not empty and we have + // an out of order delivery. Tests that the updated SACK block + // list has the first block as the one that contains the new + // SACK block representing the segment that was just delivered. + {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}}, + {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}}, + {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}}, + + // Ensure that we only retain header.MaxSACKBlocks and drop the + // oldest one if adding a new block exceeds + // header.MaxSACKBlocks. + {24, 30, 9, + []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}}, + []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}}, + + // Cases where segment extends an existing SACK block. + {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}}, + {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, + {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, + {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}}, + {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}}, + {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}}, + {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}}, + {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, + {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, + {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}}, + {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}}, + {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}}, + + // Cases where segment contains rcvNxt. + {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}}, + } + + for _, tc := range testCases { + var sack tcp.SACKInfo + copy(sack.Blocks[:], tc.sackBlocks) + sack.NumBlocks = len(tc.sackBlocks) + tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt) + if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) { + t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want) + } + + } +} + +func TestTrimSackBlockList(t *testing.T) { + testCases := []struct { + rcvNxt seqnum.Value + sackBlocks []header.SACKBlock + trimmed []header.SACKBlock + }{ + // Simple cases where we trim whole entries. + {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}}, + {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}}, + {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}}, + {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, + // Cases where we need to update a block. + {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}}, + {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}}, + {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}}, + {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, + } + for _, tc := range testCases { + var sack tcp.SACKInfo + copy(sack.Blocks[:], tc.sackBlocks) + sack.NumBlocks = len(tc.sackBlocks) + tcp.TrimSACKBlockList(&sack, tc.rcvNxt) + if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) { + t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want) + } + } +} + +func TestSACKRecovery(t *testing.T) { + const maxPayload = 10 + // See: tcp.makeOptions for why tsOptionSize is set to 12 here. + const tsOptionSize = 12 + // Enabling SACK means the payload size is reduced to account + // for the extra space required for the TCP options. + // + // We increase the MTU by 40 bytes to account for SACK and Timestamp + // options. + const maxTCPOptionSize = 40 + + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxTCPOptionSize+maxPayload)) + defer c.Cleanup() + + c.Stack().AddTCPProbe(func(s stack.TCPEndpointState) { + // We use log.Printf instead of t.Logf here because this probe + // can fire even when the test function has finished. This is + // because closing the endpoint in cleanup() does not mean the + // actual worker loop terminates immediately as it still has to + // do a full TCP shutdown. But this test can finish running + // before the shutdown is done. Using t.Logf in such a case + // causes the test to panic due to logging after test finished. + log.Printf("state: %+v\n", s) + }) + setStackSACKPermitted(t, c, true) + createConnectedWithSACKAndTS(c) + + const iterations = 3 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + } + + // Send 3 duplicate acks. This should force an immediate retransmit of + // the pending packet and put the sender into fast recovery. + rtxOffset := bytesRead - maxPayload*expected + start := c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) + end := start.Add(10) + for i := 0; i < 3; i++ { + c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) + end = end.Add(10) + } + + // Receive the retransmitted packet. + c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) + + metricPollFn := func() error { + tcpStats := c.Stack().Stats().TCP + stats := []struct { + stat *tcpip.StatCounter + name string + want uint64 + }{ + {tcpStats.FastRetransmit, "stats.TCP.FastRetransmit", 1}, + {tcpStats.Retransmits, "stats.TCP.Retransmits", 1}, + {tcpStats.SACKRecovery, "stats.TCP.SACKRecovery", 1}, + {tcpStats.FastRecovery, "stats.TCP.FastRecovery", 0}, + } + for _, s := range stats { + if got, want := s.stat.Value(), s.want; got != want { + return fmt.Errorf("got %s.Value() = %d, want = %d", s.name, got, want) + } + } + return nil + } + + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + + // Now send 7 mode duplicate ACKs. In SACK TCP dupAcks do not cause + // window inflation and sending of packets is completely handled by the + // SACK Recovery algorithm. We should see no packets being released, as + // the cwnd at this point after entering recovery should be half of the + // outstanding number of packets in flight. + for i := 0; i < 7; i++ { + c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) + end = end.Add(10) + } + + recover := bytesRead + + // Ensure no new packets arrive. + c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", + 50*time.Millisecond) + + // Acknowledge half of the pending data. This along with the 10 sacked + // segments above should reduce the outstanding below the current + // congestion window allowing the sender to transmit data. + rtxOffset = bytesRead - expected*maxPayload/2 + + // Now send a partial ACK w/ a SACK block that indicates that the next 3 + // segments are lost and we have received 6 segments after the lost + // segments. This should cause the sender to immediately transmit all 3 + // segments in response to this ACK unlike in FastRecovery where only 1 + // segment is retransmitted per ACK. + start = c.IRS.Add(seqnum.Size(rtxOffset) + 30 + 1) + end = start.Add(60) + c.SendAckWithSACK(790, rtxOffset, []header.SACKBlock{{start, end}}) + + // At this point, we acked expected/2 packets and we SACKED 6 packets and + // 3 segments were considered lost due to the SACK block we sent. + // + // So total packets outstanding can be calculated as follows after 7 + // iterations of slow start -> 10/20/40/80/160/320/640. So expected + // should be 640 at start, then we went to recover at which point the + // cwnd should be set to 320 + 3 (for the 3 dupAcks which have left the + // network). + // Outstanding at this point after acking half the window + // (320 packets) will be: + // outstanding = 640-320-6(due to SACK block)-3 = 311 + // + // The last 3 is due to the fact that the first 3 packets after + // rtxOffset will be considered lost due to the SACK blocks sent. + // Receive the retransmit due to partial ack. + + c.ReceiveAndCheckPacketWithOptions(data, rtxOffset, maxPayload, tsOptionSize) + // Receive the 2 extra packets that should have been retransmitted as + // those should be considered lost and immediately retransmitted based + // on the SACK information in the previous ACK sent above. + for i := 0; i < 2; i++ { + c.ReceiveAndCheckPacketWithOptions(data, rtxOffset+maxPayload*(i+1), maxPayload, tsOptionSize) + } + + // Now we should get 9 more new unsent packets as the cwnd is 323 and + // outstanding is 311. + for i := 0; i < 9; i++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + } + + metricPollFn = func() error { + // In SACK recovery only the first segment is fast retransmitted when + // entering recovery. + if got, want := c.Stack().Stats().TCP.FastRetransmit.Value(), uint64(1); got != want { + return fmt.Errorf("got stats.TCP.FastRetransmit.Value = %d, want = %d", got, want) + } + + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.FastRetransmit.Value(), uint64(1); got != want { + return fmt.Errorf("got EP stats SendErrors.FastRetransmit = %d, want = %d", got, want) + } + + if got, want := c.Stack().Stats().TCP.Retransmits.Value(), uint64(4); got != want { + return fmt.Errorf("got stats.TCP.Retransmits.Value = %d, want = %d", got, want) + } + + if got, want := c.EP.Stats().(*tcp.Stats).SendErrors.Retransmits.Value(), uint64(4); got != want { + return fmt.Errorf("got EP stats Stats.SendErrors.Retransmits = %d, want = %d", got, want) + } + return nil + } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } + + c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", 50*time.Millisecond) + + // Acknowledge all pending data to recover point. + c.SendAck(790, recover) + + // At this point, the cwnd should reset to expected/2 and there are 9 + // packets outstanding. + // + // Now in the first iteration since there are 9 packets outstanding. + // We would expect to get expected/2 - 9 packets. But subsequent + // iterations will send us expected/2 + 1 (per iteration). + expected = expected/2 - 9 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacketWithOptions(data, bytesRead, maxPayload, tsOptionSize) + bytesRead += maxPayload + } + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd and iteration: %d.", expected, i), 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // In cogestion avoidance, the packets trains increase by 1 in + // each iteration. + if i == 0 { + // After the first iteration we expect to get the full + // congestion window worth of packets in every + // iteration. + expected += 9 + } + expected++ + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go new file mode 100644 index 000000000..e67ec42b1 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -0,0 +1,7258 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_test + +import ( + "bytes" + "fmt" + "math" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/ports" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.dev/gvisor/pkg/test/testutil" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // defaultMTU is the MTU, in bytes, used throughout the tests, except + // where another value is explicitly used. It is chosen to match the MTU + // of loopback interfaces on linux systems. + defaultMTU = 65535 + + // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an + // IPv4 endpoint when the MTU is set to defaultMTU in the test. + defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize +) + +func TestGiveUpConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + // Register for notification, then start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventOut) + defer wq.EventUnregister(&waitEntry) + + if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + } + + // Close the connection, wait for completion. + ep.Close() + + // Wait for ep to become writable. + <-notifyCh + if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { + t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %s, want = %s", err, tcpip.ErrAborted) + } + + // Call Connect again to retreive the handshake failure status + // and stats updates. + if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAborted { + t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrAborted) + } + + if got := c.Stack().Stats().TCP.FailedConnectionAttempts.Value(); got != 1 { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } +} + +func TestConnectIncrementActiveConnection(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.ActiveConnectionOpenings.Value() + 1 + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + if got := stats.TCP.ActiveConnectionOpenings.Value(); got != want { + t.Errorf("got stats.TCP.ActtiveConnectionOpenings.Value() = %d, want = %d", got, want) + } +} + +func TestConnectDoesNotIncrementFailedConnectionAttempts(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.FailedConnectionAttempts.Value() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats.FailedConnectionAttempts = %d, want = %d", got, want) + } +} + +func TestActiveFailedConnectionAttemptIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + c.EP = ep + want := stats.TCP.FailedConnectionAttempts.Value() + 1 + + if err := c.EP.Connect(tcpip.FullAddress{NIC: 2, Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrNoRoute { + t.Errorf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrNoRoute) + } + + if got := stats.TCP.FailedConnectionAttempts.Value(); got != want { + t.Errorf("got stats.TCP.FailedConnectionAttempts.Value() = %d, want = %d", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).FailedConnectionAttempts.Value(); got != want { + t.Errorf("got EP stats FailedConnectionAttempts = %d, want = %d", got, want) + } +} + +func TestTCPSegmentsSentIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + // SYN and ACK + want := stats.TCP.SegmentsSent.Value() + 2 + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + if got := stats.TCP.SegmentsSent.Value(); got != want { + t.Errorf("got stats.TCP.SegmentsSent.Value() = %d, want = %d", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).SegmentsSent.Value(); got != want { + t.Errorf("got EP stats SegmentsSent.Value() = %d, want = %d", got, want) + } +} + +func TestTCPResetsSentIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + stats := c.Stack().Stats() + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + want := stats.TCP.SegmentsSent.Value() + 1 + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + // If the AckNum is not the increment of the last sequence number, a RST + // segment is sent back in response. + AckNum: c.IRS + 2, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + c.GetPacket() + + metricPollFn := func() error { + if got := stats.TCP.ResetsSent.Value(); got != want { + return fmt.Errorf("got stats.TCP.ResetsSent.Value() = %d, want = %d", got, want) + } + return nil + } + if err := testutil.Poll(metricPollFn, 1*time.Second); err != nil { + t.Error(err) + } +} + +// TestTCPResetSentForACKWhenNotUsingSynCookies checks that the stack generates +// a RST if an ACK is received on the listening socket for which there is no +// active handshake in progress and we are not using SYN cookies. +func TestTCPResetSentForACKWhenNotUsingSynCookies(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPLingerTimeout to 5 seconds so that sockets are marked closed + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Lower stackwide TIME_WAIT timeout so that the reservations + // are released instantly on Close. + tcpTW := tcpip.TCPTimeWaitTimeoutOption(1 * time.Millisecond) + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpTW); err != nil { + t.Fatalf("e.stack.SetTransportProtocolOption(%d, %#v) = %s", tcp.ProtocolNumber, tcpTW, err) + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + c.GetPacket() + + // Since an active close was done we need to wait for a little more than + // tcpLingerTimeout for the port reservations to be released and the + // socket to move to a CLOSED state. + time.Sleep(20 * time.Millisecond) + + // Now resend the same ACK, this ACK should generate a RST as there + // should be no endpoint in SYN-RCVD state and we are not using + // syn-cookies yet. The reason we send the same ACK is we need a valid + // cookie(IRS) generated by the netstack without which the ACK will be + // rejected. + c.SendPacket(nil, ackHeaders) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + +func TestTCPResetsReceivedIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.ResetsReceived.Value() + 1 + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss.Add(1), + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + Flags: header.TCPFlagRst, + }) + + if got := stats.TCP.ResetsReceived.Value(); got != want { + t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) + } +} + +func TestTCPResetsDoNotGenerateResets(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + want := stats.TCP.ResetsReceived.Value() + 1 + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.CreateConnected(iss, rcvWnd, -1 /* epRcvBuf */) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss.Add(1), + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + Flags: header.TCPFlagRst, + }) + + if got := stats.TCP.ResetsReceived.Value(); got != want { + t.Errorf("got stats.TCP.ResetsReceived.Value() = %d, want = %d", got, want) + } + c.CheckNoPacketTimeout("got an unexpected packet", 100*time.Millisecond) +} + +func TestActiveHandshake(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) +} + +func TestNonBlockingClose(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + // Close the endpoint and measure how long it takes. + t0 := time.Now() + ep.Close() + if diff := time.Now().Sub(t0); diff > 3*time.Second { + t.Fatalf("Took too long to close: %s", diff) + } +} + +func TestConnectResetAfterClose(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPLinger to 3 seconds so that sockets are marked closed + // after 3 second in FIN_WAIT2 state. + tcpLingerTimeout := 3 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPLingerTimeoutOption(tcpLingerTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%s) failed: %s", tcpLingerTimeout, err) + } + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + // Close the endpoint, make sure we get a FIN segment, then acknowledge + // to complete closure of sender, but don't send our own FIN. + ep.Close() + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Wait for the ep to give up waiting for a FIN. + time.Sleep(tcpLingerTimeout + 1*time.Second) + + // Now send an ACK and it should trigger a RST as the endpoint should + // not exist anymore. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + for { + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { + // This is a retransmit of the FIN, ignore it. + continue + } + + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence number + // of the FIN itself. + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + break + } +} + +// TestCurrentConnectedIncrement tests increment of the current +// established and connected counters. +func TestCurrentConnectedIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 1 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 1", got) + } + gotConnected := c.Stack().Stats().TCP.CurrentConnected.Value() + if gotConnected != 1 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 1", gotConnected) + } + + ep.Close() + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != gotConnected { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = %d", got, gotConnected) + } + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Wait for a little more than the TIME-WAIT duration for the socket to + // transition to CLOSED state. + time.Sleep(1200 * time.Millisecond) + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } +} + +// TestClosingWithEnqueuedSegments tests handling of still enqueued segments +// when the endpoint transitions to StateClose. The in-flight segments would be +// re-enqueued to a any listening endpoint. +func TestClosingWithEnqueuedSegments(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + ep := c.EP + c.EP = nil + + if got, want := tcp.EndpointState(ep.State()), tcp.StateEstablished; got != want { + t.Errorf("unexpected endpoint state: want %d, got %d", want, got) + } + + // Send a FIN for ESTABLISHED --> CLOSED-WAIT + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Get the ACK for the FIN we sent. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Give the stack a few ms to transition the endpoint out of ESTABLISHED + // state. + time.Sleep(10 * time.Millisecond) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateCloseWait; got != want { + t.Errorf("unexpected endpoint state: want %d, got %d", want, got) + } + + // Close the application endpoint for CLOSE_WAIT --> LAST_ACK + ep.Close() + + // Get the FIN + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + if got, want := tcp.EndpointState(ep.State()), tcp.StateLastAck; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + + // Pause the endpoint`s protocolMainLoop. + ep.(interface{ StopWork() }).StopWork() + + // Enqueue last ACK followed by an ACK matching the endpoint + // + // Send Last ACK for LAST_ACK --> CLOSED + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 791, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Send a packet with ACK set, this would generate RST when + // not using SYN cookies as in this test. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 792, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Unpause endpoint`s protocolMainLoop. + ep.(interface{ ResumeWork() }).ResumeWork() + + // Wait for the protocolMainLoop to resume and update state. + time.Sleep(10 * time.Millisecond) + + // Expect the endpoint to be closed. + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = 1", got) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + + // Check if the endpoint was moved to CLOSED and netstack a reset in + // response to the ACK packet that we sent after last-ACK. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst), + ), + ) +} + +func TestSimpleReceive(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Receive data. + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + + if !bytes.Equal(data, v) { + t.Fatalf("got data = %v, want = %v", v, data) + } + + // Check that ACK is received. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +// TestUserSuppliedMSSOnConnectV4 tests that the user supplied MSS is used when +// creating a new active IPv4 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV4(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv4MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS int + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.Create(-1) + + // Set the MSS socket option. + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, test.setMSS); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv4 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("unexpected return value from Connect: %s", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +// TestUserSuppliedMSSOnConnectV6 tests that the user supplied MSS is used when +// creating a new active IPv6 TCP socket. It should be present in the sent TCP +// SYN segment. +func TestUserSuppliedMSSOnConnectV6(t *testing.T) { + const mtu = 5000 + const maxMSS = mtu - header.IPv6MinimumSize - header.TCPMinimumSize + tests := []struct { + name string + setMSS uint16 + expMSS uint16 + }{ + { + "EqualToMaxMSS", + maxMSS, + maxMSS, + }, + { + "LessThanMTU", + maxMSS - 1, + maxMSS - 1, + }, + { + "GreaterThanMTU", + maxMSS + 1, + maxMSS, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, mtu) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Set the MSS socket option. + if err := c.EP.SetSockOptInt(tcpip.MaxSegOption, int(test.setMSS)); err != nil { + t.Fatalf("SetSockOptInt(MaxSegOption, %d) failed: %s", test.setMSS, err) + } + + // Get expected window size. + rcvBufSize, err := c.EP.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) + } + ws := tcp.FindWndScale(seqnum.Size(rcvBufSize)) + + // Start connection attempt to IPv6 address. + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("unexpected return value from Connect: %s", err) + } + + // Receive SYN packet with our user supplied MSS. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: test.expMSS, WS: ws}))) + }) + } +} + +func TestSendRstOnListenerRxSynAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxSynAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +// TestTCPAckBeforeAcceptV4 tests that once the 3-way handshake is complete, +// peers can send data and expect a response within a reasonable ammount of time +// without calling Accept on the listening endpoint first. +// +// This test uses IPv4. +func TestTCPAckBeforeAcceptV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + // Send data before accepting the connection. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +// TestTCPAckBeforeAcceptV6 tests that once the 3-way handshake is complete, +// peers can send data and expect a response within a reasonable ammount of time +// without calling Accept on the listening endpoint first. +// +// This test uses IPv6. +func TestTCPAckBeforeAcceptV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + irs, iss := executeV6Handshake(t, c, context.TestPort, false /* synCookiesInUse */) + + // Send data before accepting the connection. + c.SendV6Packet([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestSendRstOnListenerRxAckV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +func TestSendRstOnListenerRxAckV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true /* v6Only */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin | header.TCPFlagAck, + SeqNum: 100, + AckNum: 200, + }) + + checker.IPv6(t, c.GetV6Packet(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + checker.SeqNum(200))) +} + +// TestListenShutdown tests for the listening endpoint replying with RST +// on read shutdown. +func TestListenShutdown(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(1 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatal("Shutdown failed:", err) + } + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: 100, + AckNum: 200, + }) + + // Expect the listening endpoint to reset the connection. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + )) +} + +// TestListenCloseWhileConnect tests for the listening endpoint to +// drain the accept-queue when closed. This should reset all of the +// pending connections that are waiting to be accepted. +func TestListenCloseWhileConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1 /* epRcvBuf */) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(1 /* backlog */); err != nil { + t.Fatal("Listen failed:", err) + } + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventIn) + defer c.WQ.EventUnregister(&waitEntry) + + executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + // Wait for the new endpoint created because of handshake to be delivered + // to the listening endpoint's accept queue. + <-notifyCh + + // Close the listening endpoint. + c.EP.Close() + + // Expect the listening endpoint to reset the connection. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + )) +} + +func TestTOSV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + c.EP = ep + + const tos = 0xC0 + if err := c.EP.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { + t.Errorf("SetSockOptInt(IPv4TOSOption, %d) failed: %s", tos, err) + } + + v, err := c.EP.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + t.Errorf("GetSockoptInt(IPv4TOSOption) failed: %s", err) + } + + if v != tos { + t.Errorf("got GetSockOptInt(IPv4TOSOption) = %d, want = %d", v, tos) + } + + testV4Connect(t, c, checker.TOS(tos, 0)) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + checker.TOS(tos, 0), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } +} + +func TestTrafficClassV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + const tos = 0xC0 + if err := c.EP.SetSockOptInt(tcpip.IPv6TrafficClassOption, tos); err != nil { + t.Errorf("SetSockOpInt(IPv6TrafficClassOption, %d) failed: %s", tos, err) + } + + v, err := c.EP.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + t.Fatalf("GetSockoptInt(IPv6TrafficClassOption) failed: %s", err) + } + + if v != tos { + t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = %d, want = %d", v, tos) + } + + // Test the connection request. + testV6Connect(t, c, checker.TOS(tos, 0)) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b := c.GetV6Packet() + checker.IPv6(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + checker.TOS(tos, 0), + ) + + if p := b[header.IPv6MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } +} + +func TestConnectBindToDevice(t *testing.T) { + for _, test := range []struct { + name string + device tcpip.NICID + want tcp.EndpointState + }{ + {"RightDevice", 1, tcp.StateEstablished}, + {"WrongDevice", 2, tcp.StateSynSent}, + {"AnyDevice", 0, tcp.StateEstablished}, + } { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + bindToDevice := tcpip.BindToDeviceOption(test.device) + c.EP.SetSockOpt(bindToDevice) + // Start connection attempt. + waitEntry, _ := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("unexpected return value from Connect: %s", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + + c.GetPacket() + if got, want := tcp.EndpointState(c.EP.State()), test.want; got != want { + t.Fatalf("unexpected endpoint state: want %s, got %s", want, got) + } + }) + } +} + +func TestRstOnSynSent(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create an endpoint, don't handshake because we want to interfere with the + // handshake process. + c.Create(-1) + + // Start connection attempt. + waitEntry, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + addr := tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort} + if err := c.EP.Connect(addr); err != tcpip.ErrConnectStarted { + t.Fatalf("got Connect(%+v) = %s, want %s", addr, err, tcpip.ErrConnectStarted) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + // Ensure that we've reached SynSent state + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + // Send a packet with a proper ACK and a RST flag to cause the socket + // to Error and close out + iss := seqnum.Value(789) + rcvWnd := seqnum.Size(30000) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagRst | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: nil, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatal("timed out waiting for packet to arrive") + } + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused) + } + + // Due to the RST the endpoint should be in an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Fatalf("got State() = %s, want %s", got, want) + } +} + +func TestOutOfOrderReceive(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + + // Send second half of data first, with seqnum 3 ahead of expected. + data := []byte{1, 2, 3, 4, 5, 6} + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 793, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Check that we get an ACK specifying which seqnum is expected. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Wait 200ms and check that no data has been received. + time.Sleep(200 * time.Millisecond) + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + + // Send the first 3 bytes now. + c.SendPacket(data[:3], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Receive data. + read := make([]byte, 0, 6) + for len(read) < len(data) { + v, _, err := c.EP.Read(nil) + if err != nil { + if err == tcpip.ErrWouldBlock { + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + continue + } + t.Fatalf("Read failed: %s", err) + } + + read = append(read, v...) + } + + // Check that we received the data in proper order. + if !bytes.Equal(data, read) { + t.Fatalf("got data = %v, want = %v", read, data) + } + + // Check that the whole data is acknowledged. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestOutOfOrderFlood(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create a new connection with initial window size of 10. + c.CreateConnected(789, 30000, 10) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + + // Send 100 packets before the actual one that is expected. + data := []byte{1, 2, 3, 4, 5, 6} + for i := 0; i < 100; i++ { + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 796, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Send packet with seqnum 793. It must be discarded because the + // out-of-order buffer was filled by the previous packets. + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 793, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Now send the expected packet, seqnum 790. + c.SendPacket(data[:3], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Check that only packet 790 is acknowledged. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(793), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestRstOnCloseWithUnreadData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that ACK is received, this happens regardless of the read. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Now that we know we have unread data, let's just close the connection + // and verify that netstack sends an RST rather than a FIN. + c.EP.Close() + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + // We shouldn't consume a sequence number on RST. + checker.SeqNum(uint32(c.IRS)+1), + )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + + // This final ACK should be ignored because an ACK on a reset doesn't mean + // anything. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + len(data)), + AckNum: c.IRS.Add(seqnum.Size(2)), + RcvWnd: 30000, + }) +} + +func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that ACK is received, this happens regardless of the read. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Cause a FIN to be generated. + c.EP.Shutdown(tcpip.ShutdownWrite) + + // Make sure we get the FIN but DON't ACK IT. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + checker.SeqNum(uint32(c.IRS)+1), + )) + + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + + // Cause a RST to be generated by closing the read end now since we have + // unread data. + c.EP.Shutdown(tcpip.ShutdownRead) + + // Make sure we get the RST + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + // RST is always generated with sndNxt which if the FIN + // has been sent will be 1 higher than the sequence + // number of the FIN itself. + checker.SeqNum(uint32(c.IRS)+2), + )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + + // The ACK to the FIN should now be rejected since the connection has been + // closed by a RST. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + len(data)), + AckNum: c.IRS.Add(seqnum.Size(2)), + RcvWnd: 30000, + }) +} + +func TestShutdownRead(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + + if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) + } + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { + t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) + } +} + +func TestFullWindowReceive(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, 10) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + _, _, err := c.EP.Read(nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("Read failed: %s", err) + } + + // Fill up the window. + data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that data is acknowledged, and window goes to zero. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(0), + ), + ) + + // Receive data and check it. + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + + if !bytes.Equal(data, v) { + t.Fatalf("got data = %v, want = %v", v, data) + } + + var want uint64 = 1 + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ZeroRcvWindowState.Value(); got != want { + t.Fatalf("got EP stats ReceiveErrors.ZeroRcvWindowState got %d want %d", got, want) + } + + // Check that we get an ACK for the newly non-zero window. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(10), + ), + ) +} + +func TestNoWindowShrinking(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Start off with a window size of 10, then shrink it to 5. + c.CreateConnected(789, 30000, 10) + + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 5); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 5) failed: %s", err) + } + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + + // Send 3 bytes, check that the peer acknowledges them. + data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + c.SendPacket(data[:3], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that data is acknowledged, and that window doesn't go to zero + // just yet because it was previously set to 10. It must go to 7 now. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(793), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(7), + ), + ) + + // Send 7 more bytes, check that the window fills up. + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 793, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(0), + ), + ) + + // Receive data and check it. + read := make([]byte, 0, 10) + for len(read) < len(data) { + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + + read = append(read, v...) + } + + if !bytes.Equal(data, read) { + t.Fatalf("got data = %v, want = %v", read, data) + } + + // Check that we get an ACK for the newly non-zero window, which is the + // new size. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(5), + ), + ) +} + +func TestSimpleSend(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Fatalf("got data = %v, want = %v", p, data) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), + RcvWnd: 30000, + }) +} + +func TestZeroWindowSend(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789 /* iss */, 0 /* rcvWnd */, -1 /* epRcvBuf */) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check if we got a zero-window probe. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Open up the window. Data should be received now. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Check that data is received. + b = c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Fatalf("got data = %v, want = %v", p, data) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), + RcvWnd: 30000, + }) +} + +func TestScaledWindowConnect(t *testing.T) { + // This test ensures that window scaling is used when the peer + // does advertise it and connection is established with Connect(). + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set the window size greater than the maximum non-scaled window. + c.CreateConnectedWithRawOptions(789, 30000, 65535*3, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, + }) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received, and that advertised window is 0xbfff, + // that is, that it is scaled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(0xbfff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) +} + +func TestNonScaledWindowConnect(t *testing.T) { + // This test ensures that window scaling is not used when the peer + // doesn't advertise it and connection is established with Connect(). + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set the window size greater than the maximum non-scaled window. + c.CreateConnected(789, 30000, 65535*3) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received, and that advertised window is 0xffff, + // that is, that it's not scaled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(0xffff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) +} + +func TestScaledWindowAccept(t *testing.T) { + // This test ensures that window scaling is used when the peer + // does advertise it and connection is established with Accept(). + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() + + // Set the window size greater than the maximum non-scaled window. + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Do 3-way handshake. + c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received, and that advertised window is 0xbfff, + // that is, that it is scaled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(0xbfff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) +} + +func TestNonScaledWindowAccept(t *testing.T) { + // This test ensures that window scaling is not used when the peer + // doesn't advertise it and connection is established with Accept(). + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() + + // Set the window size greater than the maximum non-scaled window. + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 65535*3); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 65535*3) failed failed: %s", err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Do 3-way handshake w/ window scaling disabled. The SYN-ACK to the SYN + // should not carry the window scaling option. + c.PassiveConnect(100, -1, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received, and that advertised window is 0xffff, + // that is, that it's not scaled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(0xffff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) +} + +func TestZeroScaledWindowReceive(t *testing.T) { + // This test ensures that the endpoint sends a non-zero window size + // advertisement when the scaled window transitions from 0 to non-zero, + // but the actual window (not scaled) hasn't gotten to zero. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set the window size such that a window scale of 4 will be used. + const wnd = 65535 * 10 + const ws = uint32(4) + c.CreateConnectedWithRawOptions(789, 30000, wnd, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, + }) + + // Write chunks of 50000 bytes. + remain := wnd + sent := 0 + data := make([]byte, 50000) + for remain > len(data) { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(remain>>ws)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Make the window non-zero, but the scaled window zero. + if remain >= 16 { + data = data[:remain-15] + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(0), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Read at least 1MSS of data. An ack should be sent in response to that. + sz := 0 + for sz < defaultMTU { + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + sz += len(v) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(sz>>ws)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestSegmentMerging(t *testing.T) { + tests := []struct { + name string + stop func(tcpip.Endpoint) + resume func(tcpip.Endpoint) + }{ + { + "stop work", + func(ep tcpip.Endpoint) { + ep.(interface{ StopWork() }).StopWork() + }, + func(ep tcpip.Endpoint) { + ep.(interface{ ResumeWork() }).ResumeWork() + }, + }, + { + "cork", + func(ep tcpip.Endpoint) { + ep.SetSockOptBool(tcpip.CorkOption, true) + }, + func(ep tcpip.Endpoint) { + ep.SetSockOptBool(tcpip.CorkOption, false) + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + // Send tcp.InitialCwnd number of segments to fill up + // InitialWindow but don't ACK. That should prevent + // anymore packets from going out. + for i := 0; i < tcp.InitialCwnd; i++ { + view := buffer.NewViewFromBytes([]byte{0}) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %s", i+1, err) + } + } + + // Now send the segments that should get merged as the congestion + // window is full and we won't be able to send any more packets. + var allData []byte + for i, data := range [][]byte{{1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { + allData = append(allData, data...) + view := buffer.NewViewFromBytes(data) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %s", i+1, err) + } + } + + // Check that we get tcp.InitialCwnd packets. + for i := 0; i < tcp.InitialCwnd; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(header.TCPMinimumSize+1), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+uint32(i)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + 10), // 10 for the 10 bytes of payload. + RcvWnd: 30000, + }) + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(allData)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+11), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, allData) { + t.Fatalf("got data = %v, want = %v", got, allData) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(11 + seqnum.Size(len(allData))), + RcvWnd: 30000, + }) + }) + } +} + +func TestDelay(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + c.EP.SetSockOptBool(tcpip.DelayOption, true) + + var allData []byte + for i, data := range [][]byte{{0}, {1, 2, 3, 4}, {5, 6, 7}, {8, 9}, {10}, {11}} { + allData = append(allData, data...) + view := buffer.NewViewFromBytes(data) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %s", i+1, err) + } + } + + seq := c.IRS.Add(1) + for _, want := range [][]byte{allData[:1], allData[1:]} { + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(want)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(seq)), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if got := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(got, want) { + t.Fatalf("got data = %v, want = %v", got, want) + } + + seq = seq.Add(seqnum.Size(len(want))) + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seq, + RcvWnd: 30000, + }) + } +} + +func TestUndelay(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + c.EP.SetSockOptBool(tcpip.DelayOption, true) + + allData := [][]byte{{0}, {1, 2, 3}} + for i, data := range allData { + view := buffer.NewViewFromBytes(data) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %s", i+1, err) + } + } + + seq := c.IRS.Add(1) + + // Check that data is received. + first := c.GetPacket() + checker.IPv4(t, first, + checker.PayloadLen(len(allData[0])+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(seq)), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if got, want := first[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[0]; !bytes.Equal(got, want) { + t.Fatalf("got first packet's data = %v, want = %v", got, want) + } + + seq = seq.Add(seqnum.Size(len(allData[0]))) + + // Check that we don't get the second packet yet. + c.CheckNoPacketTimeout("delayed second packet transmitted", 100*time.Millisecond) + + c.EP.SetSockOptBool(tcpip.DelayOption, false) + + // Check that data is received. + second := c.GetPacket() + checker.IPv4(t, second, + checker.PayloadLen(len(allData[1])+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(seq)), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if got, want := second[header.IPv4MinimumSize+header.TCPMinimumSize:], allData[1]; !bytes.Equal(got, want) { + t.Fatalf("got second packet's data = %v, want = %v", got, want) + } + + seq = seq.Add(seqnum.Size(len(allData[1]))) + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seq, + RcvWnd: 30000, + }) +} + +func TestMSSNotDelayed(t *testing.T) { + tests := []struct { + name string + fn func(tcpip.Endpoint) + }{ + {"no-op", func(tcpip.Endpoint) {}}, + {"delay", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.DelayOption, true) }}, + {"cork", func(ep tcpip.Endpoint) { ep.SetSockOptBool(tcpip.CorkOption, true) }}, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + const maxPayload = 100 + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ + header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), + }) + + test.fn(c.EP) + + allData := [][]byte{{0}, make([]byte, maxPayload), make([]byte, maxPayload)} + for i, data := range allData { + view := buffer.NewViewFromBytes(data) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write #%d failed: %s", i+1, err) + } + } + + seq := c.IRS.Add(1) + + for i, data := range allData { + // Check that data is received. + packet := c.GetPacket() + checker.IPv4(t, packet, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(seq)), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if got, want := packet[header.IPv4MinimumSize+header.TCPMinimumSize:], data; !bytes.Equal(got, want) { + t.Fatalf("got packet #%d's data = %v, want = %v", i+1, got, want) + } + + seq = seq.Add(seqnum.Size(len(data))) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seq, + RcvWnd: 30000, + }) + }) + } +} + +func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { + payloadMultiplier := 10 + dataLen := payloadMultiplier * maxPayload + data := make([]byte, dataLen) + for i := range data { + data[i] = byte(i) + } + + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received in chunks. + bytesReceived := 0 + numPackets := 0 + for bytesReceived != dataLen { + b := c.GetPacket() + numPackets++ + tcpHdr := header.TCP(header.IPv4(b).Payload()) + payloadLen := len(tcpHdr.Payload()) + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + pdata := data[bytesReceived : bytesReceived+payloadLen] + if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { + t.Fatalf("got data = %v, want = %v", p, pdata) + } + bytesReceived += payloadLen + var options []byte + if c.TimeStampEnabled { + // If timestamp option is enabled, echo back the timestamp and increment + // the TSEcr value included in the packet and send that back as the TSVal. + parsedOpts := tcpHdr.ParsedOptions() + tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) + options = tsOpt[:] + } + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), + RcvWnd: 30000, + TCPOpts: options, + }) + } + if numPackets == 1 { + t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet") + } +} + +func TestSendGreaterThanMTU(t *testing.T) { + const maxPayload = 100 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + testBrokenUpWrite(t, c, maxPayload) +} + +func TestSetTTL(t *testing.T) { + for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { + t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { + c := context.New(t, 65535) + defer c.Cleanup() + + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + if err := c.EP.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { + t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) + } + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("unexpected return value from Connect: %s", err) + } + + // Receive SYN packet. + b := c.GetPacket() + + checker.IPv4(t, b, checker.TTL(wantTTL)) + }) + } +} + +func TestActiveSendMSSLessThanMTU(t *testing.T) { + const maxPayload = 100 + c := context.New(t, 65535) + defer c.Cleanup() + + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ + header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), + }) + testBrokenUpWrite(t, c, maxPayload) +} + +func TestPassiveSendMSSLessThanMTU(t *testing.T) { + const maxPayload = 100 + const mtu = 1200 + c := context.New(t, mtu) + defer c.Cleanup() + + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() + + // Set the buffer size to a deterministic size so that we can check the + // window scaling option. + const rcvBufferSize = 0x20000 + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Do 3-way handshake. + c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Check that data gets properly segmented. + testBrokenUpWrite(t, c, maxPayload) +} + +func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { + const maxPayload = 536 + const mtu = 2000 + c := context.New(t, mtu) + defer c.Cleanup() + + // Set the SynRcvd threshold to zero to force a syn cookie based accept + // to happen. + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Do 3-way handshake. + c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Check that data gets properly segmented. + testBrokenUpWrite(t, c, maxPayload) +} + +func TestForwarderSendMSSLessThanMTU(t *testing.T) { + const maxPayload = 100 + const mtu = 1200 + c := context.New(t, mtu) + defer c.Cleanup() + + s := c.Stack() + ch := make(chan *tcpip.Error, 1) + f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { + var err *tcpip.Error + c.EP, err = r.CreateEndpoint(&c.WQ) + ch <- err + }) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) + + // Do 3-way handshake. + c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) + + // Wait for connection to be available. + select { + case err := <-ch: + if err != nil { + t.Fatalf("Error creating endpoint: %s", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("Timed out waiting for connection") + } + + // Check that data gets properly segmented. + testBrokenUpWrite(t, c, maxPayload) +} + +func TestSynOptionsOnActiveConnect(t *testing.T) { + const mtu = 1400 + c := context.New(t, mtu) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + // Set the buffer size to a deterministic size so that we can check the + // window scaling option. + const rcvBufferSize = 0x20000 + const wndScale = 2 + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBufferSize); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, %d) failed failed: %s", rcvBufferSize, err) + } + + // Start connection attempt. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventOut) + defer c.WQ.EventUnregister(&we) + + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + } + + // Receive SYN packet. + b := c.GetPacket() + mss := uint16(mtu - header.IPv4MinimumSize - header.TCPMinimumSize) + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), + ), + ) + + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + // Wait for retransmit. + time.Sleep(1 * time.Second) + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.SrcPort(tcpHdr.SourcePort()), + checker.SeqNum(tcpHdr.SequenceNumber()), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), + ), + ) + + // Send SYN-ACK. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK packet. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + ), + ) + + // Wait for connection to be established. + select { + case <-ch: + if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { + t.Fatalf("GetSockOpt failed: %s", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } +} + +func TestCloseListener(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create listener. + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + if err := ep.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Close the listener and measure how long it takes. + t0 := time.Now() + ep.Close() + if diff := time.Now().Sub(t0); diff > 3*time.Second { + t.Fatalf("Took too long to close: %s", diff) + } +} + +func TestReceiveOnResetConnection(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + // Send RST segment. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagRst, + SeqNum: 790, + RcvWnd: 30000, + }) + + // Try to read. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + +loop: + for { + switch _, _, err := c.EP.Read(nil); err { + case tcpip.ErrWouldBlock: + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for reset to arrive") + } + case tcpip.ErrConnectionReset: + break loop + default: + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) + } + } + // Expect the state to be StateError and subsequent Reads to fail with HardError. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) + } + if tcp.EndpointState(c.EP.State()) != tcp.StateError { + t.Fatalf("got EP state is not StateError") + } + + if got := c.Stack().Stats().TCP.EstablishedResets.Value(); got != 1 { + t.Errorf("got stats.TCP.EstablishedResets.Value() = %d, want = 1", got) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } +} + +func TestSendOnResetConnection(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + // Send RST segment. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagRst, + SeqNum: 790, + RcvWnd: 30000, + }) + + // Wait for the RST to be received. + time.Sleep(1 * time.Second) + + // Try to write. + view := buffer.NewView(10) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Write(...) = %s, want = %s", err, tcpip.ErrConnectionReset) + } +} + +// TestMaxRetransmitsTimeout tests if the connection is timed out after +// a segment has been retransmitted MaxRetries times. +func TestMaxRetransmitsTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + const numRetries = 2 + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRetriesOption(numRetries)); err != nil { + t.Fatalf("could not set protocol option MaxRetries.\n") + } + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) + defer c.WQ.EventUnregister(&waitEntry) + + _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Expect first transmit and MaxRetries retransmits. + for i := 0; i < numRetries+1; i++ { + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), + ), + ) + } + // Wait for the connection to timeout after MaxRetries retransmits. + initRTO := 1 * time.Second + select { + case <-notifyCh: + case <-time.After((2 << numRetries) * initRTO): + t.Fatalf("connection still alive after maximum retransmits.\n") + } + + // Send an ACK and expect a RST as the connection would have been closed. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } +} + +// TestMaxRTO tests if the retransmit interval caps to MaxRTO. +func TestMaxRTO(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + rto := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMaxRTOOption(rto)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPMaxRTO(%d) failed: %s", rto, err) + } + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(1)), tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Write failed: %s", err) + } + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + const numRetransmits = 2 + for i := 0; i < numRetransmits; i++ { + start := time.Now() + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + if time.Since(start).Round(time.Second).Seconds() != rto.Seconds() { + t.Errorf("Retransmit interval not capped to MaxRTO.\n") + } + } +} + +// TestRetransmitIPv4IDUniqueness tests that the IPv4 Identification field is +// unique on retransmits. +func TestRetransmitIPv4IDUniqueness(t *testing.T) { + for _, tc := range []struct { + name string + size int + }{ + {"1Byte", 1}, + {"512Bytes", 512}, + } { + t.Run(tc.name, func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789 /* iss */, 30000 /* rcvWnd */, -1 /* epRcvBuf */) + + // Disabling PMTU discovery causes all packets sent from this socket to + // have DF=0. This needs to be done because the IPv4 ID uniqueness + // applies only to non-atomic IPv4 datagrams as defined in RFC 6864 + // Section 4, and datagrams with DF=0 are non-atomic. + if err := c.EP.SetSockOptInt(tcpip.MTUDiscoverOption, tcpip.PMTUDiscoveryDont); err != nil { + t.Fatalf("disabling PMTU discovery via sockopt to force DF=0 failed: %s", err) + } + + if _, _, err := c.EP.Write(tcpip.SlicePayload(buffer.NewView(tc.size)), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + idSet := map[uint16]struct{}{header.IPv4(pkt).ID(): struct{}{}} + // Expect two retransmitted packets, and that all packets received have + // unique IPv4 ID values. + for i := 0; i <= 2; i++ { + pkt := c.GetPacket() + checker.IPv4(t, pkt, + checker.FragmentFlags(0), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + id := header.IPv4(pkt).ID() + if _, exists := idSet[id]; exists { + t.Fatalf("duplicate IPv4 ID=%d found in retransmitted packet", id) + } + idSet[id] = struct{}{} + } + }) + } +} + +func TestFinImmediately(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + // Shutdown immediately, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinRetransmit(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + // Shutdown immediately, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + // Don't acknowledge yet. We should get a retransmit of the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinWithNoPendingData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + // Write something out, and have it acknowledged. + view := buffer.NewView(10) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Shutdown, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinWithPendingDataCwndFull(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + // Write enough segments to fill the congestion window before ACK'ing + // any of them. + view := buffer.NewView(10) + for i := tcp.InitialCwnd; i > 0; i-- { + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + } + + next := uint32(c.IRS) + 1 + for i := tcp.InitialCwnd; i > 0; i-- { + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + } + + // Shutdown the connection, check that the FIN segment isn't sent + // because the congestion window doesn't allow it. Wait until a + // retransmit is received. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Send the ACK that will allow the FIN to be sent as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ + + // Send a FIN that acknowledges everything. Get an ACK back. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinWithPendingData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + // Write something out, and acknowledge it to get cwnd to 2. + view := buffer.NewView(10) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Write new data, but don't acknowledge it. + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + // Shutdown the connection, check that we do get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ + + // Send a FIN that acknowledges everything. Get an ACK back. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinWithPartialAck(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + // Write something out, and acknowledge it to get cwnd to 2. Also send + // FIN from the test side. + view := buffer.NewView(10) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Check that we get an ACK for the fin. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Write new data, but don't acknowledge it. + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + // Shutdown the connection, check that we do get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ + + // Send an ACK for the data, but not for the FIN yet. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 791, + AckNum: seqnum.Value(next - 1), + RcvWnd: 30000, + }) + + // Check that we don't get a retransmit of the FIN. + c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond) + + // Ack the FIN. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 791, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) +} + +func TestUpdateListenBacklog(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create listener. + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + if err := ep.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Update the backlog with another Listen() on the same endpoint. + if err := ep.Listen(20); err != nil { + t.Fatalf("Listen failed to update backlog: %s", err) + } + + ep.Close() +} + +func scaledSendWindow(t *testing.T, scale uint8) { + // This test ensures that the endpoint is using the right scaling by + // sending a buffer that is larger than the window size, and ensuring + // that the endpoint doesn't send more than allowed. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize + c.CreateConnectedWithRawOptions(789, 0, -1 /* epRcvBuf */, []byte{ + header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), + header.TCPOptionWS, 3, scale, header.TCPOptionNOP, + }) + + // Open up the window with a scaled value. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 1, + }) + + // Send some data. Check that it's capped by the window size. + view := buffer.NewView(65535) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that only data that fits in the scaled window is sent. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen((1<<scale)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Reset the connection to free resources. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagRst, + SeqNum: 790, + }) +} + +func TestScaledSendWindow(t *testing.T) { + for scale := uint8(0); scale <= 14; scale++ { + scaledSendWindow(t, scale) + } +} + +func TestReceivedValidSegmentCountIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + stats := c.Stack().Stats() + want := stats.TCP.ValidSegmentsReceived.Value() + 1 + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + if got := stats.TCP.ValidSegmentsReceived.Value(); got != want { + t.Errorf("got stats.TCP.ValidSegmentsReceived.Value() = %d, want = %d", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).SegmentsReceived.Value(); got != want { + t.Errorf("got EP stats Stats.SegmentsReceived = %d, want = %d", got, want) + } + // Ensure there were no errors during handshake. If these stats have + // incremented, then the connection should not have been established. + if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoRoute.Value(); got != 0 { + t.Errorf("got EP stats Stats.SendErrors.NoRoute = %d, want = %d", got, 0) + } + if got := c.EP.Stats().(*tcp.Stats).SendErrors.NoLinkAddr.Value(); got != 0 { + t.Errorf("got EP stats Stats.SendErrors.NoLinkAddr = %d, want = %d", got, 0) + } +} + +func TestReceivedInvalidSegmentCountIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + stats := c.Stack().Stats() + want := stats.TCP.InvalidSegmentsReceived.Value() + 1 + vv := c.BuildSegment(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + tcpbuf[header.TCPDataOffset] = ((header.TCPMinimumSize - 1) / 4) << 4 + + c.SendSegment(vv) + + if got := stats.TCP.InvalidSegmentsReceived.Value(); got != want { + t.Errorf("got stats.TCP.InvalidSegmentsReceived.Value() = %d, want = %d", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) + } +} + +func TestReceivedIncorrectChecksumIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + stats := c.Stack().Stats() + want := stats.TCP.ChecksumErrors.Value() + 1 + vv := c.BuildSegment([]byte{0x1, 0x2, 0x3}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + tcpbuf := vv.ToView()[header.IPv4MinimumSize:] + // Overwrite a byte in the payload which should cause checksum + // verification to fail. + tcpbuf[(tcpbuf[header.TCPDataOffset]>>4)*4] = 0x4 + + c.SendSegment(vv) + + if got := stats.TCP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.TCP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP stats Stats.ReceiveErrors.ChecksumErrors = %d, want = %d", got, want) + } +} + +func TestReceivedSegmentQueuing(t *testing.T) { + // This test sends 200 segments containing a few bytes each to an + // endpoint and checks that they're all received and acknowledged by + // the endpoint, that is, that none of the segments are dropped by + // internal queues. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + // Send 200 segments. + data := []byte{1, 2, 3} + for i := 0; i < 200; i++ { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + i*len(data)), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + } + + // Receive ACKs for all segments. + last := seqnum.Value(790 + 200*len(data)) + for { + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + ack := seqnum.Value(tcpHdr.AckNumber()) + if ack == last { + break + } + + if last.LessThan(ack) { + t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last) + } + } +} + +func TestReadAfterClosedState(t *testing.T) { + // This test ensures that calling Read() or Peek() after the endpoint + // has transitioned to closedState still works if there is pending + // data. To transition to stateClosed without calling Close(), we must + // shutdown the send path and the peer must send its own FIN. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 1 seconds so that sockets are marked closed + // after 1 second in TIME_WAIT state. + tcpTimeWaitTimeout := 1 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPTimeWaitTimeout(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + + // Shutdown immediately for write, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + + // Send some data and acknowledge the FIN. + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Check that ACK is received. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(uint32(791+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Give the stack the chance to transition to closed state from + // TIME_WAIT. + time.Sleep(tcpTimeWaitTimeout * 2) + + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that peek works. + peekBuf := make([]byte, 10) + n, _, err := c.EP.Peek([][]byte{peekBuf}) + if err != nil { + t.Fatalf("Peek failed: %s", err) + } + + peekBuf = peekBuf[:n] + if !bytes.Equal(data, peekBuf) { + t.Fatalf("got data = %v, want = %v", peekBuf, data) + } + + // Receive data. + v, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + + if !bytes.Equal(data, v) { + t.Fatalf("got data = %v, want = %v", v, data) + } + + // Now that we drained the queue, check that functions fail with the + // right error code. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) + } + + if _, _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { + t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive) + } +} + +func TestReusePort(t *testing.T) { + // This test ensures that ports are immediately available for reuse + // after Close on the endpoints using them returns. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // First case, just an endpoint that was bound. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + c.EP.Close() + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + c.EP.Close() + + // Second case, an endpoint that was bound and is connecting.. + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("got c.EP.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + } + c.EP.Close() + + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + c.EP.Close() + + // Third case, an endpoint that was bound and is listening. + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + c.EP.Close() + + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + if err := c.EP.SetSockOptBool(tcpip.ReuseAddressOption, true); err != nil { + t.Fatalf("SetSockOptBool ReuseAddressOption failed: %s", err) + } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } +} + +func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { + t.Helper() + + s, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt failed: %s", err) + } + + if int(s) != v { + t.Fatalf("got receive buffer size = %d, want = %d", s, v) + } +} + +func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { + t.Helper() + + s, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { + t.Fatalf("GetSockOpt failed: %s", err) + } + + if int(s) != v { + t.Fatalf("got send buffer size = %d, want = %d", s, v) + } +} + +func TestDefaultBufferSizes(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) + + // Check the default values. + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + defer func() { + if ep != nil { + ep.Close() + } + }() + + checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize) + checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) + + // Change the default send buffer size. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{ + Min: 1, + Default: tcp.DefaultSendBufferSize * 2, + Max: tcp.DefaultSendBufferSize * 20}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) + } + + ep.Close() + ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + + checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) + checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize) + + // Change the default receive buffer size. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{ + Min: 1, + Default: tcp.DefaultReceiveBufferSize * 3, + Max: tcp.DefaultReceiveBufferSize * 30}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %v", err) + } + + ep.Close() + ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + + checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*2) + checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*3) +} + +func TestMinMaxBufferSizes(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) + + // Check the default values. + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + defer ep.Close() + + // Change the min/max values for send/receive + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 200, Default: tcp.DefaultReceiveBufferSize * 2, Max: tcp.DefaultReceiveBufferSize * 20}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) + } + + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 300, Default: tcp.DefaultSendBufferSize * 3, Max: tcp.DefaultSendBufferSize * 30}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) + } + + // Set values below the min. + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 199); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption, 199) failed: %s", err) + } + + checkRecvBufferSize(t, ep, 200) + + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 299); err != nil { + t.Fatalf("SetSockOptInt(SendBufferSizeOption, 299) failed: %s", err) + } + + checkSendBufferSize(t, ep, 300) + + // Set values above the max. + if err := ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, 1+tcp.DefaultReceiveBufferSize*20); err != nil { + t.Fatalf("SetSockOptInt(ReceiveBufferSizeOption) failed: %s", err) + } + + checkRecvBufferSize(t, ep, tcp.DefaultReceiveBufferSize*20) + + if err := ep.SetSockOptInt(tcpip.SendBufferSizeOption, 1+tcp.DefaultSendBufferSize*30); err != nil { + t.Fatalf("SetSockOptInt(SendBufferSizeOption) failed: %s", err) + } + + checkSendBufferSize(t, ep, tcp.DefaultSendBufferSize*30) +} + +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}}) + + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + defer ep.Close() + + if err := s.CreateNIC(321, loopback.New()); err != nil { + t.Errorf("CreateNIC failed: %s", err) + } + + // nicIDPtr is used instead of taking the address of NICID literals, which is + // a compiler error. + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { + return &s + } + + testActions := []struct { + name string + setBindToDevice *tcpip.NICID + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%#v) got %v, want %v", bindToDevice, gotErr, wantErr) + } + } + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %s, want %v", err, nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %d, want %d", got, want) + } + }) + } +} + +func makeStack() (*stack.Stack, *tcpip.Error) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ + ipv4.NewProtocol(), + ipv6.NewProtocol(), + }, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) + + id := loopback.New() + if testing.Verbose() { + id = sniffer.New(id) + } + + if err := s.CreateNIC(1, id); err != nil { + return nil, err + } + + for _, ct := range []struct { + number tcpip.NetworkProtocolNumber + address tcpip.Address + }{ + {ipv4.ProtocolNumber, context.StackAddr}, + {ipv6.ProtocolNumber, context.StackV6Addr}, + } { + if err := s.AddAddress(1, ct.number, ct.address); err != nil { + return nil, err + } + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: 1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + + return s, nil +} + +func TestSelfConnect(t *testing.T) { + // This test ensures that intentional self-connects work. In particular, + // it checks that if an endpoint binds to say 127.0.0.1:1000 then + // connects to 127.0.0.1:1000, then it will be connected to itself, and + // is able to send and receive data through the same endpoint. + s, err := makeStack() + if err != nil { + t.Fatal(err) + } + + var wq waiter.Queue + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + // Register for notification, then start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventOut) + defer wq.EventUnregister(&waitEntry) + + if err := ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("got ep.Connect(...) = %s, want = %s", err, tcpip.ErrConnectStarted) + } + + <-notifyCh + if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil { + t.Fatalf("Connect failed: %s", err) + } + + // Write something. + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + if _, _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Read back what was written. + wq.EventUnregister(&waitEntry) + wq.EventRegister(&waitEntry, waiter.EventIn) + rd, _, err := ep.Read(nil) + if err != nil { + if err != tcpip.ErrWouldBlock { + t.Fatalf("Read failed: %s", err) + } + <-notifyCh + rd, _, err = ep.Read(nil) + if err != nil { + t.Fatalf("Read failed: %s", err) + } + } + + if !bytes.Equal(data, rd) { + t.Fatalf("got data = %v, want = %v", rd, data) + } +} + +func TestConnectAvoidsBoundPorts(t *testing.T) { + addressTypes := func(t *testing.T, network string) []string { + switch network { + case "ipv4": + return []string{"v4"} + case "ipv6": + return []string{"v6"} + case "dual": + return []string{"v6", "mapped"} + default: + t.Fatalf("unknown network: '%s'", network) + } + + panic("unreachable") + } + + address := func(t *testing.T, addressType string, isAny bool) tcpip.Address { + switch addressType { + case "v4": + if isAny { + return "" + } + return context.StackAddr + case "v6": + if isAny { + return "" + } + return context.StackV6Addr + case "mapped": + if isAny { + return context.V4MappedWildcardAddr + } + return context.StackV4MappedAddr + default: + t.Fatalf("unknown address type: '%s'", addressType) + } + + panic("unreachable") + } + // This test ensures that Endpoint.Connect doesn't select already-bound ports. + networks := []string{"ipv4", "ipv6", "dual"} + for _, exhaustedNetwork := range networks { + t.Run(fmt.Sprintf("exhaustedNetwork=%s", exhaustedNetwork), func(t *testing.T) { + for _, exhaustedAddressType := range addressTypes(t, exhaustedNetwork) { + t.Run(fmt.Sprintf("exhaustedAddressType=%s", exhaustedAddressType), func(t *testing.T) { + for _, isAny := range []bool{false, true} { + t.Run(fmt.Sprintf("isAny=%t", isAny), func(t *testing.T) { + for _, candidateNetwork := range networks { + t.Run(fmt.Sprintf("candidateNetwork=%s", candidateNetwork), func(t *testing.T) { + for _, candidateAddressType := range addressTypes(t, candidateNetwork) { + t.Run(fmt.Sprintf("candidateAddressType=%s", candidateAddressType), func(t *testing.T) { + s, err := makeStack() + if err != nil { + t.Fatal(err) + } + + var wq waiter.Queue + var eps []tcpip.Endpoint + defer func() { + for _, ep := range eps { + ep.Close() + } + }() + makeEP := func(network string) tcpip.Endpoint { + var networkProtocolNumber tcpip.NetworkProtocolNumber + switch network { + case "ipv4": + networkProtocolNumber = ipv4.ProtocolNumber + case "ipv6", "dual": + networkProtocolNumber = ipv6.ProtocolNumber + default: + t.Fatalf("unknown network: '%s'", network) + } + ep, err := s.NewEndpoint(tcp.ProtocolNumber, networkProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + eps = append(eps, ep) + switch network { + case "ipv4": + case "ipv6": + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + t.Fatalf("SetSockOptBool(V6OnlyOption(true)) failed: %s", err) + } + case "dual": + if err := ep.SetSockOptBool(tcpip.V6OnlyOption, false); err != nil { + t.Fatalf("SetSockOptBool(V6OnlyOption(false)) failed: %s", err) + } + default: + t.Fatalf("unknown network: '%s'", network) + } + return ep + } + + var v4reserved, v6reserved bool + switch exhaustedAddressType { + case "v4", "mapped": + v4reserved = true + case "v6": + v6reserved = true + // Dual stack sockets bound to v6 any reserve on v4 as + // well. + if isAny { + switch exhaustedNetwork { + case "ipv6": + case "dual": + v4reserved = true + default: + t.Fatalf("unknown address type: '%s'", exhaustedNetwork) + } + } + default: + t.Fatalf("unknown address type: '%s'", exhaustedAddressType) + } + var collides bool + switch candidateAddressType { + case "v4", "mapped": + collides = v4reserved + case "v6": + collides = v6reserved + default: + t.Fatalf("unknown address type: '%s'", candidateAddressType) + } + + for i := ports.FirstEphemeral; i <= math.MaxUint16; i++ { + if makeEP(exhaustedNetwork).Bind(tcpip.FullAddress{Addr: address(t, exhaustedAddressType, isAny), Port: uint16(i)}); err != nil { + t.Fatalf("Bind(%d) failed: %s", i, err) + } + } + want := tcpip.ErrConnectStarted + if collides { + want = tcpip.ErrNoPortAvailable + } + if err := makeEP(candidateNetwork).Connect(tcpip.FullAddress{Addr: address(t, candidateAddressType, false), Port: 31337}); err != want { + t.Fatalf("got ep.Connect(..) = %s, want = %s", err, want) + } + }) + } + }) + } + }) + } + }) + } + }) + } +} + +func TestPathMTUDiscovery(t *testing.T) { + // This test verifies the stack retransmits packets after it receives an + // ICMP packet indicating that the path MTU has been exceeded. + c := context.New(t, 1500) + defer c.Cleanup() + + // Create new connection with MSS of 1460. + const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize + c.CreateConnectedWithRawOptions(789, 30000, -1 /* epRcvBuf */, []byte{ + header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), + }) + + // Send 3200 bytes of data. + const writeSize = 3200 + data := buffer.NewView(writeSize) + for i := range data { + data[i] = byte(i) + } + + if _, _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte { + var ret []byte + for i, size := range sizes { + p := c.GetPacket() + if i == which { + ret = p + } + checker.IPv4(t, p, + checker.PayloadLen(size+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(seqNum), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + seqNum += uint32(size) + } + return ret + } + + // Receive three packets. + sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload} + first := receivePackets(c, sizes, 0, uint32(c.IRS)+1) + + // Send "packet too big" messages back to netstack. + const newMTU = 1200 + const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize + mtu := []byte{0, 0, newMTU / 256, newMTU % 256} + c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU) + + // See retransmitted packets. None exceeding the new max. + sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload} + receivePackets(c, sizes, -1, uint32(c.IRS)+1) +} + +func TestTCPEndpointProbe(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + invoked := make(chan struct{}) + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that the endpoint ID is what we expect. + // + // We don't do an extensive validation of every field but a + // basic sanity test. + if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want { + t.Fatalf("got LocalAddress: %q, want: %q", got, want) + } + if got, want := state.ID.LocalPort, c.Port; got != want { + t.Fatalf("got LocalPort: %d, want: %d", got, want) + } + if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want { + t.Fatalf("got RemoteAddress: %q, want: %q", got, want) + } + if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want { + t.Fatalf("got RemotePort: %d, want: %d", got, want) + } + + invoked <- struct{}{} + }) + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + select { + case <-invoked: + case <-time.After(100 * time.Millisecond): + t.Fatalf("TCP Probe function was not called") + } +} + +func TestStackSetCongestionControl(t *testing.T) { + testCases := []struct { + cc tcpip.CongestionControlOption + err *tcpip.Error + }{ + {"reno", nil}, + {"cubic", nil}, + {"blahblah", tcpip.ErrNoSuchFile}, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("SetTransportProtocolOption(.., %v)", tc.cc), func(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + s := c.Stack() + + var oldCC tcpip.CongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %s", tcp.ProtocolNumber, &oldCC, err) + } + + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { + t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err) + } + + var cc tcpip.CongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) + } + + got, want := cc, oldCC + // If SetTransportProtocolOption is expected to succeed + // then the returned value for congestion control should + // match the one specified in the + // SetTransportProtocolOption call above, else it should + // be what it was before the call to + // SetTransportProtocolOption. + if tc.err == nil { + want = tc.cc + } + if got != want { + t.Fatalf("got congestion control: %v, want: %v", got, want) + } + }) + } +} + +func TestStackAvailableCongestionControl(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + s := c.Stack() + + // Query permitted congestion control algorithms. + var aCC tcpip.AvailableCongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) + } + if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + } +} + +func TestStackSetAvailableCongestionControl(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + s := c.Stack() + + // Setting AvailableCongestionControlOption should fail. + aCC := tcpip.AvailableCongestionControlOption("xyz") + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) + } + + // Verify that we still get the expected list of congestion control options. + var cc tcpip.AvailableCongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) + } + if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + } +} + +func TestEndpointSetCongestionControl(t *testing.T) { + testCases := []struct { + cc tcpip.CongestionControlOption + err *tcpip.Error + }{ + {"reno", nil}, + {"cubic", nil}, + {"blahblah", tcpip.ErrNoSuchFile}, + } + + for _, connected := range []bool{false, true} { + for _, tc := range testCases { + t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + var oldCC tcpip.CongestionControlOption + if err := c.EP.GetSockOpt(&oldCC); err != nil { + t.Fatalf("c.EP.SockOpt(%v) = %s", &oldCC, err) + } + + if connected { + c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) + } + + if err := c.EP.SetSockOpt(tc.cc); err != tc.err { + t.Fatalf("c.EP.SetSockOpt(%v) = %s, want %s", tc.cc, err, tc.err) + } + + var cc tcpip.CongestionControlOption + if err := c.EP.GetSockOpt(&cc); err != nil { + t.Fatalf("c.EP.SockOpt(%v) = %s", &cc, err) + } + + got, want := cc, oldCC + // If SetSockOpt is expected to succeed then the + // returned value for congestion control should match + // the one specified in the SetSockOpt above, else it + // should be what it was before the call to SetSockOpt. + if tc.err == nil { + want = tc.cc + } + if got != want { + t.Fatalf("got congestion control: %v, want: %v", got, want) + } + }) + } + } +} + +func enableCUBIC(t *testing.T, c *context.Context) { + t.Helper() + opt := tcpip.CongestionControlOption("cubic") + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %s = %s", opt, err) + } +} + +func TestKeepalive(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + const keepAliveInterval = 3 * time.Second + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 5) + c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) + + // 5 unacked keepalives are sent. ACK each one, and check that the + // connection stays alive after 5. + for i := 0; i < 10; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Acknowledge the keepalive. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS, + RcvWnd: 30000, + }) + } + + // Check that the connection is still alive. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + + // Send some data and wait before ACKing it. Keepalives should be disabled + // during this period. + view := buffer.NewView(3) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Wait for the packet to be retransmitted. Verify that no keepalives + // were sent. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagPsh), + ), + ) + c.CheckNoPacket("Keepalive packet received while unACKed data is pending") + + next += uint32(len(view)) + + // Send ACK. Keepalives should start sending again. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Now receive 5 keepalives, but don't ACK them. The connection + // should be reset after 5. + for i := 0; i < 5; i++ { + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(next-1)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + keepAliveInterval/2) + + // The connection should be terminated after 5 unacked keepalives. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(next)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if got := c.Stack().Stats().TCP.EstablishedTimedout.Value(); got != 1 { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) + } + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) + } + + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } +} + +func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + // Send a SYN request. + irs = seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + iss = seqnum.Value(tcp.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + + if synCookieInUse { + // When cookies are in use window scaling is disabled. + tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ + WS: -1, + MSS: c.MSSWithoutOptions(), + })) + } + + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + + // Send ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + return irs, iss +} + +func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + // Send a SYN request. + irs = seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetV6Packet() + tcp := header.TCP(header.IPv6(b).Payload()) + iss = seqnum.Value(tcp.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(srcPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + + if synCookieInUse { + // When cookies are in use window scaling is disabled. + tcpCheckers = append(tcpCheckers, checker.TCPSynOptions(header.TCPSynOptions{ + WS: -1, + MSS: c.MSSWithoutOptionsV6(), + })) + } + + checker.IPv6(t, b, checker.TCP(tcpCheckers...)) + + // Send ACK. + c.SendV6Packet(nil, &context.Headers{ + SrcPort: srcPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + return irs, iss +} + +// TestListenBacklogFull tests that netstack does not complete handshakes if the +// listen backlog for the endpoint is full. +func TestListenBacklogFull(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + // Test acceptance. + // Start listening. + listenBacklog := 2 + if err := c.EP.Listen(listenBacklog); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + for i := 0; i < listenBacklog; i++ { + executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */) + } + + time.Sleep(50 * time.Millisecond) + + // Now execute send one more SYN. The stack should not respond as the backlog + // is full at this point. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + 2, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(789), + RcvWnd: 30000, + }) + c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) + + // Try to accept the connections in the backlog. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + for i := 0; i < listenBacklog; i++ { + _, _, err = c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + } + + // Now verify that there are no more connections that can be accepted. + _, _, err = c.EP.Accept() + if err != tcpip.ErrWouldBlock { + select { + case <-ch: + t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) + case <-time.After(1 * time.Second): + } + } + + // Now a new handshake must succeed. + executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) + + newEP, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + newEP, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now verify that the TCP socket is usable and in a connected state. + data := "Don't panic" + newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + if string(tcp.Payload()) != data { + t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) + } +} + +// TestListenNoAcceptMulticastBroadcastV4 makes sure that TCP segments with a +// non unicast IPv4 address are not accepted. +func TestListenNoAcceptNonUnicastV4(t *testing.T) { + multicastAddr := tcpip.Address("\xe0\x00\x01\x02") + otherMulticastAddr := tcpip.Address("\xe0\x00\x01\x03") + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + "SourceUnspecified", + header.IPv4Any, + context.StackAddr, + }, + { + "SourceBroadcast", + header.IPv4Broadcast, + context.StackAddr, + }, + { + "SourceOurMulticast", + multicastAddr, + context.StackAddr, + }, + { + "SourceOtherMulticast", + otherMulticastAddr, + context.StackAddr, + }, + { + "DestUnspecified", + context.TestAddr, + header.IPv4Any, + }, + { + "DestBroadcast", + context.TestAddr, + header.IPv4Broadcast, + }, + { + "DestOurMulticast", + context.TestAddr, + multicastAddr, + }, + { + "DestOtherMulticast", + context.TestAddr, + otherMulticastAddr, + }, + } + + for _, test := range tests { + test := test // capture range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.Stack().JoinGroup(header.IPv4ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + irs := seqnum.Value(789) + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendPacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestAddr, context.StackAddr) + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + }) + } +} + +// TestListenNoAcceptMulticastBroadcastV6 makes sure that TCP segments with a +// non unicast IPv6 address are not accepted. +func TestListenNoAcceptNonUnicastV6(t *testing.T) { + multicastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x01") + otherMulticastAddr := tcpip.Address("\xff\x0e\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01\x02") + + tests := []struct { + name string + srcAddr tcpip.Address + dstAddr tcpip.Address + }{ + { + "SourceUnspecified", + header.IPv6Any, + context.StackV6Addr, + }, + { + "SourceAllNodes", + header.IPv6AllNodesMulticastAddress, + context.StackV6Addr, + }, + { + "SourceOurMulticast", + multicastAddr, + context.StackV6Addr, + }, + { + "SourceOtherMulticast", + otherMulticastAddr, + context.StackV6Addr, + }, + { + "DestUnspecified", + context.TestV6Addr, + header.IPv6Any, + }, + { + "DestAllNodes", + context.TestV6Addr, + header.IPv6AllNodesMulticastAddress, + }, + { + "DestOurMulticast", + context.TestV6Addr, + multicastAddr, + }, + { + "DestOtherMulticast", + context.TestV6Addr, + otherMulticastAddr, + }, + } + + for _, test := range tests { + test := test // capture range variable + + t.Run(test.name, func(t *testing.T) { + t.Parallel() + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + if err := c.Stack().JoinGroup(header.IPv6ProtocolNumber, 1, multicastAddr); err != nil { + t.Fatalf("JoinGroup failed: %s", err) + } + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + irs := seqnum.Value(789) + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, test.srcAddr, test.dstAddr) + c.CheckNoPacket("Should not have received a response") + + // Handle normal packet. + c.SendV6PacketWithAddrs(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }, context.TestV6Addr, context.StackV6Addr) + checker.IPv6(t, c.GetV6Packet(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + }) + } +} + +func TestListenSynRcvdQueueFull(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + // Test acceptance. + // Start listening. + listenBacklog := 1 + if err := c.EP.Listen(listenBacklog); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send two SYN's the first one should get a SYN-ACK, the + // second one should not get any response and is dropped as + // the synRcvd count will be equal to backlog. + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + + // Now execute send one more SYN. The stack should not respond as the backlog + // is full at this point. + // + // NOTE: we did not complete the handshake for the previous one so the + // accept backlog should be empty and there should be one connection in + // synRcvd state. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort + 1, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(889), + RcvWnd: 30000, + }) + c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) + + // Now complete the previous connection and verify that there is a connection + // to accept. + // Send ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + // Try to accept the connections in the backlog. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + newEP, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + newEP, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now verify that the TCP socket is usable and in a connected state. + data := "Don't panic" + newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + pkt := c.GetPacket() + tcp = header.TCP(header.IPv4(pkt).Payload()) + if string(tcp.Payload()) != data { + t.Fatalf("unexpected data: got %s, want %s", string(tcp.Payload()), data) + } +} + +func TestListenBacklogFullSynCookieInUse(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(1)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 1 failed: %s", err) + } + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + // Test acceptance. + // Start listening. + listenBacklog := 1 + portOffset := uint16(0) + if err := c.EP.Listen(listenBacklog); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + executeHandshake(t, c, context.TestPort+portOffset, false) + portOffset++ + // Wait for this to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + + // Send a SYN request. + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + // pick a different src port for new SYN. + SrcPort: context.TestPort + 1, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + // The Syn should be dropped as the endpoint's backlog is full. + c.CheckNoPacketTimeout("unexpected packet received", 50*time.Millisecond) + + // Verify that there is only one acceptable connection at this point. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + _, _, err = c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now verify that there are no more connections that can be accepted. + _, _, err = c.EP.Accept() + if err != tcpip.ErrWouldBlock { + select { + case <-ch: + t.Fatalf("unexpected endpoint delivered on Accept: %+v", c.EP) + case <-time.After(1 * time.Second): + } + } +} + +func TestSynRcvdBadSeqNumber(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN to get a SYN-ACK. This should put the ep into SYN-RCVD state + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcpHdr.SequenceNumber()) + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(irs) + 1), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + + // Now send a packet with an out-of-window sequence number + largeSeqnum := irs + seqnum.Value(tcpHdr.WindowSize()) + 1 + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: largeSeqnum, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + // Should receive an ACK with the expected SEQ number + b = c.GetPacket() + tcpCheckers = []checker.TransportChecker{ + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.AckNum(uint32(irs) + 1), + checker.SeqNum(uint32(iss + 1)), + } + checker.IPv4(t, b, checker.TCP(tcpCheckers...)) + + // Now that the socket replied appropriately with the ACK, + // complete the connection to test that the large SEQ num + // did not change the state from SYN-RCVD. + + // Send ACK to move to ESTABLISHED state. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + newEP, _, err := c.EP.Accept() + + if err != nil && err != tcpip.ErrWouldBlock { + t.Fatalf("Accept failed: %s", err) + } + + if err == tcpip.ErrWouldBlock { + // Try to accept the connections in the backlog. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + // Wait for connection to be established. + select { + case <-ch: + newEP, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now verify that the TCP socket is usable and in a connected state. + data := "Don't panic" + _, _, err = newEP.Write(tcpip.SlicePayload(buffer.NewViewFromBytes([]byte(data))), tcpip.WriteOptions{}) + + if err != nil { + t.Fatalf("Write failed: %s", err) + } + + pkt := c.GetPacket() + tcpHdr = header.TCP(header.IPv4(pkt).Payload()) + if string(tcpHdr.Payload()) != data { + t.Fatalf("unexpected data: got %s, want %s", string(tcpHdr.Payload()), data) + } +} + +func TestPassiveConnectionAttemptIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + c.EP = ep + if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + + stats := c.Stack().Stats() + want := stats.TCP.PassiveConnectionOpenings.Value() + 1 + + srcPort := uint16(context.TestPort) + executeHandshake(t, c, srcPort+1, false) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + // Verify that there is only one acceptable connection at this point. + _, _, err = c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + if got := stats.TCP.PassiveConnectionOpenings.Value(); got != want { + t.Errorf("got stats.TCP.PassiveConnectionOpenings.Value() = %d, want = %d", got, want) + } +} + +func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + c.EP = ep + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Listen(1); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + srcPort := uint16(context.TestPort) + // Now attempt a handshakes it will fill up the accept backlog. + executeHandshake(t, c, srcPort, false) + + // Give time for the final ACK to be processed as otherwise the next handshake could + // get accepted before the previous one based on goroutine scheduling. + time.Sleep(50 * time.Millisecond) + + want := stats.TCP.ListenOverflowSynDrop.Value() + 1 + + // Now we will send one more SYN and this one should get dropped + // Send a SYN request. + c.SendPacket(nil, &context.Headers{ + SrcPort: srcPort + 2, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: seqnum.Value(789), + RcvWnd: 30000, + }) + + time.Sleep(50 * time.Millisecond) + if got := stats.TCP.ListenOverflowSynDrop.Value(); got != want { + t.Errorf("got stats.TCP.ListenOverflowSynDrop.Value() = %d, want = %d", got, want) + } + if got := c.EP.Stats().(*tcp.Stats).ReceiveErrors.ListenOverflowSynDrop.Value(); got != want { + t.Errorf("got EP stats Stats.ReceiveErrors.ListenOverflowSynDrop = %d, want = %d", got, want) + } + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + // Now check that there is one acceptable connections. + _, _, err = c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + _, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } +} + +func TestEndpointBindListenAcceptState(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + + if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected { + t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected) + } + if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { + t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + + c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + aep, _, err := ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + aep, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + if err := aep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrAlreadyConnected { + t.Errorf("unexpected error attempting to call connect on an established endpoint, got: %s, want: %s", err, tcpip.ErrAlreadyConnected) + } + // Listening endpoint remains in listen state. + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + + ep.Close() + // Give worker goroutines time to receive the close notification. + time.Sleep(1 * time.Second) + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + // Accepted endpoint remains open when the listen endpoint is closed. + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("unexpected endpoint state: want %s, got %s", want, got) + } + +} + +// This test verifies that the auto tuning does not grow the receive buffer if +// the application is not reading the data actively. +func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { + const mtu = 1500 + const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize + + c := context.New(t, mtu) + defer c.Cleanup() + + stk := c.Stack() + // Set lower limits for auto-tuning tests. This is required because the + // test stops the worker which can cause packets to be dropped because + // the segment queue holding unprocessed packets is limited to 500. + const receiveBufferSize = 80 << 10 // 80KB. + const maxReceiveBufferSize = receiveBufferSize * 10 + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) + } + + // Enable auto-tuning. + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) + } + // Change the expected window scale to match the value needed for the + // maximum buffer size defined above. + c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) + + rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) + + // NOTE: The timestamp values in the sent packets are meaningless to the + // peer so we just increment the timestamp value by 1 every batch as we + // are not really using them for anything. Send a single byte to verify + // the advertised window. + tsVal := rawEP.TSVal + 1 + + // Introduce a 25ms latency by delaying the first byte. + latency := 25 * time.Millisecond + time.Sleep(latency) + rawEP.SendPacketWithTS([]byte{1}, tsVal) + + // Verify that the ACK has the expected window. + wantRcvWnd := receiveBufferSize + wantRcvWnd = (wantRcvWnd >> uint32(c.WindowScale)) + rawEP.VerifyACKRcvWnd(uint16(wantRcvWnd - 1)) + time.Sleep(25 * time.Millisecond) + + // Allocate a large enough payload for the test. + b := make([]byte, int(receiveBufferSize)*2) + offset := 0 + payloadSize := receiveBufferSize - 1 + worker := (c.EP).(interface { + StopWork() + ResumeWork() + }) + tsVal++ + + // Stop the worker goroutine. + worker.StopWork() + start := offset + end := offset + payloadSize + packetsSent := 0 + for ; start < end; start += mss { + rawEP.SendPacketWithTS(b[start:start+mss], tsVal) + packetsSent++ + } + + // Resume the worker so that it only sees the packets once all of them + // are waiting to be read. + worker.ResumeWork() + + // Since we read no bytes the window should goto zero till the + // application reads some of the data. + // Discard all intermediate acks except the last one. + if packetsSent > 100 { + for i := 0; i < (packetsSent / 100); i++ { + _ = c.GetPacket() + } + } + rawEP.VerifyACKRcvWnd(0) + + time.Sleep(25 * time.Millisecond) + // Verify that sending more data when window is closed is dropped and + // not acked. + rawEP.SendPacketWithTS(b[start:start+mss], tsVal) + + // Verify that the stack sends us back an ACK with the sequence number + // of the last packet sent indicating it was dropped. + p := c.GetPacket() + checker.IPv4(t, p, checker.TCP( + checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), + checker.Window(0), + )) + + // Now read all the data from the endpoint and verify that advertised + // window increases to the full available buffer size. + for { + _, _, err := c.EP.Read(nil) + if err == tcpip.ErrWouldBlock { + break + } + } + + // Verify that we receive a non-zero window update ACK. When running + // under thread santizer this test can end up sending more than 1 + // ack, 1 for the non-zero window + p = c.GetPacket() + checker.IPv4(t, p, checker.TCP( + checker.AckNum(uint32(rawEP.NextSeqNum)-uint32(mss)), + func(t *testing.T, h header.Transport) { + tcp, ok := h.(header.TCP) + if !ok { + return + } + if w := tcp.WindowSize(); w == 0 || w > uint16(wantRcvWnd) { + t.Errorf("expected a non-zero window: got %d, want <= wantRcvWnd", w) + } + }, + )) +} + +// This test verifies that the auto tuning does not grow the receive buffer if +// the application is not reading the data actively. +func TestReceiveBufferAutoTuning(t *testing.T) { + const mtu = 1500 + const mss = mtu - header.IPv4MinimumSize - header.TCPMinimumSize + + c := context.New(t, mtu) + defer c.Cleanup() + + // Enable Auto-tuning. + stk := c.Stack() + // Set lower limits for auto-tuning tests. This is required because the + // test stops the worker which can cause packets to be dropped because + // the segment queue holding unprocessed packets is limited to 300. + const receiveBufferSize = 80 << 10 // 80KB. + const maxReceiveBufferSize = receiveBufferSize * 10 + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: receiveBufferSize, Max: maxReceiveBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) + } + + // Enable auto-tuning. + if err := stk.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.ModerateReceiveBufferOption(true)); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) + } + // Change the expected window scale to match the value needed for the + // maximum buffer size used by stack. + c.WindowScale = uint8(tcp.FindWndScale(maxReceiveBufferSize)) + + rawEP := c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, WS: 4}) + + wantRcvWnd := receiveBufferSize + scaleRcvWnd := func(rcvWnd int) uint16 { + return uint16(rcvWnd >> uint16(c.WindowScale)) + } + // Allocate a large array to send to the endpoint. + b := make([]byte, receiveBufferSize*48) + + // In every iteration we will send double the number of bytes sent in + // the previous iteration and read the same from the app. The received + // window should grow by at least 2x of bytes read by the app in every + // RTT. + offset := 0 + payloadSize := receiveBufferSize / 8 + worker := (c.EP).(interface { + StopWork() + ResumeWork() + }) + tsVal := rawEP.TSVal + // We are going to do our own computation of what the moderated receive + // buffer should be based on sent/copied data per RTT and verify that + // the advertised window by the stack matches our calculations. + prevCopied := 0 + done := false + latency := 1 * time.Millisecond + for i := 0; !done; i++ { + tsVal++ + + // Stop the worker goroutine. + worker.StopWork() + start := offset + end := offset + payloadSize + totalSent := 0 + packetsSent := 0 + for ; start < end; start += mss { + rawEP.SendPacketWithTS(b[start:start+mss], tsVal) + totalSent += mss + packetsSent++ + } + + // Resume it so that it only sees the packets once all of them + // are waiting to be read. + worker.ResumeWork() + + // Give 1ms for the worker to process the packets. + time.Sleep(1 * time.Millisecond) + + // Verify that the advertised window on the ACK is reduced by + // the total bytes sent. + expectedWnd := wantRcvWnd - totalSent + if packetsSent > 100 { + for i := 0; i < (packetsSent / 100); i++ { + _ = c.GetPacket() + } + } + rawEP.VerifyACKRcvWnd(scaleRcvWnd(expectedWnd)) + + // Now read all the data from the endpoint and invoke the + // moderation API to allow for receive buffer auto-tuning + // to happen before we measure the new window. + totalCopied := 0 + for { + b, _, err := c.EP.Read(nil) + if err == tcpip.ErrWouldBlock { + break + } + totalCopied += len(b) + } + + // Invoke the moderation API. This is required for auto-tuning + // to happen. This method is normally expected to be invoked + // from a higher layer than tcpip.Endpoint. So we simulate + // copying to userspace by invoking it explicitly here. + c.EP.ModerateRecvBuf(totalCopied) + + // Now send a keep-alive packet to trigger an ACK so that we can + // measure the new window. + rawEP.NextSeqNum-- + rawEP.SendPacketWithTS(nil, tsVal) + rawEP.NextSeqNum++ + + if i == 0 { + // In the first iteration the receiver based RTT is not + // yet known as a result the moderation code should not + // increase the advertised window. + rawEP.VerifyACKRcvWnd(scaleRcvWnd(wantRcvWnd)) + prevCopied = totalCopied + } else { + rttCopied := totalCopied + if i == 1 { + // The moderation code accumulates copied bytes till + // RTT is established. So add in the bytes sent in + // the first iteration to the total bytes for this + // RTT. + rttCopied += prevCopied + // Now reset it to the initial value used by the + // auto tuning logic. + prevCopied = tcp.InitialCwnd * mss * 2 + } + newWnd := rttCopied<<1 + 16*mss + grow := (newWnd * (rttCopied - prevCopied)) / prevCopied + newWnd += (grow << 1) + if newWnd > maxReceiveBufferSize { + newWnd = maxReceiveBufferSize + done = true + } + rawEP.VerifyACKRcvWnd(scaleRcvWnd(newWnd)) + wantRcvWnd = newWnd + prevCopied = rttCopied + // Increase the latency after first two iterations to + // establish a low RTT value in the receiver since it + // only tracks the lowest value. This ensures that when + // ModerateRcvBuf is called the elapsed time is always > + // rtt. Without this the test is flaky due to delays due + // to scheduling/wakeup etc. + latency += 50 * time.Millisecond + } + time.Sleep(latency) + offset += payloadSize + payloadSize *= 2 + } +} + +func TestDelayEnabled(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + checkDelayOption(t, c, false, false) // Delay is disabled by default. + + for _, v := range []struct { + delayEnabled tcp.DelayEnabled + wantDelayOption bool + }{ + {delayEnabled: false, wantDelayOption: false}, + {delayEnabled: true, wantDelayOption: true}, + } { + c := context.New(t, defaultMTU) + defer c.Cleanup() + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, v.delayEnabled); err != nil { + t.Fatalf("SetTransportProtocolOption(tcp, %t) failed: %s", v.delayEnabled, err) + } + checkDelayOption(t, c, v.delayEnabled, v.wantDelayOption) + } +} + +func checkDelayOption(t *testing.T, c *context.Context, wantDelayEnabled tcp.DelayEnabled, wantDelayOption bool) { + t.Helper() + + var gotDelayEnabled tcp.DelayEnabled + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &gotDelayEnabled); err != nil { + t.Fatalf("TransportProtocolOption(tcp, &gotDelayEnabled) failed: %s", err) + } + if gotDelayEnabled != wantDelayEnabled { + t.Errorf("TransportProtocolOption(tcp, &gotDelayEnabled) got %t, want %t", gotDelayEnabled, wantDelayEnabled) + } + + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, new(waiter.Queue)) + if err != nil { + t.Fatalf("NewEndPoint(tcp, ipv4, new(waiter.Queue)) failed: %s", err) + } + gotDelayOption, err := ep.GetSockOptBool(tcpip.DelayOption) + if err != nil { + t.Fatalf("ep.GetSockOptBool(tcpip.DelayOption) failed: %s", err) + } + if gotDelayOption != wantDelayOption { + t.Errorf("ep.GetSockOptBool(tcpip.DelayOption) got: %t, want: %t", gotDelayOption, wantDelayOption) + } +} + +func TestTCPLingerTimeout(t *testing.T) { + c := context.New(t, 1500 /* mtu */) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + testCases := []struct { + name string + tcpLingerTimeout time.Duration + want time.Duration + }{ + {"NegativeLingerTimeout", -123123, 0}, + {"ZeroLingerTimeout", 0, 0}, + {"InRangeLingerTimeout", 10 * time.Second, 10 * time.Second}, + // Values > stack's TCPLingerTimeout are capped to the stack's + // value. Defaults to tcp.DefaultTCPLingerTimeout(60 seconds) + {"AboveMaxLingerTimeout", 65 * time.Second, 60 * time.Second}, + } + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + if err := c.EP.SetSockOpt(tcpip.TCPLingerTimeoutOption(tc.tcpLingerTimeout)); err != nil { + t.Fatalf("SetSockOpt(%s) = %s", tc.tcpLingerTimeout, err) + } + var v tcpip.TCPLingerTimeoutOption + if err := c.EP.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockOpt(tcpip.TCPLingerTimeoutOption) = %s", err) + } + if got, want := time.Duration(v), tc.want; got != want { + t.Fatalf("unexpected linger timeout got: %s, want: %s", got, want) + } + }) + } +} + +func TestTCPTimeWaitRSTIgnored(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now send a RST and this should be ignored and not + // generate an ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + }) + + c.CheckNoPacketTimeout("unexpected packet received in TIME_WAIT state", 1*time.Second) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitOutOfOrder(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Out of order ACK should generate an immediate ACK in + // TIME_WAIT. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 3, + }) + + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) +} + +func TestTCPTimeWaitNewSyn(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Send a SYN request w/ sequence number lower than + // the highest sequence number sent. We just reuse + // the same number. + iss = seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + + // Send a SYN request w/ sequence number higher than + // the highest sequence number sent. + iss = seqnum.Value(792) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b = c.GetPacket() + tcpHdr = header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } +} + +func TestTCPTimeWaitDuplicateFINExtendsTimeWait(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + want := c.Stack().Stats().TCP.EstablishedClosed.Value() + 1 + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + c.EP.Close() + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+1), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + time.Sleep(2 * time.Second) + + // Now send a duplicate FIN. This should cause the TIME_WAIT to extend + // by another 5 seconds and also send us a duplicate ACK as it should + // indicate that the final ACK was potentially lost. + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+2)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Sleep for 4 seconds so at this point we are 1 second past the + // original tcpLingerTimeout of 5 seconds. + time.Sleep(4 * time.Second) + + // Send an ACK and it should not generate any packet as the socket + // should still be in TIME_WAIT for another another 5 seconds due + // to the duplicate FIN we sent earlier. + *ackHeaders = *finHeaders + ackHeaders.SeqNum = ackHeaders.SeqNum + 1 + ackHeaders.Flags = header.TCPFlagAck + c.SendPacket(nil, ackHeaders) + + c.CheckNoPacketTimeout("unexpected packet received from endpoint in TIME_WAIT", 1*time.Second) + // Now sleep for another 2 seconds so that we are past the + // extended TIME_WAIT of 7 seconds (2 + 5). + time.Sleep(2 * time.Second) + + // Resend the same ACK. + c.SendPacket(nil, ackHeaders) + + // Receive the RST that should be generated as there is no valid + // endpoint. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) + + if got := c.Stack().Stats().TCP.EstablishedClosed.Value(); got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedClosed = %d, want = %d", got, want) + } + if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) + } +} + +func TestTCPCloseWithData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set TCPTimeWaitTimeout to 5 seconds so that sockets are marked closed + // after 5 seconds in TIME_WAIT state. + tcpTimeWaitTimeout := 5 * time.Second + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPTimeWaitTimeoutOption(tcpTimeWaitTimeout)); err != nil { + t.Fatalf("c.stack.SetTransportProtocolOption(tcp, tcpip.TCPLingerTimeoutOption(%d) failed: %s", tcpTimeWaitTimeout, err) + } + + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + // Send a SYN request. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + ackHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + RcvWnd: 30000, + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %s", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Now trigger a passive close by sending a FIN. + finHeaders := &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: iss + 1, + AckNum: c.IRS + 2, + RcvWnd: 30000, + } + + c.SendPacket(nil, finHeaders) + + // Get the ACK to the FIN we just sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(iss)+2), + checker.TCPFlags(header.TCPFlagAck))) + + // Now write a few bytes and then close the endpoint. + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + // Check that data is received. + b = c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+2), // Acknum is initial sequence number + 1 + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; !bytes.Equal(data, p) { + t.Errorf("got data = %x, want = %x", p, data) + } + + c.EP.Close() + // Check the FIN. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)+uint32(len(data))), + checker.AckNum(uint32(iss+2)), + checker.TCPFlags(header.TCPFlagFin|header.TCPFlagAck))) + + // First send a partial ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)-1), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now send a full ACK. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Now ACK the FIN. + ackHeaders.AckNum++ + c.SendPacket(nil, ackHeaders) + + // Now send an ACK and we should get a RST back as the endpoint should + // be in CLOSED state. + ackHeaders = &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 2, + AckNum: c.IRS + 1 + seqnum.Value(len(data)), + RcvWnd: 30000, + } + c.SendPacket(nil, ackHeaders) + + // Check the RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(ackHeaders.AckNum)), + checker.AckNum(0), + checker.TCPFlags(header.TCPFlagRst))) +} + +func TestTCPUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventHUp) + defer c.WQ.EventUnregister(&waitEntry) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + // Ensure that on the next retransmit timer fire, the user timeout has + // expired. + initRTO := 1 * time.Second + userTimeout := initRTO / 2 + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Send some data and wait before ACKing it. + view := buffer.NewView(3) + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %s", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Wait for the retransmit timer to be fired and the user timeout to cause + // close of the connection. + select { + case <-notifyCh: + case <-time.After(2 * initRTO): + t.Fatalf("connection still alive after %s, should have been closed after :%s", 2*initRTO, userTimeout) + } + + // No packet should be received as the connection should be silently + // closed due to timeout. + c.CheckNoPacket("unexpected packet received after userTimeout has expired") + + next += uint32(len(view)) + + // The connection should be terminated after userTimeout has expired. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(next)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) + } + + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } +} + +func TestKeepaliveWithUserTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, -1 /* epRcvBuf */) + + origEstablishedTimedout := c.Stack().Stats().TCP.EstablishedTimedout.Value() + + const keepAliveInterval = 3 * time.Second + c.EP.SetSockOpt(tcpip.KeepaliveIdleOption(100 * time.Millisecond)) + c.EP.SetSockOpt(tcpip.KeepaliveIntervalOption(keepAliveInterval)) + c.EP.SetSockOptInt(tcpip.KeepaliveCountOption, 10) + c.EP.SetSockOptBool(tcpip.KeepaliveEnabledOption, true) + + // Set userTimeout to be the duration to be 1 keepalive + // probes. Which means that after the first probe is sent + // the second one should cause the connection to be + // closed due to userTimeout being hit. + userTimeout := 1 * keepAliveInterval + c.EP.SetSockOpt(tcpip.TCPUserTimeoutOption(userTimeout)) + + // Check that the connection is still alive. + if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) + } + + // Now receive 1 keepalives, but don't ACK it. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)), + checker.AckNum(uint32(790)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Sleep for a litte over the KeepAlive interval to make sure + // the timer has time to fire after the last ACK and close the + // close the socket. + time.Sleep(keepAliveInterval + keepAliveInterval/2) + + // The connection should be closed with a timeout. + // Send an ACK to trigger a RST from the stack as the endpoint should + // be dead. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(c.IRS + 1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS+1)), + checker.AckNum(uint32(0)), + checker.TCPFlags(header.TCPFlagRst), + ), + ) + + if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { + t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) + } + if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { + t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) + } + if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { + t.Errorf("got stats.TCP.CurrentConnected.Value() = %d, want = 0", got) + } +} + +func TestIncreaseWindowOnReceive(t *testing.T) { + // This test ensures that the endpoint sends an ack, + // after recv() when the window grows to more than 1 MSS. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + const rcvBuf = 65535 * 10 + c.CreateConnected(789, 30000, rcvBuf) + + // Write chunks of ~30000 bytes. It's important that two + // payloads make it equal or longer than MSS. + remain := rcvBuf + sent := 0 + data := make([]byte, defaultMTU/2) + lastWnd := uint16(0) + + for remain > len(data) { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + + lastWnd = uint16(remain) + if remain > 0xffff { + lastWnd = 0xffff + } + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(lastWnd), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + if lastWnd == 0xffff || lastWnd == 0 { + t.Fatalf("expected small, non-zero window: %d", lastWnd) + } + + // We now have < 1 MSS in the buffer space. Read the data! An + // ack should be sent in response to that. The window was not + // zero, but it grew to larger than MSS. + if _, _, err := c.EP.Read(nil); err != nil { + t.Fatalf("Read failed: %s", err) + } + + if _, _, err := c.EP.Read(nil); err != nil { + t.Fatalf("Read failed: %s", err) + } + + // After reading two packets, we surely crossed MSS. See the ack: + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(0xffff)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestIncreaseWindowOnBufferResize(t *testing.T) { + // This test ensures that the endpoint sends an ack, + // after available recv buffer grows to more than 1 MSS. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + const rcvBuf = 65535 * 10 + c.CreateConnected(789, 30000, rcvBuf) + + // Write chunks of ~30000 bytes. It's important that two + // payloads make it equal or longer than MSS. + remain := rcvBuf + sent := 0 + data := make([]byte, defaultMTU/2) + lastWnd := uint16(0) + + for remain > len(data) { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + + lastWnd = uint16(remain) + if remain > 0xffff { + lastWnd = 0xffff + } + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(lastWnd), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + if lastWnd == 0xffff || lastWnd == 0 { + t.Fatalf("expected small, non-zero window: %d", lastWnd) + } + + // Increasing the buffer from should generate an ACK, + // since window grew from small value to larger equal MSS + c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, rcvBuf*2) + + // After reading two packets, we surely crossed MSS. See the ack: + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(0xffff)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestTCPDeferAccept(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + } + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give a bit of time for the socket to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestTCPDeferAcceptTimeout(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.Create(-1) + + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatal("Bind failed:", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatal("Listen failed:", err) + } + + const tcpDeferAccept = 1 * time.Second + if err := c.EP.SetSockOpt(tcpip.TCPDeferAcceptOption(tcpDeferAccept)); err != nil { + t.Fatalf("c.EP.SetSockOpt(TCPDeferAcceptOption(%s) failed: %s", tcpDeferAccept, err) + } + + irs, iss := executeHandshake(t, c, context.TestPort, false /* synCookiesInUse */) + + if _, _, err := c.EP.Accept(); err != tcpip.ErrWouldBlock { + t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: %s", err, tcpip.ErrWouldBlock) + } + + // Sleep for a little of the tcpDeferAccept timeout. + time.Sleep(tcpDeferAccept + 100*time.Millisecond) + + // On timeout expiry we should get a SYN-ACK retransmission. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1))) + + // Send data. This should result in an acceptable endpoint. + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) + + // Give sometime for the endpoint to be delivered to the accept queue. + time.Sleep(50 * time.Millisecond) + aep, _, err := c.EP.Accept() + if err != nil { + t.Fatalf("c.EP.Accept() returned unexpected error got: %s, want: nil", err) + } + + aep.Close() + // Closing aep without reading the data should trigger a RST. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.SeqNum(uint32(iss+1)), + checker.AckNum(uint32(irs+5)))) +} + +func TestResetDuringClose(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + iss := seqnum.Value(789) + c.CreateConnected(iss, 30000, -1 /* epRecvBuf */) + // Send some data to make sure there is some unread + // data to trigger a reset on c.Close. + irs := c.IRS + c.SendPacket([]byte{1, 2, 3, 4}, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: iss.Add(1), + AckNum: irs.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK for the data we sent. + checker.IPv4(t, c.GetPacket(), checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(irs.Add(1))), + checker.AckNum(uint32(iss.Add(5))))) + + // Close in a separate goroutine so that we can trigger + // a race with the RST we send below. This should not + // panic due to the route being released depeding on + // whether Close() sends an active RST or the RST sent + // below is processed by the worker first. + var wg sync.WaitGroup + + wg.Add(1) + go func() { + defer wg.Done() + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + SeqNum: iss.Add(5), + AckNum: c.IRS.Add(5), + RcvWnd: 30000, + Flags: header.TCPFlagRst, + }) + }() + + wg.Add(1) + go func() { + defer wg.Done() + c.EP.Close() + }() + + wg.Wait() +} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go new file mode 100644 index 000000000..8edbff964 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -0,0 +1,291 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp_test + +import ( + "bytes" + "math/rand" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.dev/gvisor/pkg/waiter" +) + +// createConnectedWithTimestampOption creates and connects c.ep with the +// timestamp option enabled. +func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint { + return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1}) +} + +// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on +// an active connect and sets the TS Echo Reply fields correctly when the +// SYN-ACK also indicates support for the TS option and provides a TSVal. +func TestTimeStampEnabledConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + rep := createConnectedWithTimestampOption(c) + + // Register for read and validate that we have data to read. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + // The following tests ensure that TS option once enabled behaves + // correctly as described in + // https://tools.ietf.org/html/rfc7323#section-4.3. + // + // We are not testing delayed ACKs here, but we do test out of order + // packet delivery and filling the sequence number hole created due to + // the out of order packet. + // + // The test also verifies that the sequence numbers and timestamps are + // as expected. + data := []byte{1, 2, 3} + + // First we increment tsVal by a small amount. + tsVal := rep.TSVal + 100 + rep.SendPacketWithTS(data, tsVal) + rep.VerifyACKWithTS(tsVal) + + // Next we send an out of order packet. + rep.NextSeqNum += 3 + tsVal += 200 + rep.SendPacketWithTS(data, tsVal) + + // The ACK should contain the original sequenceNumber and an older TS. + rep.NextSeqNum -= 6 + rep.VerifyACKWithTS(tsVal - 200) + + // Next we fill the hole and the returned ACK should contain the + // cumulative sequence number acking all data sent till now and have the + // latest timestamp sent below in its TSEcr field. + tsVal -= 100 + rep.SendPacketWithTS(data, tsVal) + rep.NextSeqNum += 3 + rep.VerifyACKWithTS(tsVal) + + // Increment tsVal by a large value that doesn't result in a wrap around. + tsVal += 0x7fffffff + rep.SendPacketWithTS(data, tsVal) + rep.VerifyACKWithTS(tsVal) + + // Increment tsVal again by a large value which should cause the + // timestamp value to wrap around. The returned ACK should contain the + // wrapped around timestamp in its tsEcr field and not the tsVal from + // the previous packet sent above. + tsVal += 0x7fffffff + rep.SendPacketWithTS(data, tsVal) + rep.VerifyACKWithTS(tsVal) + + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // There should be 5 views to read and each of them should + // contain the same data. + for i := 0; i < 5; i++ { + got, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Unexpected error from Read: %v", err) + } + if want := data; bytes.Compare(got, want) != 0 { + t.Fatalf("Data is different: got: %v, want: %v", got, want) + } + } +} + +// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an +// active connect but if the SYN-ACK doesn't specify the TS option then +// timestamp option is not enabled and future packets do not contain a +// timestamp. +func TestTimeStampDisabledConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnectedWithOptions(header.TCPSynOptions{}) +} + +func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + if cookieEnabled { + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + } + + t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) + tsVal := rand.Uint32() + c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal}) + + // Now send some data and validate that timestamp is echoed correctly in the ACK. + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %s", err) + } + + // Check that data is received and that the timestamp option TSEcr field + // matches the expected value. + b := c.GetPacket() + checker.IPv4(t, b, + // Add 12 bytes for the timestamp option + 2 NOPs to align at 4 + // byte boundary. + checker.PayloadLen(len(data)+header.TCPMinimumSize+12), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(wndSize), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPTimestampChecker(true, 0, tsVal+1), + ), + ) +} + +// TestTimeStampEnabledAccept tests that if the SYN on a passive connect +// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK +// and echoes the tsVal field of the original SYN in the tcEcr field of the +// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify +// that Timestamp option is enabled in both cases if requested in the original +// SYN. +func TestTimeStampEnabledAccept(t *testing.T) { + testCases := []struct { + cookieEnabled bool + wndScale int + wndSize uint16 + }{ + {true, -1, 0xffff}, // When cookie is used window scaling is disabled. + {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. + } + for _, tc := range testCases { + timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) + } +} + +func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + if cookieEnabled { + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPSynRcvdCountThresholdOption(0)); err != nil { + t.Fatalf("setting TCPSynRcvdCountThresholdOption to 0 failed: %s", err) + } + } + + t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) + c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Now send some data with the accepted connection endpoint and validate + // that no timestamp option is sent in the TCP segment. + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %s", err) + } + + // Check that data is received and that the timestamp option is disabled + // when SYN cookies are enabled/disabled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(wndSize), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPTimestampChecker(false, 0, 0), + ), + ) +} + +// TestTimeStampDisabledAccept tests that Timestamp option is not used when the +// peer doesn't advertise it and connection is established with Accept(). +func TestTimeStampDisabledAccept(t *testing.T) { + testCases := []struct { + cookieEnabled bool + wndScale int + wndSize uint16 + }{ + {true, -1, 0xffff}, // When cookie is used window scaling is disabled. + {false, 5, 0x8000}, // DefaultReceiveBufferSize is 1MB >> 5. + } + for _, tc := range testCases { + timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) + } +} + +func TestSendGreaterThanMTUWithOptions(t *testing.T) { + const maxPayload = 100 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + createConnectedWithTimestampOption(c) + testBrokenUpWrite(t, c, maxPayload) +} + +func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) { + const maxPayload = 100 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + rep := createConnectedWithTimestampOption(c) + + // Register for read. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + droppedPacketsStat := c.Stack().Stats().DroppedPackets + droppedPackets := droppedPacketsStat.Value() + data := []byte{1, 2, 3} + // Send a packet with no TCP options/timestamp. + rep.SendPacket(data, nil) + + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Assert that DroppedPackets was not incremented. + if got, want := droppedPacketsStat.Value(), droppedPackets; got != want { + t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want) + } + + // Issue a read and we should data. + got, _, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Unexpected error from Read: %v", err) + } + if want := data; bytes.Compare(got, want) != 0 { + t.Fatalf("Data is different: got: %v, want: %v", got, want) + } +} diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD new file mode 100644 index 000000000..ce6a2c31d --- /dev/null +++ b/pkg/tcpip/transport/tcp/testing/context/BUILD @@ -0,0 +1,26 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "context", + testonly = 1, + srcs = ["context.go"], + visibility = [ + "//visibility:public", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/seqnum", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go new file mode 100644 index 000000000..06fde2a79 --- /dev/null +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -0,0 +1,1121 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package context provides a test context for use in tcp tests. It also +// provides helper methods to assert/check certain behaviours. +package context + +import ( + "bytes" + "context" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // StackAddr is the IPv4 address assigned to the stack. + StackAddr = "\x0a\x00\x00\x01" + + // StackPort is used as the listening port in tests for passive + // connects. + StackPort = 1234 + + // TestAddr is the source address for packets sent to the stack via the + // link layer endpoint. + TestAddr = "\x0a\x00\x00\x02" + + // TestPort is the TCP port used for packets sent to the stack + // via the link layer endpoint. + TestPort = 4096 + + // StackV6Addr is the IPv6 address assigned to the stack. + StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + + // TestV6Addr is the source address for packets sent to the stack via + // the link layer endpoint. + TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + // StackV4MappedAddr is StackAddr as a mapped v6 address. + StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr + + // TestV4MappedAddr is TestAddr as a mapped v6 address. + TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr + + // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0. + V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + + // testInitialSequenceNumber is the initial sequence number sent in packets that + // are sent in response to a SYN or in the initial SYN sent to the stack. + testInitialSequenceNumber = 789 +) + +// Headers is used to represent the TCP header fields when building a +// new packet. +type Headers struct { + // SrcPort holds the src port value to be used in the packet. + SrcPort uint16 + + // DstPort holds the destination port value to be used in the packet. + DstPort uint16 + + // SeqNum is the value of the sequence number field in the TCP header. + SeqNum seqnum.Value + + // AckNum represents the acknowledgement number field in the TCP header. + AckNum seqnum.Value + + // Flags are the TCP flags in the TCP header. + Flags int + + // RcvWnd is the window to be advertised in the ReceiveWindow field of + // the TCP header. + RcvWnd seqnum.Size + + // TCPOpts holds the options to be sent in the option field of the TCP + // header. + TCPOpts []byte +} + +// Context provides an initialized Network stack and a link layer endpoint +// for use in TCP tests. +type Context struct { + t *testing.T + linkEP *channel.Endpoint + s *stack.Stack + + // IRS holds the initial sequence number in the SYN sent by endpoint in + // case of an active connect or the sequence number sent by the endpoint + // in the SYN-ACK sent in response to a SYN when listening in passive + // mode. + IRS seqnum.Value + + // Port holds the port bound by EP below in case of an active connect or + // the listening port number in case of a passive connect. + Port uint16 + + // EP is the test endpoint in the stack owned by this context. This endpoint + // is used in various tests to either initiate an active connect or is used + // as a passive listening endpoint to accept inbound connections. + EP tcpip.Endpoint + + // Wq is the wait queue associated with EP and is used to block for events + // on EP. + WQ waiter.Queue + + // TimeStampEnabled is true if ep is connected with the timestamp option + // enabled. + TimeStampEnabled bool + + // WindowScale is the expected window scale in SYN packets sent by + // the stack. + WindowScale uint8 +} + +// New allocates and initializes a test context containing a new +// stack and a link-layer endpoint. +func New(t *testing.T, mtu uint32) *Context { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{tcp.NewProtocol()}, + }) + + // Allow minimum send/receive buffer sizes to be 1 during tests. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{Min: 1, Default: tcp.DefaultSendBufferSize, Max: 10 * tcp.DefaultSendBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) + } + + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{Min: 1, Default: tcp.DefaultReceiveBufferSize, Max: 10 * tcp.DefaultReceiveBufferSize}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %s", err) + } + + // Increase minimum RTO in tests to avoid test flakes due to early + // retransmit in case the test executors are overloaded and cause timers + // to fire earlier than expected. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcpip.TCPMinRTOOption(3*time.Second)); err != nil { + t.Fatalf("failed to set stack-wide minRTO: %s", err) + } + + // Some of the congestion control tests send up to 640 packets, we so + // set the channel size to 1000. + ep := channel.New(1000, mtu, "") + wep := stack.LinkEndpoint(ep) + if testing.Verbose() { + wep = sniffer.New(ep) + } + opts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(1, wep, opts); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) + } + wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) + if testing.Verbose() { + wep2 = sniffer.New(channel.New(1000, mtu, "")) + } + opts2 := stack.NICOptions{Name: "nic2"} + if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { + t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: 1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + + return &Context{ + t: t, + s: s, + linkEP: ep, + WindowScale: uint8(tcp.FindWndScale(tcp.DefaultReceiveBufferSize)), + } +} + +// Cleanup closes the context endpoint if required. +func (c *Context) Cleanup() { + if c.EP != nil { + c.EP.Close() + } + c.Stack().Close() +} + +// Stack returns a reference to the stack in the Context. +func (c *Context) Stack() *stack.Stack { + return c.s +} + +// CheckNoPacketTimeout verifies that no packet is received during the time +// specified by wait. +func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { + c.t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), wait) + defer cancel() + if _, ok := c.linkEP.ReadContext(ctx); ok { + c.t.Fatal(errMsg) + } +} + +// CheckNoPacket verifies that no packet is received for 1 second. +func (c *Context) CheckNoPacket(errMsg string) { + c.CheckNoPacketTimeout(errMsg, 1*time.Second) +} + +// GetPacket reads a packet from the link layer endpoint and verifies +// that it is an IPv4 packet with the expected source and destination +// addresses. It will fail with an error if no packet is received for +// 2 seconds. +func (c *Context) GetPacket() []byte { + c.t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + c.t.Fatalf("Packet wasn't written out") + return nil + } + + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + + if p.GSO != nil && p.GSO.L3HdrLen != header.IPv4MinimumSize { + c.t.Errorf("L3HdrLen %v (expected %v)", p.GSO.L3HdrLen, header.IPv4MinimumSize) + } + + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b +} + +// GetPacketNonBlocking reads a packet from the link layer endpoint +// and verifies that it is an IPv4 packet with the expected source +// and destination address. If no packet is available it will return +// nil immediately. +func (c *Context) GetPacketNonBlocking() []byte { + c.t.Helper() + + p, ok := c.linkEP.Read() + if !ok { + return nil + } + + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b +} + +// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. +func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { + // Allocate a buffer data and headers. + buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4PayloadOffset + len(p2)) + if len(buf) > maxTotalSize { + buf = buf[:maxTotalSize] + } + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: TestAddr, + DstAddr: StackAddr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) + icmp.SetType(typ) + icmp.SetCode(code) + const icmpv4VariableHeaderOffset = 4 + copy(icmp[icmpv4VariableHeaderOffset:], p1) + copy(icmp[header.ICMPv4PayloadOffset:], p2) + + // Inject packet. + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// BuildSegment builds a TCP segment based on the given Headers and payload. +func (c *Context) BuildSegment(payload []byte, h *Headers) buffer.VectorisedView { + return c.BuildSegmentWithAddrs(payload, h, TestAddr, StackAddr) +} + +// BuildSegmentWithAddrs builds a TCP segment based on the given Headers, +// payload and source and destination IPv4 addresses. +func (c *Context) BuildSegmentWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) buffer.VectorisedView { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts) + + // Initialize the IP header. + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(tcp.ProtocolNumber), + SrcAddr: src, + DstAddr: dst, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Initialize the TCP header. + t := header.TCP(buf[header.IPv4MinimumSize:]) + t.Encode(&header.TCPFields{ + SrcPort: h.SrcPort, + DstPort: h.DstPort, + SeqNum: uint32(h.SeqNum), + AckNum: uint32(h.AckNum), + DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), + Flags: uint8(h.Flags), + WindowSize: uint16(h.RcvWnd), + }) + + // Calculate the TCP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) + + // Calculate the TCP checksum and set it. + xsum = header.Checksum(payload, xsum) + t.SetChecksum(^t.CalculateChecksum(xsum)) + + // Inject packet. + return buf.ToVectorisedView() +} + +// SendSegment sends a TCP segment that has already been built and written to a +// buffer.VectorisedView. +func (c *Context) SendSegment(s buffer.VectorisedView) { + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: s, + }) +} + +// SendPacket builds and sends a TCP segment(with the provided payload & TCP +// headers) in an IPv4 packet via the link layer endpoint. +func (c *Context) SendPacket(payload []byte, h *Headers) { + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: c.BuildSegment(payload, h), + }) +} + +// SendPacketWithAddrs builds and sends a TCP segment(with the provided payload +// & TCPheaders) in an IPv4 packet via the link layer endpoint using the +// provided source and destination IPv4 addresses. +func (c *Context) SendPacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: c.BuildSegmentWithAddrs(payload, h, src, dst), + }) +} + +// SendAck sends an ACK packet. +func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) { + c.SendAckWithSACK(seq, bytesReceived, nil) +} + +// SendAckWithSACK sends an ACK packet which includes the sackBlocks specified. +func (c *Context) SendAckWithSACK(seq seqnum.Value, bytesReceived int, sackBlocks []header.SACKBlock) { + options := make([]byte, 40) + offset := 0 + if len(sackBlocks) > 0 { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) + } + + c.SendPacket(nil, &Headers{ + SrcPort: TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seq, + AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), + RcvWnd: 30000, + TCPOpts: options[:offset], + }) +} + +// ReceiveAndCheckPacket reads a packet from the link layer endpoint and +// verifies that the packet packet payload of packet matches the slice +// of data indicated by offset & size. +func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { + c.t.Helper() + + c.ReceiveAndCheckPacketWithOptions(data, offset, size, 0) +} + +// ReceiveAndCheckPacketWithOptions reads a packet from the link layer endpoint +// and verifies that the packet packet payload of packet matches the slice of +// data indicated by offset & size and skips optlen bytes in addition to the IP +// TCP headers when comparing the data. +func (c *Context) ReceiveAndCheckPacketWithOptions(data []byte, offset, size, optlen int) { + c.t.Helper() + + b := c.GetPacket() + checker.IPv4(c.t, b, + checker.PayloadLen(size+header.TCPMinimumSize+optlen), + checker.TCP( + checker.DstPort(TestPort), + checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + pdata := data[offset:][:size] + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize+optlen:]; bytes.Compare(pdata, p) != 0 { + c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) + } +} + +// ReceiveNonBlockingAndCheckPacket reads a packet from the link layer endpoint +// and verifies that the packet packet payload of packet matches the slice of +// data indicated by offset & size. It returns true if a packet was received and +// processed. +func (c *Context) ReceiveNonBlockingAndCheckPacket(data []byte, offset, size int) bool { + c.t.Helper() + + b := c.GetPacketNonBlocking() + if b == nil { + return false + } + checker.IPv4(c.t, b, + checker.PayloadLen(size+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(TestPort), + checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + pdata := data[offset:][:size] + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 { + c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) + } + return true +} + +// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only +// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 +// only endpoint instead of a default dual stack socket. +func (c *Context) CreateV6Endpoint(v6only bool) { + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := c.EP.SetSockOptBool(tcpip.V6OnlyOption, v6only); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } +} + +// GetV6Packet reads a single packet from the link layer endpoint of the context +// and asserts that it is an IPv6 Packet with the expected src/dest addresses. +func (c *Context) GetV6Packet() []byte { + c.t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + c.t.Fatalf("Packet wasn't written out") + return nil + } + + if p.Proto != ipv6.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) + } + b := make([]byte, p.Pkt.Header.UsedLength()+p.Pkt.Data.Size()) + copy(b, p.Pkt.Header.View()) + copy(b[p.Pkt.Header.UsedLength():], p.Pkt.Data.ToView()) + + checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) + return b +} + +// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of +// the context. +func (c *Context) SendV6Packet(payload []byte, h *Headers) { + c.SendV6PacketWithAddrs(payload, h, TestV6Addr, StackV6Addr) +} + +// SendV6PacketWithAddrs builds and sends an IPv6 Packet via the link layer +// endpoint of the context using the provided source and destination IPv6 +// addresses. +func (c *Context) SendV6PacketWithAddrs(payload []byte, h *Headers, src, dst tcpip.Address) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.TCPMinimumSize + len(payload)), + NextHeader: uint8(tcp.ProtocolNumber), + HopLimit: 65, + SrcAddr: src, + DstAddr: dst, + }) + + // Initialize the TCP header. + t := header.TCP(buf[header.IPv6MinimumSize:]) + t.Encode(&header.TCPFields{ + SrcPort: h.SrcPort, + DstPort: h.DstPort, + SeqNum: uint32(h.SeqNum), + AckNum: uint32(h.AckNum), + DataOffset: header.TCPMinimumSize, + Flags: uint8(h.Flags), + WindowSize: uint16(h.RcvWnd), + }) + + // Calculate the TCP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(tcp.ProtocolNumber, src, dst, uint16(len(t))) + + // Calculate the TCP checksum and set it. + xsum = header.Checksum(payload, xsum) + t.SetChecksum(^t.CalculateChecksum(xsum)) + + // Inject packet. + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) +} + +// CreateConnected creates a connected TCP endpoint. +func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int) { + c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) +} + +// Connect performs the 3-way handshake for c.EP with the provided Initial +// Sequence Number (iss) and receive window(rcvWnd) and any options if +// specified. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +// +// PreCondition: c.EP must already be created. +func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { + c.t.Helper() + + // Start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted { + c.t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(c.t, b, + checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) + + c.SendPacket(nil, &Headers{ + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: options, + }) + + // Receive ACK packet. + checker.IPv4(c.t, c.GetPacket(), + checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + ), + ) + + // Wait for connection to be established. + select { + case <-notifyCh: + if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { + c.t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for connection") + } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + + c.Port = tcpHdr.SourcePort() +} + +// Create creates a TCP endpoint. +func (c *Context) Create(epRcvBuf int) { + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + if epRcvBuf != -1 { + if err := c.EP.SetSockOptInt(tcpip.ReceiveBufferSizeOption, epRcvBuf); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } + } +} + +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf int, options []byte) { + c.Create(epRcvBuf) + c.Connect(iss, rcvWnd, options) +} + +// RawEndpoint is just a small wrapper around a TCP endpoint's state to make +// sending data and ACK packets easy while being able to manipulate the sequence +// numbers and timestamp values as needed. +type RawEndpoint struct { + C *Context + SrcPort uint16 + DstPort uint16 + Flags int + NextSeqNum seqnum.Value + AckNum seqnum.Value + WndSize seqnum.Size + RecentTS uint32 // Stores the latest timestamp to echo back. + TSVal uint32 // TSVal stores the last timestamp sent by this endpoint. + + // SackPermitted is true if SACKPermitted option was negotiated for this endpoint. + SACKPermitted bool +} + +// SendPacketWithTS embeds the provided tsVal in the Timestamp option +// for the packet to be sent out. +func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) { + r.TSVal = tsVal + tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:]) + r.SendPacket(payload, tsOpt[:]) +} + +// SendPacket is a small wrapper function to build and send packets. +func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { + packetHeaders := &Headers{ + SrcPort: r.SrcPort, + DstPort: r.DstPort, + Flags: r.Flags, + SeqNum: r.NextSeqNum, + AckNum: r.AckNum, + RcvWnd: r.WndSize, + TCPOpts: opts, + } + r.C.SendPacket(payload, packetHeaders) + r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) +} + +// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided +// tsVal. +func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { + // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] + ackPacket := r.C.GetPacket() + checker.IPv4(r.C.t, ackPacket, + checker.TCP( + checker.DstPort(r.SrcPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(r.AckNum)), + checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPTimestampChecker(true, 0, tsVal), + ), + ) + // Store the parsed TSVal from the ack as recentTS. + tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) + opts := tcpSeg.ParsedOptions() + r.RecentTS = opts.TSVal +} + +// VerifyACKRcvWnd verifies that the window advertised by the incoming ACK +// matches the provided rcvWnd. +func (r *RawEndpoint) VerifyACKRcvWnd(rcvWnd uint16) { + ackPacket := r.C.GetPacket() + checker.IPv4(r.C.t, ackPacket, + checker.TCP( + checker.DstPort(r.SrcPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(r.AckNum)), + checker.AckNum(uint32(r.NextSeqNum)), + checker.Window(rcvWnd), + ), + ) +} + +// VerifyACKNoSACK verifies that the ACK does not contain a SACK block. +func (r *RawEndpoint) VerifyACKNoSACK() { + r.VerifyACKHasSACK(nil) +} + +// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks. +func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { + // Read ACK and verify that the TCP options in the segment do + // not contain a SACK block. + ackPacket := r.C.GetPacket() + checker.IPv4(r.C.t, ackPacket, + checker.TCP( + checker.DstPort(r.SrcPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(r.AckNum)), + checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSACKBlockChecker(sackBlocks), + ), + ) +} + +// CreateConnectedWithOptions creates and connects c.ep with the specified TCP +// options enabled and returns a RawEndpoint which represents the other end of +// the connection. +// +// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK +// does not carry an option that was not requested. +func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) + } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} + err = c.EP.Connect(testFullAddr) + if err != tcpip.ErrConnectStarted { + c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) + } + // Receive SYN packet. + b := c.GetPacket() + // Validate that the syn has the timestamp option and a valid + // TS value. + mss := uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) + + checker.IPv4(c.t, b, + checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{ + MSS: mss, + TS: true, + WS: int(c.WindowScale), + SACKPermitted: c.SACKEnabled(), + }), + ), + ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + + tcpSeg := header.TCP(header.IPv4(b).Payload()) + synOptions := header.ParseSynOptions(tcpSeg.Options(), false) + + // Build options w/ tsVal to be sent in the SYN-ACK. + synAckOptions := make([]byte, header.TCPOptionsMaximumSize) + offset := 0 + if wantOptions.WS != -1 { + offset += header.EncodeWSOption(wantOptions.WS, synAckOptions[offset:]) + } + if wantOptions.TS { + offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:]) + } + if wantOptions.SACKPermitted { + offset += header.EncodeSACKPermittedOption(synAckOptions[offset:]) + } + + offset += header.AddTCPOptionPadding(synAckOptions, offset) + + // Build SYN-ACK. + c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) + iss := seqnum.Value(testInitialSequenceNumber) + c.SendPacket(nil, &Headers{ + SrcPort: tcpSeg.DestinationPort(), + DstPort: tcpSeg.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + TCPOpts: synAckOptions[:offset], + }) + + // Read ACK. + ackPacket := c.GetPacket() + + // Verify TCP header fields. + tcpCheckers := []checker.TransportChecker{ + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS) + 1), + checker.AckNum(uint32(iss) + 1), + } + + // Verify that tsEcr of ACK packet is wantOptions.TSVal if the + // timestamp option was enabled, if not then we verify that + // there is no timestamp in the ACK packet. + if wantOptions.TS { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal)) + } else { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) + } + + checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...)) + + ackSeg := header.TCP(header.IPv4(ackPacket).Payload()) + ackOptions := ackSeg.ParsedOptions() + + // Wait for connection to be established. + select { + case <-notifyCh: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + c.t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for connection") + } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + + // Store the source port in use by the endpoint. + c.Port = tcpSeg.SourcePort() + + // Mark in context that timestamp option is enabled for this endpoint. + c.TimeStampEnabled = true + + return &RawEndpoint{ + C: c, + SrcPort: tcpSeg.DestinationPort(), + DstPort: tcpSeg.SourcePort(), + Flags: header.TCPFlagAck | header.TCPFlagPsh, + NextSeqNum: iss + 1, + AckNum: c.IRS.Add(1), + WndSize: 30000, + RecentTS: ackOptions.TSVal, + TSVal: wantOptions.TSVal, + SACKPermitted: wantOptions.SACKPermitted, + } +} + +// AcceptWithOptions initializes a listening endpoint and connects to it with the +// provided options enabled. It also verifies that the SYN-ACK has the expected +// values for the provided options. +// +// The function returns a RawEndpoint representing the other end of the accepted +// endpoint. +func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + if err := ep.Listen(10); err != nil { + c.t.Fatalf("Listen failed: %v", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + c.t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for accept") + } + } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + return rep +} + +// PassiveConnect just disables WindowScaling and delegates the call to +// PassiveConnectWithOptions. +func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) { + synOptions.WS = -1 + c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions) +} + +// PassiveConnectWithOptions initiates a new connection (with the specified TCP +// options enabled) to the port on which the Context.ep is listening for new +// connections. It also validates that the SYN-ACK has the expected values for +// the enabled options. +// +// NOTE: MSS is not a negotiated option and it can be asymmetric +// in each direction. This function uses the maxPayload to set the MSS to be +// sent to the peer on a connect and validates that the MSS in the SYN-ACK +// response is equal to the MTU - (tcphdr len + iphdr len). +// +// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the +// value of the window scaling option to be sent in the SYN. If synOptions.WS > +// 0 then we send the WindowScale option. +func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + opts := make([]byte, header.TCPOptionsMaximumSize) + offset := 0 + offset += header.EncodeMSSOption(uint32(maxPayload), opts) + + if synOptions.WS >= 0 { + offset += header.EncodeWSOption(3, opts[offset:]) + } + if synOptions.TS { + offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:]) + } + + if synOptions.SACKPermitted { + offset += header.EncodeSACKPermittedOption(opts[offset:]) + } + + paddingToAdd := 4 - offset%4 + // Now add any padding bytes that might be required to quad align the + // options. + for i := offset; i < offset+paddingToAdd; i++ { + opts[i] = header.TCPOptionNOP + } + offset += paddingToAdd + + // Send a SYN request. + iss := seqnum.Value(testInitialSequenceNumber) + c.SendPacket(nil, &Headers{ + SrcPort: TestPort, + DstPort: StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + TCPOpts: opts[:offset], + }) + + // Receive the SYN-ACK reply. Make sure MSS and other expected options + // are present. + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(StackPort), + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(iss) + 1), + checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), + } + + // If TS option was enabled in the original SYN then add a checker to + // validate the Timestamp option in the SYN-ACK. + if synOptions.TS { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal)) + } else { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) + } + + checker.IPv4(c.t, b, checker.TCP(tcpCheckers...)) + rcvWnd := seqnum.Size(30000) + ackHeaders := &Headers{ + SrcPort: TestPort, + DstPort: StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + RcvWnd: rcvWnd, + } + + // If WS was expected to be in effect then scale the advertised window + // correspondingly. + if synOptions.WS > 0 { + ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS) + } + + parsedOpts := tcp.ParsedOptions() + if synOptions.TS { + // Echo the tsVal back to the peer in the tsEcr field of the + // timestamp option. + // Increment TSVal by 1 from the value sent in the SYN and echo + // the TSVal in the SYN-ACK in the TSEcr field. + opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:]) + ackHeaders.TCPOpts = opts[:] + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + c.Port = StackPort + + return &RawEndpoint{ + C: c, + SrcPort: TestPort, + DstPort: StackPort, + Flags: header.TCPFlagPsh | header.TCPFlagAck, + NextSeqNum: iss + 1, + AckNum: c.IRS + 1, + WndSize: rcvWnd, + SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(), + RecentTS: parsedOpts.TSVal, + TSVal: synOptions.TSVal + 1, + } +} + +// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true +// for the Stack in the context. +func (c *Context) SACKEnabled() bool { + var v tcp.SACKEnabled + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { + // Stack doesn't support SACK. So just return. + return false + } + return bool(v) +} + +// SetGSOEnabled enables or disables generic segmentation offload. +func (c *Context) SetGSOEnabled(enable bool) { + if enable { + c.linkEP.LinkEPCapabilities |= stack.CapabilityHardwareGSO + } else { + c.linkEP.LinkEPCapabilities &^= stack.CapabilityHardwareGSO + } +} + +// MSSWithoutOptions returns the value for the MSS used by the stack when no +// options are in use. +func (c *Context) MSSWithoutOptions() uint16 { + return uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize) +} + +// MSSWithoutOptionsV6 returns the value for the MSS used by the stack when no +// options are in use for IPv6 packets. +func (c *Context) MSSWithoutOptionsV6() uint16 { + return uint16(c.linkEP.MTU() - header.IPv6MinimumSize - header.TCPMinimumSize) +} diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go new file mode 100644 index 000000000..7981d469b --- /dev/null +++ b/pkg/tcpip/transport/tcp/timer.go @@ -0,0 +1,142 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" + + "gvisor.dev/gvisor/pkg/sleep" +) + +type timerState int + +const ( + timerStateDisabled timerState = iota + timerStateEnabled + timerStateOrphaned +) + +// timer is a timer implementation that reduces the interactions with the +// runtime timer infrastructure by letting timers run (and potentially +// eventually expire) even if they are stopped. It makes it cheaper to +// disable/reenable timers at the expense of spurious wakes. This is useful for +// cases when the same timer is disabled/reenabled repeatedly with relatively +// long timeouts farther into the future. +// +// TCP retransmit timers benefit from this because they the timeouts are long +// (currently at least 200ms), and get disabled when acks are received, and +// reenabled when new pending segments are sent. +// +// It is advantageous to avoid interacting with the runtime because it acquires +// a global mutex and performs O(log n) operations, where n is the global number +// of timers, whenever a timer is enabled or disabled, and may make a syscall. +// +// This struct is thread-compatible. +type timer struct { + // state is the current state of the timer, it can be one of the + // following values: + // disabled - the timer is disabled. + // orphaned - the timer is disabled, but the runtime timer is + // enabled, which means that it will evetually cause a + // spurious wake (unless it gets enabled again before + // then). + // enabled - the timer is enabled, but the runtime timer may be set + // to an earlier expiration time due to a previous + // orphaned state. + state timerState + + // target is the expiration time of the current timer. It is only + // meaningful in the enabled state. + target time.Time + + // runtimeTarget is the expiration time of the runtime timer. It is + // meaningful in the enabled and orphaned states. + runtimeTarget time.Time + + // timer is the runtime timer used to wait on. + timer *time.Timer +} + +// init initializes the timer. Once it expires, it the given waker will be +// asserted. +func (t *timer) init(w *sleep.Waker) { + t.state = timerStateDisabled + + // Initialize a runtime timer that will assert the waker, then + // immediately stop it. + t.timer = time.AfterFunc(time.Hour, func() { + w.Assert() + }) + t.timer.Stop() +} + +// cleanup frees all resources associated with the timer. +func (t *timer) cleanup() { + t.timer.Stop() + *t = timer{} +} + +// checkExpiration checks if the given timer has actually expired, it should be +// called whenever a sleeper wakes up due to the waker being asserted, and is +// used to check if it's a supurious wake (due to a previously orphaned timer) +// or a legitimate one. +func (t *timer) checkExpiration() bool { + // Transition to fully disabled state if we're just consuming an + // orphaned timer. + if t.state == timerStateOrphaned { + t.state = timerStateDisabled + return false + } + + // The timer is enabled, but it may have expired early. Check if that's + // the case, and if so, reset the runtime timer to the correct time. + now := time.Now() + if now.Before(t.target) { + t.runtimeTarget = t.target + t.timer.Reset(t.target.Sub(now)) + return false + } + + // The timer has actually expired, disable it for now and inform the + // caller. + t.state = timerStateDisabled + return true +} + +// disable disables the timer, leaving it in an orphaned state if it wasn't +// already disabled. +func (t *timer) disable() { + if t.state != timerStateDisabled { + t.state = timerStateOrphaned + } +} + +// enabled returns true if the timer is currently enabled, false otherwise. +func (t *timer) enabled() bool { + return t.state == timerStateEnabled +} + +// enable enables the timer, programming the runtime timer if necessary. +func (t *timer) enable(d time.Duration) { + t.target = time.Now().Add(d) + + // Check if we need to set the runtime timer. + if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) { + t.runtimeTarget = t.target + t.timer.Reset(d) + } + + t.state = timerStateEnabled +} diff --git a/pkg/tcpip/transport/tcp/timer_test.go b/pkg/tcpip/transport/tcp/timer_test.go new file mode 100644 index 000000000..dbd6dff54 --- /dev/null +++ b/pkg/tcpip/transport/tcp/timer_test.go @@ -0,0 +1,47 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "testing" + "time" + + "gvisor.dev/gvisor/pkg/sleep" +) + +func TestCleanup(t *testing.T) { + const ( + timerDurationSeconds = 2 + isAssertedTimeoutSeconds = timerDurationSeconds + 1 + ) + + tmr := timer{} + w := sleep.Waker{} + tmr.init(&w) + tmr.enable(timerDurationSeconds * time.Second) + tmr.cleanup() + + if want := (timer{}); tmr != want { + t.Errorf("got tmr = %+v, want = %+v", tmr, want) + } + + // The waker should not be asserted. + for i := 0; i < isAssertedTimeoutSeconds; i++ { + time.Sleep(time.Second) + if w.IsAsserted() { + t.Fatalf("waker asserted unexpectedly") + } + } +} diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD new file mode 100644 index 000000000..3ad6994a7 --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -0,0 +1,23 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "tcpconntrack", + srcs = ["tcp_conntrack.go"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + ], +) + +go_test( + name = "tcpconntrack_test", + size = "small", + srcs = ["tcp_conntrack_test.go"], + deps = [ + ":tcpconntrack", + "//pkg/tcpip/header", + ], +) diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go new file mode 100644 index 000000000..12bc1b5b5 --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -0,0 +1,352 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tcpconntrack implements a TCP connection tracking object. It allows +// users with access to a segment stream to figure out when a connection is +// established, reset, and closed (and in the last case, who closed first). +package tcpconntrack + +import ( + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/seqnum" +) + +// Result is returned when the state of a TCB is updated in response to an +// inbound or outbound segment. +type Result int + +const ( + // ResultDrop indicates that the segment should be dropped. + ResultDrop Result = iota + + // ResultConnecting indicates that the connection remains in a + // connecting state. + ResultConnecting + + // ResultAlive indicates that the connection remains alive (connected). + ResultAlive + + // ResultReset indicates that the connection was reset. + ResultReset + + // ResultClosedByPeer indicates that the connection was gracefully + // closed, and the inbound stream was closed first. + ResultClosedByPeer + + // ResultClosedBySelf indicates that the connection was gracefully + // closed, and the outbound stream was closed first. + ResultClosedBySelf +) + +// TCB is a TCP Control Block. It holds state necessary to keep track of a TCP +// connection and inform the caller when the connection has been closed. +type TCB struct { + inbound stream + outbound stream + + // State handlers. + handlerInbound func(*TCB, header.TCP) Result + handlerOutbound func(*TCB, header.TCP) Result + + // firstFin holds a pointer to the first stream to send a FIN. + firstFin *stream + + // state is the current state of the stream. + state Result +} + +// Init initializes the state of the TCB according to the initial SYN. +func (t *TCB) Init(initialSyn header.TCP) Result { + t.handlerInbound = synSentStateInbound + t.handlerOutbound = synSentStateOutbound + + iss := seqnum.Value(initialSyn.SequenceNumber()) + t.outbound.una = iss + t.outbound.nxt = iss.Add(logicalLen(initialSyn)) + t.outbound.end = t.outbound.nxt + + // Even though "end" is a sequence number, we don't know the initial + // receive sequence number yet, so we store the window size until we get + // a SYN from the peer. + t.inbound.una = 0 + t.inbound.nxt = 0 + t.inbound.end = seqnum.Value(initialSyn.WindowSize()) + t.state = ResultConnecting + return t.state +} + +// UpdateStateInbound updates the state of the TCB based on the supplied inbound +// segment. +func (t *TCB) UpdateStateInbound(tcp header.TCP) Result { + st := t.handlerInbound(t, tcp) + if st != ResultDrop { + t.state = st + } + return st +} + +// UpdateStateOutbound updates the state of the TCB based on the supplied +// outbound segment. +func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { + st := t.handlerOutbound(t, tcp) + if st != ResultDrop { + t.state = st + } + return st +} + +// IsAlive returns true as long as the connection is established(Alive) +// or connecting state. +func (t *TCB) IsAlive() bool { + return !t.inbound.rstSeen && !t.outbound.rstSeen && (!t.inbound.closed() || !t.outbound.closed()) +} + +// OutboundSendSequenceNumber returns the snd.NXT for the outbound stream. +func (t *TCB) OutboundSendSequenceNumber() seqnum.Value { + return t.outbound.nxt +} + +// InboundSendSequenceNumber returns the snd.NXT for the inbound stream. +func (t *TCB) InboundSendSequenceNumber() seqnum.Value { + return t.inbound.nxt +} + +// adapResult modifies the supplied "Result" according to the state of the TCB; +// if r is anything other than "Alive", or if one of the streams isn't closed +// yet, it is returned unmodified. Otherwise it's converted to either +// ClosedBySelf or ClosedByPeer depending on which stream was closed first. +func (t *TCB) adaptResult(r Result) Result { + // Check the unmodified case. + if r != ResultAlive || !t.inbound.closed() || !t.outbound.closed() { + return r + } + + // Find out which was closed first. + if t.firstFin == &t.outbound { + return ResultClosedBySelf + } + + return ResultClosedByPeer +} + +// synSentStateInbound is the state handler for inbound segments when the +// connection is in SYN-SENT state. +func synSentStateInbound(t *TCB, tcp header.TCP) Result { + flags := tcp.Flags() + ackPresent := flags&header.TCPFlagAck != 0 + ack := seqnum.Value(tcp.AckNumber()) + + // Ignore segment if ack is present but not acceptable. + if ackPresent && !(ack-1).InRange(t.outbound.una, t.outbound.nxt) { + return ResultConnecting + } + + // If reset is specified, we will let the packet through no matter what + // but we will also destroy the connection if the ACK is present (and + // implicitly acceptable). + if flags&header.TCPFlagRst != 0 { + if ackPresent { + t.inbound.rstSeen = true + return ResultReset + } + return ResultConnecting + } + + // Ignore segment if SYN is not set. + if flags&header.TCPFlagSyn == 0 { + return ResultConnecting + } + + // Update state informed by this SYN. + irs := seqnum.Value(tcp.SequenceNumber()) + t.inbound.una = irs + t.inbound.nxt = irs.Add(logicalLen(tcp)) + t.inbound.end += irs + + t.outbound.end = t.outbound.una.Add(seqnum.Size(tcp.WindowSize())) + + // If the ACK was set (it is acceptable), update our unacknowledgement + // tracking. + if ackPresent { + // Advance the "una" and "end" indices of the outbound stream. + if t.outbound.una.LessThan(ack) { + t.outbound.una = ack + } + + if end := ack.Add(seqnum.Size(tcp.WindowSize())); t.outbound.end.LessThan(end) { + t.outbound.end = end + } + } + + // Update handlers so that new calls will be handled by new state. + t.handlerInbound = allOtherInbound + t.handlerOutbound = allOtherOutbound + + return ResultAlive +} + +// synSentStateOutbound is the state handler for outbound segments when the +// connection is in SYN-SENT state. +func synSentStateOutbound(t *TCB, tcp header.TCP) Result { + // Drop outbound segments that aren't retransmits of the original one. + if tcp.Flags() != header.TCPFlagSyn || + tcp.SequenceNumber() != uint32(t.outbound.una) { + return ResultDrop + } + + // Update the receive window. We only remember the largest value seen. + if wnd := seqnum.Value(tcp.WindowSize()); wnd > t.inbound.end { + t.inbound.end = wnd + } + + return ResultConnecting +} + +// update updates the state of inbound and outbound streams, given the supplied +// inbound segment. For outbound segments, this same function can be called with +// swapped inbound/outbound streams. +func update(tcp header.TCP, inbound, outbound *stream, firstFin **stream) Result { + // Ignore segments out of the window. + s := seqnum.Value(tcp.SequenceNumber()) + if !inbound.acceptable(s, dataLen(tcp)) { + return ResultAlive + } + + flags := tcp.Flags() + if flags&header.TCPFlagRst != 0 { + inbound.rstSeen = true + return ResultReset + } + + // Ignore segments that don't have the ACK flag, and those with the SYN + // flag. + if flags&header.TCPFlagAck == 0 || flags&header.TCPFlagSyn != 0 { + return ResultAlive + } + + // Ignore segments that acknowledge not yet sent data. + ack := seqnum.Value(tcp.AckNumber()) + if outbound.nxt.LessThan(ack) { + return ResultAlive + } + + // Advance the "una" and "end" indices of the outbound stream. + if outbound.una.LessThan(ack) { + outbound.una = ack + } + + if end := ack.Add(seqnum.Size(tcp.WindowSize())); outbound.end.LessThan(end) { + outbound.end = end + } + + // Advance the "nxt" index of the inbound stream. + end := s.Add(logicalLen(tcp)) + if inbound.nxt.LessThan(end) { + inbound.nxt = end + } + + // Note the index of the FIN segment. And stash away a pointer to the + // first stream to see a FIN. + if flags&header.TCPFlagFin != 0 && !inbound.finSeen { + inbound.finSeen = true + inbound.fin = end - 1 + + if *firstFin == nil { + *firstFin = inbound + } + } + + return ResultAlive +} + +// allOtherInbound is the state handler for inbound segments in all states +// except SYN-SENT. +func allOtherInbound(t *TCB, tcp header.TCP) Result { + return t.adaptResult(update(tcp, &t.inbound, &t.outbound, &t.firstFin)) +} + +// allOtherOutbound is the state handler for outbound segments in all states +// except SYN-SENT. +func allOtherOutbound(t *TCB, tcp header.TCP) Result { + return t.adaptResult(update(tcp, &t.outbound, &t.inbound, &t.firstFin)) +} + +// streams holds the state of a TCP unidirectional stream. +type stream struct { + // The interval [una, end) is the allowed interval as defined by the + // receiver, i.e., anything less than una has already been acknowledged + // and anything greater than or equal to end is beyond the receiver + // window. The interval [una, nxt) is the acknowledgable range, whose + // right edge indicates the sequence number of the next byte to be sent + // by the sender, i.e., anything greater than or equal to nxt hasn't + // been sent yet. + una seqnum.Value + nxt seqnum.Value + end seqnum.Value + + // finSeen indicates if a FIN has already been sent on this stream. + finSeen bool + + // fin is the sequence number of the FIN. It is only valid after finSeen + // is set to true. + fin seqnum.Value + + // rstSeen indicates if a RST has already been sent on this stream. + rstSeen bool +} + +// acceptable determines if the segment with the given sequence number and data +// length is acceptable, i.e., if it's within the [una, end) window or, in case +// the window is zero, if it's a packet with no payload and sequence number +// equal to una. +func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { + return header.Acceptable(segSeq, segLen, s.una, s.end) +} + +// closed determines if the stream has already been closed. This happens when +// a FIN has been set by the sender and acknowledged by the receiver. +func (s *stream) closed() bool { + return s.finSeen && s.fin.LessThan(s.una) +} + +// dataLen returns the length of the TCP segment payload. +func dataLen(tcp header.TCP) seqnum.Size { + return seqnum.Size(len(tcp) - int(tcp.DataOffset())) +} + +// logicalLen calculates the logical length of the TCP segment. +func logicalLen(tcp header.TCP) seqnum.Size { + l := dataLen(tcp) + flags := tcp.Flags() + if flags&header.TCPFlagSyn != 0 { + l++ + } + if flags&header.TCPFlagFin != 0 { + l++ + } + return l +} + +// IsEmpty returns true if tcb is not initialized. +func (t *TCB) IsEmpty() bool { + if t.inbound != (stream{}) || t.outbound != (stream{}) { + return false + } + + if t.firstFin != nil || t.state != ResultDrop { + return false + } + + return true +} diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go new file mode 100644 index 000000000..5e271b7ca --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go @@ -0,0 +1,511 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcpconntrack_test + +import ( + "testing" + + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/transport/tcpconntrack" +) + +// connected creates a connection tracker TCB and sets it to a connected state +// by performing a 3-way handshake. +func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: iss, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: irw, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN-ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: irs, + AckNum: iss + 1, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + WindowSize: isw, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: iss + 1, + AckNum: irs + 1, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: irw, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + return &tcb +} + +func TestConnectionRefused(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive RST. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 1235, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst | header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) + } +} + +func TestConnectionRefusedInSynRcvd(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive RST with no ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 790, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) + } +} + +func TestConnectionResetInSynRcvd(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send RST with no ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) + } +} + +func TestRetransmitOnSynSent(t *testing.T) { + // Send initial SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Retransmit the same SYN. + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting) + } +} + +func TestRetransmitOnSynRcvd(t *testing.T) { + // Send initial SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN. This will cause the state to go to SYN-RCVD. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Retransmit the original SYN. + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Transmit a SYN-ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 790, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } +} + +func TestClosedBySelf(t *testing.T) { + tcb := connected(t, 1234, 789, 30000, 50000) + + // Send FIN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 790, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive FIN/ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 790, + AckNum: 1236, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1236, + AckNum: 791, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) + } +} + +func TestClosedByPeer(t *testing.T) { + tcb := connected(t, 1234, 789, 30000, 50000) + + // Receive FIN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 790, + AckNum: 1235, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send FIN/ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 791, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 791, + AckNum: 1236, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer) + } +} + +func TestSendAndReceiveDataClosedBySelf(t *testing.T) { + sseq := uint32(1234) + rseq := uint32(789) + tcb := connected(t, sseq, rseq, 30000, 50000) + sseq++ + rseq++ + + // Send some data. + tcp := make(header.TCP, header.TCPMinimumSize+1024) + + for i := uint32(0); i < 10; i++ { + // Send some data. + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + sseq += uint32(len(tcp)) - header.TCPMinimumSize + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive ack for data. + tcp.Encode(&header.TCPFields{ + SeqNum: rseq, + AckNum: sseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + } + + for i := uint32(0); i < 10; i++ { + // Receive some data. + tcp.Encode(&header.TCPFields{ + SeqNum: rseq, + AckNum: sseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 50000, + }) + rseq += uint32(len(tcp)) - header.TCPMinimumSize + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ack for data. + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + } + + // Send FIN. + tcp = tcp[:header.TCPMinimumSize] + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 30000, + }) + sseq++ + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive FIN/ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: rseq, + AckNum: sseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 50000, + }) + rseq++ + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) + } +} + +func TestIgnoreBadResetOnSynSent(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive a RST with a bad ACK, it should not cause the connection to + // be reset. + acks := []uint32{1234, 1236, 1000, 5000} + flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} + for _, a := range acks { + for _, f := range flags { + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: a, + DataOffset: header.TCPMinimumSize, + Flags: f, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + } + } + + // Complete the handshake. + // Receive SYN-ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 1235, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 790, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } +} diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD new file mode 100644 index 000000000..b5d2d0ba6 --- /dev/null +++ b/pkg/tcpip/transport/udp/BUILD @@ -0,0 +1,60 @@ +load("//tools:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") + +package(licenses = ["notice"]) + +go_template_instance( + name = "udp_packet_list", + out = "udp_packet_list.go", + package = "udp", + prefix = "udpPacket", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*udpPacket", + "Linker": "*udpPacket", + }, +) + +go_library( + name = "udp", + srcs = [ + "endpoint.go", + "endpoint_state.go", + "forwarder.go", + "protocol.go", + "udp_packet_list.go", + ], + imports = ["gvisor.dev/gvisor/pkg/tcpip/buffer"], + visibility = ["//visibility:public"], + deps = [ + "//pkg/sleep", + "//pkg/sync", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/ports", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/raw", + "//pkg/waiter", + ], +) + +go_test( + name = "udp_x_test", + size = "small", + srcs = ["udp_test.go"], + deps = [ + ":udp", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go new file mode 100644 index 000000000..0584ec8dc --- /dev/null +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -0,0 +1,1497 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/sleep" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/ports" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// +stateify savable +type udpPacket struct { + udpPacketEntry + senderAddress tcpip.FullAddress + packetInfo tcpip.IPPacketInfo + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + timestamp int64 + // tos stores either the receiveTOS or receiveTClass value. + tos uint8 +} + +// EndpointState represents the state of a UDP endpoint. +type EndpointState uint32 + +// Endpoint states. Note that are represented in a netstack-specific manner and +// may not be meaningful externally. Specifically, they need to be translated to +// Linux's representation for these states if presented to userspace. +const ( + StateInitial EndpointState = iota + StateBound + StateConnected + StateClosed +) + +// String implements fmt.Stringer.String. +func (s EndpointState) String() string { + switch s { + case StateInitial: + return "INITIAL" + case StateBound: + return "BOUND" + case StateConnected: + return "CONNECTING" + case StateClosed: + return "CLOSED" + default: + return "UNKNOWN" + } +} + +// endpoint represents a UDP endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. +// +// It implements tcpip.Endpoint. +// +// +stateify savable +type endpoint struct { + stack.TransportEndpointInfo + + // The following fields are initialized at creation time and do not + // change throughout the lifetime of the endpoint. + stack *stack.Stack `state:"manual"` + waiterQueue *waiter.Queue + uniqueID uint64 + + // The following fields are used to manage the receive queue, and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList udpPacketList + rcvBufSizeMax int `state:".(int)"` + rcvBufSize int + rcvClosed bool + + // The following fields are protected by the mu mutex. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + sndBufSizeMax int + state EndpointState + route stack.Route `state:"manual"` + dstPort uint16 + v6only bool + ttl uint8 + multicastTTL uint8 + multicastAddr tcpip.Address + multicastNICID tcpip.NICID + multicastLoop bool + portFlags ports.Flags + bindToDevice tcpip.NICID + broadcast bool + noChecksum bool + + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` + + // Values used to reserve a port or register a transport endpoint. + // (which ever happens first). + boundBindToDevice tcpip.NICID + boundPortFlags ports.Flags + + // sendTOS represents IPv4 TOS or IPv6 TrafficClass, + // applied while sending packets. Defaults to 0 as on Linux. + sendTOS uint8 + + // receiveTOS determines if the incoming IPv4 TOS header field is passed + // as ancillary data to ControlMessages on Read. + receiveTOS bool + + // receiveTClass determines if the incoming IPv6 TClass header field is + // passed as ancillary data to ControlMessages on Read. + receiveTClass bool + + // receiveIPPacketInfo determines if the packet info is returned by Read. + receiveIPPacketInfo bool + + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + + // multicastMemberships that need to be remvoed when the endpoint is + // closed. Protected by the mu mutex. + multicastMemberships []multicastMembership + + // effectiveNetProtos contains the network protocols actually in use. In + // most cases it will only contain "netProto", but in cases like IPv6 + // endpoints with v6only set to false, this could include multiple + // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., + // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped + // address). + effectiveNetProtos []tcpip.NetworkProtocolNumber + + // TODO(b/142022063): Add ability to save and restore per endpoint stats. + stats tcpip.TransportEndpointStats `state:"nosave"` + + // owner is used to get uid and gid of the packet. + owner tcpip.PacketOwner +} + +// +stateify savable +type multicastMembership struct { + nicID tcpip.NICID + multicastAddr tcpip.Address +} + +func newEndpoint(s *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { + e := &endpoint{ + stack: s, + TransportEndpointInfo: stack.TransportEndpointInfo{ + NetProto: netProto, + TransProto: header.UDPProtocolNumber, + }, + waiterQueue: waiterQueue, + // RFC 1075 section 5.4 recommends a TTL of 1 for membership + // requests. + // + // RFC 5135 4.2.1 appears to assume that IGMP messages have a + // TTL of 1. + // + // RFC 5135 Appendix A defines TTL=1: A multicast source that + // wants its traffic to not traverse a router (e.g., leave a + // home network) may find it useful to send traffic with IP + // TTL=1. + // + // Linux defaults to TTL=1. + multicastTTL: 1, + multicastLoop: true, + rcvBufSizeMax: 32 * 1024, + sndBufSizeMax: 32 * 1024, + state: StateInitial, + uniqueID: s.UniqueID(), + } + + // Override with stack defaults. + var ss stack.SendBufferSizeOption + if err := s.Option(&ss); err == nil { + e.sndBufSizeMax = ss.Default + } + + var rs stack.ReceiveBufferSizeOption + if err := s.Option(&rs); err == nil { + e.rcvBufSizeMax = rs.Default + } + + return e +} + +// UniqueID implements stack.TransportEndpoint.UniqueID. +func (e *endpoint) UniqueID() uint64 { + return e.uniqueID +} + +func (e *endpoint) takeLastError() *tcpip.Error { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + + err := e.lastError + e.lastError = nil + return err +} + +// Abort implements stack.TransportEndpoint.Abort. +func (e *endpoint) Abort() { + e.Close() +} + +// Close puts the endpoint in a closed state and frees all resources +// associated with it. +func (e *endpoint) Close() { + e.mu.Lock() + e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite + + switch e.state { + case StateBound, StateConnected: + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, e.boundPortFlags, e.boundBindToDevice) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, e.boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) + e.boundBindToDevice = 0 + e.boundPortFlags = ports.Flags{} + } + + for _, mem := range e.multicastMemberships { + e.stack.LeaveGroup(e.NetProto, mem.nicID, mem.multicastAddr) + } + e.multicastMemberships = nil + + // Close the receive list and drain it. + e.rcvMu.Lock() + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } + e.rcvMu.Unlock() + + e.route.Release() + + // Update the state. + e.state = StateClosed + + e.mu.Unlock() + + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. +func (e *endpoint) ModerateRecvBuf(copied int) {} + +// Read reads data from the endpoint. This method does not block if +// there is no data pending. +func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + if err := e.takeLastError(); err != nil { + return buffer.View{}, tcpip.ControlMessages{}, err + } + + e.rcvMu.Lock() + + if e.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if e.rcvClosed { + e.stats.ReadErrors.ReadClosed.Increment() + err = tcpip.ErrClosedForReceive + } + e.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + p := e.rcvList.Front() + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() + e.rcvMu.Unlock() + + if addr != nil { + *addr = p.senderAddress + } + + cm := tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: p.timestamp, + } + e.mu.RLock() + receiveTOS := e.receiveTOS + receiveTClass := e.receiveTClass + receiveIPPacketInfo := e.receiveIPPacketInfo + e.mu.RUnlock() + if receiveTOS { + cm.HasTOS = true + cm.TOS = p.tos + } + if receiveTClass { + cm.HasTClass = true + // Although TClass is an 8-bit value it's read in the CMsg as a uint32. + cm.TClass = uint32(p.tos) + } + if receiveIPPacketInfo { + cm.HasIPPacketInfo = true + cm.PacketInfo = p.packetInfo + } + return p.data.ToView(), cm, nil +} + +// prepareForWrite prepares the endpoint for sending data. In particular, it +// binds it if it's still in the initial state. To do so, it must first +// reacquire the mutex in exclusive mode. +// +// Returns true for retry if preparation should be retried. +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { + switch e.state { + case StateInitial: + case StateConnected: + return false, nil + + case StateBound: + if to == nil { + return false, tcpip.ErrDestinationRequired + } + return false, nil + default: + return false, tcpip.ErrInvalidEndpointState + } + + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // The state changed when we released the shared locked and re-acquired + // it in exclusive mode. Try again. + if e.state != StateInitial { + return true, nil + } + + // The state is still 'initial', so try to bind the endpoint. + if err := e.bindLocked(tcpip.FullAddress{}); err != nil { + return false, err + } + + return true, nil +} + +// connectRoute establishes a route to the specified interface or the +// configured multicast interface if no interface is specified and the +// specified address is a multicast address. +func (e *endpoint) connectRoute(nicID tcpip.NICID, addr tcpip.FullAddress, netProto tcpip.NetworkProtocolNumber) (stack.Route, tcpip.NICID, *tcpip.Error) { + localAddr := e.ID.LocalAddress + if isBroadcastOrMulticast(localAddr) { + // A packet can only originate from a unicast address (i.e., an interface). + localAddr = "" + } + + if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { + if nicID == 0 { + nicID = e.multicastNICID + } + if localAddr == "" && nicID == 0 { + localAddr = e.multicastAddr + } + } + + // Find a route to the desired destination. + r, err := e.stack.FindRoute(nicID, localAddr, addr.Addr, netProto, e.multicastLoop) + if err != nil { + return stack.Route{}, 0, err + } + return r, nicID, nil +} + +// Write writes data to the endpoint's peer. This method does not block +// if the data cannot be written. +func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + n, ch, err := e.write(p, opts) + switch err { + case nil: + e.stats.PacketsSent.Increment() + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + e.stats.WriteErrors.InvalidArgs.Increment() + case tcpip.ErrClosedForSend: + e.stats.WriteErrors.WriteClosed.Increment() + case tcpip.ErrInvalidEndpointState: + e.stats.WriteErrors.InvalidEndpointState.Increment() + case tcpip.ErrNoLinkAddress: + e.stats.SendErrors.NoLinkAddr.Increment() + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + // Errors indicating any problem with IP routing of the packet. + e.stats.SendErrors.NoRoute.Increment() + default: + // For all other errors when writing to the network layer. + e.stats.SendErrors.SendToNetworkFailed.Increment() + } + return n, ch, err +} + +func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { + if err := e.takeLastError(); err != nil { + return 0, nil, err + } + + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) + if opts.More { + return 0, nil, tcpip.ErrInvalidOptionValue + } + + to := opts.To + + e.mu.RLock() + defer e.mu.RUnlock() + + // If we've shutdown with SHUT_WR we are in an invalid state for sending. + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { + return 0, nil, tcpip.ErrClosedForSend + } + + // Prepare for write. + for { + retry, err := e.prepareForWrite(to) + if err != nil { + return 0, nil, err + } + + if !retry { + break + } + } + + var route *stack.Route + var resolve func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) + var dstPort uint16 + if to == nil { + route = &e.route + dstPort = e.dstPort + resolve = func(waker *sleep.Waker) (ch <-chan struct{}, err *tcpip.Error) { + // Promote lock to exclusive if using a shared route, given that it may + // need to change in Route.Resolve() call below. + e.mu.RUnlock() + e.mu.Lock() + + // Recheck state after lock was re-acquired. + if e.state != StateConnected { + err = tcpip.ErrInvalidEndpointState + } + if err == nil && route.IsResolutionRequired() { + ch, err = route.Resolve(waker) + } + + e.mu.Unlock() + e.mu.RLock() + + // Recheck state after lock was re-acquired. + if e.state != StateConnected { + err = tcpip.ErrInvalidEndpointState + } + return + } + } else { + // Reject destination address if it goes through a different + // NIC than the endpoint was bound to. + nicID := to.NIC + if e.BindNICID != 0 { + if nicID != 0 && nicID != e.BindNICID { + return 0, nil, tcpip.ErrNoRoute + } + + nicID = e.BindNICID + } + + if to.Addr == header.IPv4Broadcast && !e.broadcast { + return 0, nil, tcpip.ErrBroadcastDisabled + } + + dst, netProto, err := e.checkV4MappedLocked(*to) + if err != nil { + return 0, nil, err + } + + r, _, err := e.connectRoute(nicID, dst, netProto) + if err != nil { + return 0, nil, err + } + defer r.Release() + + route = &r + dstPort = dst.Port + resolve = route.Resolve + } + + if route.IsResolutionRequired() { + if ch, err := resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + return 0, ch, tcpip.ErrNoLinkAddress + } + return 0, nil, err + } + } + + v, err := p.FullPayload() + if err != nil { + return 0, nil, err + } + if len(v) > header.UDPMaximumPacketSize { + // Payload can't possibly fit in a packet. + return 0, nil, tcpip.ErrMessageTooLong + } + + ttl := e.ttl + useDefaultTTL := ttl == 0 + + if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { + ttl = e.multicastTTL + // Multicast allows a 0 TTL. + useDefaultTTL = false + } + + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.ID.LocalPort, dstPort, ttl, useDefaultTTL, e.sendTOS, e.owner, e.noChecksum); err != nil { + return 0, nil, err + } + return int64(len(v)), nil, nil +} + +// Peek only returns data from a single datagram, so do nothing here. +func (e *endpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool. +func (e *endpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.BroadcastOption: + e.mu.Lock() + e.broadcast = v + e.mu.Unlock() + + case tcpip.MulticastLoopOption: + e.mu.Lock() + e.multicastLoop = v + e.mu.Unlock() + + case tcpip.NoChecksumOption: + e.mu.Lock() + e.noChecksum = v + e.mu.Unlock() + + case tcpip.ReceiveTOSOption: + e.mu.Lock() + e.receiveTOS = v + e.mu.Unlock() + + case tcpip.ReceiveTClassOption: + // We only support this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return tcpip.ErrNotSupported + } + + e.mu.Lock() + e.receiveTClass = v + e.mu.Unlock() + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.Lock() + e.receiveIPPacketInfo = v + e.mu.Unlock() + + case tcpip.ReuseAddressOption: + e.mu.Lock() + e.portFlags.MostRecent = v + e.mu.Unlock() + + case tcpip.ReusePortOption: + e.mu.Lock() + e.portFlags.LoadBalanced = v + e.mu.Unlock() + + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + // We only allow this to be set when we're in the initial state. + if e.state != StateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.v6only = v + } + + return nil +} + +// SetSockOptInt implements tcpip.Endpoint.SetSockOptInt. +func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.MTUDiscoverOption: + // Return not supported if the value is not disabling path + // MTU discovery. + if v != tcpip.PMTUDiscoveryDont { + return tcpip.ErrNotSupported + } + + case tcpip.MulticastTTLOption: + e.mu.Lock() + e.multicastTTL = uint8(v) + e.mu.Unlock() + + case tcpip.TTLOption: + e.mu.Lock() + e.ttl = uint8(v) + e.mu.Unlock() + + case tcpip.IPv4TOSOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + + case tcpip.IPv6TrafficClassOption: + e.mu.Lock() + e.sendTOS = uint8(v) + e.mu.Unlock() + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs stack.ReceiveBufferSizeOption + if err := e.stack.Option(&rs); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", rs, err)) + } + + if v < rs.Min { + v = rs.Min + } + if v > rs.Max { + v = rs.Max + } + + e.mu.Lock() + e.rcvBufSizeMax = v + e.mu.Unlock() + return nil + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + var ss stack.SendBufferSizeOption + if err := e.stack.Option(&ss); err != nil { + panic(fmt.Sprintf("e.stack.Option(%#v) = %s", ss, err)) + } + + if v < ss.Min { + v = ss.Min + } + if v > ss.Max { + v = ss.Max + } + + e.mu.Lock() + e.sndBufSizeMax = v + e.mu.Unlock() + return nil + } + + return nil +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.MulticastInterfaceOption: + e.mu.Lock() + defer e.mu.Unlock() + + fa := tcpip.FullAddress{Addr: v.InterfaceAddr} + fa, netProto, err := e.checkV4MappedLocked(fa) + if err != nil { + return err + } + nic := v.NIC + addr := fa.Addr + + if nic == 0 && addr == "" { + e.multicastAddr = "" + e.multicastNICID = 0 + break + } + + if nic != 0 { + if !e.stack.CheckNIC(nic) { + return tcpip.ErrBadLocalAddress + } + } else { + nic = e.stack.CheckLocalAddress(0, netProto, addr) + if nic == 0 { + return tcpip.ErrBadLocalAddress + } + } + + if e.BindNICID != 0 && e.BindNICID != nic { + return tcpip.ErrInvalidEndpointState + } + + e.multicastNICID = nic + e.multicastAddr = addr + + case tcpip.AddMembershipOption: + if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { + return tcpip.ErrInvalidOptionValue + } + + nicID := v.NIC + + // The interface address is considered not-set if it is empty or contains + // all-zeros. The former represent the zero-value in golang, the latter the + // same in a setsockopt(IP_ADD_MEMBERSHIP, &ip_mreqn) syscall. + allZeros := header.IPv4Any + if len(v.InterfaceAddr) == 0 || v.InterfaceAddr == allZeros { + if nicID == 0 { + r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) + if err == nil { + nicID = r.NICID() + r.Release() + } + } + } else { + nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) + } + if nicID == 0 { + return tcpip.ErrUnknownDevice + } + + memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} + + e.mu.Lock() + defer e.mu.Unlock() + + for _, mem := range e.multicastMemberships { + if mem == memToInsert { + return tcpip.ErrPortInUse + } + } + + if err := e.stack.JoinGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.multicastMemberships = append(e.multicastMemberships, memToInsert) + + case tcpip.RemoveMembershipOption: + if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { + return tcpip.ErrInvalidOptionValue + } + + nicID := v.NIC + if v.InterfaceAddr == header.IPv4Any { + if nicID == 0 { + r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) + if err == nil { + nicID = r.NICID() + r.Release() + } + } + } else { + nicID = e.stack.CheckLocalAddress(nicID, e.NetProto, v.InterfaceAddr) + } + if nicID == 0 { + return tcpip.ErrUnknownDevice + } + + memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} + memToRemoveIndex := -1 + + e.mu.Lock() + defer e.mu.Unlock() + + for i, mem := range e.multicastMemberships { + if mem == memToRemove { + memToRemoveIndex = i + break + } + } + if memToRemoveIndex == -1 { + return tcpip.ErrBadLocalAddress + } + + if err := e.stack.LeaveGroup(e.NetProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] + e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + + case tcpip.BindToDeviceOption: + id := tcpip.NICID(v) + if id != 0 && !e.stack.HasNIC(id) { + return tcpip.ErrUnknownDevice + } + e.mu.Lock() + e.bindToDevice = id + e.mu.Unlock() + } + return nil +} + +// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. +func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.BroadcastOption: + e.mu.RLock() + v := e.broadcast + e.mu.RUnlock() + return v, nil + + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.MulticastLoopOption: + e.mu.RLock() + v := e.multicastLoop + e.mu.RUnlock() + return v, nil + + case tcpip.NoChecksumOption: + e.mu.RLock() + v := e.noChecksum + e.mu.RUnlock() + return v, nil + + case tcpip.ReceiveTOSOption: + e.mu.RLock() + v := e.receiveTOS + e.mu.RUnlock() + return v, nil + + case tcpip.ReceiveTClassOption: + // We only support this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrNotSupported + } + + e.mu.RLock() + v := e.receiveTClass + e.mu.RUnlock() + return v, nil + + case tcpip.ReceiveIPPacketInfoOption: + e.mu.RLock() + v := e.receiveIPPacketInfo + e.mu.RUnlock() + return v, nil + + case tcpip.ReuseAddressOption: + e.mu.RLock() + v := e.portFlags.MostRecent + e.mu.RUnlock() + + return v, nil + + case tcpip.ReusePortOption: + e.mu.RLock() + v := e.portFlags.LoadBalanced + e.mu.RUnlock() + + return v, nil + + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.NetProto != header.IPv6ProtocolNumber { + return false, tcpip.ErrUnknownProtocolOption + } + + e.mu.RLock() + v := e.v6only + e.mu.RUnlock() + + return v, nil + + default: + return false, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. +func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { + switch opt { + case tcpip.IPv4TOSOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + case tcpip.IPv6TrafficClassOption: + e.mu.RLock() + v := int(e.sendTOS) + e.mu.RUnlock() + return v, nil + + case tcpip.MTUDiscoverOption: + // The only supported setting is path MTU discovery disabled. + return tcpip.PMTUDiscoveryDont, nil + + case tcpip.MulticastTTLOption: + e.mu.Lock() + v := int(e.multicastTTL) + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveQueueSizeOption: + v := 0 + e.rcvMu.Lock() + if !e.rcvList.Empty() { + p := e.rcvList.Front() + v = p.data.Size() + } + e.rcvMu.Unlock() + return v, nil + + case tcpip.SendBufferSizeOption: + e.mu.Lock() + v := e.sndBufSizeMax + e.mu.Unlock() + return v, nil + + case tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + v := e.rcvBufSizeMax + e.rcvMu.Unlock() + return v, nil + + case tcpip.TTLOption: + e.mu.Lock() + v := int(e.ttl) + e.mu.Unlock() + return v, nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return e.takeLastError() + case *tcpip.MulticastInterfaceOption: + e.mu.Lock() + *o = tcpip.MulticastInterfaceOption{ + e.multicastNICID, + e.multicastAddr, + } + e.mu.Unlock() + + case *tcpip.BindToDeviceOption: + e.mu.RLock() + *o = tcpip.BindToDeviceOption(e.bindToDevice) + e.mu.RUnlock() + + default: + return tcpip.ErrUnknownProtocolOption + } + return nil +} + +// sendUDP sends a UDP segment via the provided network endpoint and under the +// provided identity. +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8, useDefaultTTL bool, tos uint8, owner tcpip.PacketOwner, noChecksum bool) *tcpip.Error { + // Allocate a buffer for the UDP header. + hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) + + // Initialize the header. + udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + + length := uint16(hdr.UsedLength() + data.Size()) + udp.Encode(&header.UDPFields{ + SrcPort: localPort, + DstPort: remotePort, + Length: length, + }) + + // Set the checksum field unless TX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value indicates the + // transmitter skipped the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 && + (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) { + xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) + for _, v := range data.Views() { + xsum = header.Checksum(v, xsum) + } + udp.SetChecksum(^udp.CalculateChecksum(xsum)) + } + + if useDefaultTTL { + ttl = r.DefaultTTL() + } + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: ProtocolNumber, + TTL: ttl, + TOS: tos, + }, &stack.PacketBuffer{ + Header: hdr, + Data: data, + TransportHeader: buffer.View(udp), + Owner: owner, + }); err != nil { + r.Stats().UDP.PacketSendErrors.Increment() + return err + } + + // Track count of packets sent. + r.Stats().UDP.PacketsSent.Increment() + return nil +} + +// checkV4MappedLocked determines the effective network protocol and converts +// addr to its canonical form. +func (e *endpoint) checkV4MappedLocked(addr tcpip.FullAddress) (tcpip.FullAddress, tcpip.NetworkProtocolNumber, *tcpip.Error) { + unwrapped, netProto, err := e.TransportEndpointInfo.AddrNetProtoLocked(addr, e.v6only) + if err != nil { + return tcpip.FullAddress{}, 0, err + } + return unwrapped, netProto, nil +} + +// Disconnect implements tcpip.Endpoint.Disconnect. +func (e *endpoint) Disconnect() *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + if e.state != StateConnected { + return nil + } + var ( + id stack.TransportEndpointID + btd tcpip.NICID + ) + + // We change this value below and we need the old value to unregister + // the endpoint. + boundPortFlags := e.boundPortFlags + + // Exclude ephemerally bound endpoints. + if e.BindNICID != 0 || e.ID.LocalAddress == "" { + var err *tcpip.Error + id = stack.TransportEndpointID{ + LocalPort: e.ID.LocalPort, + LocalAddress: e.ID.LocalAddress, + } + id, btd, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + if err != nil { + return err + } + e.state = StateBound + boundPortFlags = e.boundPortFlags + } else { + if e.ID.LocalPort != 0 { + // Release the ephemeral port. + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.ID.LocalAddress, e.ID.LocalPort, boundPortFlags, e.boundBindToDevice, tcpip.FullAddress{}) + e.boundPortFlags = ports.Flags{} + } + e.state = StateInitial + } + + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, boundPortFlags, e.boundBindToDevice) + e.ID = id + e.boundBindToDevice = btd + e.route.Release() + e.route = stack.Route{} + e.dstPort = 0 + + return nil +} + +// Connect connects the endpoint to its peer. Specifying a NIC is optional. +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + if addr.Port == 0 { + // We don't support connecting to port zero. + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + nicID := addr.NIC + var localPort uint16 + switch e.state { + case StateInitial: + case StateBound, StateConnected: + localPort = e.ID.LocalPort + if e.BindNICID == 0 { + break + } + + if nicID != 0 && nicID != e.BindNICID { + return tcpip.ErrInvalidEndpointState + } + + nicID = e.BindNICID + default: + return tcpip.ErrInvalidEndpointState + } + + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + + r, nicID, err := e.connectRoute(nicID, addr, netProto) + if err != nil { + return err + } + defer r.Release() + + id := stack.TransportEndpointID{ + LocalAddress: e.ID.LocalAddress, + LocalPort: localPort, + RemotePort: addr.Port, + RemoteAddress: r.RemoteAddress, + } + + if e.state == StateInitial { + id.LocalAddress = r.LocalAddress + } + + // Even if we're connected, this endpoint can still be used to send + // packets on a different network protocol, so we register both even if + // v6only is set to false and this is an ipv6 endpoint. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv4ProtocolNumber, + header.IPv6ProtocolNumber, + } + } + + oldPortFlags := e.boundPortFlags + + id, btd, err := e.registerWithStack(nicID, netProtos, id) + if err != nil { + return err + } + + // Remove the old registration. + if e.ID.LocalPort != 0 { + e.stack.UnregisterTransportEndpoint(e.RegisterNICID, e.effectiveNetProtos, ProtocolNumber, e.ID, e, oldPortFlags, e.boundBindToDevice) + } + + e.ID = id + e.boundBindToDevice = btd + e.route = r.Clone() + e.dstPort = addr.Port + e.RegisterNICID = nicID + e.effectiveNetProtos = netProtos + + e.state = StateConnected + + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// ConnectEndpoint is not supported. +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + +// Shutdown closes the read and/or write end of the endpoint connection +// to its peer. +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // A socket in the bound state can still receive multicast messages, + // so we need to notify waiters on shutdown. + if e.state != StateBound && e.state != StateConnected { + return tcpip.ErrNotConnected + } + + e.shutdownFlags |= flags + + if flags&tcpip.ShutdownRead != 0 { + e.rcvMu.Lock() + wasClosed := e.rcvClosed + e.rcvClosed = true + e.rcvMu.Unlock() + + if !wasClosed { + e.waiterQueue.Notify(waiter.EventIn) + } + } + + return nil +} + +// Listen is not supported by UDP, it just fails. +func (*endpoint) Listen(int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept is not supported by UDP, it just fails. +func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +func (e *endpoint) registerWithStack(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, tcpip.NICID, *tcpip.Error) { + if e.ID.LocalPort == 0 { + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.portFlags, e.bindToDevice, tcpip.FullAddress{}) + if err != nil { + return id, e.bindToDevice, err + } + id.LocalPort = port + } + e.boundPortFlags = e.portFlags + + err := e.stack.RegisterTransportEndpoint(nicID, netProtos, ProtocolNumber, id, e, e.boundPortFlags, e.bindToDevice) + if err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.boundPortFlags, e.bindToDevice, tcpip.FullAddress{}) + e.boundPortFlags = ports.Flags{} + } + return id, e.bindToDevice, err +} + +func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { + // Don't allow binding once endpoint is not in the initial state + // anymore. + if e.state != StateInitial { + return tcpip.ErrInvalidEndpointState + } + + addr, netProto, err := e.checkV4MappedLocked(addr) + if err != nil { + return err + } + + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } + } + + nicID := addr.NIC + if len(addr.Addr) != 0 && !isBroadcastOrMulticast(addr.Addr) { + // A local unicast address was specified, verify that it's valid. + nicID = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nicID == 0 { + return tcpip.ErrBadLocalAddress + } + } + + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: addr.Addr, + } + id, btd, err := e.registerWithStack(nicID, netProtos, id) + if err != nil { + return err + } + + e.ID = id + e.boundBindToDevice = btd + e.RegisterNICID = nicID + e.effectiveNetProtos = netProtos + + // Mark endpoint as bound. + e.state = StateBound + + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// Bind binds the endpoint to a specific local address and port. +// Specifying a NIC is optional. +func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + err := e.bindLocked(addr) + if err != nil { + return err + } + + // Save the effective NICID generated by bindLocked. + e.BindNICID = e.RegisterNICID + + return nil +} + +// GetLocalAddress returns the address to which the endpoint is bound. +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + addr := e.ID.LocalAddress + if e.state == StateConnected { + addr = e.route.LocalAddress + } + + return tcpip.FullAddress{ + NIC: e.RegisterNICID, + Addr: addr, + Port: e.ID.LocalPort, + }, nil +} + +// GetRemoteAddress returns the address to which the endpoint is connected. +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state != StateConnected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + NIC: e.RegisterNICID, + Addr: e.ID.RemoteAddress, + Port: e.ID.RemotePort, + }, nil +} + +// Readiness returns the current readiness of the endpoint. For example, if +// waiter.EventIn is set, the endpoint is immediately readable. +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine if the endpoint is readable if requested. + if (mask & waiter.EventIn) != 0 { + e.rcvMu.Lock() + if !e.rcvList.Empty() || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvMu.Unlock() + } + + return result +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + // Get the header then trim it from the view. + hdr := header.UDP(pkt.TransportHeader) + if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { + // Malformed packet. + e.stack.Stats().UDP.MalformedPacketsReceived.Increment() + e.stats.ReceiveErrors.MalformedPacketsReceived.Increment() + return + } + + // Verify checksum unless RX checksum offload is enabled. + // On IPv4, UDP checksum is optional, and a zero value means + // the transmitter omitted the checksum generation (RFC768). + // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). + if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && + (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { + xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) + for _, v := range pkt.Data.Views() { + xsum = header.Checksum(v, xsum) + } + if hdr.CalculateChecksum(xsum) != 0xffff { + // Checksum Error. + e.stack.Stats().UDP.ChecksumErrors.Increment() + e.stats.ReceiveErrors.ChecksumErrors.Increment() + return + } + } + + e.rcvMu.Lock() + e.stack.Stats().UDP.PacketsReceived.Increment() + e.stats.PacketsReceived.Increment() + + // Drop the packet if our buffer is currently full. + if !e.rcvReady || e.rcvClosed { + e.rcvMu.Unlock() + e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return + } + + if e.rcvBufSize >= e.rcvBufSizeMax { + e.rcvMu.Unlock() + e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() + return + } + + wasEmpty := e.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + packet := &udpPacket{ + senderAddress: tcpip.FullAddress{ + NIC: r.NICID(), + Addr: id.RemoteAddress, + Port: header.UDP(hdr).SourcePort(), + }, + } + packet.data = pkt.Data + e.rcvList.PushBack(packet) + e.rcvBufSize += pkt.Data.Size() + + // Save any useful information from the network header to the packet. + switch r.NetProto { + case header.IPv4ProtocolNumber: + packet.tos, _ = header.IPv4(pkt.NetworkHeader).TOS() + packet.packetInfo.LocalAddr = r.LocalAddress + packet.packetInfo.DestinationAddr = r.RemoteAddress + packet.packetInfo.NIC = r.NICID() + case header.IPv6ProtocolNumber: + packet.tos, _ = header.IPv6(pkt.NetworkHeader).TOS() + } + + packet.timestamp = e.stack.NowNanoseconds() + + e.rcvMu.Unlock() + + // Notify any waiters that there's data to be read now. + if wasEmpty { + e.waiterQueue.Notify(waiter.EventIn) + } +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { + if typ == stack.ControlPortUnreachable { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state == StateConnected { + e.lastErrorMu.Lock() + defer e.lastErrorMu.Unlock() + + e.lastError = tcpip.ErrConnectionRefused + } + } +} + +// State implements tcpip.Endpoint.State. +func (e *endpoint) State() uint32 { + e.mu.Lock() + defer e.mu.Unlock() + return uint32(e.state) +} + +// Info returns a copy of the endpoint info. +func (e *endpoint) Info() tcpip.EndpointInfo { + e.mu.RLock() + // Make a copy of the endpoint info. + ret := e.TransportEndpointInfo + e.mu.RUnlock() + return &ret +} + +// Stats returns a pointer to the endpoint stats. +func (e *endpoint) Stats() tcpip.EndpointStats { + return &e.stats +} + +// Wait implements tcpip.Endpoint.Wait. +func (*endpoint) Wait() {} + +func isBroadcastOrMulticast(a tcpip.Address) bool { + return a == header.IPv4Broadcast || header.IsV4MulticastAddress(a) || header.IsV6MulticastAddress(a) +} + +func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.owner = owner +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go new file mode 100644 index 000000000..851e6b635 --- /dev/null +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -0,0 +1,137 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) + +// saveData saves udpPacket.data field. +func (u *udpPacket) saveData() buffer.VectorisedView { + // We cannot save u.data directly as u.data.views may alias to u.views, + // which is not allowed by state framework (in-struct pointer). + return u.data.Clone(nil) +} + +// loadData loads udpPacket.data field. +func (u *udpPacket) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing u.views for data.views. + u.data = data +} + +// saveLastError is invoked by stateify. +func (e *endpoint) saveLastError() string { + if e.lastError == nil { + return "" + } + + return e.lastError.String() +} + +// loadLastError is invoked by stateify. +func (e *endpoint) loadLastError(s string) { + if s == "" { + return + } + + e.lastError = tcpip.StringToError(s) +} + +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after savercvBufSizeMax(), which would have + // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + e.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (e *endpoint) saveRcvBufSizeMax() int { + max := e.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + e.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + e.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (e *endpoint) loadRcvBufSizeMax(max int) { + e.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (e *endpoint) afterLoad() { + stack.StackFromEnv.RegisterRestoredEndpoint(e) +} + +// Resume implements tcpip.ResumableEndpoint.Resume. +func (e *endpoint) Resume(s *stack.Stack) { + e.mu.Lock() + defer e.mu.Unlock() + + e.stack = s + + for _, m := range e.multicastMemberships { + if err := e.stack.JoinGroup(e.NetProto, m.nicID, m.multicastAddr); err != nil { + panic(err) + } + } + + if e.state != StateBound && e.state != StateConnected { + return + } + + netProto := e.effectiveNetProtos[0] + // Connect() and bindLocked() both assert + // + // netProto == header.IPv6ProtocolNumber + // + // before creating a multi-entry effectiveNetProtos. + if len(e.effectiveNetProtos) > 1 { + netProto = header.IPv6ProtocolNumber + } + + var err *tcpip.Error + if e.state == StateConnected { + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.ID.LocalAddress, e.ID.RemoteAddress, netProto, e.multicastLoop) + if err != nil { + panic(err) + } + } else if len(e.ID.LocalAddress) != 0 && !isBroadcastOrMulticast(e.ID.LocalAddress) { // stateBound + // A local unicast address is specified, verify that it's valid. + if e.stack.CheckLocalAddress(e.RegisterNICID, netProto, e.ID.LocalAddress) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + // Our saved state had a port, but we don't actually have a + // reservation. We need to remove the port from our state, but still + // pass it to the reservation machinery. + id := e.ID + e.ID.LocalPort = 0 + e.ID, e.boundBindToDevice, err = e.registerWithStack(e.RegisterNICID, e.effectiveNetProtos, id) + if err != nil { + panic(err) + } +} diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go new file mode 100644 index 000000000..c67e0ba95 --- /dev/null +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -0,0 +1,96 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/waiter" +) + +// Forwarder is a session request forwarder, which allows clients to decide +// what to do with a session request, for example: ignore it, or process it. +// +// The canonical way of using it is to pass the Forwarder.HandlePacket function +// to stack.SetTransportProtocolHandler. +type Forwarder struct { + handler func(*ForwarderRequest) + + stack *stack.Stack +} + +// NewForwarder allocates and initializes a new forwarder. +func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { + return &Forwarder{ + stack: s, + handler: handler, + } +} + +// HandlePacket handles all packets. +// +// This function is expected to be passed as an argument to the +// stack.SetTransportProtocolHandler function. +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + f.handler(&ForwarderRequest{ + stack: f.stack, + route: r, + id: id, + pkt: pkt, + }) + + return true +} + +// ForwarderRequest represents a session request received by the forwarder and +// passed to the client. Clients may optionally create an endpoint to represent +// it via CreateEndpoint. +type ForwarderRequest struct { + stack *stack.Stack + route *stack.Route + id stack.TransportEndpointID + pkt *stack.PacketBuffer +} + +// ID returns the 4-tuple (src address, src port, dst address, dst port) that +// represents the session request. +func (r *ForwarderRequest) ID() stack.TransportEndpointID { + return r.id +} + +// CreateEndpoint creates a connected UDP endpoint for the session request. +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + ep := newEndpoint(r.stack, r.route.NetProto, queue) + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { + ep.Close() + return nil, err + } + + ep.ID = r.id + ep.route = r.route.Clone() + ep.dstPort = r.id.RemotePort + ep.RegisterNICID = r.route.NICID() + ep.boundPortFlags = ep.portFlags + + ep.state = StateConnected + + ep.rcvMu.Lock() + ep.rcvReady = true + ep.rcvMu.Unlock() + + ep.HandlePacket(r.route, r.id, r.pkt) + + return ep, nil +} diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go new file mode 100644 index 000000000..0e7464e3a --- /dev/null +++ b/pkg/tcpip/transport/udp/protocol.go @@ -0,0 +1,231 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package udp contains the implementation of the UDP transport protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing udp.NewProtocol() as one of the +// transport protocols when calling stack.New(). Then endpoints can be created +// by passing udp.ProtocolNumber as the transport protocol number when calling +// Stack.NewEndpoint(). +package udp + +import ( + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/raw" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + // ProtocolNumber is the udp protocol number. + ProtocolNumber = header.UDPProtocolNumber + + // MinBufferSize is the smallest size of a receive or send buffer. + MinBufferSize = 4 << 10 // 4KiB bytes. + + // DefaultSendBufferSize is the default size of the send buffer for + // an endpoint. + DefaultSendBufferSize = 32 << 10 // 32KiB + + // DefaultReceiveBufferSize is the default size of the receive buffer + // for an endpoint. + DefaultReceiveBufferSize = 32 << 10 // 32KiB + + // MaxBufferSize is the largest size a receive/send buffer can grow to. + MaxBufferSize = 4 << 20 // 4MiB +) + +type protocol struct { +} + +// Number returns the udp protocol number. +func (*protocol) Number() tcpip.TransportProtocolNumber { + return ProtocolNumber +} + +// NewEndpoint creates a new udp endpoint. +func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, waiterQueue), nil +} + +// NewRawEndpoint creates a new raw UDP endpoint. It implements +// stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue) +} + +// MinimumPacketSize returns the minimum valid udp packet size. +func (*protocol) MinimumPacketSize() int { + return header.UDPMinimumSize +} + +// ParsePorts returns the source and destination ports stored in the given udp +// packet. +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + h := header.UDP(v) + return h.SourcePort(), h.DestinationPort(), nil +} + +// HandleUnknownDestinationPacket handles packets targeted at this protocol but +// that don't match any existing endpoint. +func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + hdr := header.UDP(pkt.TransportHeader) + if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { + // Malformed packet. + r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + return true + } + // TODO(b/129426613): only send an ICMP message if UDP checksum is valid. + + // Only send ICMP error if the address is not a multicast/broadcast + // v4/v6 address or the source is not the unspecified address. + // + // See: point e) in https://tools.ietf.org/html/rfc4443#section-2.4 + if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) || id.RemoteAddress == header.IPv6Any || id.RemoteAddress == header.IPv4Any { + return true + } + + // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination + // Unreachable messages with code: + // + // 2 (Protocol Unreachable), when the designated transport protocol + // is not supported; or + // + // 3 (Port Unreachable), when the designated transport protocol + // (e.g., UDP) is unable to demultiplex the datagram but has no + // protocol mechanism to inform the sender. + switch len(id.LocalAddress) { + case header.IPv4AddressSize: + if !r.Stack().AllowICMPMessage() { + r.Stack().Stats().ICMP.V4PacketsSent.RateLimited.Increment() + return true + } + // As per RFC 1812 Section 4.3.2.3 + // + // ICMP datagram SHOULD contain as much of the original + // datagram as possible without the length of the ICMP + // datagram exceeding 576 bytes + // + // NOTE: The above RFC referenced is different from the original + // recommendation in RFC 1122 where it mentioned that at least 8 + // bytes of the payload must be included. Today linux and other + // systems implement the] RFC1812 definition and not the original + // RFC 1122 requirement. + mtu := int(r.MTU()) + if mtu > header.IPv4MinimumProcessableDatagramSize { + mtu = header.IPv4MinimumProcessableDatagramSize + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv4MinimumSize + available := int(mtu) - headerLen + payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() + if payloadLen > available { + payloadLen = available + } + + // The buffers used by pkt may be used elsewhere in the system. + // For example, a raw or packet socket may use what UDP + // considers an unreachable destination. Thus we deep copy pkt + // to prevent multiple ownership and SR errors. + newHeader := append(buffer.View(nil), pkt.NetworkHeader...) + newHeader = append(newHeader, pkt.TransportHeader...) + payload := newHeader.ToVectorisedView() + payload.AppendView(pkt.Data.ToView()) + payload.CapLength(payloadLen) + + hdr := buffer.NewPrependable(headerLen) + pkt := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize)) + pkt.SetType(header.ICMPv4DstUnreachable) + pkt.SetCode(header.ICMPv4PortUnreachable) + pkt.SetChecksum(header.ICMPv4Checksum(pkt, payload)) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv4ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ + Header: hdr, + TransportHeader: buffer.View(pkt), + Data: payload, + }) + + case header.IPv6AddressSize: + if !r.Stack().AllowICMPMessage() { + r.Stack().Stats().ICMP.V6PacketsSent.RateLimited.Increment() + return true + } + + // As per RFC 4443 section 2.4 + // + // (c) Every ICMPv6 error message (type < 128) MUST include + // as much of the IPv6 offending (invoking) packet (the + // packet that caused the error) as possible without making + // the error message packet exceed the minimum IPv6 MTU + // [IPv6]. + mtu := int(r.MTU()) + if mtu > header.IPv6MinimumMTU { + mtu = header.IPv6MinimumMTU + } + headerLen := int(r.MaxHeaderLength()) + header.ICMPv6DstUnreachableMinimumSize + available := int(mtu) - headerLen + payloadLen := len(pkt.NetworkHeader) + len(pkt.TransportHeader) + pkt.Data.Size() + if payloadLen > available { + payloadLen = available + } + payload := buffer.NewVectorisedView(len(pkt.NetworkHeader)+len(pkt.TransportHeader), []buffer.View{pkt.NetworkHeader, pkt.TransportHeader}) + payload.Append(pkt.Data) + payload.CapLength(payloadLen) + + hdr := buffer.NewPrependable(headerLen) + pkt := header.ICMPv6(hdr.Prepend(header.ICMPv6DstUnreachableMinimumSize)) + pkt.SetType(header.ICMPv6DstUnreachable) + pkt.SetCode(header.ICMPv6PortUnreachable) + pkt.SetChecksum(header.ICMPv6Checksum(pkt, r.LocalAddress, r.RemoteAddress, payload)) + r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, &stack.PacketBuffer{ + Header: hdr, + TransportHeader: buffer.View(pkt), + Data: payload, + }) + } + return true +} + +// SetOption implements stack.TransportProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements stack.TransportProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Close implements stack.TransportProtocol.Close. +func (*protocol) Close() {} + +// Wait implements stack.TransportProtocol.Wait. +func (*protocol) Wait() {} + +// Parse implements stack.TransportProtocol.Parse. +func (*protocol) Parse(pkt *stack.PacketBuffer) bool { + h, ok := pkt.Data.PullUp(header.UDPMinimumSize) + if !ok { + // Packet is too small + return false + } + pkt.TransportHeader = h + pkt.Data.TrimFront(header.UDPMinimumSize) + return true +} + +// NewProtocol returns a UDP transport protocol. +func NewProtocol() stack.TransportProtocol { + return &protocol{} +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go new file mode 100644 index 000000000..91ba031fa --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -0,0 +1,2072 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp_test + +import ( + "bytes" + "context" + "fmt" + "math/rand" + "testing" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// Addresses and ports used for testing. It is recommended that tests stick to +// using these addresses as it allows using the testFlow helper. +// Naming rules: 'stack*'' denotes local addresses and ports, while 'test*' +// represents the remote endpoint. +const ( + v4MappedAddrPrefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + stackV4MappedAddr = v4MappedAddrPrefix + stackAddr + testV4MappedAddr = v4MappedAddrPrefix + testAddr + multicastV4MappedAddr = v4MappedAddrPrefix + multicastAddr + broadcastV4MappedAddr = v4MappedAddrPrefix + broadcastAddr + v4MappedWildcardAddr = v4MappedAddrPrefix + "\x00\x00\x00\x00" + + stackAddr = "\x0a\x00\x00\x01" + stackPort = 1234 + testAddr = "\x0a\x00\x00\x02" + testPort = 4096 + multicastAddr = "\xe8\x2b\xd3\xea" + multicastV6Addr = "\xff\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00" + broadcastAddr = header.IPv4Broadcast + testTOS = 0x80 + + // defaultMTU is the MTU, in bytes, used throughout the tests, except + // where another value is explicitly used. It is chosen to match the MTU + // of loopback interfaces on linux systems. + defaultMTU = 65536 +) + +// header4Tuple stores the 4-tuple {src-IP, src-port, dst-IP, dst-port} used in +// a packet header. These values are used to populate a header or verify one. +// Note that because they are used in packet headers, the addresses are never in +// a V4-mapped format. +type header4Tuple struct { + srcAddr tcpip.FullAddress + dstAddr tcpip.FullAddress +} + +// testFlow implements a helper type used for sending and receiving test +// packets. A given test flow value defines 1) the socket endpoint used for the +// test and 2) the type of packet send or received on the endpoint. E.g., a +// multicastV6Only flow is a V6 multicast packet passing through a V6-only +// endpoint. The type provides helper methods to characterize the flow (e.g., +// isV4) as well as return a proper header4Tuple for it. +type testFlow int + +const ( + unicastV4 testFlow = iota // V4 unicast on a V4 socket + unicastV4in6 // V4-mapped unicast on a V6-dual socket + unicastV6 // V6 unicast on a V6 socket + unicastV6Only // V6 unicast on a V6-only socket + multicastV4 // V4 multicast on a V4 socket + multicastV4in6 // V4-mapped multicast on a V6-dual socket + multicastV6 // V6 multicast on a V6 socket + multicastV6Only // V6 multicast on a V6-only socket + broadcast // V4 broadcast on a V4 socket + broadcastIn6 // V4-mapped broadcast on a V6-dual socket +) + +func (flow testFlow) String() string { + switch flow { + case unicastV4: + return "unicastV4" + case unicastV6: + return "unicastV6" + case unicastV6Only: + return "unicastV6Only" + case unicastV4in6: + return "unicastV4in6" + case multicastV4: + return "multicastV4" + case multicastV6: + return "multicastV6" + case multicastV6Only: + return "multicastV6Only" + case multicastV4in6: + return "multicastV4in6" + case broadcast: + return "broadcast" + case broadcastIn6: + return "broadcastIn6" + default: + return "unknown" + } +} + +// packetDirection explains if a flow is incoming (read) or outgoing (write). +type packetDirection int + +const ( + incoming packetDirection = iota + outgoing +) + +// header4Tuple returns the header4Tuple for the given flow and direction. Note +// that the tuple contains no mapped addresses as those only exist at the socket +// level but not at the packet header level. +func (flow testFlow) header4Tuple(d packetDirection) header4Tuple { + var h header4Tuple + if flow.isV4() { + if d == outgoing { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, + dstAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, + } + } else { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testAddr, Port: testPort}, + dstAddr: tcpip.FullAddress{Addr: stackAddr, Port: stackPort}, + } + } + if flow.isMulticast() { + h.dstAddr.Addr = multicastAddr + } else if flow.isBroadcast() { + h.dstAddr.Addr = broadcastAddr + } + } else { // IPv6 + if d == outgoing { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, + dstAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + } + } else { + h = header4Tuple{ + srcAddr: tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + dstAddr: tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}, + } + } + if flow.isMulticast() { + h.dstAddr.Addr = multicastV6Addr + } + } + return h +} + +func (flow testFlow) getMcastAddr() tcpip.Address { + if flow.isV4() { + return multicastAddr + } + return multicastV6Addr +} + +// mapAddrIfApplicable converts the given V4 address into its V4-mapped version +// if it is applicable to the flow. +func (flow testFlow) mapAddrIfApplicable(v4Addr tcpip.Address) tcpip.Address { + if flow.isMapped() { + return v4MappedAddrPrefix + v4Addr + } + return v4Addr +} + +// netProto returns the protocol number used for the network packet. +func (flow testFlow) netProto() tcpip.NetworkProtocolNumber { + if flow.isV4() { + return ipv4.ProtocolNumber + } + return ipv6.ProtocolNumber +} + +// sockProto returns the protocol number used when creating the socket +// endpoint for this flow. +func (flow testFlow) sockProto() tcpip.NetworkProtocolNumber { + switch flow { + case unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, multicastV6Only, broadcastIn6: + return ipv6.ProtocolNumber + case unicastV4, multicastV4, broadcast: + return ipv4.ProtocolNumber + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) checkerFn() func(*testing.T, []byte, ...checker.NetworkChecker) { + if flow.isV4() { + return checker.IPv4 + } + return checker.IPv6 +} + +func (flow testFlow) isV6() bool { return !flow.isV4() } +func (flow testFlow) isV4() bool { + return flow.sockProto() == ipv4.ProtocolNumber || flow.isMapped() +} + +func (flow testFlow) isV6Only() bool { + switch flow { + case unicastV6Only, multicastV6Only: + return true + case unicastV4, unicastV4in6, unicastV6, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isMulticast() bool { + switch flow { + case multicastV4, multicastV4in6, multicastV6, multicastV6Only: + return true + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isBroadcast() bool { + switch flow { + case broadcast, broadcastIn6: + return true + case unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, multicastV6Only: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +func (flow testFlow) isMapped() bool { + switch flow { + case unicastV4in6, multicastV4in6, broadcastIn6: + return true + case unicastV4, unicastV6, unicastV6Only, multicastV4, multicastV6, multicastV6Only, broadcast: + return false + default: + panic(fmt.Sprintf("invalid testFlow given: %d", flow)) + } +} + +type testContext struct { + t *testing.T + linkEP *channel.Endpoint + s *stack.Stack + + ep tcpip.Endpoint + wq waiter.Queue +} + +func newDualTestContext(t *testing.T, mtu uint32) *testContext { + t.Helper() + return newDualTestContextWithOptions(t, mtu, stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + }) +} + +func newDualTestContextWithOptions(t *testing.T, mtu uint32, options stack.Options) *testContext { + t.Helper() + + s := stack.New(options) + ep := channel.New(256, mtu, "") + wep := stack.LinkEndpoint(ep) + + if testing.Verbose() { + wep = sniffer.New(ep) + } + if err := s.CreateNIC(1, wep); err != nil { + t.Fatalf("CreateNIC failed: %s", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress failed: %s", err) + } + + if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { + t.Fatalf("AddAddress failed: %s", err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: 1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + + return &testContext{ + t: t, + s: s, + linkEP: ep, + } +} + +func (c *testContext) cleanup() { + if c.ep != nil { + c.ep.Close() + } +} + +func (c *testContext) createEndpoint(proto tcpip.NetworkProtocolNumber) { + c.t.Helper() + + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, proto, &c.wq) + if err != nil { + c.t.Fatal("NewEndpoint failed: ", err) + } +} + +func (c *testContext) createEndpointForFlow(flow testFlow) { + c.t.Helper() + + c.createEndpoint(flow.sockProto()) + if flow.isV6Only() { + if err := c.ep.SetSockOptBool(tcpip.V6OnlyOption, true); err != nil { + c.t.Fatalf("SetSockOptBool failed: %s", err) + } + } else if flow.isBroadcast() { + if err := c.ep.SetSockOptBool(tcpip.BroadcastOption, true); err != nil { + c.t.Fatalf("SetSockOptBool failed: %s", err) + } + } +} + +// getPacketAndVerify reads a packet from the link endpoint and verifies the +// header against expected values from the given test flow. In addition, it +// calls any extra checker functions provided. +func (c *testContext) getPacketAndVerify(flow testFlow, checkers ...checker.NetworkChecker) []byte { + c.t.Helper() + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + c.t.Fatalf("Packet wasn't written out") + return nil + } + + if p.Proto != flow.netProto() { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, flow.netProto()) + } + + hdr := p.Pkt.Header.View() + b := append(hdr[:len(hdr):len(hdr)], p.Pkt.Data.ToView()...) + + h := flow.header4Tuple(outgoing) + checkers = append( + checkers, + checker.SrcAddr(h.srcAddr.Addr), + checker.DstAddr(h.dstAddr.Addr), + checker.UDP(checker.DstPort(h.dstAddr.Port)), + ) + flow.checkerFn()(c.t, b, checkers...) + return b +} + +// injectPacket creates a packet of the given flow and with the given payload, +// and injects it into the link endpoint. +func (c *testContext) injectPacket(flow testFlow, payload []byte) { + c.t.Helper() + + h := flow.header4Tuple(incoming) + if flow.isV4() { + buf := c.buildV4Packet(payload, &h) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + } else { + buf := c.buildV6Packet(payload, &h) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + } +} + +// buildV6Packet creates a V6 test packet with the given payload and header +// values in a buffer. +func (c *testContext) buildV6Packet(payload []byte, h *header4Tuple) buffer.View { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + TrafficClass: testTOS, + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, + }) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + return buf +} + +// buildV4Packet creates a V4 test packet with the given payload and header +// values in a buffer. +func (c *testContext) buildV4Packet(payload []byte, h *header4Tuple) buffer.View { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) + payloadStart := len(buf) - len(payload) + copy(buf[payloadStart:], payload) + + // Initialize the IP header. + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: testTOS, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(udp.ProtocolNumber), + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + return buf +} + +func newPayload() []byte { + return newMinPayload(30) +} + +func newMinPayload(minSize int) []byte { + b := make([]byte, minSize+rand.Intn(100)) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return b +} + +func TestBindToDeviceOption(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + + ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %s", err) + } + defer ep.Close() + + opts := stack.NICOptions{Name: "my_device"} + if err := s.CreateNICWithOptions(321, loopback.New(), opts); err != nil { + t.Errorf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) + } + + // nicIDPtr is used instead of taking the address of NICID literals, which is + // a compiler error. + nicIDPtr := func(s tcpip.NICID) *tcpip.NICID { + return &s + } + + testActions := []struct { + name string + setBindToDevice *tcpip.NICID + setBindToDeviceError *tcpip.Error + getBindToDevice tcpip.BindToDeviceOption + }{ + {"GetDefaultValue", nil, nil, 0}, + {"BindToNonExistent", nicIDPtr(999), tcpip.ErrUnknownDevice, 0}, + {"BindToExistent", nicIDPtr(321), nil, 321}, + {"UnbindToDevice", nicIDPtr(0), nil, 0}, + } + for _, testAction := range testActions { + t.Run(testAction.name, func(t *testing.T) { + if testAction.setBindToDevice != nil { + bindToDevice := tcpip.BindToDeviceOption(*testAction.setBindToDevice) + if gotErr, wantErr := ep.SetSockOpt(bindToDevice), testAction.setBindToDeviceError; gotErr != wantErr { + t.Errorf("SetSockOpt(%v) got %v, want %v", bindToDevice, gotErr, wantErr) + } + } + bindToDevice := tcpip.BindToDeviceOption(88888) + if err := ep.GetSockOpt(&bindToDevice); err != nil { + t.Errorf("GetSockOpt got %v, want %v", err, nil) + } + if got, want := bindToDevice, testAction.getBindToDevice; got != want { + t.Errorf("bindToDevice got %d, want %d", got, want) + } + }) + } +} + +// testReadInternal sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then attempts to read it from the +// UDP endpoint and depending on if this was expected to succeed verifies its +// correctness including any additional checker functions provided. +func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expectReadError bool, checkers ...checker.ControlMessagesChecker) { + c.t.Helper() + + payload := newPayload() + c.injectPacket(flow, payload) + + // Try to receive the data. + we, ch := waiter.NewChannelEntry(nil) + c.wq.EventRegister(&we, waiter.EventIn) + defer c.wq.EventUnregister(&we) + + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + + var addr tcpip.FullAddress + v, cm, err := c.ep.Read(&addr) + if err == tcpip.ErrWouldBlock { + // Wait for data to become available. + select { + case <-ch: + v, cm, err = c.ep.Read(&addr) + + case <-time.After(300 * time.Millisecond): + if packetShouldBeDropped { + return // expected to time out + } + c.t.Fatal("timed out waiting for data") + } + } + + if expectReadError && err != nil { + c.checkEndpointReadStats(1, epstats, err) + return + } + + if err != nil { + c.t.Fatal("Read failed:", err) + } + + if packetShouldBeDropped { + c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr) + } + + // Check the peer address. + h := flow.header4Tuple(incoming) + if addr.Addr != h.srcAddr.Addr { + c.t.Fatalf("unexpected remote address: got %s, want %v", addr.Addr, h.srcAddr) + } + + // Check the payload. + if !bytes.Equal(payload, v) { + c.t.Fatalf("bad payload: got %x, want %x", v, payload) + } + + // Run any checkers against the ControlMessages. + for _, f := range checkers { + f(c.t, cm) + } + + c.checkEndpointReadStats(1, epstats, err) +} + +// testRead sends a packet of the given test flow into the stack by injecting it +// into the link endpoint. It then reads it from the UDP endpoint and verifies +// its correctness including any additional checker functions provided. +func testRead(c *testContext, flow testFlow, checkers ...checker.ControlMessagesChecker) { + c.t.Helper() + testReadInternal(c, flow, false /* packetShouldBeDropped */, false /* expectReadError */, checkers...) +} + +// testFailingRead sends a packet of the given test flow into the stack by +// injecting it into the link endpoint. It then tries to read it from the UDP +// endpoint and expects this to fail. +func testFailingRead(c *testContext, flow testFlow, expectReadError bool) { + c.t.Helper() + testReadInternal(c, flow, true /* packetShouldBeDropped */, expectReadError) +} + +func TestBindEphemeralPort(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + if err := c.ep.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("ep.Bind(...) failed: %s", err) + } +} + +func TestBindReservedPort(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %s", err) + } + + addr, err := c.ep.GetLocalAddress() + if err != nil { + t.Fatalf("GetLocalAddress failed: %s", err) + } + + // We can't bind the address reserved by the connected endpoint above. + { + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() + if got, want := ep.Bind(addr), tcpip.ErrPortInUse; got != want { + t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) + } + } + + func() { + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() + // We can't bind ipv4-any on the port reserved by the connected endpoint + // above, since the endpoint is dual-stack. + if got, want := ep.Bind(tcpip.FullAddress{Port: addr.Port}), tcpip.ErrPortInUse; got != want { + t.Fatalf("got ep.Bind(...) = %v, want = %v", got, want) + } + // We can bind an ipv4 address on this port, though. + if err := ep.Bind(tcpip.FullAddress{Addr: stackAddr, Port: addr.Port}); err != nil { + t.Fatalf("ep.Bind(...) failed: %s", err) + } + }() + + // Once the connected endpoint releases its port reservation, we are able to + // bind ipv4-any once again. + c.ep.Close() + func() { + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %s", err) + } + defer ep.Close() + if err := ep.Bind(tcpip.FullAddress{Port: addr.Port}); err != nil { + t.Fatalf("ep.Bind(...) failed: %s", err) + } + }() +} + +func TestV4ReadOnV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(unicastV4in6) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Test acceptance. + testRead(c, unicastV4in6) +} + +func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(unicastV4in6) + + // Bind to v4 mapped wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Addr: v4MappedWildcardAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Test acceptance. + testRead(c, unicastV4in6) +} + +func TestV4ReadOnBoundToV4Mapped(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(unicastV4in6) + + // Bind to local address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Test acceptance. + testRead(c, unicastV4in6) +} + +func TestV6ReadOnV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(unicastV6) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Test acceptance. + testRead(c, unicastV6) +} + +// TestV4ReadSelfSource checks that packets coming from a local IP address are +// correctly dropped when handleLocal is true and not otherwise. +func TestV4ReadSelfSource(t *testing.T) { + for _, tt := range []struct { + name string + handleLocal bool + wantErr *tcpip.Error + wantInvalidSource uint64 + }{ + {"HandleLocal", false, nil, 0}, + {"NoHandleLocal", true, tcpip.ErrWouldBlock, 1}, + } { + t.Run(tt.name, func(t *testing.T) { + c := newDualTestContextWithOptions(t, defaultMTU, stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + HandleLocal: tt.handleLocal, + }) + defer c.cleanup() + + c.createEndpointForFlow(unicastV4) + + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + h.srcAddr = h.dstAddr + + buf := c.buildV4Packet(payload, &h) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + if got := c.s.Stats().IP.InvalidSourceAddressesReceived.Value(); got != tt.wantInvalidSource { + t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) + } + + if _, _, err := c.ep.Read(nil); err != tt.wantErr { + t.Errorf("c.ep.Read() got error %v, want %v", err, tt.wantErr) + } + }) + } +} + +func TestV4ReadOnV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(unicastV4) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Test acceptance. + testRead(c, unicastV4) +} + +// TestReadOnBoundToMulticast checks that an endpoint can bind to a multicast +// address and receive data sent to that address. +func TestReadOnBoundToMulticast(t *testing.T) { + // FIXME(b/128189410): multicastV4in6 currently doesn't work as + // AddMembershipOption doesn't handle V4in6 addresses. + for _, flow := range []testFlow{multicastV4, multicastV6, multicastV6Only} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to multicast address. + mcastAddr := flow.mapAddrIfApplicable(flow.getMcastAddr()) + if err := c.ep.Bind(tcpip.FullAddress{Addr: mcastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + // Join multicast group. + ifoptSet := tcpip.AddMembershipOption{NIC: 1, MulticastAddr: mcastAddr} + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatal("SetSockOpt failed:", err) + } + + // Check that we receive multicast packets but not unicast or broadcast + // ones. + testRead(c, flow) + testFailingRead(c, broadcast, false /* expectReadError */) + testFailingRead(c, unicastV4, false /* expectReadError */) + }) + } +} + +// TestV4ReadOnBoundToBroadcast checks that an endpoint can bind to a broadcast +// address and can receive only broadcast data. +func TestV4ReadOnBoundToBroadcast(t *testing.T) { + for _, flow := range []testFlow{broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to broadcast address. + bcastAddr := flow.mapAddrIfApplicable(broadcastAddr) + if err := c.ep.Bind(tcpip.FullAddress{Addr: bcastAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Check that we receive broadcast packets but not unicast ones. + testRead(c, flow) + testFailingRead(c, unicastV4, false /* expectReadError */) + }) + } +} + +// TestV4ReadBroadcastOnBoundToWildcard checks that an endpoint can bind to ANY +// and receive broadcast and unicast data. +func TestV4ReadBroadcastOnBoundToWildcard(t *testing.T) { + for _, flow := range []testFlow{broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s (", err) + } + + // Check that we receive both broadcast and unicast packets. + testRead(c, flow) + testRead(c, unicastV4) + }) + } +} + +// testFailingWrite sends a packet of the given test flow into the UDP endpoint +// and verifies it fails with the provided error code. +func testFailingWrite(c *testContext, flow testFlow, wantErr *tcpip.Error) { + c.t.Helper() + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + h := flow.header4Tuple(outgoing) + writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) + + payload := buffer.View(newPayload()) + _, _, gotErr := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, + }) + c.checkEndpointWriteStats(1, epstats, gotErr) + if gotErr != wantErr { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", gotErr, wantErr) + } +} + +// testWrite sends a packet of the given test flow from the UDP endpoint to the +// flow's destination address:port. It then receives it from the link endpoint +// and verifies its correctness including any additional checker functions +// provided. +func testWrite(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + return testWriteInternal(c, flow, true, checkers...) +} + +// testWriteWithoutDestination sends a packet of the given test flow from the +// UDP endpoint without giving a destination address:port. It then receives it +// from the link endpoint and verifies its correctness including any additional +// checker functions provided. +func testWriteWithoutDestination(c *testContext, flow testFlow, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + return testWriteInternal(c, flow, false, checkers...) +} + +func testWriteInternal(c *testContext, flow testFlow, setDest bool, checkers ...checker.NetworkChecker) uint16 { + c.t.Helper() + // Take a snapshot of the stats to validate them at the end of the test. + epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + + writeOpts := tcpip.WriteOptions{} + if setDest { + h := flow.header4Tuple(outgoing) + writeDstAddr := flow.mapAddrIfApplicable(h.dstAddr.Addr) + writeOpts = tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: writeDstAddr, Port: h.dstAddr.Port}, + } + } + payload := buffer.View(newPayload()) + n, _, err := c.ep.Write(tcpip.SlicePayload(payload), writeOpts) + if err != nil { + c.t.Fatalf("Write failed: %s", err) + } + if n != int64(len(payload)) { + c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) + } + c.checkEndpointWriteStats(1, epstats, err) + // Received the packet and check the payload. + b := c.getPacketAndVerify(flow, checkers...) + var udp header.UDP + if flow.isV4() { + udp = header.UDP(header.IPv4(b).Payload()) + } else { + udp = header.UDP(header.IPv6(b).Payload()) + } + if !bytes.Equal(payload, udp.Payload()) { + c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + } + + return udp.SourcePort() +} + +func testDualWrite(c *testContext) uint16 { + c.t.Helper() + + v4Port := testWrite(c, unicastV4in6) + v6Port := testWrite(c, unicastV6) + if v4Port != v6Port { + c.t.Fatalf("expected v4 and v6 ports to be equal: got v4Port = %d, v6Port = %d", v4Port, v6Port) + } + + return v4Port +} + +func TestDualWriteUnbound(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + testDualWrite(c) +} + +func TestDualWriteBoundToWildcard(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + p := testDualWrite(c) + if p != stackPort { + c.t.Fatalf("Bad port: got %v, want %v", p, stackPort) + } +} + +func TestDualWriteConnectedToV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + // Connect to v6 address. + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, unicastV6) + + // Write to V4 mapped address. + testFailingWrite(c, unicastV4in6, tcpip.ErrNetworkUnreachable) + const want = 1 + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).SendErrors.NoRoute.Value(); got != want { + c.t.Fatalf("Endpoint stat not updated. got %d want %d", got, want) + } +} + +func TestDualWriteConnectedToV4Mapped(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + // Connect to v4 mapped address. + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, unicastV4in6) + + // Write to v6 address. + testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) +} + +func TestV4WriteOnV6Only(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(unicastV6Only) + + // Write to V4 mapped address. + testFailingWrite(c, unicastV4in6, tcpip.ErrNoRoute) +} + +func TestV6WriteOnBoundToV4Mapped(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + // Bind to v4 mapped address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + // Write to v6 address. + testFailingWrite(c, unicastV6, tcpip.ErrInvalidEndpointState) +} + +func TestV6WriteOnConnected(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + // Connect to v6 address. + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %s", err) + } + + testWriteWithoutDestination(c, unicastV6) +} + +func TestV4WriteOnConnected(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + // Connect to v4 mapped address. + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %s", err) + } + + testWriteWithoutDestination(c, unicastV4) +} + +// TestWriteOnBoundToV4Multicast checks that we can send packets out of a socket +// that is bound to a V4 multicast address. +func TestWriteOnBoundToV4Multicast(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV4MappedMulticast checks that we can send packets out of a +// socket that is bound to a V4-mapped multicast address. +func TestWriteOnBoundToV4MappedMulticast(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4Mapped mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV6Multicast checks that we can send packets out of a +// socket that is bound to a V6 multicast address. +func TestWriteOnBoundToV6Multicast(t *testing.T) { + for _, flow := range []testFlow{unicastV6, multicastV6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V6 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV6Multicast checks that we can send packets out of a +// V6-only socket that is bound to a V6 multicast address. +func TestWriteOnBoundToV6OnlyMulticast(t *testing.T) { + for _, flow := range []testFlow{unicastV6Only, multicastV6Only} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V6 mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: multicastV6Addr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToBroadcast checks that we can send packets out of a +// socket that is bound to the broadcast address. +func TestWriteOnBoundToBroadcast(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4 broadcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastAddr, Port: stackPort}); err != nil { + c.t.Fatal("Bind failed:", err) + } + + testWrite(c, flow) + }) + } +} + +// TestWriteOnBoundToV4MappedBroadcast checks that we can send packets out of a +// socket that is bound to the V4-mapped broadcast address. +func TestWriteOnBoundToV4MappedBroadcast(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, multicastV4in6, broadcastIn6} { + t.Run(fmt.Sprintf("%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Bind to V4Mapped mcast address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: broadcastV4MappedAddr, Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testWrite(c, flow) + }) + } +} + +func TestReadIncrementsPacketsReceived(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + // Create IPv4 UDP endpoint + c.createEndpoint(ipv6.ProtocolNumber) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + testRead(c, unicastV4) + + var want uint64 = 1 + if got := c.s.Stats().UDP.PacketsReceived.Value(); got != want { + c.t.Fatalf("Read did not increment PacketsReceived: got %v, want %v", got, want) + } +} + +func TestWriteIncrementsPacketsSent(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + testDualWrite(c) + + var want uint64 = 2 + if got := c.s.Stats().UDP.PacketsSent.Value(); got != want { + c.t.Fatalf("Write did not increment PacketsSent: got %v, want %v", got, want) + } +} + +func TestNoChecksum(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + // Disable the checksum generation. + if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, true); err != nil { + t.Fatalf("SetSockOptBool failed: %s", err) + } + // This option is effective on IPv4 only. + testWrite(c, flow, checker.UDP(checker.NoChecksum(flow.isV4()))) + + // Enable the checksum generation. + if err := c.ep.SetSockOptBool(tcpip.NoChecksumOption, false); err != nil { + t.Fatalf("SetSockOptBool failed: %s", err) + } + testWrite(c, flow, checker.UDP(checker.NoChecksum(false))) + }) + } +} + +func TestTTL(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + const multicastTTL = 42 + if err := c.ep.SetSockOptInt(tcpip.MulticastTTLOption, multicastTTL); err != nil { + c.t.Fatalf("SetSockOptInt failed: %s", err) + } + + var wantTTL uint8 + if flow.isMulticast() { + wantTTL = multicastTTL + } else { + var p stack.NetworkProtocol + if flow.isV4() { + p = ipv4.NewProtocol() + } else { + p = ipv6.NewProtocol() + } + ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + })) + if err != nil { + t.Fatal(err) + } + wantTTL = ep.DefaultTTL() + ep.Close() + } + + testWrite(c, flow, checker.TTL(wantTTL)) + }) + } +} + +func TestSetTTL(t *testing.T) { + for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, broadcast, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + for _, wantTTL := range []uint8{1, 2, 50, 64, 128, 254, 255} { + t.Run(fmt.Sprintf("TTL:%d", wantTTL), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + if err := c.ep.SetSockOptInt(tcpip.TTLOption, int(wantTTL)); err != nil { + c.t.Fatalf("SetSockOptInt(TTLOption, %d) failed: %s", wantTTL, err) + } + + var p stack.NetworkProtocol + if flow.isV4() { + p = ipv4.NewProtocol() + } else { + p = ipv6.NewProtocol() + } + ep, err := p.NewEndpoint(0, tcpip.AddressWithPrefix{}, nil, nil, nil, stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}, + })) + if err != nil { + t.Fatal(err) + } + ep.Close() + + testWrite(c, flow, checker.TTL(wantTTL)) + }) + } + }) + } +} + +func TestSetTOS(t *testing.T) { + for _, flow := range []testFlow{unicastV4, multicastV4, broadcast} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + const tos = testTOS + v, err := c.ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) + } + // Test for expected default value. + if v != 0 { + c.t.Errorf("got GetSockOpt(IPv4TOSOption) = 0x%x, want = 0x%x", v, 0) + } + + if err := c.ep.SetSockOptInt(tcpip.IPv4TOSOption, tos); err != nil { + c.t.Errorf("SetSockOptInt(IPv4TOSOption, 0x%x) failed: %s", tos, err) + } + + v, err = c.ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv4TOSOption) failed: %s", err) + } + + if v != tos { + c.t.Errorf("got GetSockOptInt(IPv4TOSOption) = 0x%x, want = 0x%x", v, tos) + } + + testWrite(c, flow, checker.TOS(tos, 0)) + }) + } +} + +func TestSetTClass(t *testing.T) { + for _, flow := range []testFlow{unicastV4in6, unicastV6, unicastV6Only, multicastV4in6, multicastV6, broadcastIn6} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + + const tClass = testTOS + v, err := c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) + } + // Test for expected default value. + if v != 0 { + c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, 0) + } + + if err := c.ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, tClass); err != nil { + c.t.Errorf("SetSockOptInt(IPv6TrafficClassOption, 0x%x) failed: %s", tClass, err) + } + + v, err = c.ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { + c.t.Errorf("GetSockOptInt(IPv6TrafficClassOption) failed: %s", err) + } + + if v != tClass { + c.t.Errorf("got GetSockOptInt(IPv6TrafficClassOption) = 0x%x, want = 0x%x", v, tClass) + } + + // The header getter for TClass is called TOS, so use that checker. + testWrite(c, flow, checker.TOS(tClass, 0)) + }) + } +} + +func TestReceiveTosTClass(t *testing.T) { + testCases := []struct { + name string + getReceiveOption tcpip.SockOptBool + tests []testFlow + }{ + {"ReceiveTosOption", tcpip.ReceiveTOSOption, []testFlow{unicastV4, broadcast}}, + {"ReceiveTClassOption", tcpip.ReceiveTClassOption, []testFlow{unicastV4in6, unicastV6, unicastV6Only, broadcastIn6}}, + } + for _, testCase := range testCases { + for _, flow := range testCase.tests { + t.Run(fmt.Sprintf("%s:flow:%s", testCase.name, flow), func(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpointForFlow(flow) + option := testCase.getReceiveOption + name := testCase.name + + // Verify that setting and reading the option works. + v, err := c.ep.GetSockOptBool(option) + if err != nil { + c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) + } + // Test for expected default value. + if v != false { + c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, v, false) + } + + want := true + if err := c.ep.SetSockOptBool(option, want); err != nil { + c.t.Fatalf("SetSockOptBool(%s, %t) failed: %s", name, want, err) + } + + got, err := c.ep.GetSockOptBool(option) + if err != nil { + c.t.Errorf("GetSockOptBool(%s) failed: %s", name, err) + } + + if got != want { + c.t.Errorf("got GetSockOptBool(%s) = %t, want = %t", name, got, want) + } + + // Verify that the correct received TOS or TClass is handed through as + // ancillary data to the ControlMessages struct. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + switch option { + case tcpip.ReceiveTClassOption: + testRead(c, flow, checker.ReceiveTClass(testTOS)) + case tcpip.ReceiveTOSOption: + testRead(c, flow, checker.ReceiveTOS(testTOS)) + default: + t.Fatalf("unknown test variant: %s", name) + } + }) + } + } +} + +func TestMulticastInterfaceOption(t *testing.T) { + for _, flow := range []testFlow{multicastV4, multicastV4in6, multicastV6, multicastV6Only} { + t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { + for _, bindTyp := range []string{"bound", "unbound"} { + t.Run(bindTyp, func(t *testing.T) { + for _, optTyp := range []string{"use local-addr", "use NICID", "use local-addr and NIC"} { + t.Run(optTyp, func(t *testing.T) { + h := flow.header4Tuple(outgoing) + mcastAddr := h.dstAddr.Addr + localIfAddr := h.srcAddr.Addr + + var ifoptSet tcpip.MulticastInterfaceOption + switch optTyp { + case "use local-addr": + ifoptSet.InterfaceAddr = localIfAddr + case "use NICID": + ifoptSet.NIC = 1 + case "use local-addr and NIC": + ifoptSet.InterfaceAddr = localIfAddr + ifoptSet.NIC = 1 + default: + t.Fatal("unknown test variant") + } + + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(flow.sockProto()) + + if bindTyp == "bound" { + // Bind the socket by connecting to the multicast address. + // This may have an influence on how the multicast interface + // is set. + addr := tcpip.FullAddress{ + Addr: flow.mapAddrIfApplicable(mcastAddr), + Port: stackPort, + } + if err := c.ep.Connect(addr); err != nil { + c.t.Fatalf("Connect failed: %s", err) + } + } + + if err := c.ep.SetSockOpt(ifoptSet); err != nil { + c.t.Fatalf("SetSockOpt failed: %s", err) + } + + // Verify multicast interface addr and NIC were set correctly. + // Note that NIC must be 1 since this is our outgoing interface. + ifoptWant := tcpip.MulticastInterfaceOption{NIC: 1, InterfaceAddr: ifoptSet.InterfaceAddr} + var ifoptGot tcpip.MulticastInterfaceOption + if err := c.ep.GetSockOpt(&ifoptGot); err != nil { + c.t.Fatalf("GetSockOpt failed: %s", err) + } + if ifoptGot != ifoptWant { + c.t.Errorf("got GetSockOpt() = %#v, want = %#v", ifoptGot, ifoptWant) + } + }) + } + }) + } + }) + } +} + +// TestV4UnknownDestination verifies that we generate an ICMPv4 Destination +// Unreachable message when a udp datagram is received on ports for which there +// is no bound udp socket. +func TestV4UnknownDestination(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + testCases := []struct { + flow testFlow + icmpRequired bool + // largePayload if true, will result in a payload large enough + // so that the final generated IPv4 packet is larger than + // header.IPv4MinimumProcessableDatagramSize. + largePayload bool + }{ + {unicastV4, true, false}, + {unicastV4, true, true}, + {multicastV4, false, false}, + {multicastV4, false, true}, + {broadcast, false, false}, + {broadcast, false, true}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + payload := newPayload() + if tc.largePayload { + payload = newMinPayload(576) + } + c.injectPacket(tc.flow, payload) + if !tc.icmpRequired { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if p, ok := c.linkEP.ReadContext(ctx); ok { + t.Fatalf("unexpected packet received: %+v", p) + } + return + } + + // ICMP required. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + t.Fatalf("packet wasn't written out") + return + } + + var pkt []byte + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) + if got, want := len(pkt), header.IPv4MinimumProcessableDatagramSize; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } + + hdr := header.IPv4(pkt) + checker.IPv4(t, hdr, checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4DstUnreachable), + checker.ICMPv4Code(header.ICMPv4PortUnreachable))) + + icmpPkt := header.ICMPv4(hdr.Payload()) + payloadIPHeader := header.IPv4(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MinimumSize*2 - header.ICMPv4MinimumSize - header.UDPMinimumSize + } + + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetTotalLength(uint16(wantLen + header.UDPMinimumSize + header.IPv4MinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %d, want: %d", got, want) + } + }) + } +} + +// TestV6UnknownDestination verifies that we generate an ICMPv6 Destination +// Unreachable message when a udp datagram is received on ports for which there +// is no bound udp socket. +func TestV6UnknownDestination(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + testCases := []struct { + flow testFlow + icmpRequired bool + // largePayload if true will result in a payload large enough to + // create an IPv6 packet > header.IPv6MinimumMTU bytes. + largePayload bool + }{ + {unicastV6, true, false}, + {unicastV6, true, true}, + {multicastV6, false, false}, + {multicastV6, false, true}, + } + for _, tc := range testCases { + t.Run(fmt.Sprintf("flow:%s icmpRequired:%t largePayload:%t", tc.flow, tc.icmpRequired, tc.largePayload), func(t *testing.T) { + payload := newPayload() + if tc.largePayload { + payload = newMinPayload(1280) + } + c.injectPacket(tc.flow, payload) + if !tc.icmpRequired { + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + if p, ok := c.linkEP.ReadContext(ctx); ok { + t.Fatalf("unexpected packet received: %+v", p) + } + return + } + + // ICMP required. + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + p, ok := c.linkEP.ReadContext(ctx) + if !ok { + t.Fatalf("packet wasn't written out") + return + } + + var pkt []byte + pkt = append(pkt, p.Pkt.Header.View()...) + pkt = append(pkt, p.Pkt.Data.ToView()...) + if got, want := len(pkt), header.IPv6MinimumMTU; got > want { + t.Fatalf("got an ICMP packet of size: %d, want: sz <= %d", got, want) + } + + hdr := header.IPv6(pkt) + checker.IPv6(t, hdr, checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6DstUnreachable), + checker.ICMPv6Code(header.ICMPv6PortUnreachable))) + + icmpPkt := header.ICMPv6(hdr.Payload()) + payloadIPHeader := header.IPv6(icmpPkt.Payload()) + wantLen := len(payload) + if tc.largePayload { + wantLen = header.IPv6MinimumMTU - header.IPv6MinimumSize*2 - header.ICMPv6MinimumSize - header.UDPMinimumSize + } + // In case of large payloads the IP packet may be truncated. Update + // the length field before retrieving the udp datagram payload. + payloadIPHeader.SetPayloadLength(uint16(wantLen + header.UDPMinimumSize)) + + origDgram := header.UDP(payloadIPHeader.Payload()) + if got, want := len(origDgram.Payload()), wantLen; got != want { + t.Fatalf("unexpected payload length got: %d, want: %d", got, want) + } + if got, want := origDgram.Payload(), payload[:wantLen]; !bytes.Equal(got, want) { + t.Fatalf("unexpected payload got: %v, want: %v", got, want) + } + }) + } +} + +// TestIncrementMalformedPacketsReceived verifies if the malformed received +// global and endpoint stats are incremented. +func TestIncrementMalformedPacketsReceived(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + + // Invalidate the UDP header length field. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetLength(u.Length() + 1) + + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got stats.UDP.MalformedPacketsReceived.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.MalformedPacketsReceived.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.MalformedPacketsReceived stats = %d, want = %d", got, want) + } +} + +// TestShortHeader verifies that when a packet with a too-short UDP header is +// received, the malformed received global stat gets incremented. +func TestShortHeader(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + h := unicastV6.header4Tuple(incoming) + + // Allocate a buffer for an IPv6 and too-short UDP header. + const udpSize = header.UDPMinimumSize - 1 + buf := buffer.NewView(header.IPv6MinimumSize + udpSize) + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + TrafficClass: testTOS, + PayloadLength: uint16(udpSize), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: h.srcAddr.Addr, + DstAddr: h.dstAddr.Addr, + }) + + // Initialize the UDP header. + udpHdr := header.UDP(buffer.NewView(header.UDPMinimumSize)) + udpHdr.Encode(&header.UDPFields{ + SrcPort: h.srcAddr.Port, + DstPort: h.dstAddr.Port, + Length: header.UDPMinimumSize, + }) + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, h.srcAddr.Addr, h.dstAddr.Addr, uint16(len(udpHdr))) + udpHdr.SetChecksum(^udpHdr.CalculateChecksum(xsum)) + // Copy all but the last byte of the UDP header into the packet. + copy(buf[header.IPv6MinimumSize:], udpHdr) + + // Inject packet. + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + if got, want := c.s.Stats().MalformedRcvdPackets.Value(), uint64(1); got != want { + t.Errorf("got c.s.Stats().MalformedRcvdPackets.Value() = %d, want = %d", got, want) + } +} + +// TestIncrementChecksumErrorsV4 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestIncrementChecksumErrorsV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + + // Invalidate the UDP header checksum field, taking care to avoid + // overflow to zero, which would disable checksum validation. + for u := header.UDP(buf[header.IPv4MinimumSize:]); ; { + u.SetChecksum(u.Checksum() + 1) + if u.Checksum() != 0 { + break + } + } + + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestIncrementChecksumErrorsV6 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestIncrementChecksumErrorsV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + + // Invalidate the UDP header checksum field. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(u.Checksum() + 1) + + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestPayloadModifiedV4 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestPayloadModifiedV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + // Modify the payload so that the checksum value in the UDP header will be incorrect. + buf[len(buf)-1]++ + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestPayloadModifiedV6 verifies if a checksum error is detected, +// global and endpoint stats are incremented. +func TestPayloadModifiedV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + // Modify the payload so that the checksum value in the UDP header will be incorrect. + buf[len(buf)-1]++ + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestChecksumZeroV4 verifies if the checksum value is zero, global and +// endpoint states are *not* incremented (UDP checksum is optional on IPv4). +func TestChecksumZeroV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv4.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV4.header4Tuple(incoming) + buf := c.buildV4Packet(payload, &h) + // Set the checksum field in the UDP header to zero. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.SetChecksum(0) + c.linkEP.InjectInbound(ipv4.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 0 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestChecksumZeroV6 verifies if the checksum value is zero, global and +// endpoint states are incremented (UDP checksum is *not* optional on IPv6). +func TestChecksumZeroV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + payload := newPayload() + h := unicastV6.header4Tuple(incoming) + buf := c.buildV6Packet(payload, &h) + // Set the checksum field in the UDP header to zero. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.SetChecksum(0) + c.linkEP.InjectInbound(ipv6.ProtocolNumber, &stack.PacketBuffer{ + Data: buf.ToVectorisedView(), + }) + + const want = 1 + if got := c.s.Stats().UDP.ChecksumErrors.Value(); got != want { + t.Errorf("got stats.UDP.ChecksumErrors.Value() = %d, want = %d", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ChecksumErrors.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ChecksumErrors stats = %d, want = %d", got, want) + } +} + +// TestShutdownRead verifies endpoint read shutdown and error +// stats increment on packet receive. +func TestShutdownRead(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}); err != nil { + c.t.Fatalf("Bind failed: %s", err) + } + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %s", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownRead); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + testFailingRead(c, unicastV6, true /* expectReadError */) + + var want uint64 = 1 + if got := c.s.Stats().UDP.ReceiveBufferErrors.Value(); got != want { + t.Errorf("got stats.UDP.ReceiveBufferErrors.Value() = %v, want = %v", got, want) + } + if got := c.ep.Stats().(*tcpip.TransportEndpointStats).ReceiveErrors.ClosedReceiver.Value(); got != want { + t.Errorf("got EP Stats.ReceiveErrors.ClosedReceiver stats = %v, want = %v", got, want) + } +} + +// TestShutdownWrite verifies endpoint write shutdown and error +// stats increment on packet write. +func TestShutdownWrite(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createEndpoint(ipv6.ProtocolNumber) + + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %s", err) + } + + if err := c.ep.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Shutdown failed: %s", err) + } + + testFailingWrite(c, unicastV6, tcpip.ErrClosedForSend) +} + +func (c *testContext) checkEndpointWriteStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil: + want.PacketsSent.IncrementBy(incr) + case tcpip.ErrMessageTooLong, tcpip.ErrInvalidOptionValue: + want.WriteErrors.InvalidArgs.IncrementBy(incr) + case tcpip.ErrClosedForSend: + want.WriteErrors.WriteClosed.IncrementBy(incr) + case tcpip.ErrInvalidEndpointState: + want.WriteErrors.InvalidEndpointState.IncrementBy(incr) + case tcpip.ErrNoLinkAddress: + want.SendErrors.NoLinkAddr.IncrementBy(incr) + case tcpip.ErrNoRoute, tcpip.ErrBroadcastDisabled, tcpip.ErrNetworkUnreachable: + want.SendErrors.NoRoute.IncrementBy(incr) + default: + want.SendErrors.SendToNetworkFailed.IncrementBy(incr) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} + +func (c *testContext) checkEndpointReadStats(incr uint64, want tcpip.TransportEndpointStats, err *tcpip.Error) { + got := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() + switch err { + case nil, tcpip.ErrWouldBlock: + case tcpip.ErrClosedForReceive: + want.ReadErrors.ReadClosed.IncrementBy(incr) + default: + c.t.Errorf("Endpoint error missing stats update err %v", err) + } + if got != want { + c.t.Errorf("Endpoint stats not matching for error %s got %+v want %+v", err, got, want) + } +} |