summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2019-06-02 06:44:55 +0000
committergVisor bot <gvisor-bot@google.com>2019-06-02 06:44:55 +0000
commitceb0d792f328d1fc0692197d8856a43c3936a571 (patch)
tree83155f302eff44a78bcc30a3a08f4efe59a79379 /pkg/tcpip/transport
parentdeb7ecf1e46862d54f4b102f2d163cfbcfc37f3b (diff)
parent216da0b733dbed9aad9b2ab92ac75bcb906fd7ee (diff)
Merge 216da0b7 (automated)
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go710
-rw-r--r--pkg/tcpip/transport/icmp/endpoint_state.go90
-rwxr-xr-xpkg/tcpip/transport/icmp/icmp_packet_list.go173
-rwxr-xr-xpkg/tcpip/transport/icmp/icmp_state_autogen.go98
-rw-r--r--pkg/tcpip/transport/icmp/protocol.go136
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go521
-rw-r--r--pkg/tcpip/transport/raw/endpoint_state.go88
-rwxr-xr-xpkg/tcpip/transport/raw/packet_list.go173
-rwxr-xr-xpkg/tcpip/transport/raw/raw_state_autogen.go96
-rw-r--r--pkg/tcpip/transport/tcp/accept.go499
-rw-r--r--pkg/tcpip/transport/tcp/connect.go1066
-rw-r--r--pkg/tcpip/transport/tcp/cubic.go233
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go1741
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go362
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go171
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go250
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go221
-rw-r--r--pkg/tcpip/transport/tcp/reno.go103
-rw-r--r--pkg/tcpip/transport/tcp/sack.go99
-rw-r--r--pkg/tcpip/transport/tcp/sack_scoreboard.go306
-rw-r--r--pkg/tcpip/transport/tcp/segment.go186
-rw-r--r--pkg/tcpip/transport/tcp/segment_heap.go46
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go79
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go82
-rw-r--r--pkg/tcpip/transport/tcp/snd.go1180
-rw-r--r--pkg/tcpip/transport/tcp/snd_state.go50
-rwxr-xr-xpkg/tcpip/transport/tcp/tcp_segment_list.go173
-rwxr-xr-xpkg/tcpip/transport/tcp/tcp_state_autogen.go400
-rw-r--r--pkg/tcpip/transport/tcp/timer.go141
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go1002
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go112
-rw-r--r--pkg/tcpip/transport/udp/forwarder.go96
-rw-r--r--pkg/tcpip/transport/udp/protocol.go90
-rwxr-xr-xpkg/tcpip/transport/udp/udp_packet_list.go173
-rwxr-xr-xpkg/tcpip/transport/udp/udp_state_autogen.go128
35 files changed, 11074 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
new file mode 100644
index 000000000..e2b90ef10
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -0,0 +1,710 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package icmp
+
+import (
+ "encoding/binary"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type icmpPacket struct {
+ icmpPacketEntry
+ senderAddress tcpip.FullAddress
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View `state:"nosave"`
+}
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateConnected
+ stateClosed
+)
+
+// endpoint represents an ICMP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+//
+// +stateify savable
+type endpoint struct {
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // The following fields are used to manage the receive queue, and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList icmpPacketList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+ id stack.TransportEndpointID
+ state endpointState
+ // bindNICID and bindAddr are set via calls to Bind(). They are used to
+ // reject attempts to send data or connect via a different NIC or
+ // address
+ bindNICID tcpip.NICID
+ bindAddr tcpip.Address
+ // regNICID is the default NIC to be used when callers don't specify a
+ // NIC.
+ regNICID tcpip.NICID
+ route stack.Route `state:"manual"`
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return &endpoint{
+ stack: stack,
+ netProto: netProto,
+ transProto: transProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }, nil
+}
+
+// Close puts the endpoint in a closed state and frees all resources
+// associated with it.
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
+ switch e.state {
+ case stateBound, stateConnected:
+ e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e)
+ }
+
+ // Close the receive list and drain it.
+ e.rcvMu.Lock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
+ e.rcvMu.Unlock()
+
+ e.route.Release()
+
+ // Update the state.
+ e.state = stateClosed
+
+ e.mu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// Read reads data from the endpoint. This method does not block if
+// there is no data pending.
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ e.rcvMu.Lock()
+
+ if e.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if e.rcvClosed {
+ err = tcpip.ErrClosedForReceive
+ }
+ e.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+
+ e.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = p.senderAddress
+ }
+
+ return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+}
+
+// prepareForWrite prepares the endpoint for sending data. In particular, it
+// binds it if it's still in the initial state. To do so, it must first
+// reacquire the mutex in exclusive mode.
+//
+// Returns true for retry if preparation should be retried.
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+ switch e.state {
+ case stateInitial:
+ case stateConnected:
+ return false, nil
+
+ case stateBound:
+ if to == nil {
+ return false, tcpip.ErrDestinationRequired
+ }
+ return false, nil
+ default:
+ return false, tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // The state changed when we released the shared locked and re-acquired
+ // it in exclusive mode. Try again.
+ if e.state != stateInitial {
+ return true, nil
+ }
+
+ // The state is still 'initial', so try to bind the endpoint.
+ if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
+ return false, err
+ }
+
+ return true, nil
+}
+
+// Write writes data to the endpoint's peer. This method does not block
+// if the data cannot be written.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ to := opts.To
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // If we've shutdown with SHUT_WR we are in an invalid state for sending.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+
+ // Prepare for write.
+ for {
+ retry, err := e.prepareForWrite(to)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if !retry {
+ break
+ }
+ }
+
+ var route *stack.Route
+ if to == nil {
+ route = &e.route
+
+ if route.IsResolutionRequired() {
+ // Promote lock to exclusive if using a shared route,
+ // given that it may need to change in Route.Resolve()
+ // call below.
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != stateConnected {
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+ }
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicid := to.NIC
+ if e.bindNICID != 0 {
+ if nicid != 0 && nicid != e.bindNICID {
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ nicid = e.bindNICID
+ }
+
+ toCopy := *to
+ to = &toCopy
+ netProto, err := e.checkV4Mapped(to, true)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ // Find the enpoint.
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto, false /* multicastLoop */)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer r.Release()
+
+ route = &r
+ }
+
+ if route.IsResolutionRequired() {
+ if ch, err := route.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ return 0, ch, tcpip.ErrNoLinkAddress
+ }
+ return 0, nil, err
+ }
+ }
+
+ v, err := p.Get(p.Size())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ switch e.netProto {
+ case header.IPv4ProtocolNumber:
+ err = e.send4(route, v)
+
+ case header.IPv6ProtocolNumber:
+ err = send6(route, e.id.LocalPort, v)
+ }
+
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return uintptr(len(v)), nil, nil
+}
+
+// Peek only returns data from a single datagram, so do nothing here.
+func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ *o = tcpip.SendBufferSizeOption(e.sndBufSize)
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ e.rcvMu.Lock()
+ if e.rcvList.Empty() {
+ *o = 0
+ } else {
+ p := e.rcvList.Front()
+ *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
+ }
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error {
+ if len(data) < header.ICMPv4EchoMinimumSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Set the ident to the user-specified port. Sequence number should
+ // already be set by the user.
+ binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort)
+
+ hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength()))
+
+ icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize))
+ copy(icmpv4, data)
+ data = data[header.ICMPv4EchoMinimumSize:]
+
+ // Linux performs these basic checks.
+ if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ icmpv4.SetChecksum(0)
+ icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
+
+ return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL())
+}
+
+func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error {
+ if len(data) < header.ICMPv6EchoMinimumSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Set the ident. Sequence number is provided by the user.
+ binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident)
+
+ hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength()))
+
+ icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize))
+ copy(icmpv6, data)
+ data = data[header.ICMPv6EchoMinimumSize:]
+
+ if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ icmpv6.SetChecksum(0)
+ icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0)))
+
+ return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL())
+}
+
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.netProto
+ if header.IsV4MappedAddress(addr.Addr) {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ // Fail if we're bound to an address length different from the one we're
+ // checking.
+ if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return netProto, nil
+}
+
+// Connect connects the endpoint to its peer. Specifying a NIC is optional.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ nicid := addr.NIC
+ localPort := uint16(0)
+ switch e.state {
+ case stateBound, stateConnected:
+ localPort = e.id.LocalPort
+ if e.bindNICID == 0 {
+ break
+ }
+
+ if nicid != 0 && nicid != e.bindNICID {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nicid = e.bindNICID
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ id := stack.TransportEndpointID{
+ LocalAddress: r.LocalAddress,
+ LocalPort: localPort,
+ RemoteAddress: r.RemoteAddress,
+ }
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+
+ id, err = e.registerWithStack(nicid, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ e.id = id
+ e.route = r.Clone()
+ e.regNICID = nicid
+
+ e.state = stateConnected
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection
+// to its peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.shutdownFlags |= flags
+
+ if e.state != stateConnected {
+ return tcpip.ErrNotConnected
+ }
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.rcvMu.Lock()
+ wasClosed := e.rcvClosed
+ e.rcvClosed = true
+ e.rcvMu.Unlock()
+
+ if !wasClosed {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ }
+
+ return nil
+}
+
+// Listen is not supported by UDP, it just fails.
+func (*endpoint) Listen(int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept is not supported by UDP, it just fails.
+func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+ if id.LocalPort != 0 {
+ // The endpoint already has a local port, just attempt to
+ // register it.
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ return id, err
+ }
+
+ // We need to find a port for the endpoint.
+ _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ id.LocalPort = p
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false)
+ switch err {
+ case nil:
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ })
+
+ return id, err
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+
+ if len(addr.Addr) != 0 {
+ // A local address was specified, verify that it's valid.
+ if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err = e.registerWithStack(addr.NIC, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ e.id = id
+ e.regNICID = addr.NIC
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// Bind binds the endpoint to a specific local address and port.
+// Specifying a NIC is optional.
+func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ err := e.bindLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ e.bindNICID = addr.NIC
+ e.bindAddr = addr.Addr
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.LocalAddress,
+ Port: e.id.LocalPort,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ }, nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+ // Only accept echo replies.
+ switch e.netProto {
+ case header.IPv4ProtocolNumber:
+ h := header.ICMPv4(vv.First())
+ if h.Type() != header.ICMPv4EchoReply {
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+ case header.IPv6ProtocolNumber:
+ h := header.ICMPv6(vv.First())
+ if h.Type() != header.ICMPv6EchoReply {
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+ }
+
+ e.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := e.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ pkt := &icmpPacket{
+ senderAddress: tcpip.FullAddress{
+ NIC: r.NICID(),
+ Addr: id.RemoteAddress,
+ },
+ }
+
+ pkt.data = vv.Clone(pkt.views[:])
+
+ e.rcvList.PushBack(pkt)
+ e.rcvBufSize += pkt.data.Size()
+
+ pkt.timestamp = e.stack.NowNanoseconds()
+
+ e.rcvMu.Unlock()
+
+ // Notify any waiters that there's data to be read now.
+ if wasEmpty {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+}
diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go
new file mode 100644
index 000000000..332b3cd33
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/endpoint_state.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package icmp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves icmpPacket.data field.
+func (p *icmpPacket) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads icmpPacket.data field.
+func (p *icmpPacket) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after savercvBufSizeMax(), which would have
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ e.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ e.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ e.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ e.stack = stack.StackFromEnv
+
+ if e.state != stateBound && e.state != stateConnected {
+ return
+ }
+
+ var err *tcpip.Error
+ if e.state == stateConnected {
+ e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */)
+ if err != nil {
+ panic(*err)
+ }
+
+ e.id.LocalAddress = e.route.LocalAddress
+ } else if len(e.id.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id)
+ if err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go
new file mode 100755
index 000000000..1b35e5b4a
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go
@@ -0,0 +1,173 @@
+package icmp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type icmpPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (icmpPacketElementMapper) linkerFor(elem *icmpPacket) *icmpPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type icmpPacketList struct {
+ head *icmpPacket
+ tail *icmpPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *icmpPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *icmpPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *icmpPacketList) Front() *icmpPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *icmpPacketList) Back() *icmpPacket {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *icmpPacketList) PushFront(e *icmpPacket) {
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(l.head)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *icmpPacketList) PushBack(e *icmpPacket) {
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(nil)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *icmpPacketList) PushBackList(m *icmpPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ icmpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *icmpPacketList) InsertAfter(b, e *icmpPacket) {
+ a := icmpPacketElementMapper{}.linkerFor(b).Next()
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ icmpPacketElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ icmpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *icmpPacketList) InsertBefore(a, e *icmpPacket) {
+ b := icmpPacketElementMapper{}.linkerFor(a).Prev()
+ icmpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ icmpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ icmpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ icmpPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *icmpPacketList) Remove(e *icmpPacket) {
+ prev := icmpPacketElementMapper{}.linkerFor(e).Prev()
+ next := icmpPacketElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ icmpPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ icmpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type icmpPacketEntry struct {
+ next *icmpPacket
+ prev *icmpPacket
+}
+
+// Next returns the entry that follows e in the list.
+func (e *icmpPacketEntry) Next() *icmpPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *icmpPacketEntry) Prev() *icmpPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *icmpPacketEntry) SetNext(elem *icmpPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *icmpPacketEntry) SetPrev(elem *icmpPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
new file mode 100755
index 000000000..b66857348
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go
@@ -0,0 +1,98 @@
+// automatically generated by stateify.
+
+package icmp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *icmpPacket) beforeSave() {}
+func (x *icmpPacket) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ m.Save("icmpPacketEntry", &x.icmpPacketEntry)
+ m.Save("senderAddress", &x.senderAddress)
+ m.Save("timestamp", &x.timestamp)
+}
+
+func (x *icmpPacket) afterLoad() {}
+func (x *icmpPacket) load(m state.Map) {
+ m.Load("icmpPacketEntry", &x.icmpPacketEntry)
+ m.Load("senderAddress", &x.senderAddress)
+ m.Load("timestamp", &x.timestamp)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var rcvBufSizeMax int = x.saveRcvBufSizeMax()
+ m.SaveValue("rcvBufSizeMax", rcvBufSizeMax)
+ m.Save("netProto", &x.netProto)
+ m.Save("transProto", &x.transProto)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("rcvReady", &x.rcvReady)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("shutdownFlags", &x.shutdownFlags)
+ m.Save("id", &x.id)
+ m.Save("state", &x.state)
+ m.Save("bindNICID", &x.bindNICID)
+ m.Save("bindAddr", &x.bindAddr)
+ m.Save("regNICID", &x.regNICID)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("netProto", &x.netProto)
+ m.Load("transProto", &x.transProto)
+ m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("rcvReady", &x.rcvReady)
+ m.Load("rcvList", &x.rcvList)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("shutdownFlags", &x.shutdownFlags)
+ m.Load("id", &x.id)
+ m.Load("state", &x.state)
+ m.Load("bindNICID", &x.bindNICID)
+ m.Load("bindAddr", &x.bindAddr)
+ m.Load("regNICID", &x.regNICID)
+ m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *icmpPacketList) beforeSave() {}
+func (x *icmpPacketList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *icmpPacketList) afterLoad() {}
+func (x *icmpPacketList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *icmpPacketEntry) beforeSave() {}
+func (x *icmpPacketEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *icmpPacketEntry) afterLoad() {}
+func (x *icmpPacketEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("icmp.icmpPacket", (*icmpPacket)(nil), state.Fns{Save: (*icmpPacket).save, Load: (*icmpPacket).load})
+ state.Register("icmp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("icmp.icmpPacketList", (*icmpPacketList)(nil), state.Fns{Save: (*icmpPacketList).save, Load: (*icmpPacketList).load})
+ state.Register("icmp.icmpPacketEntry", (*icmpPacketEntry)(nil), state.Fns{Save: (*icmpPacketEntry).save, Load: (*icmpPacketEntry).load})
+}
diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go
new file mode 100644
index 000000000..954fde9d8
--- /dev/null
+++ b/pkg/tcpip/transport/icmp/protocol.go
@@ -0,0 +1,136 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport
+// protocols for use in ping. To use it in the networking stack, this package
+// must be added to the project, and
+// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or
+// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when
+// calling stack.New(). Then endpoints can be created by passing
+// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number
+// when calling Stack.NewEndpoint().
+package icmp
+
+import (
+ "encoding/binary"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolName4 is the string representation of the icmp protocol name.
+ ProtocolName4 = "icmp4"
+
+ // ProtocolNumber4 is the ICMP protocol number.
+ ProtocolNumber4 = header.ICMPv4ProtocolNumber
+
+ // ProtocolName6 is the string representation of the icmp protocol name.
+ ProtocolName6 = "icmp6"
+
+ // ProtocolNumber6 is the IPv6-ICMP protocol number.
+ ProtocolNumber6 = header.ICMPv6ProtocolNumber
+)
+
+// protocol implements stack.TransportProtocol.
+type protocol struct {
+ number tcpip.TransportProtocolNumber
+}
+
+// Number returns the ICMP protocol number.
+func (p *protocol) Number() tcpip.TransportProtocolNumber {
+ return p.number
+}
+
+func (p *protocol) netProto() tcpip.NetworkProtocolNumber {
+ switch p.number {
+ case ProtocolNumber4:
+ return header.IPv4ProtocolNumber
+ case ProtocolNumber6:
+ return header.IPv6ProtocolNumber
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// NewEndpoint creates a new icmp endpoint. It implements
+// stack.TransportProtocol.NewEndpoint.
+func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != p.netProto() {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+ return newEndpoint(stack, netProto, p.number, waiterQueue)
+}
+
+// NewRawEndpoint creates a new raw icmp endpoint. It implements
+// stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != p.netProto() {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+ return raw.NewEndpoint(stack, netProto, p.number, waiterQueue)
+}
+
+// MinimumPacketSize returns the minimum valid icmp packet size.
+func (p *protocol) MinimumPacketSize() int {
+ switch p.number {
+ case ProtocolNumber4:
+ return header.ICMPv4EchoMinimumSize
+ case ProtocolNumber6:
+ return header.ICMPv6EchoMinimumSize
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// ParsePorts returns the source and destination ports stored in the given icmp
+// packet.
+func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ switch p.number {
+ case ProtocolNumber4:
+ return 0, binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize:]), nil
+ case ProtocolNumber6:
+ return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil
+ }
+ panic(fmt.Sprint("unknown protocol number: ", p.number))
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+ return true
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol {
+ return &protocol{ProtocolNumber4}
+ })
+
+ stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol {
+ return &protocol{ProtocolNumber6}
+ })
+}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
new file mode 100644
index 000000000..1daf5823f
--- /dev/null
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -0,0 +1,521 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package raw provides the implementation of raw sockets (see raw(7)). Raw
+// sockets allow applications to:
+//
+// * manually write and inspect transport layer headers and payloads
+// * receive all traffic of a given transport protcol (e.g. ICMP or UDP)
+// * optionally write and inspect network layer and link layer headers for
+// packets
+//
+// Raw sockets don't have any notion of ports, and incoming packets are
+// demultiplexed solely by protocol number. Thus, a raw UDP endpoint will
+// receive every UDP packet received by netstack. bind(2) and connect(2) can be
+// used to filter incoming packets by source and destination.
+package raw
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type packet struct {
+ packetEntry
+ // data holds the actual packet data, including any headers and
+ // payload.
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // views is pre-allocated space to back data. As long as the packet is
+ // made up of fewer than 8 buffer.Views, no extra allocation is
+ // necessary to store packet data.
+ views [8]buffer.View `state:"nosave"`
+ // timestampNS is the unix time at which the packet was received.
+ timestampNS int64
+ // senderAddr is the network address of the sender.
+ senderAddr tcpip.FullAddress
+}
+
+// endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to
+// have goroutines make concurrent calls into the endpoint.
+//
+// Lock order:
+// endpoint.mu
+// endpoint.rcvMu
+//
+// +stateify savable
+type endpoint struct {
+ // The following fields are initialized at creation time and are
+ // immutable.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ transProto tcpip.TransportProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // The following fields are used to manage the receive queue and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvList packetList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by mu.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ closed bool
+ connected bool
+ bound bool
+ // registeredNIC is the NIC to which th endpoint is explicitly
+ // registered. Is set when Connect or Bind are used to specify a NIC.
+ registeredNIC tcpip.NICID
+ // boundNIC and boundAddr are set on calls to Bind(). When callers
+ // attempt actions that would invalidate the binding data (e.g. sending
+ // data via a NIC other than boundNIC), the endpoint will return an
+ // error.
+ boundNIC tcpip.NICID
+ boundAddr tcpip.Address
+ // route is the route to a remote network endpoint. It is set via
+ // Connect(), and is valid only when conneted is true.
+ route stack.Route `state:"manual"`
+}
+
+// NewEndpoint returns a raw endpoint for the given protocols.
+// TODO(b/129292371): IP_HDRINCL, IPPROTO_RAW, and AF_PACKET.
+func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ if netProto != header.IPv4ProtocolNumber {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ ep := &endpoint{
+ stack: stack,
+ netProto: netProto,
+ transProto: transProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ return nil, err
+ }
+
+ return ep, nil
+}
+
+// Close implements tcpip.Endpoint.Close.
+func (ep *endpoint) Close() {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.closed {
+ return
+ }
+
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+
+ ep.rcvMu.Lock()
+ defer ep.rcvMu.Unlock()
+
+ // Clear the receive list.
+ ep.rcvClosed = true
+ ep.rcvBufSize = 0
+ for !ep.rcvList.Empty() {
+ ep.rcvList.Remove(ep.rcvList.Front())
+ }
+
+ if ep.connected {
+ ep.route.Release()
+ }
+
+ ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ ep.rcvMu.Lock()
+
+ // If there's no data to read, return that read would block or that the
+ // endpoint is closed.
+ if ep.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if ep.rcvClosed {
+ err = tcpip.ErrClosedForReceive
+ }
+ ep.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ packet := ep.rcvList.Front()
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+
+ ep.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = packet.senderAddr
+ }
+
+ return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
+}
+
+// Write implements tcpip.Endpoint.Write.
+func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op.
+ if opts.More {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ ep.mu.RLock()
+
+ if ep.closed {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ // Did the user caller provide a destination? If not, use the connected
+ // destination.
+ if opts.To == nil {
+ // If the user doesn't specify a destination, they should have
+ // connected to another address.
+ if !ep.connected {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrDestinationRequired
+ }
+
+ if ep.route.IsResolutionRequired() {
+ savedRoute := &ep.route
+ // Promote lock to exclusive if using a shared route,
+ // given that it may need to change in finishWrite.
+ ep.mu.RUnlock()
+ ep.mu.Lock()
+
+ // Make sure that the route didn't change during the
+ // time we didn't hold the lock.
+ if !ep.connected || savedRoute != &ep.route {
+ ep.mu.Unlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ n, ch, err := ep.finishWrite(payload, savedRoute)
+ ep.mu.Unlock()
+ return n, ch, err
+ }
+
+ n, ch, err := ep.finishWrite(payload, &ep.route)
+ ep.mu.RUnlock()
+ return n, ch, err
+ }
+
+ // The caller provided a destination. Reject destination address if it
+ // goes through a different NIC than the endpoint was bound to.
+ nic := opts.To.NIC
+ if ep.bound && nic != 0 && nic != ep.boundNIC {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ // We don't support IPv6 yet, so this has to be an IPv4 address.
+ if len(opts.To.Addr) != header.IPv4AddressSize {
+ ep.mu.RUnlock()
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ // Find the route to the destination. If boundAddress is 0,
+ // FindRoute will choose an appropriate source address.
+ route, err := ep.stack.FindRoute(nic, ep.boundAddr, opts.To.Addr, ep.netProto, false)
+ if err != nil {
+ ep.mu.RUnlock()
+ return 0, nil, err
+ }
+
+ n, ch, err := ep.finishWrite(payload, &route)
+ route.Release()
+ ep.mu.RUnlock()
+ return n, ch, err
+}
+
+// finishWrite writes the payload to a route. It resolves the route if
+// necessary. It's really just a helper to make defer unnecessary in Write.
+func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // We may need to resolve the route (match a link layer address to the
+ // network address). If that requires blocking (e.g. to use ARP),
+ // return a channel on which the caller can wait.
+ if route.IsResolutionRequired() {
+ if ch, err := route.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ return 0, ch, tcpip.ErrNoLinkAddress
+ }
+ return 0, nil, err
+ }
+ }
+
+ payloadBytes, err := payload.Get(payload.Size())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ switch ep.netProto {
+ case header.IPv4ProtocolNumber:
+ hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength()))
+ if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil {
+ return 0, nil, err
+ }
+
+ default:
+ return 0, nil, tcpip.ErrUnknownProtocol
+ }
+
+ return uintptr(len(payloadBytes)), nil, nil
+}
+
+// Peek implements tcpip.Endpoint.Peek.
+func (ep *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// Connect implements tcpip.Endpoint.Connect.
+func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if ep.closed {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // We don't support IPv6 yet.
+ if len(addr.Addr) != header.IPv4AddressSize {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nic := addr.NIC
+ if ep.bound {
+ if ep.boundNIC == 0 {
+ // If we're bound, but not to a specific NIC, the NIC
+ // in addr will be used. Nothing to do here.
+ } else if addr.NIC == 0 {
+ // If we're bound to a specific NIC, but addr doesn't
+ // specify a NIC, use the bound NIC.
+ nic = ep.boundNIC
+ } else if addr.NIC != ep.boundNIC {
+ // We're bound and addr specifies a NIC. They must be
+ // the same.
+ return tcpip.ErrInvalidEndpointState
+ }
+ }
+
+ // Find a route to the destination.
+ route, err := ep.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, ep.netProto, false)
+ if err != nil {
+ return err
+ }
+ defer route.Release()
+
+ // Re-register the endpoint with the appropriate NIC.
+ if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ return err
+ }
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+
+ // Save the route and NIC we've connected via.
+ ep.route = route.Clone()
+ ep.registeredNIC = nic
+ ep.connected = true
+
+ return nil
+}
+
+// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets.
+func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ if !ep.connected {
+ return tcpip.ErrNotConnected
+ }
+ return nil
+}
+
+// Listen implements tcpip.Endpoint.Listen.
+func (ep *endpoint) Listen(backlog int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept implements tcpip.Endpoint.Accept.
+func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+// Bind implements tcpip.Endpoint.Bind.
+func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ ep.mu.Lock()
+ defer ep.mu.Unlock()
+
+ // Callers must provide an IPv4 address or no network address (for
+ // binding to a NIC, but not an address).
+ if len(addr.Addr) != 0 && len(addr.Addr) != 4 {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // If a local address was specified, verify that it's valid.
+ if len(addr.Addr) == header.IPv4AddressSize && ep.stack.CheckLocalAddress(addr.NIC, ep.netProto, addr.Addr) == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ // Re-register the endpoint with the appropriate NIC.
+ if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil {
+ return err
+ }
+ ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep)
+
+ ep.registeredNIC = addr.NIC
+ ep.boundNIC = addr.NIC
+ ep.boundAddr = addr.Addr
+ ep.bound = true
+
+ return nil
+}
+
+// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress.
+func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, tcpip.ErrNotSupported
+}
+
+// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress.
+func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ // Even a connected socket doesn't return a remote address.
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Readiness implements tcpip.Endpoint.Readiness.
+func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine whether the endpoint is readable.
+ if (mask & waiter.EventIn) != 0 {
+ ep.rcvMu.Lock()
+ if !ep.rcvList.Empty() || ep.rcvClosed {
+ result |= waiter.EventIn
+ }
+ ep.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// SetSockOpt implements tcpip.Endpoint.SetSockOpt.
+func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ ep.mu.Lock()
+ *o = tcpip.SendBufferSizeOption(ep.sndBufSize)
+ ep.mu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ ep.rcvMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax)
+ ep.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ ep.rcvMu.Lock()
+ if ep.rcvList.Empty() {
+ *o = 0
+ } else {
+ p := ep.rcvList.Front()
+ *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
+ }
+ ep.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
+func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) {
+ ep.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if ep.rcvClosed || ep.rcvBufSize >= ep.rcvBufSizeMax {
+ ep.stack.Stats().DroppedPackets.Increment()
+ ep.rcvMu.Unlock()
+ return
+ }
+
+ if ep.bound {
+ // If bound to a NIC, only accept data for that NIC.
+ if ep.boundNIC != 0 && ep.boundNIC != route.NICID() {
+ ep.rcvMu.Unlock()
+ return
+ }
+ // If bound to an address, only accept data for that address.
+ if ep.boundAddr != "" && ep.boundAddr != route.RemoteAddress {
+ ep.rcvMu.Unlock()
+ return
+ }
+ }
+
+ // If connected, only accept packets from the remote address we
+ // connected to.
+ if ep.connected && ep.route.RemoteAddress != route.RemoteAddress {
+ ep.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := ep.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ packet := &packet{
+ senderAddr: tcpip.FullAddress{
+ NIC: route.NICID(),
+ Addr: route.RemoteAddress,
+ },
+ }
+
+ combinedVV := netHeader.ToVectorisedView()
+ combinedVV.Append(vv)
+ packet.data = combinedVV.Clone(packet.views[:])
+ packet.timestampNS = ep.stack.NowNanoseconds()
+
+ ep.rcvList.PushBack(packet)
+ ep.rcvBufSize += packet.data.Size()
+
+ ep.rcvMu.Unlock()
+
+ // Notify waiters that there's data to be read.
+ if wasEmpty {
+ ep.waiterQueue.Notify(waiter.EventIn)
+ }
+}
diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go
new file mode 100644
index 000000000..e8907ebb1
--- /dev/null
+++ b/pkg/tcpip/transport/raw/endpoint_state.go
@@ -0,0 +1,88 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package raw
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves packet.data field.
+func (p *packet) saveData() buffer.VectorisedView {
+ // We cannot save p.data directly as p.data.views may alias to p.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return p.data.Clone(nil)
+}
+
+// loadData loads packet.data field.
+func (p *packet) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing p.views for data.views.
+ p.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (ep *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after saveRcvBufSizeMax(), which would have
+ // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ ep.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) saveRcvBufSizeMax() int {
+ max := ep.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ ep.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ ep.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (ep *endpoint) loadRcvBufSizeMax(max int) {
+ ep.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (ep *endpoint) afterLoad() {
+ // StackFromEnv is a stack used specifically for save/restore.
+ ep.stack = stack.StackFromEnv
+
+ // If the endpoint is connected, re-connect via the save/restore stack.
+ if ep.connected {
+ var err *tcpip.Error
+ ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false)
+ if err != nil {
+ panic(*err)
+ }
+ }
+
+ // If the endpoint is bound, re-bind via the save/restore stack.
+ if ep.bound {
+ if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/raw/packet_list.go b/pkg/tcpip/transport/raw/packet_list.go
new file mode 100755
index 000000000..2e9074934
--- /dev/null
+++ b/pkg/tcpip/transport/raw/packet_list.go
@@ -0,0 +1,173 @@
+package raw
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type packetElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (packetElementMapper) linkerFor(elem *packet) *packet { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type packetList struct {
+ head *packet
+ tail *packet
+}
+
+// Reset resets list l to the empty state.
+func (l *packetList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *packetList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *packetList) Front() *packet {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *packetList) Back() *packet {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *packetList) PushFront(e *packet) {
+ packetElementMapper{}.linkerFor(e).SetNext(l.head)
+ packetElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ packetElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *packetList) PushBack(e *packet) {
+ packetElementMapper{}.linkerFor(e).SetNext(nil)
+ packetElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ packetElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *packetList) PushBackList(m *packetList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ packetElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ packetElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *packetList) InsertAfter(b, e *packet) {
+ a := packetElementMapper{}.linkerFor(b).Next()
+ packetElementMapper{}.linkerFor(e).SetNext(a)
+ packetElementMapper{}.linkerFor(e).SetPrev(b)
+ packetElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ packetElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *packetList) InsertBefore(a, e *packet) {
+ b := packetElementMapper{}.linkerFor(a).Prev()
+ packetElementMapper{}.linkerFor(e).SetNext(a)
+ packetElementMapper{}.linkerFor(e).SetPrev(b)
+ packetElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ packetElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *packetList) Remove(e *packet) {
+ prev := packetElementMapper{}.linkerFor(e).Prev()
+ next := packetElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ packetElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ packetElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type packetEntry struct {
+ next *packet
+ prev *packet
+}
+
+// Next returns the entry that follows e in the list.
+func (e *packetEntry) Next() *packet {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *packetEntry) Prev() *packet {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *packetEntry) SetNext(elem *packet) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *packetEntry) SetPrev(elem *packet) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go
new file mode 100755
index 000000000..3327811b4
--- /dev/null
+++ b/pkg/tcpip/transport/raw/raw_state_autogen.go
@@ -0,0 +1,96 @@
+// automatically generated by stateify.
+
+package raw
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *packet) beforeSave() {}
+func (x *packet) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ m.Save("packetEntry", &x.packetEntry)
+ m.Save("timestampNS", &x.timestampNS)
+ m.Save("senderAddr", &x.senderAddr)
+}
+
+func (x *packet) afterLoad() {}
+func (x *packet) load(m state.Map) {
+ m.Load("packetEntry", &x.packetEntry)
+ m.Load("timestampNS", &x.timestampNS)
+ m.Load("senderAddr", &x.senderAddr)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var rcvBufSizeMax int = x.saveRcvBufSizeMax()
+ m.SaveValue("rcvBufSizeMax", rcvBufSizeMax)
+ m.Save("netProto", &x.netProto)
+ m.Save("transProto", &x.transProto)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("closed", &x.closed)
+ m.Save("connected", &x.connected)
+ m.Save("bound", &x.bound)
+ m.Save("registeredNIC", &x.registeredNIC)
+ m.Save("boundNIC", &x.boundNIC)
+ m.Save("boundAddr", &x.boundAddr)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("netProto", &x.netProto)
+ m.Load("transProto", &x.transProto)
+ m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("rcvList", &x.rcvList)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("closed", &x.closed)
+ m.Load("connected", &x.connected)
+ m.Load("bound", &x.bound)
+ m.Load("registeredNIC", &x.registeredNIC)
+ m.Load("boundNIC", &x.boundNIC)
+ m.Load("boundAddr", &x.boundAddr)
+ m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *packetList) beforeSave() {}
+func (x *packetList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *packetList) afterLoad() {}
+func (x *packetList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *packetEntry) beforeSave() {}
+func (x *packetEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *packetEntry) afterLoad() {}
+func (x *packetEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("raw.packet", (*packet)(nil), state.Fns{Save: (*packet).save, Load: (*packet).load})
+ state.Register("raw.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("raw.packetList", (*packetList)(nil), state.Fns{Save: (*packetList).save, Load: (*packetList).load})
+ state.Register("raw.packetEntry", (*packetEntry)(nil), state.Fns{Save: (*packetEntry).save, Load: (*packetEntry).load})
+}
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
new file mode 100644
index 000000000..d4b860975
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -0,0 +1,499 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "crypto/sha1"
+ "encoding/binary"
+ "hash"
+ "io"
+ "log"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // tsLen is the length, in bits, of the timestamp in the SYN cookie.
+ tsLen = 8
+
+ // tsMask is a mask for timestamp values (i.e., tsLen bits).
+ tsMask = (1 << tsLen) - 1
+
+ // tsOffset is the offset, in bits, of the timestamp in the SYN cookie.
+ tsOffset = 24
+
+ // hashMask is the mask for hash values (i.e., tsOffset bits).
+ hashMask = (1 << tsOffset) - 1
+
+ // maxTSDiff is the maximum allowed difference between a received cookie
+ // timestamp and the current timestamp. If the difference is greater
+ // than maxTSDiff, the cookie is expired.
+ maxTSDiff = 2
+)
+
+var (
+ // SynRcvdCountThreshold is the global maximum number of connections
+ // that are allowed to be in SYN-RCVD state before TCP starts using SYN
+ // cookies to accept connections.
+ //
+ // It is an exported variable only for testing, and should not otherwise
+ // be used by importers of this package.
+ SynRcvdCountThreshold uint64 = 1000
+
+ // mssTable is a slice containing the possible MSS values that we
+ // encode in the SYN cookie with two bits.
+ mssTable = []uint16{536, 1300, 1440, 1460}
+)
+
+func encodeMSS(mss uint16) uint32 {
+ for i := len(mssTable) - 1; i > 0; i-- {
+ if mss >= mssTable[i] {
+ return uint32(i)
+ }
+ }
+ return 0
+}
+
+// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is
+// protected by a mutex so that we can increment only when it's guaranteed not
+// to go above a threshold.
+var synRcvdCount struct {
+ sync.Mutex
+ value uint64
+ pending sync.WaitGroup
+}
+
+// listenContext is used by a listening endpoint to store state used while
+// listening for connections. This struct is allocated by the listen goroutine
+// and must not be accessed or have its methods called concurrently as they
+// may mutate the stored objects.
+type listenContext struct {
+ stack *stack.Stack
+ rcvWnd seqnum.Size
+ nonce [2][sha1.BlockSize]byte
+ listenEP *endpoint
+
+ hasherMu sync.Mutex
+ hasher hash.Hash
+ v6only bool
+ netProto tcpip.NetworkProtocolNumber
+}
+
+// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
+func timeStamp() uint32 {
+ return uint32(time.Now().Unix()>>6) & tsMask
+}
+
+// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD
+// state. It succeeds if the increment doesn't make the count go beyond the
+// threshold, and fails otherwise.
+func incSynRcvdCount() bool {
+ synRcvdCount.Lock()
+
+ if synRcvdCount.value >= SynRcvdCountThreshold {
+ synRcvdCount.Unlock()
+ return false
+ }
+
+ synRcvdCount.pending.Add(1)
+ synRcvdCount.value++
+
+ synRcvdCount.Unlock()
+ return true
+}
+
+// decSynRcvdCount atomically decrements the global number of endpoints in
+// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount
+// succeeded.
+func decSynRcvdCount() {
+ synRcvdCount.Lock()
+
+ synRcvdCount.value--
+ synRcvdCount.pending.Done()
+ synRcvdCount.Unlock()
+}
+
+// newListenContext creates a new listen context.
+func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+ l := &listenContext{
+ stack: stack,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6only: v6only,
+ netProto: netProto,
+ listenEP: listenEP,
+ }
+
+ rand.Read(l.nonce[0][:])
+ rand.Read(l.nonce[1][:])
+
+ return l
+}
+
+// cookieHash calculates the cookieHash for the given id, timestamp and nonce
+// index. The hash is used to create and validate cookies.
+func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 {
+
+ // Initialize block with fixed-size data: local ports and v.
+ var payload [8]byte
+ binary.BigEndian.PutUint16(payload[0:], id.LocalPort)
+ binary.BigEndian.PutUint16(payload[2:], id.RemotePort)
+ binary.BigEndian.PutUint32(payload[4:], ts)
+
+ // Feed everything to the hasher.
+ l.hasherMu.Lock()
+ l.hasher.Reset()
+ l.hasher.Write(payload[:])
+ l.hasher.Write(l.nonce[nonceIndex][:])
+ io.WriteString(l.hasher, string(id.LocalAddress))
+ io.WriteString(l.hasher, string(id.RemoteAddress))
+
+ // Finalize the calculation of the hash and return the first 4 bytes.
+ h := make([]byte, 0, sha1.Size)
+ h = l.hasher.Sum(h)
+ l.hasherMu.Unlock()
+
+ return binary.BigEndian.Uint32(h[:])
+}
+
+// createCookie creates a SYN cookie for the given id and incoming sequence
+// number.
+func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value {
+ ts := timeStamp()
+ v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset)
+ v += (l.cookieHash(id, ts, 1) + data) & hashMask
+ return seqnum.Value(v)
+}
+
+// isCookieValid checks if the supplied cookie is valid for the given id and
+// sequence number. If it is, it also returns the data originally encoded in the
+// cookie when createCookie was called.
+func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) {
+ ts := timeStamp()
+ v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq)
+ cookieTS := v >> tsOffset
+ if ((ts - cookieTS) & tsMask) > maxTSDiff {
+ return 0, false
+ }
+
+ return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
+}
+
+// createConnectingEndpoint creates a new endpoint in a connecting state, with
+// the connection parameters given by the arguments.
+func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+ // Create a new endpoint.
+ netProto := l.netProto
+ if netProto == 0 {
+ netProto = s.route.NetProto
+ }
+ n := newEndpoint(l.stack, netProto, nil)
+ n.v6only = l.v6only
+ n.id = s.id
+ n.boundNICID = s.route.NICID()
+ n.route = s.route.Clone()
+ n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
+ n.rcvBufSize = int(l.rcvWnd)
+
+ n.maybeEnableTimestamp(rcvdSynOpts)
+ n.maybeEnableSACKPermitted(rcvdSynOpts)
+
+ n.initGSO()
+
+ // Register new endpoint so that packets are routed to it.
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil {
+ n.Close()
+ return nil, err
+ }
+
+ n.isRegistered = true
+ n.state = stateConnecting
+
+ // Create sender and receiver.
+ //
+ // The receiver at least temporarily has a zero receive window scale,
+ // but the caller may change it (before starting the protocol loop).
+ n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
+ n.rcv = newReceiver(n, irs, l.rcvWnd, 0)
+
+ return n, nil
+}
+
+// createEndpoint creates a new endpoint in connected state and then performs
+// the TCP 3-way handshake.
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+ // Create new endpoint.
+ irs := s.sequenceNumber
+ cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS))
+ ep, err := l.createConnectingEndpoint(s, cookie, irs, opts)
+ if err != nil {
+ return nil, err
+ }
+
+ // Perform the 3-way handshake.
+ h := newHandshake(ep, l.rcvWnd)
+
+ h.resetToSynRcvd(cookie, irs, opts, l.listenEP)
+ if err := h.execute(); err != nil {
+ ep.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ ep.Close()
+ return nil, err
+ }
+
+ ep.state = stateConnected
+
+ // Update the receive window scaling. We can't do it before the
+ // handshake because it's possible that the peer doesn't support window
+ // scaling.
+ ep.rcv.rcvWndScale = h.effectiveRcvWndScale()
+
+ return ep, nil
+}
+
+// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
+// endpoint has transitioned out of the listen state, the new endpoint is closed
+// instead.
+func (e *endpoint) deliverAccepted(n *endpoint) {
+ e.mu.RLock()
+ state := e.state
+ e.mu.RUnlock()
+ if state == stateListen {
+ e.acceptedChan <- n
+ e.waiterQueue.Notify(waiter.EventIn)
+ } else {
+ n.Close()
+ }
+}
+
+// handleSynSegment is called in its own goroutine once the listening endpoint
+// receives a SYN segment. It is responsible for completing the handshake and
+// queueing the new endpoint for acceptance.
+//
+// A limited number of these goroutines are allowed before TCP starts using SYN
+// cookies to accept connections.
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
+ defer decSynRcvdCount()
+ defer e.decSynRcvdCount()
+ defer s.decRef()
+
+ n, err := ctx.createEndpointAndPerformHandshake(s, opts)
+ if err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ return
+ }
+
+ e.deliverAccepted(n)
+}
+
+func (e *endpoint) incSynRcvdCount() bool {
+ e.mu.Lock()
+ log.Printf("l: %d, c: %d, e.synRcvdCount: %d", len(e.acceptedChan), cap(e.acceptedChan), e.synRcvdCount)
+ if l, c := len(e.acceptedChan), cap(e.acceptedChan); l == c && e.synRcvdCount >= c {
+ e.mu.Unlock()
+ return false
+ }
+ e.synRcvdCount++
+ e.mu.Unlock()
+ return true
+}
+
+func (e *endpoint) decSynRcvdCount() {
+ e.mu.Lock()
+ e.synRcvdCount--
+ e.mu.Unlock()
+}
+
+// handleListenSegment is called when a listening endpoint receives a segment
+// and needs to handle it.
+func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
+ switch s.flags {
+ case header.TCPFlagSyn:
+ opts := parseSynSegmentOptions(s)
+ if incSynRcvdCount() {
+ // Drop the SYN if the listen endpoint's accept queue is
+ // overflowing.
+ if e.incSynRcvdCount() {
+ log.Printf("processing syn packet")
+ s.incRef()
+ go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier.
+ return
+ }
+ log.Printf("dropping syn packet")
+ e.stack.Stats().TCP.ListenOverflowSynDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ } else {
+ // TODO(bhaskerh): Increment syncookie sent stat.
+ cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ // Send SYN with window scaling because we currently
+ // dont't encode this information in the cookie.
+ //
+ // Enable Timestamp option if the original syn did have
+ // the timestamp option specified.
+ synOpts := header.TCPSynOptions{
+ WS: -1,
+ TS: opts.TS,
+ TSVal: tcpTimeStamp(timeStampOffset()),
+ TSEcr: opts.TSVal,
+ }
+ sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
+ e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment()
+ }
+
+ case header.TCPFlagAck:
+ if len(e.acceptedChan) == cap(e.acceptedChan) {
+ // Silently drop the ack as the application can't accept
+ // the connection at this point. The ack will be
+ // retransmitted by the sender anyway and we can
+ // complete the connection at the time of retransmit if
+ // the backlog has space.
+ e.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+
+ // Validate the cookie.
+ data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1)
+ if !ok || int(data) >= len(mssTable) {
+ e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment()
+ e.stack.Stats().DroppedPackets.Increment()
+ return
+ }
+ e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment()
+ // Create newly accepted endpoint and deliver it.
+ rcvdSynOptions := &header.TCPSynOptions{
+ MSS: mssTable[data],
+ // Disable Window scaling as original SYN is
+ // lost.
+ WS: -1,
+ }
+
+ // When syn cookies are in use we enable timestamp only
+ // if the ack specifies the timestamp option assuming
+ // that the other end did in fact negotiate the
+ // timestamp option in the original SYN.
+ if s.parsedOptions.TS {
+ rcvdSynOptions.TS = true
+ rcvdSynOptions.TSVal = s.parsedOptions.TSVal
+ rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
+ }
+
+ n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
+ if err != nil {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ return
+ }
+
+ // clear the tsOffset for the newly created
+ // endpoint as the Timestamp was already
+ // randomly offset when the original SYN-ACK was
+ // sent above.
+ n.tsOffset = 0
+
+ // Switch state to connected.
+ n.state = stateConnected
+
+ // Do the delivery in a separate goroutine so
+ // that we don't block the listen loop in case
+ // the application is slow to accept or stops
+ // accepting.
+ //
+ // NOTE: This won't result in an unbounded
+ // number of goroutines as we do check before
+ // entering here that there was at least some
+ // space available in the backlog.
+ go e.deliverAccepted(n)
+ }
+}
+
+// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
+// its own goroutine and is responsible for handling connection requests.
+func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
+ defer func() {
+ // Mark endpoint as closed. This will prevent goroutines running
+ // handleSynSegment() from attempting to queue new connections
+ // to the endpoint.
+ e.mu.Lock()
+ e.state = stateClosed
+
+ // Do cleanup if needed.
+ e.completeWorkerLocked()
+
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
+ e.mu.Unlock()
+
+ // Notify waiters that the endpoint is shutdown.
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
+ }()
+
+ e.mu.Lock()
+ v6only := e.v6only
+ e.mu.Unlock()
+
+ ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto)
+
+ s := sleep.Sleeper{}
+ s.AddWaker(&e.notificationWaker, wakerForNotification)
+ s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
+ for {
+ switch index, _ := s.Fetch(true); index {
+ case wakerForNotification:
+ n := e.fetchNotifications()
+ if n&notifyClose != 0 {
+ return nil
+ }
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ s := e.segmentQueue.dequeue()
+ e.handleListenSegment(ctx, s)
+ s.decRef()
+ }
+ synRcvdCount.pending.Wait()
+ close(e.drainDone)
+ <-e.undrain
+ }
+
+ case wakerForNewSegment:
+ // Process at most maxSegmentsPerWake segments.
+ mayRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ mayRequeue = false
+ break
+ }
+
+ e.handleListenSegment(ctx, s)
+ s.decRef()
+ }
+
+ // If the queue is not empty, make sure we'll wake up
+ // in the next iteration.
+ if mayRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
new file mode 100644
index 000000000..2aed6f286
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -0,0 +1,1066 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// maxSegmentsPerWake is the maximum number of segments to process in the main
+// protocol goroutine per wake-up. Yielding [after this number of segments are
+// processed] allows other events to be processed as well (e.g., timeouts,
+// resets, etc.).
+const maxSegmentsPerWake = 100
+
+type handshakeState int
+
+// The following are the possible states of the TCP connection during a 3-way
+// handshake. A depiction of the states and transitions can be found in RFC 793,
+// page 23.
+const (
+ handshakeSynSent handshakeState = iota
+ handshakeSynRcvd
+ handshakeCompleted
+)
+
+// The following are used to set up sleepers.
+const (
+ wakerForNotification = iota
+ wakerForNewSegment
+ wakerForResend
+ wakerForResolution
+)
+
+const (
+ // Maximum space available for options.
+ maxOptionSize = 40
+)
+
+// handshake holds the state used during a TCP 3-way handshake.
+type handshake struct {
+ ep *endpoint
+ listenEP *endpoint // only non nil when doing passive connects.
+ state handshakeState
+ active bool
+ flags uint8
+ ackNum seqnum.Value
+
+ // iss is the initial send sequence number, as defined in RFC 793.
+ iss seqnum.Value
+
+ // rcvWnd is the receive window, as defined in RFC 793.
+ rcvWnd seqnum.Size
+
+ // sndWnd is the send window, as defined in RFC 793.
+ sndWnd seqnum.Size
+
+ // mss is the maximum segment size received from the peer.
+ mss uint16
+
+ // sndWndScale is the send window scale, as defined in RFC 1323. A
+ // negative value means no scaling is supported by the peer.
+ sndWndScale int
+
+ // rcvWndScale is the receive window scale, as defined in RFC 1323.
+ rcvWndScale int
+}
+
+func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake {
+ h := handshake{
+ ep: ep,
+ active: true,
+ rcvWnd: rcvWnd,
+ rcvWndScale: FindWndScale(rcvWnd),
+ }
+ h.resetState()
+ return h
+}
+
+// FindWndScale determines the window scale to use for the given maximum window
+// size.
+func FindWndScale(wnd seqnum.Size) int {
+ if wnd < 0x10000 {
+ return 0
+ }
+
+ max := seqnum.Size(0xffff)
+ s := 0
+ for wnd > max && s < header.MaxWndScale {
+ s++
+ max <<= 1
+ }
+
+ return s
+}
+
+// resetState resets the state of the handshake object such that it becomes
+// ready for a new 3-way handshake.
+func (h *handshake) resetState() {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+
+ h.state = handshakeSynSent
+ h.flags = header.TCPFlagSyn
+ h.ackNum = 0
+ h.mss = 0
+ h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24)
+}
+
+// effectiveRcvWndScale returns the effective receive window scale to be used.
+// If the peer doesn't support window scaling, the effective rcv wnd scale is
+// zero; otherwise it's the value calculated based on the initial rcv wnd.
+func (h *handshake) effectiveRcvWndScale() uint8 {
+ if h.sndWndScale < 0 {
+ return 0
+ }
+ return uint8(h.rcvWndScale)
+}
+
+// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
+// state.
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, listenEP *endpoint) {
+ h.active = false
+ h.state = handshakeSynRcvd
+ h.flags = header.TCPFlagSyn | header.TCPFlagAck
+ h.iss = iss
+ h.ackNum = irs + 1
+ h.mss = opts.MSS
+ h.sndWndScale = opts.WS
+ h.listenEP = listenEP
+}
+
+// checkAck checks if the ACK number, if present, of a segment received during
+// a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in
+// response.
+func (h *handshake) checkAck(s *segment) bool {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber != h.iss+1 {
+ // RFC 793, page 36, states that a reset must be generated when
+ // the connection is in any non-synchronized state and an
+ // incoming segment acknowledges something not yet sent. The
+ // connection remains in the same state.
+ ack := s.sequenceNumber.Add(s.logicalLen())
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, s.ackNumber, ack, 0)
+ return false
+ }
+
+ return true
+}
+
+// synSentState handles a segment received when the TCP 3-way handshake is in
+// the SYN-SENT state.
+func (h *handshake) synSentState(s *segment) *tcpip.Error {
+ // RFC 793, page 37, states that in the SYN-SENT state, a reset is
+ // acceptable if the ack field acknowledges the SYN.
+ if s.flagIsSet(header.TCPFlagRst) {
+ if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 {
+ return tcpip.ErrConnectionRefused
+ }
+ return nil
+ }
+
+ if !h.checkAck(s) {
+ return nil
+ }
+
+ // We are in the SYN-SENT state. We only care about segments that have
+ // the SYN flag.
+ if !s.flagIsSet(header.TCPFlagSyn) {
+ return nil
+ }
+
+ // Parse the SYN options.
+ rcvSynOpts := parseSynSegmentOptions(s)
+
+ // Remember if the Timestamp option was negotiated.
+ h.ep.maybeEnableTimestamp(&rcvSynOpts)
+
+ // Remember if the SACKPermitted option was negotiated.
+ h.ep.maybeEnableSACKPermitted(&rcvSynOpts)
+
+ // Remember the sequence we'll ack from now on.
+ h.ackNum = s.sequenceNumber + 1
+ h.flags |= header.TCPFlagAck
+ h.mss = rcvSynOpts.MSS
+ h.sndWndScale = rcvSynOpts.WS
+
+ // If this is a SYN ACK response, we only need to acknowledge the SYN
+ // and the handshake is completed.
+ if s.flagIsSet(header.TCPFlagAck) {
+ h.state = handshakeCompleted
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
+ return nil
+ }
+
+ // A SYN segment was received, but no ACK in it. We acknowledge the SYN
+ // but resend our own SYN and wait for it to be acknowledged in the
+ // SYN-RCVD state.
+ h.state = handshakeSynRcvd
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: rcvSynOpts.TS,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+
+ // We only send SACKPermitted if the other side indicated it
+ // permits SACK. This is not explicitly defined in the RFC but
+ // this is the behaviour implemented by Linux.
+ SACKPermitted: rcvSynOpts.SACKPermitted,
+ }
+ sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
+ return nil
+}
+
+// synRcvdState handles a segment received when the TCP 3-way handshake is in
+// the SYN-RCVD state.
+func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
+ if s.flagIsSet(header.TCPFlagRst) {
+ // RFC 793, page 37, states that in the SYN-RCVD state, a reset
+ // is acceptable if the sequence number is in the window.
+ if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) {
+ return tcpip.ErrConnectionRefused
+ }
+ return nil
+ }
+
+ if !h.checkAck(s) {
+ return nil
+ }
+
+ if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 {
+ // We received two SYN segments with different sequence
+ // numbers, so we reset this and restart the whole
+ // process, except that we don't reset the timer.
+ ack := s.sequenceNumber.Add(s.logicalLen())
+ seq := seqnum.Value(0)
+ if s.flagIsSet(header.TCPFlagAck) {
+ seq = s.ackNumber
+ }
+ h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0)
+
+ if !h.active {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ h.resetState()
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: h.ep.sendTSOk,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+ SACKPermitted: h.ep.sackPermitted,
+ }
+ sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ return nil
+ }
+
+ // We have previously received (and acknowledged) the peer's SYN. If the
+ // peer acknowledges our SYN, the handshake is completed.
+ if s.flagIsSet(header.TCPFlagAck) {
+ // listenContext is also used by a tcp.Forwarder and in that
+ // context we do not have a listening endpoint to check the
+ // backlog. So skip this check if listenEP is nil.
+ if h.listenEP != nil && len(h.listenEP.acceptedChan) == cap(h.listenEP.acceptedChan) {
+ // If there is no space in the accept queue to accept
+ // this endpoint then silently drop this ACK. The peer
+ // will anyway resend the ack and we can complete the
+ // connection the next time it's retransmitted.
+ h.ep.stack.Stats().TCP.ListenOverflowAckDrop.Increment()
+ h.ep.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+ // If the timestamp option is negotiated and the segment does
+ // not carry a timestamp option then the segment must be dropped
+ // as per https://tools.ietf.org/html/rfc7323#section-3.2.
+ if h.ep.sendTSOk && !s.parsedOptions.TS {
+ h.ep.stack.Stats().DroppedPackets.Increment()
+ return nil
+ }
+
+ // Update timestamp if required. See RFC7323, section-4.3.
+ if h.ep.sendTSOk && s.parsedOptions.TS {
+ h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
+ }
+ h.state = handshakeCompleted
+ return nil
+ }
+
+ return nil
+}
+
+func (h *handshake) handleSegment(s *segment) *tcpip.Error {
+ h.sndWnd = s.window
+ if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 {
+ h.sndWnd <<= uint8(h.sndWndScale)
+ }
+
+ switch h.state {
+ case handshakeSynRcvd:
+ return h.synRcvdState(s)
+ case handshakeSynSent:
+ return h.synSentState(s)
+ }
+ return nil
+}
+
+// processSegments goes through the segment queue and processes up to
+// maxSegmentsPerWake (if they're available).
+func (h *handshake) processSegments() *tcpip.Error {
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := h.ep.segmentQueue.dequeue()
+ if s == nil {
+ return nil
+ }
+
+ err := h.handleSegment(s)
+ s.decRef()
+ if err != nil {
+ return err
+ }
+
+ // We stop processing packets once the handshake is completed,
+ // otherwise we may process packets meant to be processed by
+ // the main protocol goroutine.
+ if h.state == handshakeCompleted {
+ break
+ }
+ }
+
+ // If the queue is not empty, make sure we'll wake up in the next
+ // iteration.
+ if !h.ep.segmentQueue.empty() {
+ h.ep.newSegmentWaker.Assert()
+ }
+
+ return nil
+}
+
+func (h *handshake) resolveRoute() *tcpip.Error {
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ resolutionWaker := &sleep.Waker{}
+ s.AddWaker(resolutionWaker, wakerForResolution)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ defer s.Done()
+
+ // Initial action is to resolve route.
+ index := wakerForResolution
+ for {
+ switch index {
+ case wakerForResolution:
+ if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
+ // Either success (err == nil) or failure.
+ return err
+ }
+ // Resolution not completed. Keep trying...
+
+ case wakerForNotification:
+ n := h.ep.fetchNotifications()
+ if n&notifyClose != 0 {
+ h.ep.route.RemoveWaker(resolutionWaker)
+ return tcpip.ErrAborted
+ }
+ if n&notifyDrain != 0 {
+ close(h.ep.drainDone)
+ <-h.ep.undrain
+ }
+ }
+
+ // Wait for notification.
+ index, _ = s.Fetch(true)
+ }
+}
+
+// execute executes the TCP 3-way handshake.
+func (h *handshake) execute() *tcpip.Error {
+ if h.ep.route.IsResolutionRequired() {
+ if err := h.resolveRoute(); err != nil {
+ return err
+ }
+ }
+
+ // Initialize the resend timer.
+ resendWaker := sleep.Waker{}
+ timeOut := time.Duration(time.Second)
+ rt := time.AfterFunc(timeOut, func() {
+ resendWaker.Assert()
+ })
+ defer rt.Stop()
+
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ s.AddWaker(&resendWaker, wakerForResend)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
+ defer s.Done()
+
+ var sackEnabled SACKEnabled
+ if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
+ // If stack returned an error when checking for SACKEnabled
+ // status then just default to switching off SACK negotiation.
+ sackEnabled = false
+ }
+
+ // Send the initial SYN segment and loop until the handshake is
+ // completed.
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: true,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+ SACKPermitted: bool(sackEnabled),
+ }
+
+ // Execute is also called in a listen context so we want to make sure we
+ // only send the TS/SACK option when we received the TS/SACK in the
+ // initial SYN.
+ if h.state == handshakeSynRcvd {
+ synOpts.TS = h.ep.sendTSOk
+ synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
+ }
+ sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ for h.state != handshakeCompleted {
+ switch index, _ := s.Fetch(true); index {
+ case wakerForResend:
+ timeOut *= 2
+ if timeOut > 60*time.Second {
+ return tcpip.ErrTimeout
+ }
+ rt.Reset(timeOut)
+ sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
+ case wakerForNotification:
+ n := h.ep.fetchNotifications()
+ if n&notifyClose != 0 {
+ return tcpip.ErrAborted
+ }
+ if n&notifyDrain != 0 {
+ for !h.ep.segmentQueue.empty() {
+ s := h.ep.segmentQueue.dequeue()
+ err := h.handleSegment(s)
+ s.decRef()
+ if err != nil {
+ return err
+ }
+ if h.state == handshakeCompleted {
+ return nil
+ }
+ }
+ close(h.ep.drainDone)
+ <-h.ep.undrain
+ }
+
+ case wakerForNewSegment:
+ if err := h.processSegments(); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
+ synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck))
+ if synOpts.TS {
+ s.parsedOptions.TSVal = synOpts.TSVal
+ s.parsedOptions.TSEcr = synOpts.TSEcr
+ }
+ return synOpts
+}
+
+var optionPool = sync.Pool{
+ New: func() interface{} {
+ return make([]byte, maxOptionSize)
+ },
+}
+
+func getOptions() []byte {
+ return optionPool.Get().([]byte)
+}
+
+func putOptions(options []byte) {
+ // Reslice to full capacity.
+ optionPool.Put(options[0:cap(options)])
+}
+
+func makeSynOptions(opts header.TCPSynOptions) []byte {
+ // Emulate linux option order. This is as follows:
+ //
+ // if md5: NOP NOP MD5SIG 18 md5sig(16)
+ // if mss: MSS 4 mss(2)
+ // if ts and sack_advertise:
+ // SACK 2 TIMESTAMP 2 timestamp(8)
+ // elif ts: NOP NOP TIMESTAMP 10 timestamp(8)
+ // elif sack: NOP NOP SACK 2
+ // if wscale: NOP WINDOW 3 ws(1)
+ // if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8))
+ // [for each block] start_seq(4) end_seq(4)
+ // if fastopen_cookie:
+ // if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2)
+ // else: FASTOPEN (2 + len(cookie))
+ // cookie(variable) [padding to four bytes]
+ //
+ options := getOptions()
+
+ // Always encode the mss.
+ offset := header.EncodeMSSOption(uint32(opts.MSS), options)
+
+ // Special ordering is required here. If both TS and SACK are enabled,
+ // then the SACK option precedes TS, with no padding. If they are
+ // enabled individually, then we see padding before the option.
+ if opts.TS && opts.SACKPermitted {
+ offset += header.EncodeSACKPermittedOption(options[offset:])
+ offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
+ } else if opts.TS {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
+ } else if opts.SACKPermitted {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKPermittedOption(options[offset:])
+ }
+
+ // Initialize the WS option.
+ if opts.WS >= 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeWSOption(opts.WS, options[offset:])
+ }
+
+ // Padding to the end; note that this never apply unless we add a
+ // fastopen option, we always expect the offset to remain the same.
+ if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
+ panic("unexpected option encoding")
+ }
+
+ return options[:offset]
+}
+
+func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
+ // The MSS in opts is automatically calculated as this function is
+ // called from many places and we don't want every call point being
+ // embedded with the MSS calculation.
+ if opts.MSS == 0 {
+ opts.MSS = uint16(r.MTU() - header.TCPMinimumSize)
+ }
+
+ options := makeSynOptions(opts)
+ err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil)
+ putOptions(options)
+ return err
+}
+
+// sendTCP sends a TCP segment with the provided options via the provided
+// network endpoint and under the provided identity.
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error {
+ optLen := len(opts)
+ // Allocate a buffer for the TCP header.
+ hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
+
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ // Initialize the header.
+ tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
+ tcp.Encode(&header.TCPFields{
+ SrcPort: id.LocalPort,
+ DstPort: id.RemotePort,
+ SeqNum: uint32(seq),
+ AckNum: uint32(ack),
+ DataOffset: uint8(header.TCPMinimumSize + optLen),
+ Flags: flags,
+ WindowSize: uint16(rcvWnd),
+ })
+ copy(tcp[header.TCPMinimumSize:], opts)
+
+ length := uint16(hdr.UsedLength() + data.Size())
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
+ // Only calculate the checksum if offloading isn't supported.
+ if gso != nil && gso.NeedsCsum {
+ // This is called CHECKSUM_PARTIAL in the Linux kernel. We
+ // calculate a checksum of the pseudo-header and save it in the
+ // TCP header, then the kernel calculate a checksum of the
+ // header and data and get the right sum of the TCP packet.
+ tcp.SetChecksum(xsum)
+ } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ xsum = header.ChecksumVV(data, xsum)
+ tcp.SetChecksum(^tcp.CalculateChecksum(xsum))
+ }
+
+ r.Stats().TCP.SegmentsSent.Increment()
+ if (flags & header.TCPFlagRst) != 0 {
+ r.Stats().TCP.ResetsSent.Increment()
+ }
+
+ return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl)
+}
+
+// makeOptions makes an options slice.
+func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
+ options := getOptions()
+ offset := 0
+
+ // N.B. the ordering here matches the ordering used by Linux internally
+ // and described in the raw makeOptions function. We don't include
+ // unnecessary cases here (post connection.)
+ if e.sendTSOk {
+ // Embed the timestamp if timestamp has been enabled.
+ //
+ // We only use the lower 32 bits of the unix time in
+ // milliseconds. This is similar to what Linux does where it
+ // uses the lower 32 bits of the jiffies value in the tsVal
+ // field of the timestamp option.
+ //
+ // Further, RFC7323 section-5.4 recommends millisecond
+ // resolution as the lowest recommended resolution for the
+ // timestamp clock.
+ //
+ // Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:])
+ }
+ if e.sackPermitted && len(sackBlocks) > 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
+ }
+
+ // We expect the above to produce an aligned offset.
+ if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
+ panic("unexpected option encoding")
+ }
+
+ return options[:offset]
+}
+
+// sendRaw sends a TCP segment to the endpoint's peer.
+func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
+ var sackBlocks []header.SACKBlock
+ if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) {
+ sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
+ }
+ options := e.makeOptions(sackBlocks)
+ err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options, e.gso)
+ putOptions(options)
+ return err
+}
+
+func (e *endpoint) handleWrite() *tcpip.Error {
+ // Move packets from send queue to send list. The queue is accessible
+ // from other goroutines and protected by the send mutex, while the send
+ // list is only accessible from the handler goroutine, so it needs no
+ // mutexes.
+ e.sndBufMu.Lock()
+
+ first := e.sndQueue.Front()
+ if first != nil {
+ e.snd.writeList.PushBackList(&e.sndQueue)
+ e.snd.sndNxtList.UpdateForward(e.sndBufInQueue)
+ e.sndBufInQueue = 0
+ }
+
+ e.sndBufMu.Unlock()
+
+ // Initialize the next segment to write if it's currently nil.
+ if e.snd.writeNext == nil {
+ e.snd.writeNext = first
+ }
+
+ // Push out any new packets.
+ e.snd.sendData()
+
+ return nil
+}
+
+func (e *endpoint) handleClose() *tcpip.Error {
+ // Drain the send queue.
+ e.handleWrite()
+
+ // Mark send side as closed.
+ e.snd.closed = true
+
+ return nil
+}
+
+// resetConnectionLocked sends a RST segment and puts the endpoint in an error
+// state with the given error code. This method must only be called from the
+// protocol goroutine.
+func (e *endpoint) resetConnectionLocked(err *tcpip.Error) {
+ e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+
+ e.state = stateError
+ e.hardError = err
+}
+
+// completeWorkerLocked is called by the worker goroutine when it's about to
+// exit. It marks the worker as completed and performs cleanup work if requested
+// by Close().
+func (e *endpoint) completeWorkerLocked() {
+ e.workerRunning = false
+ if e.workerCleanup {
+ e.cleanupLocked()
+ }
+}
+
+// handleSegments pulls segments from the queue and processes them. It returns
+// no error if the protocol loop should continue, an error otherwise.
+func (e *endpoint) handleSegments() *tcpip.Error {
+ checkRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ checkRequeue = false
+ break
+ }
+
+ // Invoke the tcp probe if installed.
+ if e.probe != nil {
+ e.probe(e.completeState())
+ }
+
+ if s.flagIsSet(header.TCPFlagRst) {
+ if e.rcv.acceptable(s.sequenceNumber, 0) {
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+ s.decRef()
+ return tcpip.ErrConnectionReset
+ }
+ } else if s.flagIsSet(header.TCPFlagAck) {
+ // Patch the window size in the segment according to the
+ // send window scale.
+ s.window <<= e.snd.sndWndScale
+
+ // RFC 793, page 41 states that "once in the ESTABLISHED
+ // state all segments must carry current acknowledgment
+ // information."
+ e.rcv.handleRcvdSegment(s)
+ e.snd.handleRcvdSegment(s)
+ }
+ s.decRef()
+ }
+
+ // If the queue is not empty, make sure we'll wake up in the next
+ // iteration.
+ if checkRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+
+ // Send an ACK for all processed packets if needed.
+ if e.rcv.rcvNxt != e.snd.maxSentAck {
+ e.snd.sendAck()
+ }
+
+ e.resetKeepaliveTimer(true)
+
+ return nil
+}
+
+// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP
+// keepalive packets periodically when the connection is idle. If we don't hear
+// from the other side after a number of tries, we terminate the connection.
+func (e *endpoint) keepaliveTimerExpired() *tcpip.Error {
+ e.keepalive.Lock()
+ if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() {
+ e.keepalive.Unlock()
+ return nil
+ }
+
+ if e.keepalive.unacked >= e.keepalive.count {
+ e.keepalive.Unlock()
+ return tcpip.ErrConnectionReset
+ }
+
+ // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with
+ // seg.seq = snd.nxt-1.
+ e.keepalive.unacked++
+ e.keepalive.Unlock()
+ e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1)
+ e.resetKeepaliveTimer(false)
+ return nil
+}
+
+// resetKeepaliveTimer restarts or stops the keepalive timer, depending on
+// whether it is enabled for this endpoint.
+func (e *endpoint) resetKeepaliveTimer(receivedData bool) {
+ e.keepalive.Lock()
+ defer e.keepalive.Unlock()
+ if receivedData {
+ e.keepalive.unacked = 0
+ }
+ // Start the keepalive timer IFF it's enabled and there is no pending
+ // data to send.
+ if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt {
+ e.keepalive.timer.disable()
+ return
+ }
+ if e.keepalive.unacked > 0 {
+ e.keepalive.timer.enable(e.keepalive.interval)
+ } else {
+ e.keepalive.timer.enable(e.keepalive.idle)
+ }
+}
+
+// disableKeepaliveTimer stops the keepalive timer.
+func (e *endpoint) disableKeepaliveTimer() {
+ e.keepalive.Lock()
+ e.keepalive.timer.disable()
+ e.keepalive.Unlock()
+}
+
+// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
+// goroutine and is responsible for sending segments and handling received
+// segments.
+func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error {
+ var closeTimer *time.Timer
+ var closeWaker sleep.Waker
+
+ epilogue := func() {
+ // e.mu is expected to be hold upon entering this section.
+
+ if e.snd != nil {
+ e.snd.resendTimer.cleanup()
+ }
+
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
+
+ e.completeWorkerLocked()
+
+ if e.drainDone != nil {
+ close(e.drainDone)
+ }
+
+ e.mu.Unlock()
+
+ // When the protocol loop exits we should wake up our waiters.
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+ }
+
+ if handshake {
+ // This is an active connection, so we must initiate the 3-way
+ // handshake, and then inform potential waiters about its
+ // completion.
+ h := newHandshake(e, seqnum.Size(e.receiveBufferAvailable()))
+ if err := h.execute(); err != nil {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ e.mu.Lock()
+ e.state = stateError
+ e.hardError = err
+ // Lock released below.
+ epilogue()
+
+ return err
+ }
+
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ e.rcvListMu.Lock()
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
+ e.rcvListMu.Unlock()
+ }
+
+ e.keepalive.timer.init(&e.keepalive.waker)
+ defer e.keepalive.timer.cleanup()
+
+ // Tell waiters that the endpoint is connected and writable.
+ e.mu.Lock()
+ e.state = stateConnected
+ drained := e.drainDone != nil
+ e.mu.Unlock()
+ if drained {
+ close(e.drainDone)
+ <-e.undrain
+ }
+
+ e.waiterQueue.Notify(waiter.EventOut)
+
+ // Set up the functions that will be called when the main protocol loop
+ // wakes up.
+ funcs := []struct {
+ w *sleep.Waker
+ f func() *tcpip.Error
+ }{
+ {
+ w: &e.sndWaker,
+ f: e.handleWrite,
+ },
+ {
+ w: &e.sndCloseWaker,
+ f: e.handleClose,
+ },
+ {
+ w: &e.newSegmentWaker,
+ f: e.handleSegments,
+ },
+ {
+ w: &closeWaker,
+ f: func() *tcpip.Error {
+ return tcpip.ErrConnectionAborted
+ },
+ },
+ {
+ w: &e.snd.resendWaker,
+ f: func() *tcpip.Error {
+ if !e.snd.retransmitTimerExpired() {
+ return tcpip.ErrTimeout
+ }
+ return nil
+ },
+ },
+ {
+ w: &e.keepalive.waker,
+ f: e.keepaliveTimerExpired,
+ },
+ {
+ w: &e.notificationWaker,
+ f: func() *tcpip.Error {
+ n := e.fetchNotifications()
+ if n&notifyNonZeroReceiveWindow != 0 {
+ e.rcv.nonZeroWindow()
+ }
+
+ if n&notifyReceiveWindowChanged != 0 {
+ e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize())
+ }
+
+ if n&notifyMTUChanged != 0 {
+ e.sndBufMu.Lock()
+ count := e.packetTooBigCount
+ e.packetTooBigCount = 0
+ mtu := e.sndMTU
+ e.sndBufMu.Unlock()
+
+ e.snd.updateMaxPayloadSize(mtu, count)
+ }
+
+ if n&notifyReset != 0 {
+ e.mu.Lock()
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.mu.Unlock()
+ }
+ if n&notifyClose != 0 && closeTimer == nil {
+ // Reset the connection 3 seconds after
+ // the endpoint has been closed.
+ //
+ // The timer could fire in background
+ // when the endpoint is drained. That's
+ // OK as the loop here will not honor
+ // the firing until the undrain arrives.
+ closeTimer = time.AfterFunc(3*time.Second, func() {
+ closeWaker.Assert()
+ })
+ }
+
+ if n&notifyKeepaliveChanged != 0 {
+ // The timer could fire in background
+ // when the endpoint is drained. That's
+ // OK. See above.
+ e.resetKeepaliveTimer(true)
+ }
+
+ if n&notifyDrain != 0 {
+ for !e.segmentQueue.empty() {
+ if err := e.handleSegments(); err != nil {
+ return err
+ }
+ }
+ if e.state != stateError {
+ close(e.drainDone)
+ <-e.undrain
+ }
+ }
+
+ return nil
+ },
+ },
+ }
+
+ // Initialize the sleeper based on the wakers in funcs.
+ s := sleep.Sleeper{}
+ for i := range funcs {
+ s.AddWaker(funcs[i].w, i)
+ }
+
+ // The following assertions and notifications are needed for restored
+ // endpoints. Fresh newly created endpoints have empty states and should
+ // not invoke any.
+ e.segmentQueue.mu.Lock()
+ if !e.segmentQueue.list.Empty() {
+ e.newSegmentWaker.Assert()
+ }
+ e.segmentQueue.mu.Unlock()
+
+ e.rcvListMu.Lock()
+ if !e.rcvList.Empty() {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ e.rcvListMu.Unlock()
+
+ e.mu.RLock()
+ if e.workerCleanup {
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+ e.mu.RUnlock()
+
+ // Main loop. Handle segments until both send and receive ends of the
+ // connection have completed.
+ for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList {
+ e.workMu.Unlock()
+ v, _ := s.Fetch(true)
+ e.workMu.Lock()
+ if err := funcs[v].f(); err != nil {
+ e.mu.Lock()
+ e.resetConnectionLocked(err)
+ // Lock released below.
+ epilogue()
+
+ return nil
+ }
+ }
+
+ // Mark endpoint as closed.
+ e.mu.Lock()
+ if e.state != stateError {
+ e.state = stateClosed
+ }
+ // Lock released below.
+ epilogue()
+
+ return nil
+}
diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go
new file mode 100644
index 000000000..e618cd2b9
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/cubic.go
@@ -0,0 +1,233 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "math"
+ "time"
+)
+
+// cubicState stores the variables related to TCP CUBIC congestion
+// control algorithm state.
+//
+// See: https://tools.ietf.org/html/rfc8312.
+type cubicState struct {
+ // wLastMax is the previous wMax value.
+ wLastMax float64
+
+ // wMax is the value of the congestion window at the
+ // time of last congestion event.
+ wMax float64
+
+ // t denotes the time when the current congestion avoidance
+ // was entered.
+ t time.Time
+
+ // numCongestionEvents tracks the number of congestion events since last
+ // RTO.
+ numCongestionEvents int
+
+ // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as
+ // per RFC.
+ c float64
+
+ // k is the time period that the above function takes to increase the
+ // current window size to W_max if there are no further congestion
+ // events and is calculated using the following equation:
+ //
+ // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2)
+ k float64
+
+ // beta is the CUBIC multiplication decrease factor. that is, when a
+ // congestion event is detected, CUBIC reduces its cwnd to
+ // W_cubic(0)=W_max*beta_cubic.
+ beta float64
+
+ // wC is window computed by CUBIC at time t. It's calculated using the
+ // formula:
+ //
+ // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1)
+ wC float64
+
+ // wEst is the window computed by CUBIC at time t+RTT i.e
+ // W_cubic(t+RTT).
+ wEst float64
+
+ s *sender
+}
+
+// newCubicCC returns a partially initialized cubic state with the constants
+// beta and c set and t set to current time.
+func newCubicCC(s *sender) *cubicState {
+ return &cubicState{
+ t: time.Now(),
+ beta: 0.7,
+ c: 0.4,
+ s: s,
+ }
+}
+
+// enterCongestionAvoidance is used to initialize cubic in cases where we exit
+// SlowStart without a real congestion event taking place. This can happen when
+// a connection goes back to slow start due to a retransmit and we exceed the
+// previously lowered ssThresh without experiencing packet loss.
+//
+// Refer: https://tools.ietf.org/html/rfc8312#section-4.8
+func (c *cubicState) enterCongestionAvoidance() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.7 &
+ // https://tools.ietf.org/html/rfc8312#section-4.8
+ if c.numCongestionEvents == 0 {
+ c.k = 0
+ c.t = time.Now()
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+ }
+}
+
+// updateSlowStart will update the congestion window as per the slow-start
+// algorithm used by NewReno. If after adjusting the congestion window we cross
+// the ssThresh then it will return the number of packets that must be consumed
+// in congestion avoidance mode.
+func (c *cubicState) updateSlowStart(packetsAcked int) int {
+ // Don't let the congestion window cross into the congestion
+ // avoidance range.
+ newcwnd := c.s.sndCwnd + packetsAcked
+ enterCA := false
+ if newcwnd >= c.s.sndSsthresh {
+ newcwnd = c.s.sndSsthresh
+ c.s.sndCAAckCount = 0
+ enterCA = true
+ }
+
+ packetsAcked -= newcwnd - c.s.sndCwnd
+ c.s.sndCwnd = newcwnd
+ if enterCA {
+ c.enterCongestionAvoidance()
+ }
+ return packetsAcked
+}
+
+// Update updates cubic's internal state variables. It must be called on every
+// ACK received.
+// Refer: https://tools.ietf.org/html/rfc8312#section-4
+func (c *cubicState) Update(packetsAcked int) {
+ if c.s.sndCwnd < c.s.sndSsthresh {
+ packetsAcked = c.updateSlowStart(packetsAcked)
+ if packetsAcked == 0 {
+ return
+ }
+ } else {
+ c.s.rtt.Lock()
+ srtt := c.s.rtt.srtt
+ c.s.rtt.Unlock()
+ c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt)
+ }
+}
+
+// cubicCwnd computes the CUBIC congestion window after t seconds from last
+// congestion event.
+func (c *cubicState) cubicCwnd(t float64) float64 {
+ return c.c*math.Pow(t, 3.0) + c.wMax
+}
+
+// getCwnd returns the current congestion window as computed by CUBIC.
+// Refer: https://tools.ietf.org/html/rfc8312#section-4
+func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int {
+ elapsed := time.Since(c.t).Seconds()
+
+ // Compute the window as per Cubic after 'elapsed' time
+ // since last congestion event.
+ c.wC = c.cubicCwnd(elapsed - c.k)
+
+ // Compute the TCP friendly estimate of the congestion window.
+ c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds())
+
+ // Make sure in the TCP friendly region CUBIC performs at least
+ // as well as Reno.
+ if c.wC < c.wEst && float64(sndCwnd) < c.wEst {
+ // TCP Friendly region of cubic.
+ return int(c.wEst)
+ }
+
+ // In Concave/Convex region of CUBIC, calculate what CUBIC window
+ // will be after 1 RTT and use that to grow congestion window
+ // for every ack.
+ tEst := (time.Since(c.t) + srtt).Seconds()
+ wtRtt := c.cubicCwnd(tEst - c.k)
+ // As per 4.3 for each received ACK cwnd must be incremented
+ // by (w_cubic(t+RTT) - cwnd/cwnd.
+ cwnd := float64(sndCwnd)
+ for i := 0; i < packetsAcked; i++ {
+ // Concave/Convex regions of cubic have the same formulas.
+ // See: https://tools.ietf.org/html/rfc8312#section-4.3
+ cwnd += (wtRtt - cwnd) / cwnd
+ }
+ return int(cwnd)
+}
+
+// HandleNDupAcks implements congestionControl.HandleNDupAcks.
+func (c *cubicState) HandleNDupAcks() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.5
+ c.numCongestionEvents++
+ c.t = time.Now()
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+
+ c.fastConvergence()
+ c.reduceSlowStartThreshold()
+}
+
+// HandleRTOExpired implements congestionContrl.HandleRTOExpired.
+func (c *cubicState) HandleRTOExpired() {
+ // See: https://tools.ietf.org/html/rfc8312#section-4.6
+ c.t = time.Now()
+ c.numCongestionEvents = 0
+ c.wLastMax = c.wMax
+ c.wMax = float64(c.s.sndCwnd)
+
+ c.fastConvergence()
+
+ // We lost a packet, so reduce ssthresh.
+ c.reduceSlowStartThreshold()
+
+ // Reduce the congestion window to 1, i.e., enter slow-start. Per
+ // RFC 5681, page 7, we must use 1 regardless of the value of the
+ // initial congestion window.
+ c.s.sndCwnd = 1
+}
+
+// fastConvergence implements the logic for Fast Convergence algorithm as
+// described in https://tools.ietf.org/html/rfc8312#section-4.6.
+func (c *cubicState) fastConvergence() {
+ if c.wMax < c.wLastMax {
+ c.wLastMax = c.wMax
+ c.wMax = c.wMax * (1.0 + c.beta) / 2.0
+ } else {
+ c.wLastMax = c.wMax
+ }
+ // Recompute k as wMax may have changed.
+ c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c)
+}
+
+// PostRecovery implemements congestionControl.PostRecovery.
+func (c *cubicState) PostRecovery() {
+ c.t = time.Now()
+}
+
+// reduceSlowStartThreshold returns new SsThresh as described in
+// https://tools.ietf.org/html/rfc8312#section-4.7.
+func (c *cubicState) reduceSlowStartThreshold() {
+ c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0))
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
new file mode 100644
index 000000000..fd697402e
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -0,0 +1,1741 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "fmt"
+ "math"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/rand"
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tmutex"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateListen
+ stateConnecting
+ stateConnected
+ stateClosed
+ stateError
+)
+
+// Reasons for notifying the protocol goroutine.
+const (
+ notifyNonZeroReceiveWindow = 1 << iota
+ notifyReceiveWindowChanged
+ notifyClose
+ notifyMTUChanged
+ notifyDrain
+ notifyReset
+ notifyKeepaliveChanged
+)
+
+// SACKInfo holds TCP SACK related information for a given endpoint.
+//
+// +stateify savable
+type SACKInfo struct {
+ // Blocks is the maximum number of SACK blocks we track
+ // per endpoint.
+ Blocks [MaxSACKBlocks]header.SACKBlock
+
+ // NumBlocks is the number of valid SACK blocks stored in the
+ // blocks array above.
+ NumBlocks int
+}
+
+// endpoint represents a TCP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized. The protocol implementation, however, runs in a single
+// goroutine.
+//
+// +stateify savable
+type endpoint struct {
+ // workMu is used to arbitrate which goroutine may perform protocol
+ // work. Only the main protocol goroutine is expected to call Lock() on
+ // it, but other goroutines (e.g., send) may call TryLock() to eagerly
+ // perform work without having to wait for the main one to wake up.
+ workMu tmutex.Mutex `state:"nosave"`
+
+ // The following fields are initialized at creation time and do not
+ // change throughout the lifetime of the endpoint.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue `state:"wait"`
+
+ // lastError represents the last error that the endpoint reported;
+ // access to it is protected by the following mutex.
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error `state:".(string)"`
+
+ // The following fields are used to manage the receive queue. The
+ // protocol goroutine adds ready-for-delivery segments to rcvList,
+ // which are returned by Read() calls to users.
+ //
+ // Once the peer has closed its send side, rcvClosed is set to true
+ // to indicate to users that no more data is coming.
+ //
+ // rcvListMu can be taken after the endpoint mu below.
+ rcvListMu sync.Mutex `state:"nosave"`
+ rcvList segmentList `state:"wait"`
+ rcvClosed bool
+ rcvBufSize int
+ rcvBufUsed int
+
+ // The following fields are protected by the mutex.
+ mu sync.RWMutex `state:"nosave"`
+ id stack.TransportEndpointID
+ state endpointState `state:".(endpointState)"`
+ isPortReserved bool `state:"manual"`
+ isRegistered bool
+ boundNICID tcpip.NICID `state:"manual"`
+ route stack.Route `state:"manual"`
+ v6only bool
+ isConnectNotified bool
+ // TCP should never broadcast but Linux nevertheless supports enabling/
+ // disabling SO_BROADCAST, albeit as a NOOP.
+ broadcast bool
+
+ // effectiveNetProtos contains the network protocols actually in use. In
+ // most cases it will only contain "netProto", but in cases like IPv6
+ // endpoints with v6only set to false, this could include multiple
+ // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
+ // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
+ // address).
+ effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"`
+
+ // hardError is meaningful only when state is stateError, it stores the
+ // error to be returned when read/write syscalls are called and the
+ // endpoint is in this state. hardError is protected by mu.
+ hardError *tcpip.Error `state:".(string)"`
+
+ // workerRunning specifies if a worker goroutine is running.
+ workerRunning bool
+
+ // workerCleanup specifies if the worker goroutine must perform cleanup
+ // before exitting. This can only be set to true when workerRunning is
+ // also true, and they're both protected by the mutex.
+ workerCleanup bool
+
+ // sendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1
+ sendTSOk bool
+
+ // recentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field is
+ // updated if required when a new segment is received by this endpoint.
+ recentTS uint32
+
+ // tsOffset is a randomized offset added to the value of the
+ // TSVal field in the timestamp option.
+ tsOffset uint32
+
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+
+ // sackPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ sackPermitted bool
+
+ // sack holds TCP SACK related information for this endpoint.
+ sack SACKInfo
+
+ // reusePort is set to true if SO_REUSEPORT is enabled.
+ reusePort bool
+
+ // delay enables Nagle's algorithm.
+ //
+ // delay is a boolean (0 is false) and must be accessed atomically.
+ delay uint32
+
+ // cork holds back segments until full.
+ //
+ // cork is a boolean (0 is false) and must be accessed atomically.
+ cork uint32
+
+ // scoreboard holds TCP SACK Scoreboard information for this endpoint.
+ scoreboard *SACKScoreboard
+
+ // The options below aren't implemented, but we remember the user
+ // settings because applications expect to be able to set/query these
+ // options.
+ reuseAddr bool
+
+ // slowAck holds the negated state of quick ack. It is stubbed out and
+ // does nothing.
+ //
+ // slowAck is a boolean (0 is false) and must be accessed atomically.
+ slowAck uint32
+
+ // segmentQueue is used to hand received segments to the protocol
+ // goroutine. Segments are queued as long as the queue is not full,
+ // and dropped when it is.
+ segmentQueue segmentQueue `state:"wait"`
+
+ // synRcvdCount is the number of connections for this endpoint that are
+ // in SYN-RCVD state.
+ synRcvdCount int
+
+ // The following fields are used to manage the send buffer. When
+ // segments are ready to be sent, they are added to sndQueue and the
+ // protocol goroutine is signaled via sndWaker.
+ //
+ // When the send side is closed, the protocol goroutine is notified via
+ // sndCloseWaker, and sndClosed is set to true.
+ sndBufMu sync.Mutex `state:"nosave"`
+ sndBufSize int
+ sndBufUsed int
+ sndClosed bool
+ sndBufInQueue seqnum.Size
+ sndQueue segmentList `state:"wait"`
+ sndWaker sleep.Waker `state:"manual"`
+ sndCloseWaker sleep.Waker `state:"manual"`
+
+ // cc stores the name of the Congestion Control algorithm to use for
+ // this endpoint.
+ cc CongestionControlOption
+
+ // The following are used when a "packet too big" control packet is
+ // received. They are protected by sndBufMu. They are used to
+ // communicate to the main protocol goroutine how many such control
+ // messages have been received since the last notification was processed
+ // and what was the smallest MTU seen.
+ packetTooBigCount int
+ sndMTU int
+
+ // newSegmentWaker is used to indicate to the protocol goroutine that
+ // it needs to wake up and handle new segments queued to it.
+ newSegmentWaker sleep.Waker `state:"manual"`
+
+ // notificationWaker is used to indicate to the protocol goroutine that
+ // it needs to wake up and check for notifications.
+ notificationWaker sleep.Waker `state:"manual"`
+
+ // notifyFlags is a bitmask of flags used to indicate to the protocol
+ // goroutine what it was notified; this is only accessed atomically.
+ notifyFlags uint32 `state:"nosave"`
+
+ // keepalive manages TCP keepalive state. When the connection is idle
+ // (no data sent or received) for keepaliveIdle, we start sending
+ // keepalives every keepalive.interval. If we send keepalive.count
+ // without hearing a response, the connection is closed.
+ keepalive keepalive
+
+ // acceptedChan is used by a listening endpoint protocol goroutine to
+ // send newly accepted connections to the endpoint so that they can be
+ // read by Accept() calls.
+ acceptedChan chan *endpoint `state:".([]*endpoint)"`
+
+ // The following are only used from the protocol goroutine, and
+ // therefore don't need locks to protect them.
+ rcv *receiver `state:"wait"`
+ snd *sender `state:"wait"`
+
+ // The goroutine drain completion notification channel.
+ drainDone chan struct{} `state:"nosave"`
+
+ // The goroutine undrain notification channel.
+ undrain chan struct{} `state:"nosave"`
+
+ // probe if not nil is invoked on every received segment. It is passed
+ // a copy of the current state of the endpoint.
+ probe stack.TCPProbeFunc `state:"nosave"`
+
+ // The following are only used to assist the restore run to re-connect.
+ bindAddress tcpip.Address
+ connectingAddress tcpip.Address
+
+ gso *stack.GSO
+}
+
+// StopWork halts packet processing. Only to be used in tests.
+func (e *endpoint) StopWork() {
+ e.workMu.Lock()
+}
+
+// ResumeWork resumes packet processing. Only to be used in tests.
+func (e *endpoint) ResumeWork() {
+ e.workMu.Unlock()
+}
+
+// keepalive is a synchronization wrapper used to appease stateify. See the
+// comment in endpoint, where it is used.
+//
+// +stateify savable
+type keepalive struct {
+ sync.Mutex `state:"nosave"`
+ enabled bool
+ idle time.Duration
+ interval time.Duration
+ count int
+ unacked int
+ timer timer `state:"nosave"`
+ waker sleep.Waker `state:"nosave"`
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+ e := &endpoint{
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ rcvBufSize: DefaultBufferSize,
+ sndBufSize: DefaultBufferSize,
+ sndMTU: int(math.MaxInt32),
+ reuseAddr: true,
+ keepalive: keepalive{
+ // Linux defaults.
+ idle: 2 * time.Hour,
+ interval: 75 * time.Second,
+ count: 9,
+ },
+ }
+
+ var ss SendBufferSizeOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ e.sndBufSize = ss.Default
+ }
+
+ var rs ReceiveBufferSizeOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ e.rcvBufSize = rs.Default
+ }
+
+ var cs CongestionControlOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil {
+ e.cc = cs
+ }
+
+ if p := stack.GetTCPProbe(); p != nil {
+ e.probe = p
+ }
+
+ e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.workMu.Init()
+ e.workMu.Lock()
+ e.tsOffset = timeStampOffset()
+ return e
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ result := waiter.EventMask(0)
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ switch e.state {
+ case stateInitial, stateBound, stateConnecting:
+ // Ready for nothing.
+
+ case stateClosed, stateError:
+ // Ready for anything.
+ result = mask
+
+ case stateListen:
+ // Check if there's anything in the accepted channel.
+ if (mask & waiter.EventIn) != 0 {
+ if len(e.acceptedChan) > 0 {
+ result |= waiter.EventIn
+ }
+ }
+
+ case stateConnected:
+ // Determine if the endpoint is writable if requested.
+ if (mask & waiter.EventOut) != 0 {
+ e.sndBufMu.Lock()
+ if e.sndClosed || e.sndBufUsed < e.sndBufSize {
+ result |= waiter.EventOut
+ }
+ e.sndBufMu.Unlock()
+ }
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvListMu.Lock()
+ if e.rcvBufUsed > 0 || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvListMu.Unlock()
+ }
+ }
+
+ return result
+}
+
+func (e *endpoint) fetchNotifications() uint32 {
+ return atomic.SwapUint32(&e.notifyFlags, 0)
+}
+
+func (e *endpoint) notifyProtocolGoroutine(n uint32) {
+ for {
+ v := atomic.LoadUint32(&e.notifyFlags)
+ if v&n == n {
+ // The flags are already set.
+ return
+ }
+
+ if atomic.CompareAndSwapUint32(&e.notifyFlags, v, v|n) {
+ if v == 0 {
+ // We are causing a transition from no flags to
+ // at least one flag set, so we must cause the
+ // protocol goroutine to wake up.
+ e.notificationWaker.Assert()
+ }
+ return
+ }
+ }
+}
+
+// Close puts the endpoint in a closed state and frees all resources associated
+// with it. It must be called only once and with no other concurrent calls to
+// the endpoint.
+func (e *endpoint) Close() {
+ // Issue a shutdown so that the peer knows we won't send any more data
+ // if we're connected, or stop accepting if we're listening.
+ e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+
+ e.mu.Lock()
+
+ // For listening sockets, we always release ports inline so that they
+ // are immediately available for reuse after Close() is called. If also
+ // registered, we unregister as well otherwise the next user would fail
+ // in Listen() when trying to register.
+ if e.state == stateListen && e.isPortReserved {
+ if e.isRegistered {
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.isRegistered = false
+ }
+
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.isPortReserved = false
+ }
+
+ // Either perform the local cleanup or kick the worker to make sure it
+ // knows it needs to cleanup.
+ tcpip.AddDanglingEndpoint(e)
+ if !e.workerRunning {
+ e.cleanupLocked()
+ } else {
+ e.workerCleanup = true
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+
+ e.mu.Unlock()
+}
+
+// cleanupLocked frees all resources associated with the endpoint. It is called
+// after Close() is called and the worker goroutine (if any) is done with its
+// work.
+func (e *endpoint) cleanupLocked() {
+ // Close all endpoints that might have been accepted by TCP but not by
+ // the client.
+ if e.acceptedChan != nil {
+ close(e.acceptedChan)
+ for n := range e.acceptedChan {
+ n.mu.Lock()
+ n.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ n.mu.Unlock()
+ n.Close()
+ }
+ e.acceptedChan = nil
+ }
+ e.workerCleanup = false
+
+ if e.isRegistered {
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.isRegistered = false
+ }
+
+ if e.isPortReserved {
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.isPortReserved = false
+ }
+
+ e.route.Release()
+ tcpip.DeleteDanglingEndpoint(e)
+}
+
+// Read reads data from the endpoint.
+func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ e.mu.RLock()
+ // The endpoint can be read if it's connected, or if it's already closed
+ // but has some pending unread data. Also note that a RST being received
+ // would cause the state to become stateError so we should allow the
+ // reads to proceed before returning a ECONNRESET.
+ e.rcvListMu.Lock()
+ bufUsed := e.rcvBufUsed
+ if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 {
+ e.rcvListMu.Unlock()
+ he := e.hardError
+ e.mu.RUnlock()
+ if s == stateError {
+ return buffer.View{}, tcpip.ControlMessages{}, he
+ }
+ return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ }
+
+ v, err := e.readLocked()
+ e.rcvListMu.Unlock()
+
+ e.mu.RUnlock()
+
+ return v, tcpip.ControlMessages{}, err
+}
+
+func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
+ if e.rcvBufUsed == 0 {
+ if e.rcvClosed || e.state != stateConnected {
+ return buffer.View{}, tcpip.ErrClosedForReceive
+ }
+ return buffer.View{}, tcpip.ErrWouldBlock
+ }
+
+ s := e.rcvList.Front()
+ views := s.data.Views()
+ v := views[s.viewToDeliver]
+ s.viewToDeliver++
+
+ if s.viewToDeliver >= len(views) {
+ e.rcvList.Remove(s)
+ s.decRef()
+ }
+
+ scale := e.rcv.rcvWndScale
+ wasZero := e.zeroReceiveWindow(scale)
+ e.rcvBufUsed -= len(v)
+ if wasZero && !e.zeroReceiveWindow(scale) {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
+
+ return v, nil
+}
+
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // The endpoint cannot be written to if it's not connected.
+ if e.state != stateConnected {
+ switch e.state {
+ case stateError:
+ return 0, nil, e.hardError
+ default:
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+ }
+
+ // Nothing to do if the buffer is empty.
+ if p.Size() == 0 {
+ return 0, nil, nil
+ }
+
+ e.sndBufMu.Lock()
+
+ // Check if the connection has already been closed for sends.
+ if e.sndClosed {
+ e.sndBufMu.Unlock()
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+
+ // Check against the limit.
+ avail := e.sndBufSize - e.sndBufUsed
+ if avail <= 0 {
+ e.sndBufMu.Unlock()
+ return 0, nil, tcpip.ErrWouldBlock
+ }
+
+ v, perr := p.Get(avail)
+ if perr != nil {
+ e.sndBufMu.Unlock()
+ return 0, nil, perr
+ }
+
+ l := len(v)
+ s := newSegmentFromView(&e.route, e.id, v)
+
+ // Add data to the send queue.
+ e.sndBufUsed += l
+ e.sndBufInQueue += seqnum.Size(l)
+ e.sndQueue.PushBack(s)
+
+ e.sndBufMu.Unlock()
+
+ if e.workMu.TryLock() {
+ // Do the work inline.
+ e.handleWrite()
+ e.workMu.Unlock()
+ } else {
+ // Let the protocol goroutine do the work.
+ e.sndWaker.Assert()
+ }
+ return uintptr(l), nil, nil
+}
+
+// Peek reads data without consuming it from the endpoint.
+//
+// This method does not block if there is no data pending.
+func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // The endpoint can be read if it's connected, or if it's already closed
+ // but has some pending unread data.
+ if s := e.state; s != stateConnected && s != stateClosed {
+ if s == stateError {
+ return 0, tcpip.ControlMessages{}, e.hardError
+ }
+ return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState
+ }
+
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
+ if e.rcvBufUsed == 0 {
+ if e.rcvClosed || e.state != stateConnected {
+ return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
+ }
+ return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ }
+
+ // Make a copy of vec so we can modify the slide headers.
+ vec = append([][]byte(nil), vec...)
+
+ var num uintptr
+
+ for s := e.rcvList.Front(); s != nil; s = s.Next() {
+ views := s.data.Views()
+
+ for i := s.viewToDeliver; i < len(views); i++ {
+ v := views[i]
+
+ for len(v) > 0 {
+ if len(vec) == 0 {
+ return num, tcpip.ControlMessages{}, nil
+ }
+ if len(vec[0]) == 0 {
+ vec = vec[1:]
+ continue
+ }
+
+ n := copy(vec[0], v)
+ v = v[n:]
+ vec[0] = vec[0][n:]
+ num += uintptr(n)
+ }
+ }
+ }
+
+ return num, tcpip.ControlMessages{}, nil
+}
+
+// zeroReceiveWindow checks if the receive window to be announced now would be
+// zero, based on the amount of available buffer and the receive window scaling.
+//
+// It must be called with rcvListMu held.
+func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
+ if e.rcvBufUsed >= e.rcvBufSize {
+ return true
+ }
+
+ return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.DelayOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.delay, 0)
+
+ // Handle delayed data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.delay, 1)
+ }
+ return nil
+
+ case tcpip.CorkOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.cork, 0)
+
+ // Handle the corked data.
+ e.sndWaker.Assert()
+ } else {
+ atomic.StoreUint32(&e.cork, 1)
+ }
+ return nil
+
+ case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.reuseAddr = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.QuickAckOption:
+ if v == 0 {
+ atomic.StoreUint32(&e.slowAck, 1)
+ } else {
+ atomic.StoreUint32(&e.slowAck, 0)
+ }
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs ReceiveBufferSizeOption
+ size := int(v)
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if size < rs.Min {
+ size = rs.Min
+ }
+ if size > rs.Max {
+ size = rs.Max
+ }
+ }
+
+ mask := uint32(notifyReceiveWindowChanged)
+
+ e.rcvListMu.Lock()
+
+ // Make sure the receive buffer size allows us to send a
+ // non-zero window size.
+ scale := uint8(0)
+ if e.rcv != nil {
+ scale = e.rcv.rcvWndScale
+ }
+ if size>>scale == 0 {
+ size = 1 << scale
+ }
+
+ // Make sure 2*size doesn't overflow.
+ if size > math.MaxInt32/2 {
+ size = math.MaxInt32 / 2
+ }
+
+ wasZero := e.zeroReceiveWindow(scale)
+ e.rcvBufSize = size
+ if wasZero && !e.zeroReceiveWindow(scale) {
+ mask |= notifyNonZeroReceiveWindow
+ }
+ e.rcvListMu.Unlock()
+
+ e.notifyProtocolGoroutine(mask)
+ return nil
+
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ size := int(v)
+ var ss SendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if size < ss.Min {
+ size = ss.Min
+ }
+ if size > ss.Max {
+ size = ss.Max
+ }
+ }
+
+ e.sndBufMu.Lock()
+ e.sndBufSize = size
+ e.sndBufMu.Unlock()
+ return nil
+
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // We only allow this to be set when we're in the initial state.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.v6only = v != 0
+ return nil
+
+ case tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ e.keepalive.enabled = v != 0
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+ return nil
+
+ case tcpip.KeepaliveIdleOption:
+ e.keepalive.Lock()
+ e.keepalive.idle = time.Duration(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+ return nil
+
+ case tcpip.KeepaliveIntervalOption:
+ e.keepalive.Lock()
+ e.keepalive.interval = time.Duration(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+ return nil
+
+ case tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ e.keepalive.count = int(v)
+ e.keepalive.Unlock()
+ e.notifyProtocolGoroutine(notifyKeepaliveChanged)
+ return nil
+
+ case tcpip.BroadcastOption:
+ e.mu.Lock()
+ e.broadcast = v != 0
+ e.mu.Unlock()
+ return nil
+
+ default:
+ return nil
+ }
+}
+
+// readyReceiveSize returns the number of bytes ready to be received.
+func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // The endpoint cannot be in listen state.
+ if e.state == stateListen {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
+ return e.rcvBufUsed, nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ e.lastErrorMu.Lock()
+ err := e.lastError
+ e.lastError = nil
+ e.lastErrorMu.Unlock()
+ return err
+
+ case *tcpip.SendBufferSizeOption:
+ e.sndBufMu.Lock()
+ *o = tcpip.SendBufferSizeOption(e.sndBufSize)
+ e.sndBufMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.rcvListMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize)
+ e.rcvListMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ v, err := e.readyReceiveSize()
+ if err != nil {
+ return err
+ }
+
+ *o = tcpip.ReceiveQueueSizeOption(v)
+ return nil
+
+ case *tcpip.DelayOption:
+ *o = 0
+ if v := atomic.LoadUint32(&e.delay); v != 0 {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.CorkOption:
+ *o = 0
+ if v := atomic.LoadUint32(&e.cork); v != 0 {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.ReuseAddressOption:
+ e.mu.RLock()
+ v := e.reuseAddr
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.reusePort
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.QuickAckOption:
+ *o = 1
+ if v := atomic.LoadUint32(&e.slowAck); v != 0 {
+ *o = 0
+ }
+ return nil
+
+ case *tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.TCPInfoOption:
+ *o = tcpip.TCPInfoOption{}
+ e.mu.RLock()
+ snd := e.snd
+ e.mu.RUnlock()
+ if snd != nil {
+ snd.rtt.Lock()
+ o.RTT = snd.rtt.srtt
+ o.RTTVar = snd.rtt.rttvar
+ snd.rtt.Unlock()
+ }
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ e.keepalive.Lock()
+ v := e.keepalive.enabled
+ e.keepalive.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.KeepaliveIdleOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveIdleOption(e.keepalive.idle)
+ e.keepalive.Unlock()
+ return nil
+
+ case *tcpip.KeepaliveIntervalOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval)
+ e.keepalive.Unlock()
+ return nil
+
+ case *tcpip.KeepaliveCountOption:
+ e.keepalive.Lock()
+ *o = tcpip.KeepaliveCountOption(e.keepalive.count)
+ e.keepalive.Unlock()
+ return nil
+
+ case *tcpip.OutOfBandInlineOption:
+ // We don't currently support disabling this option.
+ *o = 1
+ return nil
+
+ case *tcpip.BroadcastOption:
+ e.mu.Lock()
+ v := e.broadcast
+ e.mu.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.netProto
+ if header.IsV4MappedAddress(addr.Addr) {
+ // Fail if using a v4 mapped address on a v6only endpoint.
+ if e.v6only {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ netProto = header.IPv4ProtocolNumber
+ addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
+ if addr.Addr == "\x00\x00\x00\x00" {
+ addr.Addr = ""
+ }
+ }
+
+ // Fail if we're bound to an address length different from the one we're
+ // checking.
+ if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return netProto, nil
+}
+
+// Connect connects the endpoint to its peer.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ return e.connect(addr, true, true)
+}
+
+// connect connects the endpoint to its peer. In the normal non-S/R case, the
+// new connection is expected to run the main goroutine and perform handshake.
+// In restore of previously connected endpoints, both ends will be passively
+// created (so no new handshaking is done); for stack-accepted connections not
+// yet accepted by the app, they are restored without running the main goroutine
+// here.
+func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ defer func() {
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ }
+ }()
+
+ connectingAddr := addr.Addr
+
+ netProto, err := e.checkV4Mapped(&addr)
+ if err != nil {
+ return err
+ }
+
+ nicid := addr.NIC
+ switch e.state {
+ case stateBound:
+ // If we're already bound to a NIC but the caller is requesting
+ // that we use a different one now, we cannot proceed.
+ if e.boundNICID == 0 {
+ break
+ }
+
+ if nicid != 0 && nicid != e.boundNICID {
+ return tcpip.ErrNoRoute
+ }
+
+ nicid = e.boundNICID
+
+ case stateInitial:
+ // Nothing to do. We'll eventually fill-in the gaps in the ID
+ // (if any) when we find a route.
+
+ case stateConnecting:
+ // A connection request has already been issued but hasn't
+ // completed yet.
+ return tcpip.ErrAlreadyConnecting
+
+ case stateConnected:
+ // The endpoint is already connected. If caller hasn't been notified yet, return success.
+ if !e.isConnectNotified {
+ e.isConnectNotified = true
+ return nil
+ }
+ // Otherwise return that it's already connected.
+ return tcpip.ErrAlreadyConnected
+
+ case stateError:
+ return e.hardError
+
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ origID := e.id
+
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ e.id.LocalAddress = r.LocalAddress
+ e.id.RemoteAddress = r.RemoteAddress
+ e.id.RemotePort = addr.Port
+
+ if e.id.LocalPort != 0 {
+ // The endpoint is bound to a port, attempt to register it.
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort)
+ if err != nil {
+ return err
+ }
+ } else {
+ // The endpoint doesn't have a local port yet, so try to get
+ // one. Make sure that it isn't one that will result in the same
+ // address/port for both local and remote (otherwise this
+ // endpoint would be trying to connect to itself).
+ sameAddr := e.id.LocalAddress == e.id.RemoteAddress
+ if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ if sameAddr && p == e.id.RemotePort {
+ return false, nil
+ }
+ if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) {
+ return false, nil
+ }
+
+ id := e.id
+ id.LocalPort = p
+ switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) {
+ case nil:
+ e.id = id
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ }); err != nil {
+ return err
+ }
+ }
+
+ // Remove the port reservation. This can happen when Bind is called
+ // before Connect: in such a case we don't want to hold on to
+ // reservations anymore.
+ if e.isPortReserved {
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
+ e.isPortReserved = false
+ }
+
+ e.isRegistered = true
+ e.state = stateConnecting
+ e.route = r.Clone()
+ e.boundNICID = nicid
+ e.effectiveNetProtos = netProtos
+ e.connectingAddress = connectingAddr
+
+ e.initGSO()
+
+ // Connect in the restore phase does not perform handshake. Restore its
+ // connection setting here.
+ if !handshake {
+ e.segmentQueue.mu.Lock()
+ for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} {
+ for s := l.Front(); s != nil; s = s.Next() {
+ s.id = e.id
+ s.route = r.Clone()
+ e.sndWaker.Assert()
+ }
+ }
+ e.segmentQueue.mu.Unlock()
+ e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0)
+ e.state = stateConnected
+ }
+
+ if run {
+ e.workerRunning = true
+ e.stack.Stats().TCP.ActiveConnectionOpenings.Increment()
+ go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save.
+ }
+
+ return tcpip.ErrConnectStarted
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection to its
+// peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.shutdownFlags |= flags
+
+ switch e.state {
+ case stateConnected:
+ // Close for read.
+ if (e.shutdownFlags & tcpip.ShutdownRead) != 0 {
+ // Mark read side as closed.
+ e.rcvListMu.Lock()
+ e.rcvClosed = true
+ rcvBufUsed := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+
+ // If we're fully closed and we have unread data we need to abort
+ // the connection with a RST.
+ if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 {
+ e.notifyProtocolGoroutine(notifyReset)
+ return nil
+ }
+ }
+
+ // Close for write.
+ if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 {
+ e.sndBufMu.Lock()
+
+ if e.sndClosed {
+ // Already closed.
+ e.sndBufMu.Unlock()
+ break
+ }
+
+ // Queue fin segment.
+ s := newSegmentFromView(&e.route, e.id, nil)
+ e.sndQueue.PushBack(s)
+ e.sndBufInQueue++
+
+ // Mark endpoint as closed.
+ e.sndClosed = true
+
+ e.sndBufMu.Unlock()
+
+ // Tell protocol goroutine to close.
+ e.sndCloseWaker.Assert()
+ }
+
+ case stateListen:
+ // Tell protocolListenLoop to stop.
+ if flags&tcpip.ShutdownRead != 0 {
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+
+ default:
+ return tcpip.ErrNotConnected
+ }
+
+ return nil
+}
+
+// Listen puts the endpoint in "listen" mode, which allows it to accept
+// new connections.
+func (e *endpoint) Listen(backlog int) (err *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ defer func() {
+ if err != nil && !err.IgnoreStats() {
+ e.stack.Stats().TCP.FailedConnectionAttempts.Increment()
+ }
+ }()
+
+ // Allow the backlog to be adjusted if the endpoint is not shutting down.
+ // When the endpoint shuts down, it sets workerCleanup to true, and from
+ // that point onward, acceptedChan is the responsibility of the cleanup()
+ // method (and should not be touched anywhere else, including here).
+ if e.state == stateListen && !e.workerCleanup {
+ // Adjust the size of the channel iff we can fix existing
+ // pending connections into the new one.
+ if len(e.acceptedChan) > backlog {
+ return tcpip.ErrInvalidEndpointState
+ }
+ if cap(e.acceptedChan) == backlog {
+ return nil
+ }
+ origChan := e.acceptedChan
+ e.acceptedChan = make(chan *endpoint, backlog)
+ close(origChan)
+ for ep := range origChan {
+ e.acceptedChan <- ep
+ }
+ return nil
+ }
+
+ // Endpoint must be bound before it can transition to listen mode.
+ if e.state != stateBound {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Register the endpoint.
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil {
+ return err
+ }
+
+ e.isRegistered = true
+ e.state = stateListen
+ if e.acceptedChan == nil {
+ e.acceptedChan = make(chan *endpoint, backlog)
+ }
+ e.workerRunning = true
+
+ go e.protocolListenLoop( // S/R-SAFE: drained on save.
+ seqnum.Size(e.receiveBufferAvailable()))
+
+ return nil
+}
+
+// startAcceptedLoop sets up required state and starts a goroutine with the
+// main loop for accepted connections.
+func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
+ e.waiterQueue = waiterQueue
+ e.workerRunning = true
+ go e.protocolMainLoop(false) // S/R-SAFE: drained on save.
+}
+
+// Accept returns a new endpoint if a peer has established a connection
+// to an endpoint previously set to listen mode.
+func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // Endpoint must be in listen state before it can accept connections.
+ if e.state != stateListen {
+ return nil, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ // Get the new accepted endpoint.
+ var n *endpoint
+ select {
+ case n = <-e.acceptedChan:
+ default:
+ return nil, nil, tcpip.ErrWouldBlock
+ }
+
+ // Start the protocol goroutine.
+ wq := &waiter.Queue{}
+ n.startAcceptedLoop(wq)
+ e.stack.Stats().TCP.PassiveConnectionOpenings.Increment()
+
+ return n, wq, nil
+}
+
+// Bind binds the endpoint to a specific local port and optionally address.
+func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore. This is because once the endpoint goes into a connected or
+ // listen state, it is already bound.
+ if e.state != stateInitial {
+ return tcpip.ErrAlreadyBound
+ }
+
+ e.bindAddress = addr.Addr
+ netProto, err := e.checkV4Mapped(&addr)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
+ }
+
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort)
+ if err != nil {
+ return err
+ }
+
+ e.isPortReserved = true
+ e.effectiveNetProtos = netProtos
+ e.id.LocalPort = port
+
+ // Any failures beyond this point must remove the port registration.
+ defer func() {
+ if err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
+ e.isPortReserved = false
+ e.effectiveNetProtos = nil
+ e.id.LocalPort = 0
+ e.id.LocalAddress = ""
+ e.boundNICID = 0
+ }
+ }()
+
+ // If an address is specified, we must ensure that it's one of our
+ // local addresses.
+ if len(addr.Addr) != 0 {
+ nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ if nic == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ e.boundNICID = nic
+ e.id.LocalAddress = addr.Addr
+ }
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ Addr: e.id.LocalAddress,
+ Port: e.id.LocalPort,
+ NIC: e.boundNICID,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ NIC: e.boundNICID,
+ }, nil
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+ s := newSegment(r, id, vv)
+ if !s.parse() {
+ e.stack.Stats().MalformedRcvdPackets.Increment()
+ e.stack.Stats().TCP.InvalidSegmentsReceived.Increment()
+ s.decRef()
+ return
+ }
+
+ if !s.csumValid {
+ e.stack.Stats().MalformedRcvdPackets.Increment()
+ e.stack.Stats().TCP.ChecksumErrors.Increment()
+ s.decRef()
+ return
+ }
+
+ e.stack.Stats().TCP.ValidSegmentsReceived.Increment()
+ if (s.flags & header.TCPFlagRst) != 0 {
+ e.stack.Stats().TCP.ResetsReceived.Increment()
+ }
+
+ // Send packet to worker goroutine.
+ if e.segmentQueue.enqueue(s) {
+ e.newSegmentWaker.Assert()
+ } else {
+ // The queue is full, so we drop the segment.
+ e.stack.Stats().DroppedPackets.Increment()
+ s.decRef()
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+ switch typ {
+ case stack.ControlPacketTooBig:
+ e.sndBufMu.Lock()
+ e.packetTooBigCount++
+ if v := int(extra); v < e.sndMTU {
+ e.sndMTU = v
+ }
+ e.sndBufMu.Unlock()
+
+ e.notifyProtocolGoroutine(notifyMTUChanged)
+ }
+}
+
+// updateSndBufferUsage is called by the protocol goroutine when room opens up
+// in the send buffer. The number of newly available bytes is v.
+func (e *endpoint) updateSndBufferUsage(v int) {
+ e.sndBufMu.Lock()
+ notify := e.sndBufUsed >= e.sndBufSize>>1
+ e.sndBufUsed -= v
+ // We only notify when there is half the sndBufSize available after
+ // a full buffer event occurs. This ensures that we don't wake up
+ // writers to queue just 1-2 segments and go back to sleep.
+ notify = notify && e.sndBufUsed < e.sndBufSize>>1
+ e.sndBufMu.Unlock()
+
+ if notify {
+ e.waiterQueue.Notify(waiter.EventOut)
+ }
+}
+
+// readyToRead is called by the protocol goroutine when a new segment is ready
+// to be read, or when the connection is closed for receiving (in which case
+// s will be nil).
+func (e *endpoint) readyToRead(s *segment) {
+ e.rcvListMu.Lock()
+ if s != nil {
+ s.incRef()
+ e.rcvBufUsed += s.data.Size()
+ e.rcvList.PushBack(s)
+ } else {
+ e.rcvClosed = true
+ }
+ e.rcvListMu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventIn)
+}
+
+// receiveBufferAvailable calculates how many bytes are still available in the
+// receive buffer.
+func (e *endpoint) receiveBufferAvailable() int {
+ e.rcvListMu.Lock()
+ size := e.rcvBufSize
+ used := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+
+ // We may use more bytes than the buffer size when the receive buffer
+ // shrinks.
+ if used >= size {
+ return 0
+ }
+
+ return size - used
+}
+
+func (e *endpoint) receiveBufferSize() int {
+ e.rcvListMu.Lock()
+ size := e.rcvBufSize
+ e.rcvListMu.Unlock()
+
+ return size
+}
+
+// updateRecentTimestamp updates the recent timestamp using the algorithm
+// described in https://tools.ietf.org/html/rfc7323#section-4.3
+func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
+ if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
+ e.recentTS = tsVal
+ }
+}
+
+// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if
+// the SYN options indicate that timestamp option was negotiated. It also
+// initializes the recentTS with the value provided in synOpts.TSval.
+func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
+ if synOpts.TS {
+ e.sendTSOk = true
+ e.recentTS = synOpts.TSVal
+ }
+}
+
+// timestamp returns the timestamp value to be used in the TSVal field of the
+// timestamp option for outgoing TCP segments for a given endpoint.
+func (e *endpoint) timestamp() uint32 {
+ return tcpTimeStamp(e.tsOffset)
+}
+
+// tcpTimeStamp returns a timestamp offset by the provided offset. This is
+// not inlined above as it's used when SYN cookies are in use and endpoint
+// is not created at the time when the SYN cookie is sent.
+func tcpTimeStamp(offset uint32) uint32 {
+ now := time.Now()
+ return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset
+}
+
+// timeStampOffset returns a randomized timestamp offset to be used when sending
+// timestamp values in a timestamp option for a TCP segment.
+func timeStampOffset() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ // Initialize a random tsOffset that will be added to the recentTS
+ // everytime the timestamp is sent when the Timestamp option is enabled.
+ //
+ // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
+ // why this is required.
+ //
+ // NOTE: This is not completely to spec as normally this should be
+ // initialized in a manner analogous to how sequence numbers are
+ // randomized per connection basis. But for now this is sufficient.
+ return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+}
+
+// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
+// if the SYN options indicate that the SACK option was negotiated and the TCP
+// stack is configured to enable TCP SACK option.
+func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
+ var v SACKEnabled
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
+ // Stack doesn't support SACK. So just return.
+ return
+ }
+ if bool(v) && synOpts.SACKPermitted {
+ e.sackPermitted = true
+ }
+}
+
+// maxOptionSize return the maximum size of TCP options.
+func (e *endpoint) maxOptionSize() (size int) {
+ var maxSackBlocks [header.TCPMaxSACKBlocks]header.SACKBlock
+ options := e.makeOptions(maxSackBlocks[:])
+ size = len(options)
+ putOptions(options)
+
+ return size
+}
+
+// completeState makes a full copy of the endpoint and returns it. This is used
+// before invoking the probe. The state returned may not be fully consistent if
+// there are intervening syscalls when the state is being copied.
+func (e *endpoint) completeState() stack.TCPEndpointState {
+ var s stack.TCPEndpointState
+ s.SegTime = time.Now()
+
+ // Copy EndpointID.
+ e.mu.Lock()
+ s.ID = stack.TCPEndpointID(e.id)
+ e.mu.Unlock()
+
+ // Copy endpoint rcv state.
+ e.rcvListMu.Lock()
+ s.RcvBufSize = e.rcvBufSize
+ s.RcvBufUsed = e.rcvBufUsed
+ s.RcvClosed = e.rcvClosed
+ e.rcvListMu.Unlock()
+
+ // Endpoint TCP Option state.
+ s.SendTSOk = e.sendTSOk
+ s.RecentTS = e.recentTS
+ s.TSOffset = e.tsOffset
+ s.SACKPermitted = e.sackPermitted
+ s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks)
+ copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks])
+ s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy()
+
+ // Copy endpoint send state.
+ e.sndBufMu.Lock()
+ s.SndBufSize = e.sndBufSize
+ s.SndBufUsed = e.sndBufUsed
+ s.SndClosed = e.sndClosed
+ s.SndBufInQueue = e.sndBufInQueue
+ s.PacketTooBigCount = e.packetTooBigCount
+ s.SndMTU = e.sndMTU
+ e.sndBufMu.Unlock()
+
+ // Copy receiver state.
+ s.Receiver = stack.TCPReceiverState{
+ RcvNxt: e.rcv.rcvNxt,
+ RcvAcc: e.rcv.rcvAcc,
+ RcvWndScale: e.rcv.rcvWndScale,
+ PendingBufUsed: e.rcv.pendingBufUsed,
+ PendingBufSize: e.rcv.pendingBufSize,
+ }
+
+ // Copy sender state.
+ s.Sender = stack.TCPSenderState{
+ LastSendTime: e.snd.lastSendTime,
+ DupAckCount: e.snd.dupAckCount,
+ FastRecovery: stack.TCPFastRecoveryState{
+ Active: e.snd.fr.active,
+ First: e.snd.fr.first,
+ Last: e.snd.fr.last,
+ MaxCwnd: e.snd.fr.maxCwnd,
+ HighRxt: e.snd.fr.highRxt,
+ RescueRxt: e.snd.fr.rescueRxt,
+ },
+ SndCwnd: e.snd.sndCwnd,
+ Ssthresh: e.snd.sndSsthresh,
+ SndCAAckCount: e.snd.sndCAAckCount,
+ Outstanding: e.snd.outstanding,
+ SndWnd: e.snd.sndWnd,
+ SndUna: e.snd.sndUna,
+ SndNxt: e.snd.sndNxt,
+ RTTMeasureSeqNum: e.snd.rttMeasureSeqNum,
+ RTTMeasureTime: e.snd.rttMeasureTime,
+ Closed: e.snd.closed,
+ RTO: e.snd.rto,
+ SRTTInited: e.snd.srttInited,
+ MaxPayloadSize: e.snd.maxPayloadSize,
+ SndWndScale: e.snd.sndWndScale,
+ MaxSentAck: e.snd.maxSentAck,
+ }
+ e.snd.rtt.Lock()
+ s.Sender.SRTT = e.snd.rtt.srtt
+ e.snd.rtt.Unlock()
+
+ if cubic, ok := e.snd.cc.(*cubicState); ok {
+ s.Sender.Cubic = stack.TCPCubicState{
+ WMax: cubic.wMax,
+ WLastMax: cubic.wLastMax,
+ T: cubic.t,
+ TimeSinceLastCongestion: time.Since(cubic.t),
+ C: cubic.c,
+ K: cubic.k,
+ Beta: cubic.beta,
+ WC: cubic.wC,
+ WEst: cubic.wEst,
+ }
+ }
+ return s
+}
+
+func (e *endpoint) initGSO() {
+ if e.route.Capabilities()&stack.CapabilityGSO == 0 {
+ return
+ }
+
+ gso := &stack.GSO{}
+ switch e.route.NetProto {
+ case header.IPv4ProtocolNumber:
+ gso.Type = stack.GSOTCPv4
+ gso.L3HdrLen = header.IPv4MinimumSize
+ case header.IPv6ProtocolNumber:
+ gso.Type = stack.GSOTCPv6
+ gso.L3HdrLen = header.IPv6MinimumSize
+ default:
+ panic(fmt.Sprintf("Unknown netProto: %v", e.netProto))
+ }
+ gso.NeedsCsum = true
+ gso.CsumOffset = header.TCPChecksumOffset
+ gso.MaxSize = e.route.GSOMaxSize()
+ e.gso = gso
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
new file mode 100644
index 000000000..e8aed2875
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -0,0 +1,362 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+func (e *endpoint) drainSegmentLocked() {
+ // Drain only up to once.
+ if e.drainDone != nil {
+ return
+ }
+
+ e.drainDone = make(chan struct{})
+ e.undrain = make(chan struct{})
+ e.mu.Unlock()
+
+ e.notifyProtocolGoroutine(notifyDrain)
+ <-e.drainDone
+
+ e.mu.Lock()
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets.
+ e.segmentQueue.setLimit(0)
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch e.state {
+ case stateInitial, stateBound:
+ case stateConnected:
+ if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 {
+ if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 {
+ panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
+ }
+ e.resetConnectionLocked(tcpip.ErrConnectionAborted)
+ e.mu.Unlock()
+ e.Close()
+ e.mu.Lock()
+ }
+ if !e.workerRunning {
+ // The endpoint must be in acceptedChan or has been just
+ // disconnected and closed.
+ break
+ }
+ fallthrough
+ case stateListen, stateConnecting:
+ e.drainSegmentLocked()
+ if e.state != stateClosed && e.state != stateError {
+ if !e.workerRunning {
+ panic("endpoint has no worker running in listen, connecting, or connected state")
+ }
+ break
+ }
+ fallthrough
+ case stateError, stateClosed:
+ for e.state == stateError && e.workerRunning {
+ e.mu.Unlock()
+ time.Sleep(100 * time.Millisecond)
+ e.mu.Lock()
+ }
+ if e.workerRunning {
+ panic("endpoint still has worker running in closed or error state")
+ }
+ default:
+ panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
+ }
+
+ if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() {
+ panic("endpoint still has waiters upon save")
+ }
+
+ if e.state != stateClosed && !((e.state == stateBound || e.state == stateListen) == e.isPortReserved) {
+ panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state")
+ }
+}
+
+// saveAcceptedChan is invoked by stateify.
+func (e *endpoint) saveAcceptedChan() []*endpoint {
+ if e.acceptedChan == nil {
+ return nil
+ }
+ acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan))
+ for i := 0; i < len(acceptedEndpoints); i++ {
+ select {
+ case ep := <-e.acceptedChan:
+ acceptedEndpoints[i] = ep
+ default:
+ panic("endpoint acceptedChan buffer got consumed by background context")
+ }
+ }
+ for i := 0; i < len(acceptedEndpoints); i++ {
+ select {
+ case e.acceptedChan <- acceptedEndpoints[i]:
+ default:
+ panic("endpoint acceptedChan buffer got populated by background context")
+ }
+ }
+ return acceptedEndpoints
+}
+
+// loadAcceptedChan is invoked by stateify.
+func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) {
+ if cap(acceptedEndpoints) > 0 {
+ e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints))
+ for _, ep := range acceptedEndpoints {
+ e.acceptedChan <- ep
+ }
+ }
+}
+
+// saveState is invoked by stateify.
+func (e *endpoint) saveState() endpointState {
+ return e.state
+}
+
+// Endpoint loading must be done in the following ordering by their state, to
+// avoid dangling connecting w/o listening peer, and to avoid conflicts in port
+// reservation.
+var connectedLoading sync.WaitGroup
+var listenLoading sync.WaitGroup
+var connectingLoading sync.WaitGroup
+
+// Bound endpoint loading happens last.
+
+// loadState is invoked by stateify.
+func (e *endpoint) loadState(state endpointState) {
+ // This is to ensure that the loading wait groups include all applicable
+ // endpoints before any asynchronous calls to the Wait() methods.
+ switch state {
+ case stateConnected:
+ connectedLoading.Add(1)
+ case stateListen:
+ listenLoading.Add(1)
+ case stateConnecting:
+ connectingLoading.Add(1)
+ }
+ e.state = state
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ e.stack = stack.StackFromEnv
+ e.segmentQueue.setLimit(MaxUnprocessedSegments)
+ e.workMu.Init()
+
+ state := e.state
+ switch state {
+ case stateInitial, stateBound, stateListen, stateConnecting, stateConnected:
+ var ss SendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
+ panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
+ }
+ if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
+ panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
+ }
+ }
+ }
+
+ bind := func() {
+ e.state = stateInitial
+ if len(e.bindAddress) == 0 {
+ e.bindAddress = e.id.LocalAddress
+ }
+ if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil {
+ panic("endpoint binding failed: " + err.String())
+ }
+ }
+
+ switch state {
+ case stateConnected:
+ bind()
+ if len(e.connectingAddress) == 0 {
+ // This endpoint is accepted by netstack but not yet by
+ // the app. If the endpoint is IPv6 but the remote
+ // address is IPv4, we need to connect as IPv6 so that
+ // dual-stack mode can be properly activated.
+ if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize {
+ e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress
+ } else {
+ e.connectingAddress = e.id.RemoteAddress
+ }
+ }
+ // Reset the scoreboard to reinitialize the sack information as
+ // we do not restore SACK information.
+ e.scoreboard.Reset()
+ if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ connectedLoading.Done()
+ case stateListen:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ bind()
+ backlog := cap(e.acceptedChan)
+ if err := e.Listen(backlog); err != nil {
+ panic("endpoint listening failed: " + err.String())
+ }
+ listenLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case stateConnecting:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ bind()
+ if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted {
+ panic("endpoint connecting failed: " + err.String())
+ }
+ connectingLoading.Done()
+ tcpip.AsyncLoading.Done()
+ }()
+ case stateBound:
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ connectingLoading.Wait()
+ bind()
+ tcpip.AsyncLoading.Done()
+ }()
+ case stateClosed:
+ if e.isPortReserved {
+ tcpip.AsyncLoading.Add(1)
+ go func() {
+ connectedLoading.Wait()
+ listenLoading.Wait()
+ connectingLoading.Wait()
+ bind()
+ e.state = stateClosed
+ tcpip.AsyncLoading.Done()
+ }()
+ }
+ fallthrough
+ case stateError:
+ tcpip.DeleteDanglingEndpoint(e)
+ }
+}
+
+// saveLastError is invoked by stateify.
+func (e *endpoint) saveLastError() string {
+ if e.lastError == nil {
+ return ""
+ }
+
+ return e.lastError.String()
+}
+
+// loadLastError is invoked by stateify.
+func (e *endpoint) loadLastError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.lastError = loadError(s)
+}
+
+// saveHardError is invoked by stateify.
+func (e *endpoint) saveHardError() string {
+ if e.hardError == nil {
+ return ""
+ }
+
+ return e.hardError.String()
+}
+
+// loadHardError is invoked by stateify.
+func (e *endpoint) loadHardError(s string) {
+ if s == "" {
+ return
+ }
+
+ e.hardError = loadError(s)
+}
+
+var messageToError map[string]*tcpip.Error
+
+var populate sync.Once
+
+func loadError(s string) *tcpip.Error {
+ populate.Do(func() {
+ var errors = []*tcpip.Error{
+ tcpip.ErrUnknownProtocol,
+ tcpip.ErrUnknownNICID,
+ tcpip.ErrUnknownDevice,
+ tcpip.ErrUnknownProtocolOption,
+ tcpip.ErrDuplicateNICID,
+ tcpip.ErrDuplicateAddress,
+ tcpip.ErrNoRoute,
+ tcpip.ErrBadLinkEndpoint,
+ tcpip.ErrAlreadyBound,
+ tcpip.ErrInvalidEndpointState,
+ tcpip.ErrAlreadyConnecting,
+ tcpip.ErrAlreadyConnected,
+ tcpip.ErrNoPortAvailable,
+ tcpip.ErrPortInUse,
+ tcpip.ErrBadLocalAddress,
+ tcpip.ErrClosedForSend,
+ tcpip.ErrClosedForReceive,
+ tcpip.ErrWouldBlock,
+ tcpip.ErrConnectionRefused,
+ tcpip.ErrTimeout,
+ tcpip.ErrAborted,
+ tcpip.ErrConnectStarted,
+ tcpip.ErrDestinationRequired,
+ tcpip.ErrNotSupported,
+ tcpip.ErrQueueSizeNotSupported,
+ tcpip.ErrNotConnected,
+ tcpip.ErrConnectionReset,
+ tcpip.ErrConnectionAborted,
+ tcpip.ErrNoSuchFile,
+ tcpip.ErrInvalidOptionValue,
+ tcpip.ErrNoLinkAddress,
+ tcpip.ErrBadAddress,
+ tcpip.ErrNetworkUnreachable,
+ tcpip.ErrMessageTooLong,
+ tcpip.ErrNoBufferSpace,
+ tcpip.ErrBroadcastDisabled,
+ tcpip.ErrNotPermitted,
+ }
+
+ messageToError = make(map[string]*tcpip.Error)
+ for _, e := range errors {
+ if messageToError[e.String()] != nil {
+ panic("tcpip errors with duplicated message: " + e.String())
+ }
+ messageToError[e.String()] = e
+ }
+ })
+
+ e, ok := messageToError[s]
+ if !ok {
+ panic("unknown error message: " + s)
+ }
+
+ return e
+}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
new file mode 100644
index 000000000..c30b45c2c
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -0,0 +1,171 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Forwarder is a connection request forwarder, which allows clients to decide
+// what to do with a connection request, for example: ignore it, send a RST, or
+// attempt to complete the 3-way handshake.
+//
+// The canonical way of using it is to pass the Forwarder.HandlePacket function
+// to stack.SetTransportProtocolHandler.
+type Forwarder struct {
+ maxInFlight int
+ handler func(*ForwarderRequest)
+
+ mu sync.Mutex
+ inFlight map[stack.TransportEndpointID]struct{}
+ listen *listenContext
+}
+
+// NewForwarder allocates and initializes a new forwarder with the given
+// maximum number of in-flight connection attempts. Once the maximum is reached
+// new incoming connection requests will be ignored.
+//
+// If rcvWnd is set to zero, the default buffer size is used instead.
+func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder {
+ if rcvWnd == 0 {
+ rcvWnd = DefaultBufferSize
+ }
+ return &Forwarder{
+ maxInFlight: maxInFlight,
+ handler: handler,
+ inFlight: make(map[stack.TransportEndpointID]struct{}),
+ listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0),
+ }
+}
+
+// HandlePacket handles a packet if it is of interest to the forwarder (i.e., if
+// it's a SYN packet), returning true if it's the case. Otherwise the packet
+// is not handled and false is returned.
+//
+// This function is expected to be passed as an argument to the
+// stack.SetTransportProtocolHandler function.
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ s := newSegment(r, id, vv)
+ defer s.decRef()
+
+ // We only care about well-formed SYN packets.
+ if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn {
+ return false
+ }
+
+ opts := parseSynSegmentOptions(s)
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // We have an inflight request for this id, ignore this one for now.
+ if _, ok := f.inFlight[id]; ok {
+ return true
+ }
+
+ // Ignore the segment if we're beyond the limit.
+ if len(f.inFlight) >= f.maxInFlight {
+ return true
+ }
+
+ // Launch a new goroutine to handle the request.
+ f.inFlight[id] = struct{}{}
+ s.incRef()
+ go f.handler(&ForwarderRequest{ // S/R-SAFE: not used by Sentry.
+ forwarder: f,
+ segment: s,
+ synOptions: opts,
+ })
+
+ return true
+}
+
+// ForwarderRequest represents a connection request received by the forwarder
+// and passed to the client. Clients must eventually call Complete() on it, and
+// may optionally create an endpoint to represent it via CreateEndpoint.
+type ForwarderRequest struct {
+ mu sync.Mutex
+ forwarder *Forwarder
+ segment *segment
+ synOptions header.TCPSynOptions
+}
+
+// ID returns the 4-tuple (src address, src port, dst address, dst port) that
+// represents the connection request.
+func (r *ForwarderRequest) ID() stack.TransportEndpointID {
+ return r.segment.id
+}
+
+// Complete completes the request, and optionally sends a RST segment back to the
+// sender.
+func (r *ForwarderRequest) Complete(sendReset bool) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.segment == nil {
+ panic("Completing already completed forwarder request")
+ }
+
+ // Remove request from the forwarder.
+ r.forwarder.mu.Lock()
+ delete(r.forwarder.inFlight, r.segment.id)
+ r.forwarder.mu.Unlock()
+
+ // If the caller requested, send a reset.
+ if sendReset {
+ replyWithReset(r.segment)
+ }
+
+ // Release all resources.
+ r.segment.decRef()
+ r.segment = nil
+ r.forwarder = nil
+}
+
+// CreateEndpoint creates a TCP endpoint for the connection request, performing
+// the 3-way handshake in the process.
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.segment == nil {
+ return nil, tcpip.ErrInvalidEndpointState
+ }
+
+ f := r.forwarder
+ ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{
+ MSS: r.synOptions.MSS,
+ WS: r.synOptions.WS,
+ TS: r.synOptions.TS,
+ TSVal: r.synOptions.TSVal,
+ TSEcr: r.synOptions.TSEcr,
+ SACKPermitted: r.synOptions.SACKPermitted,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Start the protocol goroutine.
+ ep.startAcceptedLoop(queue)
+
+ return ep, nil
+}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
new file mode 100644
index 000000000..b31bcccfa
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -0,0 +1,250 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package tcp contains the implementation of the TCP transport protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the
+// transport protocols when calling stack.New(). Then endpoints can be created
+// by passing tcp.ProtocolNumber as the transport protocol number when calling
+// Stack.NewEndpoint().
+package tcp
+
+import (
+ "strings"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolName is the string representation of the tcp protocol name.
+ ProtocolName = "tcp"
+
+ // ProtocolNumber is the tcp protocol number.
+ ProtocolNumber = header.TCPProtocolNumber
+
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ minBufferSize = 4 << 10 // 4096 bytes.
+
+ // DefaultBufferSize is the default size of the receive and send buffers.
+ DefaultBufferSize = 1 << 20 // 1MB
+
+ // MaxBufferSize is the largest size a receive and send buffer can grow to.
+ maxBufferSize = 4 << 20 // 4MB
+
+ // MaxUnprocessedSegments is the maximum number of unprocessed segments
+ // that can be queued for a given endpoint.
+ MaxUnprocessedSegments = 300
+)
+
+// SACKEnabled option can be used to enable SACK support in the TCP
+// protocol. See: https://tools.ietf.org/html/rfc2018.
+type SACKEnabled bool
+
+// SendBufferSizeOption allows the default, min and max send buffer sizes for
+// TCP endpoints to be queried or configured.
+type SendBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// ReceiveBufferSizeOption allows the default, min and max receive buffer size
+// for TCP endpoints to be queried or configured.
+type ReceiveBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+const (
+ ccReno = "reno"
+ ccCubic = "cubic"
+)
+
+// CongestionControlOption sets the current congestion control algorithm.
+type CongestionControlOption string
+
+// AvailableCongestionControlOption returns the supported congestion control
+// algorithms.
+type AvailableCongestionControlOption string
+
+type protocol struct {
+ mu sync.Mutex
+ sackEnabled bool
+ sendBufferSize SendBufferSizeOption
+ recvBufferSize ReceiveBufferSizeOption
+ congestionControl string
+ availableCongestionControl []string
+ allowedCongestionControl []string
+}
+
+// Number returns the tcp protocol number.
+func (*protocol) Number() tcpip.TransportProtocolNumber {
+ return ProtocolNumber
+}
+
+// NewEndpoint creates a new tcp endpoint.
+func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, waiterQueue), nil
+}
+
+// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently
+// unsupported. It implements stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue)
+}
+
+// MinimumPacketSize returns the minimum valid tcp packet size.
+func (*protocol) MinimumPacketSize() int {
+ return header.TCPMinimumSize
+}
+
+// ParsePorts returns the source and destination ports stored in the given tcp
+// packet.
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ h := header.TCP(v)
+ return h.SourcePort(), h.DestinationPort(), nil
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+//
+// RFC 793, page 36, states that "If the connection does not exist (CLOSED) then
+// a reset is sent in response to any incoming segment except another reset. In
+// particular, SYNs addressed to a non-existent connection are rejected by this
+// means."
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool {
+ s := newSegment(r, id, vv)
+ defer s.decRef()
+
+ if !s.parse() || !s.csumValid {
+ return false
+ }
+
+ // There's nothing to do if this is already a reset packet.
+ if s.flagIsSet(header.TCPFlagRst) {
+ return true
+ }
+
+ replyWithReset(s)
+ return true
+}
+
+// replyWithReset replies to the given segment with a reset segment.
+func replyWithReset(s *segment) {
+ // Get the seqnum from the packet if the ack flag is set.
+ seq := seqnum.Value(0)
+ if s.flagIsSet(header.TCPFlagAck) {
+ seq = s.ackNumber
+ }
+
+ ack := s.sequenceNumber.Add(s.logicalLen())
+
+ sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */)
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case SACKEnabled:
+ p.mu.Lock()
+ p.sackEnabled = bool(v)
+ p.mu.Unlock()
+ return nil
+
+ case SendBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.sendBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ case ReceiveBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.recvBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ case CongestionControlOption:
+ for _, c := range p.availableCongestionControl {
+ if string(v) == c {
+ p.mu.Lock()
+ p.congestionControl = string(v)
+ p.mu.Unlock()
+ return nil
+ }
+ }
+ return tcpip.ErrInvalidOptionValue
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *SACKEnabled:
+ p.mu.Lock()
+ *v = SACKEnabled(p.sackEnabled)
+ p.mu.Unlock()
+ return nil
+
+ case *SendBufferSizeOption:
+ p.mu.Lock()
+ *v = p.sendBufferSize
+ p.mu.Unlock()
+ return nil
+
+ case *ReceiveBufferSizeOption:
+ p.mu.Lock()
+ *v = p.recvBufferSize
+ p.mu.Unlock()
+ return nil
+ case *CongestionControlOption:
+ p.mu.Lock()
+ *v = CongestionControlOption(p.congestionControl)
+ p.mu.Unlock()
+ return nil
+ case *AvailableCongestionControlOption:
+ p.mu.Lock()
+ *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " "))
+ p.mu.Unlock()
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
+ return &protocol{
+ sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ congestionControl: ccReno,
+ availableCongestionControl: []string{ccReno, ccCubic},
+ }
+ })
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
new file mode 100644
index 000000000..b08a0e356
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -0,0 +1,221 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "container/heap"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+// receiver holds the state necessary to receive TCP segments and turn them
+// into a stream of bytes.
+//
+// +stateify savable
+type receiver struct {
+ ep *endpoint
+
+ rcvNxt seqnum.Value
+
+ // rcvAcc is one beyond the last acceptable sequence number. That is,
+ // the "largest" sequence value that the receiver has announced to the
+ // its peer that it's willing to accept. This may be different than
+ // rcvNxt + rcvWnd if the receive window is reduced; in that case we
+ // have to reduce the window as we receive more data instead of
+ // shrinking it.
+ rcvAcc seqnum.Value
+
+ rcvWndScale uint8
+
+ closed bool
+
+ pendingRcvdSegments segmentHeap
+ pendingBufUsed seqnum.Size
+ pendingBufSize seqnum.Size
+}
+
+func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
+ return &receiver{
+ ep: ep,
+ rcvNxt: irs + 1,
+ rcvAcc: irs.Add(rcvWnd + 1),
+ rcvWndScale: rcvWndScale,
+ pendingBufSize: rcvWnd,
+ }
+}
+
+// acceptable checks if the segment sequence number range is acceptable
+// according to the table on page 26 of RFC 793.
+func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
+ rcvWnd := r.rcvNxt.Size(r.rcvAcc)
+ if rcvWnd == 0 {
+ return segLen == 0 && segSeq == r.rcvNxt
+ }
+
+ return segSeq.InWindow(r.rcvNxt, rcvWnd) ||
+ seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen)
+}
+
+// getSendParams returns the parameters needed by the sender when building
+// segments to send.
+func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
+ // Calculate the window size based on the current buffer size.
+ n := r.ep.receiveBufferAvailable()
+ acc := r.rcvNxt.Add(seqnum.Size(n))
+ if r.rcvAcc.LessThan(acc) {
+ r.rcvAcc = acc
+ }
+
+ return r.rcvNxt, r.rcvNxt.Size(r.rcvAcc) >> r.rcvWndScale
+}
+
+// nonZeroWindow is called when the receive window grows from zero to nonzero;
+// in such cases we may need to send an ack to indicate to our peer that it can
+// resume sending data.
+func (r *receiver) nonZeroWindow() {
+ if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 {
+ // We never got around to announcing a zero window size, so we
+ // don't need to immediately announce a nonzero one.
+ return
+ }
+
+ // Immediately send an ack.
+ r.ep.snd.sendAck()
+}
+
+// consumeSegment attempts to consume a segment that was received by r. The
+// segment may have just been received or may have been received earlier but
+// wasn't ready to be consumed then.
+//
+// Returns true if the segment was consumed, false if it cannot be consumed
+// yet because of a missing segment.
+func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum.Size) bool {
+ if segLen > 0 {
+ // If the segment doesn't include the seqnum we're expecting to
+ // consume now, we're missing a segment. We cannot proceed until
+ // we receive that segment though.
+ if !r.rcvNxt.InWindow(segSeq, segLen) {
+ return false
+ }
+
+ // Trim segment to eliminate already acknowledged data.
+ if segSeq.LessThan(r.rcvNxt) {
+ diff := segSeq.Size(r.rcvNxt)
+ segLen -= diff
+ segSeq.UpdateForward(diff)
+ s.sequenceNumber.UpdateForward(diff)
+ s.data.TrimFront(int(diff))
+ }
+
+ // Move segment to ready-to-deliver list. Wakeup any waiters.
+ r.ep.readyToRead(s)
+
+ } else if segSeq != r.rcvNxt {
+ return false
+ }
+
+ // Update the segment that we're expecting to consume.
+ r.rcvNxt = segSeq.Add(segLen)
+
+ // Trim SACK Blocks to remove any SACK information that covers
+ // sequence numbers that have been consumed.
+ TrimSACKBlockList(&r.ep.sack, r.rcvNxt)
+
+ if s.flagIsSet(header.TCPFlagFin) {
+ r.rcvNxt++
+
+ // Send ACK immediately.
+ r.ep.snd.sendAck()
+
+ // Tell any readers that no more data will come.
+ r.closed = true
+ r.ep.readyToRead(nil)
+
+ // Flush out any pending segments, except the very first one if
+ // it happens to be the one we're handling now because the
+ // caller is using it.
+ first := 0
+ if len(r.pendingRcvdSegments) != 0 && r.pendingRcvdSegments[0] == s {
+ first = 1
+ }
+
+ for i := first; i < len(r.pendingRcvdSegments); i++ {
+ r.pendingRcvdSegments[i].decRef()
+ }
+ r.pendingRcvdSegments = r.pendingRcvdSegments[:first]
+ }
+
+ return true
+}
+
+// handleRcvdSegment handles TCP segments directed at the connection managed by
+// r as they arrive. It is called by the protocol main loop.
+func (r *receiver) handleRcvdSegment(s *segment) {
+ // We don't care about receive processing anymore if the receive side
+ // is closed.
+ if r.closed {
+ return
+ }
+
+ segLen := seqnum.Size(s.data.Size())
+ segSeq := s.sequenceNumber
+
+ // If the sequence number range is outside the acceptable range, just
+ // send an ACK. This is according to RFC 793, page 37.
+ if !r.acceptable(segSeq, segLen) {
+ r.ep.snd.sendAck()
+ return
+ }
+
+ // Defer segment processing if it can't be consumed now.
+ if !r.consumeSegment(s, segSeq, segLen) {
+ if segLen > 0 || s.flagIsSet(header.TCPFlagFin) {
+ // We only store the segment if it's within our buffer
+ // size limit.
+ if r.pendingBufUsed < r.pendingBufSize {
+ r.pendingBufUsed += s.logicalLen()
+ s.incRef()
+ heap.Push(&r.pendingRcvdSegments, s)
+ }
+
+ UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
+
+ // Immediately send an ack so that the peer knows it may
+ // have to retransmit.
+ r.ep.snd.sendAck()
+ }
+ return
+ }
+
+ // By consuming the current segment, we may have filled a gap in the
+ // sequence number domain that allows pending segments to be consumed
+ // now. So try to do it.
+ for !r.closed && r.pendingRcvdSegments.Len() > 0 {
+ s := r.pendingRcvdSegments[0]
+ segLen := seqnum.Size(s.data.Size())
+ segSeq := s.sequenceNumber
+
+ // Skip segment altogether if it has already been acknowledged.
+ if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) &&
+ !r.consumeSegment(s, segSeq, segLen) {
+ break
+ }
+
+ heap.Pop(&r.pendingRcvdSegments)
+ r.pendingBufUsed -= s.logicalLen()
+ s.decRef()
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go
new file mode 100644
index 000000000..f83ebc717
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/reno.go
@@ -0,0 +1,103 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+// renoState stores the variables related to TCP New Reno congestion
+// control algorithm.
+//
+// +stateify savable
+type renoState struct {
+ s *sender
+}
+
+// newRenoCC initializes the state for the NewReno congestion control algorithm.
+func newRenoCC(s *sender) *renoState {
+ return &renoState{s: s}
+}
+
+// updateSlowStart will update the congestion window as per the slow-start
+// algorithm used by NewReno. If after adjusting the congestion window
+// we cross the SSthreshold then it will return the number of packets that
+// must be consumed in congestion avoidance mode.
+func (r *renoState) updateSlowStart(packetsAcked int) int {
+ // Don't let the congestion window cross into the congestion
+ // avoidance range.
+ newcwnd := r.s.sndCwnd + packetsAcked
+ if newcwnd >= r.s.sndSsthresh {
+ newcwnd = r.s.sndSsthresh
+ r.s.sndCAAckCount = 0
+ }
+
+ packetsAcked -= newcwnd - r.s.sndCwnd
+ r.s.sndCwnd = newcwnd
+ return packetsAcked
+}
+
+// updateCongestionAvoidance will update congestion window in congestion
+// avoidance mode as described in RFC5681 section 3.1
+func (r *renoState) updateCongestionAvoidance(packetsAcked int) {
+ // Consume the packets in congestion avoidance mode.
+ r.s.sndCAAckCount += packetsAcked
+ if r.s.sndCAAckCount >= r.s.sndCwnd {
+ r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd
+ r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd
+ }
+}
+
+// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681,
+// page 6, eq. 4. It is called when we detect congestion in the network.
+func (r *renoState) reduceSlowStartThreshold() {
+ r.s.sndSsthresh = r.s.outstanding / 2
+ if r.s.sndSsthresh < 2 {
+ r.s.sndSsthresh = 2
+ }
+
+}
+
+// Update updates the congestion state based on the number of packets that
+// were acknowledged.
+// Update implements congestionControl.Update.
+func (r *renoState) Update(packetsAcked int) {
+ if r.s.sndCwnd < r.s.sndSsthresh {
+ packetsAcked = r.updateSlowStart(packetsAcked)
+ if packetsAcked == 0 {
+ return
+ }
+ }
+ r.updateCongestionAvoidance(packetsAcked)
+}
+
+// HandleNDupAcks implements congestionControl.HandleNDupAcks.
+func (r *renoState) HandleNDupAcks() {
+ // A retransmit was triggered due to nDupAckThreshold
+ // being hit. Reduce our slow start threshold.
+ r.reduceSlowStartThreshold()
+}
+
+// HandleRTOExpired implements congestionControl.HandleRTOExpired.
+func (r *renoState) HandleRTOExpired() {
+ // We lost a packet, so reduce ssthresh.
+ r.reduceSlowStartThreshold()
+
+ // Reduce the congestion window to 1, i.e., enter slow-start. Per
+ // RFC 5681, page 7, we must use 1 regardless of the value of the
+ // initial congestion window.
+ r.s.sndCwnd = 1
+}
+
+// PostRecovery implements congestionControl.PostRecovery.
+func (r *renoState) PostRecovery() {
+ // noop.
+}
diff --git a/pkg/tcpip/transport/tcp/sack.go b/pkg/tcpip/transport/tcp/sack.go
new file mode 100644
index 000000000..6a013d99b
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/sack.go
@@ -0,0 +1,99 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // MaxSACKBlocks is the maximum number of SACK blocks stored
+ // at receiver side.
+ MaxSACKBlocks = 6
+)
+
+// UpdateSACKBlocks updates the list of SACK blocks to include the segment
+// specified by segStart->segEnd. If the segment happens to be an out of order
+// delivery then the first block in the sack.blocks always includes the
+// segment identified by segStart->segEnd.
+func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) {
+ newSB := header.SACKBlock{Start: segStart, End: segEnd}
+ if sack.NumBlocks == 0 {
+ sack.Blocks[0] = newSB
+ sack.NumBlocks = 1
+ return
+ }
+ var n = 0
+ for i := 0; i < sack.NumBlocks; i++ {
+ start, end := sack.Blocks[i].Start, sack.Blocks[i].End
+ if end.LessThanEq(start) || start.LessThanEq(rcvNxt) {
+ // Discard any invalid blocks where end is before start
+ // and discard any sack blocks that are before rcvNxt as
+ // those have already been acked.
+ continue
+ }
+ if newSB.Start.LessThanEq(end) && start.LessThanEq(newSB.End) {
+ // Merge this SACK block into newSB and discard this SACK
+ // block.
+ if start.LessThan(newSB.Start) {
+ newSB.Start = start
+ }
+ if newSB.End.LessThan(end) {
+ newSB.End = end
+ }
+ } else {
+ // Save this block.
+ sack.Blocks[n] = sack.Blocks[i]
+ n++
+ }
+ }
+ if rcvNxt.LessThan(newSB.Start) {
+ // If this was an out of order segment then make sure that the
+ // first SACK block is the one that includes the segment.
+ //
+ // See the first bullet point in
+ // https://tools.ietf.org/html/rfc2018#section-4
+ if n == MaxSACKBlocks {
+ // If the number of SACK blocks is equal to
+ // MaxSACKBlocks then discard the last SACK block.
+ n--
+ }
+ for i := n - 1; i >= 0; i-- {
+ sack.Blocks[i+1] = sack.Blocks[i]
+ }
+ sack.Blocks[0] = newSB
+ n++
+ }
+ sack.NumBlocks = n
+}
+
+// TrimSACKBlockList updates the sack block list by removing/modifying any block
+// where start is < rcvNxt.
+func TrimSACKBlockList(sack *SACKInfo, rcvNxt seqnum.Value) {
+ n := 0
+ for i := 0; i < sack.NumBlocks; i++ {
+ if sack.Blocks[i].End.LessThanEq(rcvNxt) {
+ continue
+ }
+ if sack.Blocks[i].Start.LessThan(rcvNxt) {
+ // Shrink this SACK block.
+ sack.Blocks[i].Start = rcvNxt
+ }
+ sack.Blocks[n] = sack.Blocks[i]
+ n++
+ }
+ sack.NumBlocks = n
+}
diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go
new file mode 100644
index 000000000..1c5766a42
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go
@@ -0,0 +1,306 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "fmt"
+ "strings"
+
+ "github.com/google/btree"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // maxSACKBlocks is the maximum number of distinct SACKBlocks the
+ // scoreboard will track. Once there are 100 distinct blocks, new
+ // insertions will fail.
+ maxSACKBlocks = 100
+
+ // defaultBtreeDegree is set to 2 as btree.New(2) results in a 2-3-4
+ // tree.
+ defaultBtreeDegree = 2
+)
+
+// SACKScoreboard stores a set of disjoint SACK ranges.
+//
+// +stateify savable
+type SACKScoreboard struct {
+ // smss is defined in RFC5681 as following:
+ //
+ // The SMSS is the size of the largest segment that the sender can
+ // transmit. This value can be based on the maximum transmission unit
+ // of the network, the path MTU discovery [RFC1191, RFC4821] algorithm,
+ // RMSS (see next item), or other factors. The size does not include
+ // the TCP/IP headers and options.
+ smss uint16
+ maxSACKED seqnum.Value
+ sacked seqnum.Size `state:"nosave"`
+ ranges *btree.BTree `state:"nosave"`
+}
+
+// NewSACKScoreboard returns a new SACK Scoreboard.
+func NewSACKScoreboard(smss uint16, iss seqnum.Value) *SACKScoreboard {
+ return &SACKScoreboard{
+ smss: smss,
+ ranges: btree.New(defaultBtreeDegree),
+ maxSACKED: iss,
+ }
+}
+
+// Reset erases all known range information from the SACK scoreboard.
+func (s *SACKScoreboard) Reset() {
+ s.ranges = btree.New(defaultBtreeDegree)
+ s.sacked = 0
+}
+
+// Insert inserts/merges the provided SACKBlock into the scoreboard.
+func (s *SACKScoreboard) Insert(r header.SACKBlock) {
+ if s.ranges.Len() >= maxSACKBlocks {
+ return
+ }
+
+ // Check if we can merge the new range with a range before or after it.
+ var toDelete []btree.Item
+ if s.maxSACKED.LessThan(r.End - 1) {
+ s.maxSACKED = r.End - 1
+ }
+ s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool {
+ if i == r {
+ return true
+ }
+ sacked := i.(header.SACKBlock)
+ // There is a hole between these two SACK blocks, so we can't
+ // merge anymore.
+ if r.End.LessThan(sacked.Start) {
+ return false
+ }
+ // There is some overlap at this point, merge the blocks and
+ // delete the other one.
+ //
+ // ----sS--------sE
+ // r.S---------------rE
+ // -------sE
+ if sacked.End.LessThan(r.End) {
+ // sacked is contained in the newly inserted range.
+ // Delete this block.
+ toDelete = append(toDelete, i)
+ return true
+ }
+ // sacked covers a range past end of the newly inserted
+ // block.
+ r.End = sacked.End
+ toDelete = append(toDelete, i)
+ return true
+ })
+
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ if i == r {
+ return true
+ }
+ sacked := i.(header.SACKBlock)
+ // sA------sE
+ // rA----rE
+ if sacked.End.LessThan(r.Start) {
+ return false
+ }
+ // The previous range extends into the current block. Merge it
+ // into the newly inserted range and delete the other one.
+ //
+ // <-rA---rE----<---rE--->
+ // sA--------------sE
+ r.Start = sacked.Start
+ // Extend r to cover sacked if sacked extends past r.
+ if r.End.LessThan(sacked.End) {
+ r.End = sacked.End
+ }
+ toDelete = append(toDelete, i)
+ return true
+ })
+ for _, i := range toDelete {
+ if sb := s.ranges.Delete(i); sb != nil {
+ sb := i.(header.SACKBlock)
+ s.sacked -= sb.Start.Size(sb.End)
+ }
+ }
+
+ replaced := s.ranges.ReplaceOrInsert(r)
+ if replaced == nil {
+ s.sacked += r.Start.Size(r.End)
+ }
+}
+
+// IsSACKED returns true if the a given range of sequence numbers denoted by r
+// are already covered by SACK information in the scoreboard.
+func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool {
+ if s.Empty() {
+ return false
+ }
+
+ found := false
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ sacked := i.(header.SACKBlock)
+ if sacked.End.LessThan(r.Start) {
+ return false
+ }
+ if sacked.Contains(r) {
+ found = true
+ return false
+ }
+ return true
+ })
+ return found
+}
+
+// Dump prints the state of the scoreboard structure.
+func (s *SACKScoreboard) String() string {
+ var str strings.Builder
+ str.WriteString("SACKScoreboard: {")
+ s.ranges.Ascend(func(i btree.Item) bool {
+ str.WriteString(fmt.Sprintf("%v,", i))
+ return true
+ })
+ str.WriteString("}\n")
+ return str.String()
+}
+
+// Delete removes all SACK information prior to seq.
+func (s *SACKScoreboard) Delete(seq seqnum.Value) {
+ if s.Empty() {
+ return
+ }
+ toDelete := []btree.Item{}
+ toInsert := []btree.Item{}
+ r := header.SACKBlock{seq, seq.Add(1)}
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ if i == r {
+ return true
+ }
+ sb := i.(header.SACKBlock)
+ toDelete = append(toDelete, i)
+ if sb.End.LessThanEq(seq) {
+ s.sacked -= sb.Start.Size(sb.End)
+ } else {
+ newSB := header.SACKBlock{seq, sb.End}
+ toInsert = append(toInsert, newSB)
+ s.sacked -= sb.Start.Size(seq)
+ }
+ return true
+ })
+ for _, sb := range toDelete {
+ s.ranges.Delete(sb)
+ }
+ for _, sb := range toInsert {
+ s.ranges.ReplaceOrInsert(sb)
+ }
+}
+
+// Copy provides a copy of the SACK scoreboard.
+func (s *SACKScoreboard) Copy() (sackBlocks []header.SACKBlock, maxSACKED seqnum.Value) {
+ s.ranges.Ascend(func(i btree.Item) bool {
+ sackBlocks = append(sackBlocks, i.(header.SACKBlock))
+ return true
+ })
+ return sackBlocks, s.maxSACKED
+}
+
+// IsRangeLost implements the IsLost(SeqNum) operation defined in RFC 6675
+// section 4 but operates on a range of sequence numbers and returns true if
+// there are at least nDupAckThreshold SACK blocks greater than the range being
+// checked or if at least (nDupAckThreshold-1)*s.smss bytes have been SACKED
+// with sequence numbers greater than the block being checked.
+func (s *SACKScoreboard) IsRangeLost(r header.SACKBlock) bool {
+ if s.Empty() {
+ return false
+ }
+ nDupSACK := 0
+ nDupSACKBytes := seqnum.Size(0)
+ isLost := false
+
+ // We need to check if the immediate lower (if any) sacked
+ // range contains or partially overlaps with r.
+ searchMore := true
+ s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool {
+ sacked := i.(header.SACKBlock)
+ if sacked.Contains(r) {
+ searchMore = false
+ return false
+ }
+ if sacked.End.LessThanEq(r.Start) {
+ // all sequence numbers covered by sacked are below
+ // r so we continue searching.
+ return false
+ }
+ // There is a partial overlap. In this case we r.Start is
+ // between sacked.Start & sacked.End and r.End extends beyond
+ // sacked.End.
+ // Move r.Start to sacked.End and continuing searching blocks
+ // above r.Start.
+ r.Start = sacked.End
+ return false
+ })
+
+ if !searchMore {
+ return isLost
+ }
+
+ s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool {
+ sacked := i.(header.SACKBlock)
+ if sacked.Contains(r) {
+ return false
+ }
+ nDupSACKBytes += sacked.Start.Size(sacked.End)
+ nDupSACK++
+ if nDupSACK >= nDupAckThreshold || nDupSACKBytes >= seqnum.Size((nDupAckThreshold-1)*s.smss) {
+ isLost = true
+ return false
+ }
+ return true
+ })
+ return isLost
+}
+
+// IsLost implements the IsLost(SeqNum) operation defined in RFC3517 section
+// 4.
+//
+// This routine returns whether the given sequence number is considered to be
+// lost. The routine returns true when either nDupAckThreshold discontiguous
+// SACKed sequences have arrived above 'SeqNum' or (nDupAckThreshold * SMSS)
+// bytes with sequence numbers greater than 'SeqNum' have been SACKed.
+// Otherwise, the routine returns false.
+func (s *SACKScoreboard) IsLost(seq seqnum.Value) bool {
+ return s.IsRangeLost(header.SACKBlock{seq, seq.Add(1)})
+}
+
+// Empty returns true if the SACK scoreboard has no entries, false otherwise.
+func (s *SACKScoreboard) Empty() bool {
+ return s.ranges.Len() == 0
+}
+
+// Sacked returns the current number of bytes held in the SACK scoreboard.
+func (s *SACKScoreboard) Sacked() seqnum.Size {
+ return s.sacked
+}
+
+// MaxSACKED returns the highest sequence number ever inserted in the SACK
+// scoreboard.
+func (s *SACKScoreboard) MaxSACKED() seqnum.Value {
+ return s.maxSACKED
+}
+
+// SMSS returns the sender's MSS as held by the SACK scoreboard.
+func (s *SACKScoreboard) SMSS() uint16 {
+ return s.smss
+}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
new file mode 100644
index 000000000..450d9fbc1
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -0,0 +1,186 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// segment represents a TCP segment. It holds the payload and parsed TCP segment
+// information, and can be added to intrusive lists.
+// segment is mostly immutable, the only field allowed to change is viewToDeliver.
+//
+// +stateify savable
+type segment struct {
+ segmentEntry
+ refCnt int32
+ id stack.TransportEndpointID `state:"manual"`
+ route stack.Route `state:"manual"`
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View `state:"nosave"`
+ // viewToDeliver keeps track of the next View that should be
+ // delivered by the Read endpoint.
+ viewToDeliver int
+ sequenceNumber seqnum.Value
+ ackNumber seqnum.Value
+ flags uint8
+ window seqnum.Size
+ // csum is only populated for received segments.
+ csum uint16
+ // csumValid is true if the csum in the received segment is valid.
+ csumValid bool
+
+ // parsedOptions stores the parsed values from the options in the segment.
+ parsedOptions header.TCPOptions
+ options []byte `state:".([]byte)"`
+ hasNewSACKInfo bool
+ rcvdTime time.Time `state:".(unixTime)"`
+ // xmitTime is the last transmit time of this segment. A zero value
+ // indicates that the segment has yet to be transmitted.
+ xmitTime time.Time `state:".(unixTime)"`
+}
+
+func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment {
+ s := &segment{
+ refCnt: 1,
+ id: id,
+ route: r.Clone(),
+ }
+ s.data = vv.Clone(s.views[:])
+ s.rcvdTime = time.Now()
+ return s
+}
+
+func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment {
+ s := &segment{
+ refCnt: 1,
+ id: id,
+ route: r.Clone(),
+ }
+ s.views[0] = v
+ s.data = buffer.NewVectorisedView(len(v), s.views[:1])
+ s.rcvdTime = time.Now()
+ return s
+}
+
+func (s *segment) clone() *segment {
+ t := &segment{
+ refCnt: 1,
+ id: s.id,
+ sequenceNumber: s.sequenceNumber,
+ ackNumber: s.ackNumber,
+ flags: s.flags,
+ window: s.window,
+ route: s.route.Clone(),
+ viewToDeliver: s.viewToDeliver,
+ rcvdTime: s.rcvdTime,
+ }
+ t.data = s.data.Clone(t.views[:])
+ return t
+}
+
+func (s *segment) flagIsSet(flag uint8) bool {
+ return (s.flags & flag) != 0
+}
+
+func (s *segment) decRef() {
+ if atomic.AddInt32(&s.refCnt, -1) == 0 {
+ s.route.Release()
+ }
+}
+
+func (s *segment) incRef() {
+ atomic.AddInt32(&s.refCnt, 1)
+}
+
+// logicalLen is the segment length in the sequence number space. It's defined
+// as the data length plus one for each of the SYN and FIN bits set.
+func (s *segment) logicalLen() seqnum.Size {
+ l := seqnum.Size(s.data.Size())
+ if s.flagIsSet(header.TCPFlagSyn) {
+ l++
+ }
+ if s.flagIsSet(header.TCPFlagFin) {
+ l++
+ }
+ return l
+}
+
+// parse populates the sequence & ack numbers, flags, and window fields of the
+// segment from the TCP header stored in the data. It then updates the view to
+// skip the header.
+//
+// Returns boolean indicating if the parsing was successful.
+//
+// If checksum verification is not offloaded then parse also verifies the
+// TCP checksum and stores the checksum and result of checksum verification in
+// the csum and csumValid fields of the segment.
+func (s *segment) parse() bool {
+ h := header.TCP(s.data.First())
+
+ // h is the header followed by the payload. We check that the offset to
+ // the data respects the following constraints:
+ // 1. That it's at least the minimum header size; if we don't do this
+ // then part of the header would be delivered to user.
+ // 2. That the header fits within the buffer; if we don't do this, we
+ // would panic when we tried to access data beyond the buffer.
+ //
+ // N.B. The segment has already been validated as having at least the
+ // minimum TCP size before reaching here, so it's safe to read the
+ // fields.
+ offset := int(h.DataOffset())
+ if offset < header.TCPMinimumSize || offset > len(h) {
+ return false
+ }
+
+ s.options = []byte(h[header.TCPMinimumSize:offset])
+ s.parsedOptions = header.ParseTCPOptions(s.options)
+
+ // Query the link capabilities to decide if checksum validation is
+ // required.
+ verifyChecksum := true
+ if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 {
+ s.csumValid = true
+ verifyChecksum = false
+ s.data.TrimFront(offset)
+ }
+ if verifyChecksum {
+ s.csum = h.Checksum()
+ xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()))
+ xsum = h.CalculateChecksum(xsum)
+ s.data.TrimFront(offset)
+ xsum = header.ChecksumVV(s.data, xsum)
+ s.csumValid = xsum == 0xffff
+ }
+
+ s.sequenceNumber = seqnum.Value(h.SequenceNumber())
+ s.ackNumber = seqnum.Value(h.AckNumber())
+ s.flags = h.Flags()
+ s.window = seqnum.Size(h.WindowSize())
+ return true
+}
+
+// sackBlock returns a header.SACKBlock that represents this segment.
+func (s *segment) sackBlock() header.SACKBlock {
+ return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())}
+}
diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go
new file mode 100644
index 000000000..9fd061d7d
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_heap.go
@@ -0,0 +1,46 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+type segmentHeap []*segment
+
+// Len returns the length of h.
+func (h segmentHeap) Len() int {
+ return len(h)
+}
+
+// Less determines whether the i-th element of h is less than the j-th element.
+func (h segmentHeap) Less(i, j int) bool {
+ return h[i].sequenceNumber.LessThan(h[j].sequenceNumber)
+}
+
+// Swap swaps the i-th and j-th elements of h.
+func (h segmentHeap) Swap(i, j int) {
+ h[i], h[j] = h[j], h[i]
+}
+
+// Push adds x as the last element of h.
+func (h *segmentHeap) Push(x interface{}) {
+ *h = append(*h, x.(*segment))
+}
+
+// Pop removes the last element of h and returns it.
+func (h *segmentHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[:n-1]
+ return x
+}
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
new file mode 100644
index 000000000..e0759225e
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -0,0 +1,79 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "sync"
+)
+
+// segmentQueue is a bounded, thread-safe queue of TCP segments.
+//
+// +stateify savable
+type segmentQueue struct {
+ mu sync.Mutex `state:"nosave"`
+ list segmentList `state:"wait"`
+ limit int
+ used int
+}
+
+// empty determines if the queue is empty.
+func (q *segmentQueue) empty() bool {
+ q.mu.Lock()
+ r := q.used == 0
+ q.mu.Unlock()
+
+ return r
+}
+
+// setLimit updates the limit. No segments are immediately dropped in case the
+// queue becomes full due to the new limit.
+func (q *segmentQueue) setLimit(limit int) {
+ q.mu.Lock()
+ q.limit = limit
+ q.mu.Unlock()
+}
+
+// enqueue adds the given segment to the queue.
+//
+// Returns true when the segment is successfully added to the queue, in which
+// case ownership of the reference is transferred to the queue. And returns
+// false if the queue is full, in which case ownership is retained by the
+// caller.
+func (q *segmentQueue) enqueue(s *segment) bool {
+ q.mu.Lock()
+ r := q.used < q.limit
+ if r {
+ q.list.PushBack(s)
+ q.used++
+ }
+ q.mu.Unlock()
+
+ return r
+}
+
+// dequeue removes and returns the next segment from queue, if one exists.
+// Ownership is transferred to the caller, who is responsible for decrementing
+// the ref count when done.
+func (q *segmentQueue) dequeue() *segment {
+ q.mu.Lock()
+ s := q.list.Front()
+ if s != nil {
+ q.list.Remove(s)
+ q.used--
+ }
+ q.mu.Unlock()
+
+ return s
+}
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
new file mode 100644
index 000000000..dd7e14aa6
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -0,0 +1,82 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+// saveData is invoked by stateify.
+func (s *segment) saveData() buffer.VectorisedView {
+ // We cannot save s.data directly as s.data.views may alias to s.views,
+ // which is not allowed by state framework (in-struct pointer).
+ v := make([]buffer.View, len(s.data.Views()))
+ // For views already delivered, we cannot save them directly as they may
+ // have already been sliced and saved elsewhere (e.g., readViews).
+ for i := 0; i < s.viewToDeliver; i++ {
+ v[i] = append([]byte(nil), s.data.Views()[i]...)
+ }
+ for i := s.viewToDeliver; i < len(v); i++ {
+ v[i] = s.data.Views()[i]
+ }
+ return buffer.NewVectorisedView(s.data.Size(), v)
+}
+
+// loadData is invoked by stateify.
+func (s *segment) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the s.data = data.Clone(s.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing s.views for data.views.
+ s.data = data
+}
+
+// saveOptions is invoked by stateify.
+func (s *segment) saveOptions() []byte {
+ // We cannot save s.options directly as it may point to s.data's trimmed
+ // tail, which is not allowed by state framework (in-struct pointer).
+ b := make([]byte, 0, cap(s.options))
+ return append(b, s.options...)
+}
+
+// loadOptions is invoked by stateify.
+func (s *segment) loadOptions(options []byte) {
+ // NOTE: We cannot point s.options back into s.data's trimmed tail. But
+ // it is OK as they do not need to aliased. Plus, options is already
+ // allocated so there is no cost here.
+ s.options = options
+}
+
+// saveRcvdTime is invoked by stateify.
+func (s *segment) saveRcvdTime() unixTime {
+ return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()}
+}
+
+// loadRcvdTime is invoked by stateify.
+func (s *segment) loadRcvdTime(unix unixTime) {
+ s.rcvdTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveXmitTime is invoked by stateify.
+func (s *segment) saveXmitTime() unixTime {
+ return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()}
+}
+
+// loadXmitTime is invoked by stateify.
+func (s *segment) loadXmitTime(unix unixTime) {
+ s.rcvdTime = time.Unix(unix.second, unix.nano)
+}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
new file mode 100644
index 000000000..afc1d0a55
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -0,0 +1,1180 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "math"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // minRTO is the minimum allowed value for the retransmit timeout.
+ minRTO = 200 * time.Millisecond
+
+ // InitialCwnd is the initial congestion window.
+ InitialCwnd = 10
+
+ // nDupAckThreshold is the number of duplicate ACK's required
+ // before fast-retransmit is entered.
+ nDupAckThreshold = 3
+)
+
+// congestionControl is an interface that must be implemented by any supported
+// congestion control algorithm.
+type congestionControl interface {
+ // HandleNDupAcks is invoked when sender.dupAckCount >= nDupAckThreshold
+ // just before entering fast retransmit.
+ HandleNDupAcks()
+
+ // HandleRTOExpired is invoked when the retransmit timer expires.
+ HandleRTOExpired()
+
+ // Update is invoked when processing inbound acks. It's passed the
+ // number of packet's that were acked by the most recent cumulative
+ // acknowledgement.
+ Update(packetsAcked int)
+
+ // PostRecovery is invoked when the sender is exiting a fast retransmit/
+ // recovery phase. This provides congestion control algorithms a way
+ // to adjust their state when exiting recovery.
+ PostRecovery()
+}
+
+// sender holds the state necessary to send TCP segments.
+//
+// +stateify savable
+type sender struct {
+ ep *endpoint
+
+ // lastSendTime is the timestamp when the last packet was sent.
+ lastSendTime time.Time `state:".(unixTime)"`
+
+ // dupAckCount is the number of duplicated acks received. It is used for
+ // fast retransmit.
+ dupAckCount int
+
+ // fr holds state related to fast recovery.
+ fr fastRecovery
+
+ // sndCwnd is the congestion window, in packets.
+ sndCwnd int
+
+ // sndSsthresh is the threshold between slow start and congestion
+ // avoidance.
+ sndSsthresh int
+
+ // sndCAAckCount is the number of packets acknowledged during congestion
+ // avoidance. When enough packets have been ack'd (typically cwnd
+ // packets), the congestion window is incremented by one.
+ sndCAAckCount int
+
+ // outstanding is the number of outstanding packets, that is, packets
+ // that have been sent but not yet acknowledged.
+ outstanding int
+
+ // sndWnd is the send window size.
+ sndWnd seqnum.Size
+
+ // sndUna is the next unacknowledged sequence number.
+ sndUna seqnum.Value
+
+ // sndNxt is the sequence number of the next segment to be sent.
+ sndNxt seqnum.Value
+
+ // sndNxtList is the sequence number of the next segment to be added to
+ // the send list.
+ sndNxtList seqnum.Value
+
+ // rttMeasureSeqNum is the sequence number being used for the latest RTT
+ // measurement.
+ rttMeasureSeqNum seqnum.Value
+
+ // rttMeasureTime is the time when the rttMeasureSeqNum was sent.
+ rttMeasureTime time.Time `state:".(unixTime)"`
+
+ closed bool
+ writeNext *segment
+ writeList segmentList
+ resendTimer timer `state:"nosave"`
+ resendWaker sleep.Waker `state:"nosave"`
+
+ // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time",
+ // "round-trip time variation" and "retransmit timeout", as defined in
+ // section 2 of RFC 6298.
+ rtt rtt
+ rto time.Duration
+ srttInited bool
+
+ // maxPayloadSize is the maximum size of the payload of a given segment.
+ // It is initialized on demand.
+ maxPayloadSize int
+
+ // gso is set if generic segmentation offload is enabled.
+ gso bool
+
+ // sndWndScale is the number of bits to shift left when reading the send
+ // window size from a segment.
+ sndWndScale uint8
+
+ // maxSentAck is the maxium acknowledgement actually sent.
+ maxSentAck seqnum.Value
+
+ // cc is the congestion control algorithm in use for this sender.
+ cc congestionControl
+}
+
+// rtt is a synchronization wrapper used to appease stateify. See the comment
+// in sender, where it is used.
+//
+// +stateify savable
+type rtt struct {
+ sync.Mutex `state:"nosave"`
+
+ srtt time.Duration
+ rttvar time.Duration
+}
+
+// fastRecovery holds information related to fast recovery from a packet loss.
+//
+// +stateify savable
+type fastRecovery struct {
+ // active whether the endpoint is in fast recovery. The following fields
+ // are only meaningful when active is true.
+ active bool
+
+ // first and last represent the inclusive sequence number range being
+ // recovered.
+ first seqnum.Value
+ last seqnum.Value
+
+ // maxCwnd is the maximum value the congestion window may be inflated to
+ // due to duplicate acks. This exists to avoid attacks where the
+ // receiver intentionally sends duplicate acks to artificially inflate
+ // the sender's cwnd.
+ maxCwnd int
+
+ // highRxt is the highest sequence number which has been retransmitted
+ // during the current loss recovery phase.
+ // See: RFC 6675 Section 2 for details.
+ highRxt seqnum.Value
+
+ // rescueRxt is the highest sequence number which has been
+ // optimistically retransmitted to prevent stalling of the ACK clock
+ // when there is loss at the end of the window and no new data is
+ // available for transmission.
+ // See: RFC 6675 Section 2 for details.
+ rescueRxt seqnum.Value
+}
+
+func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender {
+ // The sender MUST reduce the TCP data length to account for any IP or
+ // TCP options that it is including in the packets that it sends.
+ // See: https://tools.ietf.org/html/rfc6691#section-2
+ maxPayloadSize := int(mss) - ep.maxOptionSize()
+
+ s := &sender{
+ ep: ep,
+ sndCwnd: InitialCwnd,
+ sndSsthresh: math.MaxInt64,
+ sndWnd: sndWnd,
+ sndUna: iss + 1,
+ sndNxt: iss + 1,
+ sndNxtList: iss + 1,
+ rto: 1 * time.Second,
+ rttMeasureSeqNum: iss + 1,
+ lastSendTime: time.Now(),
+ maxPayloadSize: maxPayloadSize,
+ maxSentAck: irs + 1,
+ fr: fastRecovery{
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
+ last: iss,
+ highRxt: iss,
+ rescueRxt: iss,
+ },
+ gso: ep.gso != nil,
+ }
+
+ if s.gso {
+ s.ep.gso.MSS = uint16(maxPayloadSize)
+ }
+
+ s.cc = s.initCongestionControl(ep.cc)
+
+ // A negative sndWndScale means that no scaling is in use, otherwise we
+ // store the scaling value.
+ if sndWndScale > 0 {
+ s.sndWndScale = uint8(sndWndScale)
+ }
+
+ s.resendTimer.init(&s.resendWaker)
+
+ s.updateMaxPayloadSize(int(ep.route.MTU()), 0)
+
+ // Initialize SACK Scoreboard after updating max payload size as we use
+ // the maxPayloadSize as the smss when determining if a segment is lost
+ // etc.
+ s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss)
+
+ return s
+}
+
+func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl {
+ switch congestionControlName {
+ case ccCubic:
+ return newCubicCC(s)
+ case ccReno:
+ fallthrough
+ default:
+ return newRenoCC(s)
+ }
+}
+
+// updateMaxPayloadSize updates the maximum payload size based on the given
+// MTU. If this is in response to "packet too big" control packets (indicated
+// by the count argument), it also reduces the number of outstanding packets and
+// attempts to retransmit the first packet above the MTU size.
+func (s *sender) updateMaxPayloadSize(mtu, count int) {
+ m := mtu - header.TCPMinimumSize
+
+ m -= s.ep.maxOptionSize()
+
+ // We don't adjust up for now.
+ if m >= s.maxPayloadSize {
+ return
+ }
+
+ // Make sure we can transmit at least one byte.
+ if m <= 0 {
+ m = 1
+ }
+
+ s.maxPayloadSize = m
+ if s.gso {
+ s.ep.gso.MSS = uint16(m)
+ }
+
+ if count == 0 {
+ // updateMaxPayloadSize is also called when the sender is created.
+ // and there is no data to send in such cases. Return immediately.
+ return
+ }
+
+ // Update the scoreboard's smss to reflect the new lowered
+ // maxPayloadSize.
+ s.ep.scoreboard.smss = uint16(m)
+
+ s.outstanding -= count
+ if s.outstanding < 0 {
+ s.outstanding = 0
+ }
+
+ // Rewind writeNext to the first segment exceeding the MTU. Do nothing
+ // if it is already before such a packet.
+ for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
+ if seg == s.writeNext {
+ // We got to writeNext before we could find a segment
+ // exceeding the MTU.
+ break
+ }
+
+ if seg.data.Size() > m {
+ // We found a segment exceeding the MTU. Rewind
+ // writeNext and try to retransmit it.
+ s.writeNext = seg
+ break
+ }
+ }
+
+ // Since we likely reduced the number of outstanding packets, we may be
+ // ready to send some more.
+ s.sendData()
+}
+
+// sendAck sends an ACK segment.
+func (s *sender) sendAck() {
+ s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt)
+}
+
+// updateRTO updates the retransmit timeout when a new roud-trip time is
+// available. This is done in accordance with section 2 of RFC 6298.
+func (s *sender) updateRTO(rtt time.Duration) {
+ s.rtt.Lock()
+ if !s.srttInited {
+ s.rtt.rttvar = rtt / 2
+ s.rtt.srtt = rtt
+ s.srttInited = true
+ } else {
+ diff := s.rtt.srtt - rtt
+ if diff < 0 {
+ diff = -diff
+ }
+ // Use RFC6298 standard algorithm to update rttvar and srtt when
+ // no timestamps are available.
+ if !s.ep.sendTSOk {
+ s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4
+ s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8
+ } else {
+ // When we are taking RTT measurements of every ACK then
+ // we need to use a modified method as specified in
+ // https://tools.ietf.org/html/rfc7323#appendix-G
+ if s.outstanding == 0 {
+ s.rtt.Unlock()
+ return
+ }
+ // Netstack measures congestion window/inflight all in
+ // terms of packets and not bytes. This is similar to
+ // how linux also does cwnd and inflight. In practice
+ // this approximation works as expected.
+ expectedSamples := math.Ceil(float64(s.outstanding) / 2)
+
+ // alpha & beta values are the original values as recommended in
+ // https://tools.ietf.org/html/rfc6298#section-2.3.
+ const alpha = 0.125
+ const beta = 0.25
+
+ alphaPrime := alpha / expectedSamples
+ betaPrime := beta / expectedSamples
+ rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds()
+ srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds()
+ s.rtt.rttvar = time.Duration(rttVar * float64(time.Second))
+ s.rtt.srtt = time.Duration(srtt * float64(time.Second))
+ }
+ }
+
+ s.rto = s.rtt.srtt + 4*s.rtt.rttvar
+ s.rtt.Unlock()
+ if s.rto < minRTO {
+ s.rto = minRTO
+ }
+}
+
+// resendSegment resends the first unacknowledged segment.
+func (s *sender) resendSegment() {
+ // Don't use any segments we already sent to measure RTT as they may
+ // have been affected by packets being lost.
+ s.rttMeasureSeqNum = s.sndNxt
+
+ // Resend the segment.
+ if seg := s.writeList.Front(); seg != nil {
+ if seg.data.Size() > s.maxPayloadSize {
+ s.splitSeg(seg, s.maxPayloadSize)
+ }
+
+ // See: RFC 6675 section 5 Step 4.3
+ //
+ // To prevent retransmission, set both the HighRXT and RescueRXT
+ // to the highest sequence number in the retransmitted segment.
+ s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1
+ s.sendSegment(seg)
+ s.ep.stack.Stats().TCP.FastRetransmit.Increment()
+
+ // Run SetPipe() as per RFC 6675 section 5 Step 4.4
+ s.SetPipe()
+ }
+}
+
+// retransmitTimerExpired is called when the retransmit timer expires, and
+// unacknowledged segments are assumed lost, and thus need to be resent.
+// Returns true if the connection is still usable, or false if the connection
+// is deemed lost.
+func (s *sender) retransmitTimerExpired() bool {
+ // Check if the timer actually expired or if it's a spurious wake due
+ // to a previously orphaned runtime timer.
+ if !s.resendTimer.checkExpiration() {
+ return true
+ }
+
+ s.ep.stack.Stats().TCP.Timeouts.Increment()
+
+ // Give up if we've waited more than a minute since the last resend.
+ if s.rto >= 60*time.Second {
+ return false
+ }
+
+ // Set new timeout. The timer will be restarted by the call to sendData
+ // below.
+ s.rto *= 2
+
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
+ //
+ // Retransmit timeouts:
+ // After a retransmit timeout, record the highest sequence number
+ // transmitted in the variable recover, and exit the fast recovery
+ // procedure if applicable.
+ s.fr.last = s.sndNxt - 1
+
+ if s.fr.active {
+ // We were attempting fast recovery but were not successful.
+ // Leave the state. We don't need to update ssthresh because it
+ // has already been updated when entered fast-recovery.
+ s.leaveFastRecovery()
+ }
+
+ s.cc.HandleRTOExpired()
+
+ // Mark the next segment to be sent as the first unacknowledged one and
+ // start sending again. Set the number of outstanding packets to 0 so
+ // that we'll be able to retransmit.
+ //
+ // We'll keep on transmitting (or retransmitting) as we get acks for
+ // the data we transmit.
+ s.outstanding = 0
+
+ // Expunge all SACK information as per https://tools.ietf.org/html/rfc6675#section-5.1
+ //
+ // In order to avoid memory deadlocks, the TCP receiver is allowed to
+ // discard data that has already been selectively acknowledged. As a
+ // result, [RFC2018] suggests that a TCP sender SHOULD expunge the SACK
+ // information gathered from a receiver upon a retransmission timeout
+ // (RTO) "since the timeout might indicate that the data receiver has
+ // reneged." Additionally, a TCP sender MUST "ignore prior SACK
+ // information in determining which data to retransmit."
+ //
+ // NOTE: We take the stricter interpretation and just expunge all
+ // information as we lack more rigorous checks to validate if the SACK
+ // information is usable after an RTO.
+ s.ep.scoreboard.Reset()
+ s.writeNext = s.writeList.Front()
+ s.sendData()
+
+ return true
+}
+
+// pCount returns the number of packets in the segment. Due to GSO, a segment
+// can be composed of multiple packets.
+func (s *sender) pCount(seg *segment) int {
+ size := seg.data.Size()
+ if size == 0 {
+ return 1
+ }
+
+ return (size-1)/s.maxPayloadSize + 1
+}
+
+// splitSeg splits a given segment at the size specified and inserts the
+// remainder as a new segment after the current one in the write list.
+func (s *sender) splitSeg(seg *segment, size int) {
+ if seg.data.Size() <= size {
+ return
+ }
+ // Split this segment up.
+ nSeg := seg.clone()
+ nSeg.data.TrimFront(size)
+ nSeg.sequenceNumber.UpdateForward(seqnum.Size(size))
+ s.writeList.InsertAfter(seg, nSeg)
+ seg.data.CapLength(size)
+}
+
+// NextSeg implements the RFC6675 NextSeg() operation. It returns segments that
+// match rule 1, 3 and 4 of the NextSeg() operation defined in RFC6675. Rule 2
+// is handled by the normal send logic.
+func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) {
+ var s3 *segment
+ var s4 *segment
+ smss := s.ep.scoreboard.SMSS()
+ // Step 1.
+ for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
+ if !s.isAssignedSequenceNumber(seg) {
+ break
+ }
+ segSeq := seg.sequenceNumber
+ if seg.data.Size() > int(smss) {
+ s.splitSeg(seg, int(smss))
+ }
+ // See RFC 6675 Section 4
+ //
+ // 1. If there exists a smallest unSACKED sequence number
+ // 'S2' that meets the following 3 criteria for determinig
+ // loss, the sequence range of one segment of up to SMSS
+ // octects starting with S2 MUST be returned.
+ if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) {
+ // NextSeg():
+ //
+ // (1.a) S2 is greater than HighRxt
+ // (1.b) S2 is less than highest octect covered by
+ // any received SACK.
+ if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) {
+ // NextSeg():
+ // (1.c) IsLost(S2) returns true.
+ if s.ep.scoreboard.IsLost(segSeq) {
+ return seg, s3, s4
+ }
+ // NextSeg():
+ //
+ // (3): If the conditions for rules (1) and (2)
+ // fail, but there exists an unSACKed sequence
+ // number S3 that meets the criteria for
+ // detecting loss given in steps 1.a and 1.b
+ // above (specifically excluding (1.c)) then one
+ // segment of upto SMSS octets starting with S3
+ // SHOULD be returned.
+ if s3 == nil {
+ s3 = seg
+ }
+ }
+ // NextSeg():
+ //
+ // (4) If the conditions for (1), (2) and (3) fail,
+ // but there exists outstanding unSACKED data, we
+ // provide the opportunity for a single "rescue"
+ // retransmission per entry into loss recovery. If
+ // HighACK is greater than RescueRxt, the one
+ // segment of upto SMSS octects that MUST include
+ // the highest outstanding unSACKed sequence number
+ // SHOULD be returned.
+ if s.fr.rescueRxt.LessThan(s.sndUna - 1) {
+ if s4 != nil {
+ if s4.sequenceNumber.LessThan(segSeq) {
+ s4 = seg
+ }
+ } else {
+ s4 = seg
+ }
+ s.fr.rescueRxt = s.fr.last
+ }
+ }
+ }
+
+ return nil, s3, s4
+}
+
+// maybeSendSegment tries to send the specified segment and either coalesces
+// other segments into this one or splits the specified segment based on the
+// lower of the specified limit value or the receivers window size specified by
+// end.
+func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (sent bool) {
+ // We abuse the flags field to determine if we have already
+ // assigned a sequence number to this segment.
+ if !s.isAssignedSequenceNumber(seg) {
+ // Merge segments if allowed.
+ if seg.data.Size() != 0 {
+ available := int(seg.sequenceNumber.Size(end))
+ if available > limit {
+ available = limit
+ }
+
+ // nextTooBig indicates that the next segment was too
+ // large to entirely fit in the current segment. It
+ // would be possible to split the next segment and merge
+ // the portion that fits, but unexpectedly splitting
+ // segments can have user visible side-effects which can
+ // break applications. For example, RFC 7766 section 8
+ // says that the length and data of a DNS response
+ // should be sent in the same TCP segment to avoid
+ // triggering bugs in poorly written DNS
+ // implementations.
+ var nextTooBig bool
+ for seg.Next() != nil && seg.Next().data.Size() != 0 {
+ if seg.data.Size()+seg.Next().data.Size() > available {
+ nextTooBig = true
+ break
+ }
+ seg.data.Append(seg.Next().data)
+
+ // Consume the segment that we just merged in.
+ s.writeList.Remove(seg.Next())
+ }
+ if !nextTooBig && seg.data.Size() < available {
+ // Segment is not full.
+ if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 {
+ // Nagle's algorithm. From Wikipedia:
+ // Nagle's algorithm works by
+ // combining a number of small
+ // outgoing messages and sending them
+ // all at once. Specifically, as long
+ // as there is a sent packet for which
+ // the sender has received no
+ // acknowledgment, the sender should
+ // keep buffering its output until it
+ // has a full packet's worth of
+ // output, thus allowing output to be
+ // sent all at once.
+ return false
+ }
+ if atomic.LoadUint32(&s.ep.cork) != 0 {
+ // Hold back the segment until full.
+ return false
+ }
+ }
+ }
+
+ // Assign flags. We don't do it above so that we can merge
+ // additional data if Nagle holds the segment.
+ seg.sequenceNumber = s.sndNxt
+ seg.flags = header.TCPFlagAck | header.TCPFlagPsh
+ }
+
+ var segEnd seqnum.Value
+ if seg.data.Size() == 0 {
+ if s.writeList.Back() != seg {
+ panic("FIN segments must be the final segment in the write list.")
+ }
+ seg.flags = header.TCPFlagAck | header.TCPFlagFin
+ segEnd = seg.sequenceNumber.Add(1)
+ } else {
+ // We're sending a non-FIN segment.
+ if seg.flags&header.TCPFlagFin != 0 {
+ panic("Netstack queues FIN segments without data.")
+ }
+
+ if !seg.sequenceNumber.LessThan(end) {
+ return false
+ }
+
+ available := int(seg.sequenceNumber.Size(end))
+ if available == 0 {
+ return false
+ }
+ if available > limit {
+ available = limit
+ }
+
+ if seg.data.Size() > available {
+ s.splitSeg(seg, available)
+ }
+
+ segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ }
+
+ s.sendSegment(seg)
+
+ // Update sndNxt if we actually sent new data (as opposed to
+ // retransmitting some previously sent data).
+ if s.sndNxt.LessThan(segEnd) {
+ s.sndNxt = segEnd
+ }
+
+ return true
+}
+
+// handleSACKRecovery implements the loss recovery phase as described in RFC6675
+// section 5, step C.
+func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) {
+ s.SetPipe()
+ for s.outstanding < s.sndCwnd {
+ nextSeg, s3, s4 := s.NextSeg()
+ if nextSeg == nil {
+ // NextSeg():
+ //
+ // Step (2): "If no sequence number 'S2' per rule (1)
+ // exists but there exists available unsent data and the
+ // receiver's advertised window allows, the sequence
+ // range of one segment of up to SMSS octets of
+ // previously unsent data starting with sequence number
+ // HighData+1 MUST be returned."
+ for seg := s.writeNext; seg != nil; seg = seg.Next() {
+ if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) {
+ continue
+ }
+ // Step C.3 described below is handled by
+ // maybeSendSegment which increments sndNxt when
+ // a segment is transmitted.
+ //
+ // Step C.3 "If any of the data octets sent in
+ // (C.1) are above HighData, HighData must be
+ // updated to reflect the transmission of
+ // previously unsent data."
+ if sent := s.maybeSendSegment(seg, limit, end); !sent {
+ break
+ }
+ dataSent = true
+ s.outstanding++
+ s.writeNext = seg.Next()
+ nextSeg = seg
+ break
+ }
+ if nextSeg != nil {
+ continue
+ }
+ }
+ rescueRtx := false
+ if nextSeg == nil && s3 != nil {
+ nextSeg = s3
+ }
+ if nextSeg == nil && s4 != nil {
+ nextSeg = s4
+ rescueRtx = true
+ }
+ if nextSeg == nil {
+ break
+ }
+ segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen())
+ if !rescueRtx && nextSeg.sequenceNumber.LessThan(s.sndNxt) {
+ // RFC 6675, Step C.2
+ //
+ // "If any of the data octets sent in (C.1) are below
+ // HighData, HighRxt MUST be set to the highest sequence
+ // number of the retransmitted segment unless NextSeg ()
+ // rule (4) was invoked for this retransmission."
+ s.fr.highRxt = segEnd - 1
+ }
+
+ // RFC 6675, Step C.4.
+ //
+ // "The estimate of the amount of data outstanding in the network
+ // must be updated by incrementing pipe by the number of octets
+ // transmitted in (C.1)."
+ s.outstanding++
+ dataSent = true
+ s.sendSegment(nextSeg)
+ }
+ return dataSent
+}
+
+// sendData sends new data segments. It is called when data becomes available or
+// when the send window opens up.
+func (s *sender) sendData() {
+ limit := s.maxPayloadSize
+ if s.gso {
+ limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize)
+ }
+ end := s.sndUna.Add(s.sndWnd)
+
+ // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10.
+ // "A TCP SHOULD set cwnd to no more than RW before beginning
+ // transmission if the TCP has not sent data in the interval exceeding
+ // the retrasmission timeout."
+ if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto {
+ if s.sndCwnd > InitialCwnd {
+ s.sndCwnd = InitialCwnd
+ }
+ }
+
+ var dataSent bool
+
+ // RFC 6675 recovery algorithm step C 1-5.
+ if s.fr.active && s.ep.sackPermitted {
+ dataSent = s.handleSACKRecovery(s.maxPayloadSize, end)
+ } else {
+ for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
+ cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize
+ if cwndLimit < limit {
+ limit = cwndLimit
+ }
+ if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ continue
+ }
+ if sent := s.maybeSendSegment(seg, limit, end); !sent {
+ break
+ }
+ dataSent = true
+ s.outstanding++
+ s.writeNext = seg.Next()
+ }
+ }
+
+ if dataSent {
+ // We sent data, so we should stop the keepalive timer to ensure
+ // that no keepalives are sent while there is pending data.
+ s.ep.disableKeepaliveTimer()
+ }
+
+ // Enable the timer if we have pending data and it's not enabled yet.
+ if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
+ s.resendTimer.enable(s.rto)
+ }
+ // If we have no more pending data, start the keepalive timer.
+ if s.sndUna == s.sndNxt {
+ s.ep.resetKeepaliveTimer(false)
+ }
+}
+
+func (s *sender) enterFastRecovery() {
+ s.fr.active = true
+ // Save state to reflect we're now in fast recovery.
+ //
+ // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3.
+ // We inflate the cwnd by 3 to account for the 3 packets which triggered
+ // the 3 duplicate ACKs and are now not in flight.
+ s.sndCwnd = s.sndSsthresh + 3
+ s.fr.first = s.sndUna
+ s.fr.last = s.sndNxt - 1
+ s.fr.maxCwnd = s.sndCwnd + s.outstanding
+ if s.ep.sackPermitted {
+ s.ep.stack.Stats().TCP.SACKRecovery.Increment()
+ return
+ }
+ s.ep.stack.Stats().TCP.FastRecovery.Increment()
+}
+
+func (s *sender) leaveFastRecovery() {
+ s.fr.active = false
+ s.fr.maxCwnd = 0
+ s.dupAckCount = 0
+
+ // Deflate cwnd. It had been artificially inflated when new dups arrived.
+ s.sndCwnd = s.sndSsthresh
+
+ s.cc.PostRecovery()
+}
+
+func (s *sender) handleFastRecovery(seg *segment) (rtx bool) {
+ ack := seg.ackNumber
+ // We are in fast recovery mode. Ignore the ack if it's out of
+ // range.
+ if !ack.InRange(s.sndUna, s.sndNxt+1) {
+ return false
+ }
+
+ // Leave fast recovery if it acknowledges all the data covered by
+ // this fast recovery session.
+ if s.fr.last.LessThan(ack) {
+ s.leaveFastRecovery()
+ return false
+ }
+
+ if s.ep.sackPermitted {
+ // When SACK is enabled we let retransmission be governed by
+ // the SACK logic.
+ return false
+ }
+
+ // Don't count this as a duplicate if it is carrying data or
+ // updating the window.
+ if seg.logicalLen() != 0 || s.sndWnd != seg.window {
+ return false
+ }
+
+ // Inflate the congestion window if we're getting duplicate acks
+ // for the packet we retransmitted.
+ if ack == s.fr.first {
+ // We received a dup, inflate the congestion window by 1 packet
+ // if we're not at the max yet. Only inflate the window if
+ // regular FastRecovery is in use, RFC6675 does not require
+ // inflating cwnd on duplicate ACKs.
+ if s.sndCwnd < s.fr.maxCwnd {
+ s.sndCwnd++
+ }
+ return false
+ }
+
+ // A partial ack was received. Retransmit this packet and
+ // remember it so that we don't retransmit it again. We don't
+ // inflate the window because we're putting the same packet back
+ // onto the wire.
+ //
+ // N.B. The retransmit timer will be reset by the caller.
+ s.fr.first = ack
+ s.dupAckCount = 0
+ return true
+}
+
+// isAssignedSequenceNumber relies on the fact that we only set flags once a
+// sequencenumber is assigned and that is only done right before we send the
+// segment. As a result any segment that has a non-zero flag has a valid
+// sequence number assigned to it.
+func (s *sender) isAssignedSequenceNumber(seg *segment) bool {
+ return seg.flags != 0
+}
+
+// SetPipe implements the SetPipe() function described in RFC6675. Netstack
+// maintains the congestion window in number of packets and not bytes, so
+// SetPipe() here measures number of outstanding packets rather than actual
+// outstanding bytes in the network.
+func (s *sender) SetPipe() {
+ // If SACK isn't permitted or it is permitted but recovery is not active
+ // then ignore pipe calculations.
+ if !s.ep.sackPermitted || !s.fr.active {
+ return
+ }
+ pipe := 0
+ smss := seqnum.Size(s.ep.scoreboard.SMSS())
+ for s1 := s.writeList.Front(); s1 != nil && s1.data.Size() != 0 && s.isAssignedSequenceNumber(s1); s1 = s1.Next() {
+ // With GSO each segment can be much larger than SMSS. So check the segment
+ // in SMSS sized ranges.
+ segEnd := s1.sequenceNumber.Add(seqnum.Size(s1.data.Size()))
+ for startSeq := s1.sequenceNumber; startSeq.LessThan(segEnd); startSeq = startSeq.Add(smss) {
+ endSeq := startSeq.Add(smss)
+ if segEnd.LessThan(endSeq) {
+ endSeq = segEnd
+ }
+ sb := header.SACKBlock{startSeq, endSeq}
+ // SetPipe():
+ //
+ // After initializing pipe to zero, the following steps are
+ // taken for each octet 'S1' in the sequence space between
+ // HighACK and HighData that has not been SACKed:
+ if !s1.sequenceNumber.LessThan(s.sndNxt) {
+ break
+ }
+ if s.ep.scoreboard.IsSACKED(sb) {
+ continue
+ }
+
+ // SetPipe():
+ //
+ // (a) If IsLost(S1) returns false, Pipe is incremened by 1.
+ //
+ // NOTE: here we mark the whole segment as lost. We do not try
+ // and test every byte in our write buffer as we maintain our
+ // pipe in terms of oustanding packets and not bytes.
+ if !s.ep.scoreboard.IsRangeLost(sb) {
+ pipe++
+ }
+ // SetPipe():
+ // (b) If S1 <= HighRxt, Pipe is incremented by 1.
+ if s1.sequenceNumber.LessThanEq(s.fr.highRxt) {
+ pipe++
+ }
+ }
+ }
+ s.outstanding = pipe
+}
+
+// checkDuplicateAck is called when an ack is received. It manages the state
+// related to duplicate acks and determines if a retransmit is needed according
+// to the rules in RFC 6582 (NewReno).
+func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) {
+ ack := seg.ackNumber
+ if s.fr.active {
+ return s.handleFastRecovery(seg)
+ }
+
+ // We're not in fast recovery yet. A segment is considered a duplicate
+ // only if it doesn't carry any data and doesn't update the send window,
+ // because if it does, it wasn't sent in response to an out-of-order
+ // segment. If SACK is enabled then we have an additional check to see
+ // if the segment carries new SACK information. If it does then it is
+ // considered a duplicate ACK as per RFC6675.
+ if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt {
+ if !s.ep.sackPermitted || !seg.hasNewSACKInfo {
+ s.dupAckCount = 0
+ return false
+ }
+ }
+
+ s.dupAckCount++
+
+ // Do not enter fast recovery until we reach nDupAckThreshold or the
+ // first unacknowledged byte is considered lost as per SACK scoreboard.
+ if s.dupAckCount < nDupAckThreshold || (s.ep.sackPermitted && !s.ep.scoreboard.IsLost(s.sndUna)) {
+ // RFC 6675 Step 3.
+ s.fr.highRxt = s.sndUna - 1
+ // Do run SetPipe() to calculate the outstanding segments.
+ s.SetPipe()
+ return false
+ }
+
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2
+ //
+ // We only do the check here, the incrementing of last to the highest
+ // sequence number transmitted till now is done when enterFastRecovery
+ // is invoked.
+ if !s.fr.last.LessThan(seg.ackNumber) {
+ s.dupAckCount = 0
+ return false
+ }
+ s.cc.HandleNDupAcks()
+ s.enterFastRecovery()
+ s.dupAckCount = 0
+ return true
+}
+
+// handleRcvdSegment is called when a segment is received; it is responsible for
+// updating the send-related state.
+func (s *sender) handleRcvdSegment(seg *segment) {
+ // Check if we can extract an RTT measurement from this ack.
+ if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
+ s.updateRTO(time.Now().Sub(s.rttMeasureTime))
+ s.rttMeasureSeqNum = s.sndNxt
+ }
+
+ // Update Timestamp if required. See RFC7323, section-4.3.
+ if s.ep.sendTSOk && seg.parsedOptions.TS {
+ s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber)
+ }
+
+ // Insert SACKBlock information into our scoreboard.
+ if s.ep.sackPermitted {
+ for _, sb := range seg.parsedOptions.SACKBlocks {
+ // Only insert the SACK block if the following holds
+ // true:
+ // * SACK block acks data after the ack number in the
+ // current segment.
+ // * SACK block represents a sequence
+ // between sndUna and sndNxt (i.e. data that is
+ // currently unacked and in-flight).
+ // * SACK block that has not been SACKed already.
+ //
+ // NOTE: This check specifically excludes DSACK blocks
+ // which have start/end before sndUna and are used to
+ // indicate spurious retransmissions.
+ if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) {
+ s.ep.scoreboard.Insert(sb)
+ seg.hasNewSACKInfo = true
+ }
+ }
+ s.SetPipe()
+ }
+
+ // Count the duplicates and do the fast retransmit if needed.
+ rtx := s.checkDuplicateAck(seg)
+
+ // Stash away the current window size.
+ s.sndWnd = seg.window
+
+ // Ignore ack if it doesn't acknowledge any new data.
+ ack := seg.ackNumber
+ if (ack - 1).InRange(s.sndUna, s.sndNxt) {
+ s.dupAckCount = 0
+
+ // See : https://tools.ietf.org/html/rfc1323#section-3.3.
+ // Specifically we should only update the RTO using TSEcr if the
+ // following condition holds:
+ //
+ // A TSecr value received in a segment is used to update the
+ // averaged RTT measurement only if the segment acknowledges
+ // some new data, i.e., only if it advances the left edge of
+ // the send window.
+ if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 {
+ // TSVal/Ecr values sent by Netstack are at a millisecond
+ // granularity.
+ elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond
+ s.updateRTO(elapsed)
+ }
+
+ // When an ack is received we must rearm the timer.
+ // RFC 6298 5.2
+ s.resendTimer.enable(s.rto)
+
+ // Remove all acknowledged data from the write list.
+ acked := s.sndUna.Size(ack)
+ s.sndUna = ack
+
+ ackLeft := acked
+ originalOutstanding := s.outstanding
+ for ackLeft > 0 {
+ // We use logicalLen here because we can have FIN
+ // segments (which are always at the end of list) that
+ // have no data, but do consume a sequence number.
+ seg := s.writeList.Front()
+ datalen := seg.logicalLen()
+
+ if datalen > ackLeft {
+ prevCount := s.pCount(seg)
+ seg.data.TrimFront(int(ackLeft))
+ seg.sequenceNumber.UpdateForward(ackLeft)
+ s.outstanding -= prevCount - s.pCount(seg)
+ break
+ }
+
+ if s.writeNext == seg {
+ s.writeNext = seg.Next()
+ }
+ s.writeList.Remove(seg)
+
+ // if SACK is enabled then Only reduce outstanding if
+ // the segment was not previously SACKED as these have
+ // already been accounted for in SetPipe().
+ if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) {
+ s.outstanding -= s.pCount(seg)
+ }
+ seg.decRef()
+ ackLeft -= datalen
+ }
+
+ // Update the send buffer usage and notify potential waiters.
+ s.ep.updateSndBufferUsage(int(acked))
+
+ // Clear SACK information for all acked data.
+ s.ep.scoreboard.Delete(s.sndUna)
+
+ // If we are not in fast recovery then update the congestion
+ // window based on the number of acknowledged packets.
+ if !s.fr.active {
+ s.cc.Update(originalOutstanding - s.outstanding)
+ }
+
+ // It is possible for s.outstanding to drop below zero if we get
+ // a retransmit timeout, reset outstanding to zero but later
+ // get an ack that cover previously sent data.
+ if s.outstanding < 0 {
+ s.outstanding = 0
+ }
+
+ s.SetPipe()
+
+ // If all outstanding data was acknowledged the disable the timer.
+ // RFC 6298 Rule 5.3
+ if s.sndUna == s.sndNxt {
+ s.outstanding = 0
+ s.resendTimer.disable()
+ }
+ }
+ // Now that we've popped all acknowledged data from the retransmit
+ // queue, retransmit if needed.
+ if rtx {
+ s.resendSegment()
+ }
+
+ // Send more data now that some of the pending data has been ack'd, or
+ // that the window opened up, or the congestion window was inflated due
+ // to a duplicate ack during fast recovery. This will also re-enable
+ // the retransmit timer if needed.
+ if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo {
+ s.sendData()
+ }
+}
+
+// sendSegment sends the specified segment.
+func (s *sender) sendSegment(seg *segment) *tcpip.Error {
+ if !seg.xmitTime.IsZero() {
+ s.ep.stack.Stats().TCP.Retransmits.Increment()
+ if s.sndCwnd < s.sndSsthresh {
+ s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment()
+ }
+ }
+ seg.xmitTime = time.Now()
+ return s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber)
+}
+
+// sendSegmentFromView sends a new segment containing the given payload, flags
+// and sequence number.
+func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error {
+ s.lastSendTime = time.Now()
+ if seq == s.rttMeasureSeqNum {
+ s.rttMeasureTime = s.lastSendTime
+ }
+
+ rcvNxt, rcvWnd := s.ep.rcv.getSendParams()
+
+ // Remember the max sent ack.
+ s.maxSentAck = rcvNxt
+
+ // Every time a packet containing data is sent (including a
+ // retransmission), if SACK is enabled then use the conservative timer
+ // described in RFC6675 Section 4.0, otherwise follow the standard time
+ // described in RFC6298 Section 5.2.
+ if data.Size() != 0 {
+ if s.ep.sackPermitted {
+ s.resendTimer.enable(s.rto)
+ } else {
+ if !s.resendTimer.enabled() {
+ s.resendTimer.enable(s.rto)
+ }
+ }
+ }
+
+ return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd)
+}
diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go
new file mode 100644
index 000000000..12eff8afc
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/snd_state.go
@@ -0,0 +1,50 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+)
+
+// +stateify savable
+type unixTime struct {
+ second int64
+ nano int64
+}
+
+// saveLastSendTime is invoked by stateify.
+func (s *sender) saveLastSendTime() unixTime {
+ return unixTime{s.lastSendTime.Unix(), s.lastSendTime.UnixNano()}
+}
+
+// loadLastSendTime is invoked by stateify.
+func (s *sender) loadLastSendTime(unix unixTime) {
+ s.lastSendTime = time.Unix(unix.second, unix.nano)
+}
+
+// saveRttMeasureTime is invoked by stateify.
+func (s *sender) saveRttMeasureTime() unixTime {
+ return unixTime{s.rttMeasureTime.Unix(), s.rttMeasureTime.UnixNano()}
+}
+
+// loadRttMeasureTime is invoked by stateify.
+func (s *sender) loadRttMeasureTime(unix unixTime) {
+ s.rttMeasureTime = time.Unix(unix.second, unix.nano)
+}
+
+// afterLoad is invoked by stateify.
+func (s *sender) afterLoad() {
+ s.resendTimer.init(&s.resendWaker)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go
new file mode 100755
index 000000000..029f98a11
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go
@@ -0,0 +1,173 @@
+package tcp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type segmentElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type segmentList struct {
+ head *segment
+ tail *segment
+}
+
+// Reset resets list l to the empty state.
+func (l *segmentList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *segmentList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *segmentList) Front() *segment {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *segmentList) Back() *segment {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *segmentList) PushFront(e *segment) {
+ segmentElementMapper{}.linkerFor(e).SetNext(l.head)
+ segmentElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ segmentElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *segmentList) PushBack(e *segment) {
+ segmentElementMapper{}.linkerFor(e).SetNext(nil)
+ segmentElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ segmentElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *segmentList) PushBackList(m *segmentList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *segmentList) InsertAfter(b, e *segment) {
+ a := segmentElementMapper{}.linkerFor(b).Next()
+ segmentElementMapper{}.linkerFor(e).SetNext(a)
+ segmentElementMapper{}.linkerFor(e).SetPrev(b)
+ segmentElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ segmentElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *segmentList) InsertBefore(a, e *segment) {
+ b := segmentElementMapper{}.linkerFor(a).Prev()
+ segmentElementMapper{}.linkerFor(e).SetNext(a)
+ segmentElementMapper{}.linkerFor(e).SetPrev(b)
+ segmentElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ segmentElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *segmentList) Remove(e *segment) {
+ prev := segmentElementMapper{}.linkerFor(e).Prev()
+ next := segmentElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ segmentElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ segmentElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type segmentEntry struct {
+ next *segment
+ prev *segment
+}
+
+// Next returns the entry that follows e in the list.
+func (e *segmentEntry) Next() *segment {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *segmentEntry) Prev() *segment {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *segmentEntry) SetNext(elem *segment) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *segmentEntry) SetPrev(elem *segment) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
new file mode 100755
index 000000000..9049a99b2
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -0,0 +1,400 @@
+// automatically generated by stateify.
+
+package tcp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *SACKInfo) beforeSave() {}
+func (x *SACKInfo) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Blocks", &x.Blocks)
+ m.Save("NumBlocks", &x.NumBlocks)
+}
+
+func (x *SACKInfo) afterLoad() {}
+func (x *SACKInfo) load(m state.Map) {
+ m.Load("Blocks", &x.Blocks)
+ m.Load("NumBlocks", &x.NumBlocks)
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var lastError string = x.saveLastError()
+ m.SaveValue("lastError", lastError)
+ var state endpointState = x.saveState()
+ m.SaveValue("state", state)
+ var hardError string = x.saveHardError()
+ m.SaveValue("hardError", hardError)
+ var acceptedChan []*endpoint = x.saveAcceptedChan()
+ m.SaveValue("acceptedChan", acceptedChan)
+ m.Save("netProto", &x.netProto)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvBufUsed", &x.rcvBufUsed)
+ m.Save("id", &x.id)
+ m.Save("isRegistered", &x.isRegistered)
+ m.Save("v6only", &x.v6only)
+ m.Save("isConnectNotified", &x.isConnectNotified)
+ m.Save("broadcast", &x.broadcast)
+ m.Save("workerRunning", &x.workerRunning)
+ m.Save("workerCleanup", &x.workerCleanup)
+ m.Save("sendTSOk", &x.sendTSOk)
+ m.Save("recentTS", &x.recentTS)
+ m.Save("tsOffset", &x.tsOffset)
+ m.Save("shutdownFlags", &x.shutdownFlags)
+ m.Save("sackPermitted", &x.sackPermitted)
+ m.Save("sack", &x.sack)
+ m.Save("reusePort", &x.reusePort)
+ m.Save("delay", &x.delay)
+ m.Save("cork", &x.cork)
+ m.Save("scoreboard", &x.scoreboard)
+ m.Save("reuseAddr", &x.reuseAddr)
+ m.Save("slowAck", &x.slowAck)
+ m.Save("segmentQueue", &x.segmentQueue)
+ m.Save("synRcvdCount", &x.synRcvdCount)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("sndBufUsed", &x.sndBufUsed)
+ m.Save("sndClosed", &x.sndClosed)
+ m.Save("sndBufInQueue", &x.sndBufInQueue)
+ m.Save("sndQueue", &x.sndQueue)
+ m.Save("cc", &x.cc)
+ m.Save("packetTooBigCount", &x.packetTooBigCount)
+ m.Save("sndMTU", &x.sndMTU)
+ m.Save("keepalive", &x.keepalive)
+ m.Save("rcv", &x.rcv)
+ m.Save("snd", &x.snd)
+ m.Save("bindAddress", &x.bindAddress)
+ m.Save("connectingAddress", &x.connectingAddress)
+ m.Save("gso", &x.gso)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("netProto", &x.netProto)
+ m.LoadWait("waiterQueue", &x.waiterQueue)
+ m.LoadWait("rcvList", &x.rcvList)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvBufUsed", &x.rcvBufUsed)
+ m.Load("id", &x.id)
+ m.Load("isRegistered", &x.isRegistered)
+ m.Load("v6only", &x.v6only)
+ m.Load("isConnectNotified", &x.isConnectNotified)
+ m.Load("broadcast", &x.broadcast)
+ m.Load("workerRunning", &x.workerRunning)
+ m.Load("workerCleanup", &x.workerCleanup)
+ m.Load("sendTSOk", &x.sendTSOk)
+ m.Load("recentTS", &x.recentTS)
+ m.Load("tsOffset", &x.tsOffset)
+ m.Load("shutdownFlags", &x.shutdownFlags)
+ m.Load("sackPermitted", &x.sackPermitted)
+ m.Load("sack", &x.sack)
+ m.Load("reusePort", &x.reusePort)
+ m.Load("delay", &x.delay)
+ m.Load("cork", &x.cork)
+ m.Load("scoreboard", &x.scoreboard)
+ m.Load("reuseAddr", &x.reuseAddr)
+ m.Load("slowAck", &x.slowAck)
+ m.LoadWait("segmentQueue", &x.segmentQueue)
+ m.Load("synRcvdCount", &x.synRcvdCount)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("sndBufUsed", &x.sndBufUsed)
+ m.Load("sndClosed", &x.sndClosed)
+ m.Load("sndBufInQueue", &x.sndBufInQueue)
+ m.LoadWait("sndQueue", &x.sndQueue)
+ m.Load("cc", &x.cc)
+ m.Load("packetTooBigCount", &x.packetTooBigCount)
+ m.Load("sndMTU", &x.sndMTU)
+ m.Load("keepalive", &x.keepalive)
+ m.LoadWait("rcv", &x.rcv)
+ m.LoadWait("snd", &x.snd)
+ m.Load("bindAddress", &x.bindAddress)
+ m.Load("connectingAddress", &x.connectingAddress)
+ m.Load("gso", &x.gso)
+ m.LoadValue("lastError", new(string), func(y interface{}) { x.loadLastError(y.(string)) })
+ m.LoadValue("state", new(endpointState), func(y interface{}) { x.loadState(y.(endpointState)) })
+ m.LoadValue("hardError", new(string), func(y interface{}) { x.loadHardError(y.(string)) })
+ m.LoadValue("acceptedChan", new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *keepalive) beforeSave() {}
+func (x *keepalive) save(m state.Map) {
+ x.beforeSave()
+ m.Save("enabled", &x.enabled)
+ m.Save("idle", &x.idle)
+ m.Save("interval", &x.interval)
+ m.Save("count", &x.count)
+ m.Save("unacked", &x.unacked)
+}
+
+func (x *keepalive) afterLoad() {}
+func (x *keepalive) load(m state.Map) {
+ m.Load("enabled", &x.enabled)
+ m.Load("idle", &x.idle)
+ m.Load("interval", &x.interval)
+ m.Load("count", &x.count)
+ m.Load("unacked", &x.unacked)
+}
+
+func (x *receiver) beforeSave() {}
+func (x *receiver) save(m state.Map) {
+ x.beforeSave()
+ m.Save("ep", &x.ep)
+ m.Save("rcvNxt", &x.rcvNxt)
+ m.Save("rcvAcc", &x.rcvAcc)
+ m.Save("rcvWndScale", &x.rcvWndScale)
+ m.Save("closed", &x.closed)
+ m.Save("pendingRcvdSegments", &x.pendingRcvdSegments)
+ m.Save("pendingBufUsed", &x.pendingBufUsed)
+ m.Save("pendingBufSize", &x.pendingBufSize)
+}
+
+func (x *receiver) afterLoad() {}
+func (x *receiver) load(m state.Map) {
+ m.Load("ep", &x.ep)
+ m.Load("rcvNxt", &x.rcvNxt)
+ m.Load("rcvAcc", &x.rcvAcc)
+ m.Load("rcvWndScale", &x.rcvWndScale)
+ m.Load("closed", &x.closed)
+ m.Load("pendingRcvdSegments", &x.pendingRcvdSegments)
+ m.Load("pendingBufUsed", &x.pendingBufUsed)
+ m.Load("pendingBufSize", &x.pendingBufSize)
+}
+
+func (x *renoState) beforeSave() {}
+func (x *renoState) save(m state.Map) {
+ x.beforeSave()
+ m.Save("s", &x.s)
+}
+
+func (x *renoState) afterLoad() {}
+func (x *renoState) load(m state.Map) {
+ m.Load("s", &x.s)
+}
+
+func (x *SACKScoreboard) beforeSave() {}
+func (x *SACKScoreboard) save(m state.Map) {
+ x.beforeSave()
+ m.Save("smss", &x.smss)
+ m.Save("maxSACKED", &x.maxSACKED)
+}
+
+func (x *SACKScoreboard) afterLoad() {}
+func (x *SACKScoreboard) load(m state.Map) {
+ m.Load("smss", &x.smss)
+ m.Load("maxSACKED", &x.maxSACKED)
+}
+
+func (x *segment) beforeSave() {}
+func (x *segment) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ var options []byte = x.saveOptions()
+ m.SaveValue("options", options)
+ var rcvdTime unixTime = x.saveRcvdTime()
+ m.SaveValue("rcvdTime", rcvdTime)
+ var xmitTime unixTime = x.saveXmitTime()
+ m.SaveValue("xmitTime", xmitTime)
+ m.Save("segmentEntry", &x.segmentEntry)
+ m.Save("refCnt", &x.refCnt)
+ m.Save("viewToDeliver", &x.viewToDeliver)
+ m.Save("sequenceNumber", &x.sequenceNumber)
+ m.Save("ackNumber", &x.ackNumber)
+ m.Save("flags", &x.flags)
+ m.Save("window", &x.window)
+ m.Save("csum", &x.csum)
+ m.Save("csumValid", &x.csumValid)
+ m.Save("parsedOptions", &x.parsedOptions)
+ m.Save("hasNewSACKInfo", &x.hasNewSACKInfo)
+}
+
+func (x *segment) afterLoad() {}
+func (x *segment) load(m state.Map) {
+ m.Load("segmentEntry", &x.segmentEntry)
+ m.Load("refCnt", &x.refCnt)
+ m.Load("viewToDeliver", &x.viewToDeliver)
+ m.Load("sequenceNumber", &x.sequenceNumber)
+ m.Load("ackNumber", &x.ackNumber)
+ m.Load("flags", &x.flags)
+ m.Load("window", &x.window)
+ m.Load("csum", &x.csum)
+ m.Load("csumValid", &x.csumValid)
+ m.Load("parsedOptions", &x.parsedOptions)
+ m.Load("hasNewSACKInfo", &x.hasNewSACKInfo)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+ m.LoadValue("options", new([]byte), func(y interface{}) { x.loadOptions(y.([]byte)) })
+ m.LoadValue("rcvdTime", new(unixTime), func(y interface{}) { x.loadRcvdTime(y.(unixTime)) })
+ m.LoadValue("xmitTime", new(unixTime), func(y interface{}) { x.loadXmitTime(y.(unixTime)) })
+}
+
+func (x *segmentQueue) beforeSave() {}
+func (x *segmentQueue) save(m state.Map) {
+ x.beforeSave()
+ m.Save("list", &x.list)
+ m.Save("limit", &x.limit)
+ m.Save("used", &x.used)
+}
+
+func (x *segmentQueue) afterLoad() {}
+func (x *segmentQueue) load(m state.Map) {
+ m.LoadWait("list", &x.list)
+ m.Load("limit", &x.limit)
+ m.Load("used", &x.used)
+}
+
+func (x *sender) beforeSave() {}
+func (x *sender) save(m state.Map) {
+ x.beforeSave()
+ var lastSendTime unixTime = x.saveLastSendTime()
+ m.SaveValue("lastSendTime", lastSendTime)
+ var rttMeasureTime unixTime = x.saveRttMeasureTime()
+ m.SaveValue("rttMeasureTime", rttMeasureTime)
+ m.Save("ep", &x.ep)
+ m.Save("dupAckCount", &x.dupAckCount)
+ m.Save("fr", &x.fr)
+ m.Save("sndCwnd", &x.sndCwnd)
+ m.Save("sndSsthresh", &x.sndSsthresh)
+ m.Save("sndCAAckCount", &x.sndCAAckCount)
+ m.Save("outstanding", &x.outstanding)
+ m.Save("sndWnd", &x.sndWnd)
+ m.Save("sndUna", &x.sndUna)
+ m.Save("sndNxt", &x.sndNxt)
+ m.Save("sndNxtList", &x.sndNxtList)
+ m.Save("rttMeasureSeqNum", &x.rttMeasureSeqNum)
+ m.Save("closed", &x.closed)
+ m.Save("writeNext", &x.writeNext)
+ m.Save("writeList", &x.writeList)
+ m.Save("rtt", &x.rtt)
+ m.Save("rto", &x.rto)
+ m.Save("srttInited", &x.srttInited)
+ m.Save("maxPayloadSize", &x.maxPayloadSize)
+ m.Save("gso", &x.gso)
+ m.Save("sndWndScale", &x.sndWndScale)
+ m.Save("maxSentAck", &x.maxSentAck)
+ m.Save("cc", &x.cc)
+}
+
+func (x *sender) load(m state.Map) {
+ m.Load("ep", &x.ep)
+ m.Load("dupAckCount", &x.dupAckCount)
+ m.Load("fr", &x.fr)
+ m.Load("sndCwnd", &x.sndCwnd)
+ m.Load("sndSsthresh", &x.sndSsthresh)
+ m.Load("sndCAAckCount", &x.sndCAAckCount)
+ m.Load("outstanding", &x.outstanding)
+ m.Load("sndWnd", &x.sndWnd)
+ m.Load("sndUna", &x.sndUna)
+ m.Load("sndNxt", &x.sndNxt)
+ m.Load("sndNxtList", &x.sndNxtList)
+ m.Load("rttMeasureSeqNum", &x.rttMeasureSeqNum)
+ m.Load("closed", &x.closed)
+ m.Load("writeNext", &x.writeNext)
+ m.Load("writeList", &x.writeList)
+ m.Load("rtt", &x.rtt)
+ m.Load("rto", &x.rto)
+ m.Load("srttInited", &x.srttInited)
+ m.Load("maxPayloadSize", &x.maxPayloadSize)
+ m.Load("gso", &x.gso)
+ m.Load("sndWndScale", &x.sndWndScale)
+ m.Load("maxSentAck", &x.maxSentAck)
+ m.Load("cc", &x.cc)
+ m.LoadValue("lastSendTime", new(unixTime), func(y interface{}) { x.loadLastSendTime(y.(unixTime)) })
+ m.LoadValue("rttMeasureTime", new(unixTime), func(y interface{}) { x.loadRttMeasureTime(y.(unixTime)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *rtt) beforeSave() {}
+func (x *rtt) save(m state.Map) {
+ x.beforeSave()
+ m.Save("srtt", &x.srtt)
+ m.Save("rttvar", &x.rttvar)
+}
+
+func (x *rtt) afterLoad() {}
+func (x *rtt) load(m state.Map) {
+ m.Load("srtt", &x.srtt)
+ m.Load("rttvar", &x.rttvar)
+}
+
+func (x *fastRecovery) beforeSave() {}
+func (x *fastRecovery) save(m state.Map) {
+ x.beforeSave()
+ m.Save("active", &x.active)
+ m.Save("first", &x.first)
+ m.Save("last", &x.last)
+ m.Save("maxCwnd", &x.maxCwnd)
+ m.Save("highRxt", &x.highRxt)
+ m.Save("rescueRxt", &x.rescueRxt)
+}
+
+func (x *fastRecovery) afterLoad() {}
+func (x *fastRecovery) load(m state.Map) {
+ m.Load("active", &x.active)
+ m.Load("first", &x.first)
+ m.Load("last", &x.last)
+ m.Load("maxCwnd", &x.maxCwnd)
+ m.Load("highRxt", &x.highRxt)
+ m.Load("rescueRxt", &x.rescueRxt)
+}
+
+func (x *unixTime) beforeSave() {}
+func (x *unixTime) save(m state.Map) {
+ x.beforeSave()
+ m.Save("second", &x.second)
+ m.Save("nano", &x.nano)
+}
+
+func (x *unixTime) afterLoad() {}
+func (x *unixTime) load(m state.Map) {
+ m.Load("second", &x.second)
+ m.Load("nano", &x.nano)
+}
+
+func (x *segmentList) beforeSave() {}
+func (x *segmentList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *segmentList) afterLoad() {}
+func (x *segmentList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *segmentEntry) beforeSave() {}
+func (x *segmentEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *segmentEntry) afterLoad() {}
+func (x *segmentEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("tcp.SACKInfo", (*SACKInfo)(nil), state.Fns{Save: (*SACKInfo).save, Load: (*SACKInfo).load})
+ state.Register("tcp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("tcp.keepalive", (*keepalive)(nil), state.Fns{Save: (*keepalive).save, Load: (*keepalive).load})
+ state.Register("tcp.receiver", (*receiver)(nil), state.Fns{Save: (*receiver).save, Load: (*receiver).load})
+ state.Register("tcp.renoState", (*renoState)(nil), state.Fns{Save: (*renoState).save, Load: (*renoState).load})
+ state.Register("tcp.SACKScoreboard", (*SACKScoreboard)(nil), state.Fns{Save: (*SACKScoreboard).save, Load: (*SACKScoreboard).load})
+ state.Register("tcp.segment", (*segment)(nil), state.Fns{Save: (*segment).save, Load: (*segment).load})
+ state.Register("tcp.segmentQueue", (*segmentQueue)(nil), state.Fns{Save: (*segmentQueue).save, Load: (*segmentQueue).load})
+ state.Register("tcp.sender", (*sender)(nil), state.Fns{Save: (*sender).save, Load: (*sender).load})
+ state.Register("tcp.rtt", (*rtt)(nil), state.Fns{Save: (*rtt).save, Load: (*rtt).load})
+ state.Register("tcp.fastRecovery", (*fastRecovery)(nil), state.Fns{Save: (*fastRecovery).save, Load: (*fastRecovery).load})
+ state.Register("tcp.unixTime", (*unixTime)(nil), state.Fns{Save: (*unixTime).save, Load: (*unixTime).load})
+ state.Register("tcp.segmentList", (*segmentList)(nil), state.Fns{Save: (*segmentList).save, Load: (*segmentList).load})
+ state.Register("tcp.segmentEntry", (*segmentEntry)(nil), state.Fns{Save: (*segmentEntry).save, Load: (*segmentEntry).load})
+}
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
new file mode 100644
index 000000000..fc1c7cbd2
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -0,0 +1,141 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+)
+
+type timerState int
+
+const (
+ timerStateDisabled timerState = iota
+ timerStateEnabled
+ timerStateOrphaned
+)
+
+// timer is a timer implementation that reduces the interactions with the
+// runtime timer infrastructure by letting timers run (and potentially
+// eventually expire) even if they are stopped. It makes it cheaper to
+// disable/reenable timers at the expense of spurious wakes. This is useful for
+// cases when the same timer is disabled/reenabled repeatedly with relatively
+// long timeouts farther into the future.
+//
+// TCP retransmit timers benefit from this because they the timeouts are long
+// (currently at least 200ms), and get disabled when acks are received, and
+// reenabled when new pending segments are sent.
+//
+// It is advantageous to avoid interacting with the runtime because it acquires
+// a global mutex and performs O(log n) operations, where n is the global number
+// of timers, whenever a timer is enabled or disabled, and may make a syscall.
+//
+// This struct is thread-compatible.
+type timer struct {
+ // state is the current state of the timer, it can be one of the
+ // following values:
+ // disabled - the timer is disabled.
+ // orphaned - the timer is disabled, but the runtime timer is
+ // enabled, which means that it will evetually cause a
+ // spurious wake (unless it gets enabled again before
+ // then).
+ // enabled - the timer is enabled, but the runtime timer may be set
+ // to an earlier expiration time due to a previous
+ // orphaned state.
+ state timerState
+
+ // target is the expiration time of the current timer. It is only
+ // meaningful in the enabled state.
+ target time.Time
+
+ // runtimeTarget is the expiration time of the runtime timer. It is
+ // meaningful in the enabled and orphaned states.
+ runtimeTarget time.Time
+
+ // timer is the runtime timer used to wait on.
+ timer *time.Timer
+}
+
+// init initializes the timer. Once it expires, it the given waker will be
+// asserted.
+func (t *timer) init(w *sleep.Waker) {
+ t.state = timerStateDisabled
+
+ // Initialize a runtime timer that will assert the waker, then
+ // immediately stop it.
+ t.timer = time.AfterFunc(time.Hour, func() {
+ w.Assert()
+ })
+ t.timer.Stop()
+}
+
+// cleanup frees all resources associated with the timer.
+func (t *timer) cleanup() {
+ t.timer.Stop()
+}
+
+// checkExpiration checks if the given timer has actually expired, it should be
+// called whenever a sleeper wakes up due to the waker being asserted, and is
+// used to check if it's a supurious wake (due to a previously orphaned timer)
+// or a legitimate one.
+func (t *timer) checkExpiration() bool {
+ // Transition to fully disabled state if we're just consuming an
+ // orphaned timer.
+ if t.state == timerStateOrphaned {
+ t.state = timerStateDisabled
+ return false
+ }
+
+ // The timer is enabled, but it may have expired early. Check if that's
+ // the case, and if so, reset the runtime timer to the correct time.
+ now := time.Now()
+ if now.Before(t.target) {
+ t.runtimeTarget = t.target
+ t.timer.Reset(t.target.Sub(now))
+ return false
+ }
+
+ // The timer has actually expired, disable it for now and inform the
+ // caller.
+ t.state = timerStateDisabled
+ return true
+}
+
+// disable disables the timer, leaving it in an orphaned state if it wasn't
+// already disabled.
+func (t *timer) disable() {
+ if t.state != timerStateDisabled {
+ t.state = timerStateOrphaned
+ }
+}
+
+// enabled returns true if the timer is currently enabled, false otherwise.
+func (t *timer) enabled() bool {
+ return t.state == timerStateEnabled
+}
+
+// enable enables the timer, programming the runtime timer if necessary.
+func (t *timer) enable(d time.Duration) {
+ t.target = time.Now().Add(d)
+
+ // Check if we need to set the runtime timer.
+ if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) {
+ t.runtimeTarget = t.target
+ t.timer.Reset(d)
+ }
+
+ t.state = timerStateEnabled
+}
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
new file mode 100644
index 000000000..3d52a4f31
--- /dev/null
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -0,0 +1,1002 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp
+
+import (
+ "math"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// +stateify savable
+type udpPacket struct {
+ udpPacketEntry
+ senderAddress tcpip.FullAddress
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ timestamp int64
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View `state:"nosave"`
+}
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateConnected
+ stateClosed
+)
+
+// endpoint represents a UDP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+//
+// It implements tcpip.Endpoint.
+//
+// +stateify savable
+type endpoint struct {
+ // The following fields are initialized at creation time and do not
+ // change throughout the lifetime of the endpoint.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // The following fields are used to manage the receive queue, and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList udpPacketList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ id stack.TransportEndpointID
+ state endpointState
+ bindNICID tcpip.NICID
+ regNICID tcpip.NICID
+ route stack.Route `state:"manual"`
+ dstPort uint16
+ v6only bool
+ multicastTTL uint8
+ multicastAddr tcpip.Address
+ multicastNICID tcpip.NICID
+ multicastLoop bool
+ reusePort bool
+ broadcast bool
+
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+
+ // multicastMemberships that need to be remvoed when the endpoint is
+ // closed. Protected by the mu mutex.
+ multicastMemberships []multicastMembership
+
+ // effectiveNetProtos contains the network protocols actually in use. In
+ // most cases it will only contain "netProto", but in cases like IPv6
+ // endpoints with v6only set to false, this could include multiple
+ // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
+ // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
+ // address).
+ effectiveNetProtos []tcpip.NetworkProtocolNumber
+}
+
+// +stateify savable
+type multicastMembership struct {
+ nicID tcpip.NICID
+ multicastAddr tcpip.Address
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+ return &endpoint{
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ // RFC 1075 section 5.4 recommends a TTL of 1 for membership
+ // requests.
+ //
+ // RFC 5135 4.2.1 appears to assume that IGMP messages have a
+ // TTL of 1.
+ //
+ // RFC 5135 Appendix A defines TTL=1: A multicast source that
+ // wants its traffic to not traverse a router (e.g., leave a
+ // home network) may find it useful to send traffic with IP
+ // TTL=1.
+ //
+ // Linux defaults to TTL=1.
+ multicastTTL: 1,
+ multicastLoop: true,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+}
+
+// Close puts the endpoint in a closed state and frees all resources
+// associated with it.
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite
+
+ switch e.state {
+ case stateBound, stateConnected:
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ }
+
+ for _, mem := range e.multicastMemberships {
+ e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
+ }
+ e.multicastMemberships = nil
+
+ // Close the receive list and drain it.
+ e.rcvMu.Lock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
+ e.rcvMu.Unlock()
+
+ e.route.Release()
+
+ // Update the state.
+ e.state = stateClosed
+
+ e.mu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut)
+}
+
+// Read reads data from the endpoint. This method does not block if
+// there is no data pending.
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+ e.rcvMu.Lock()
+
+ if e.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if e.rcvClosed {
+ err = tcpip.ErrClosedForReceive
+ }
+ e.rcvMu.Unlock()
+ return buffer.View{}, tcpip.ControlMessages{}, err
+ }
+
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+
+ e.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = p.senderAddress
+ }
+
+ return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+}
+
+// prepareForWrite prepares the endpoint for sending data. In particular, it
+// binds it if it's still in the initial state. To do so, it must first
+// reacquire the mutex in exclusive mode.
+//
+// Returns true for retry if preparation should be retried.
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+ switch e.state {
+ case stateInitial:
+ case stateConnected:
+ return false, nil
+
+ case stateBound:
+ if to == nil {
+ return false, tcpip.ErrDestinationRequired
+ }
+ return false, nil
+ default:
+ return false, tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // The state changed when we released the shared locked and re-acquired
+ // it in exclusive mode. Try again.
+ if e.state != stateInitial {
+ return true, nil
+ }
+
+ // The state is still 'initial', so try to bind the endpoint.
+ if err := e.bindLocked(tcpip.FullAddress{}); err != nil {
+ return false, err
+ }
+
+ return true, nil
+}
+
+// connectRoute establishes a route to the specified interface or the
+// configured multicast interface if no interface is specified and the
+// specified address is a multicast address.
+func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stack.Route, tcpip.NICID, tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return stack.Route{}, 0, 0, err
+ }
+
+ localAddr := e.id.LocalAddress
+ if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) {
+ if nicid == 0 {
+ nicid = e.multicastNICID
+ }
+ if localAddr == "" {
+ localAddr = e.multicastAddr
+ }
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop)
+ if err != nil {
+ return stack.Route{}, 0, 0, err
+ }
+ return r, nicid, netProto, nil
+}
+
+// Write writes data to the endpoint's peer. This method does not block
+// if the data cannot be written.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) {
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, nil, tcpip.ErrInvalidOptionValue
+ }
+
+ if p.Size() > math.MaxUint16 {
+ // Payload can't possibly fit in a packet.
+ return 0, nil, tcpip.ErrMessageTooLong
+ }
+
+ to := opts.To
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // If we've shutdown with SHUT_WR we are in an invalid state for sending.
+ if e.shutdownFlags&tcpip.ShutdownWrite != 0 {
+ return 0, nil, tcpip.ErrClosedForSend
+ }
+
+ // Prepare for write.
+ for {
+ retry, err := e.prepareForWrite(to)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ if !retry {
+ break
+ }
+ }
+
+ var route *stack.Route
+ var dstPort uint16
+ if to == nil {
+ route = &e.route
+ dstPort = e.dstPort
+
+ if route.IsResolutionRequired() {
+ // Promote lock to exclusive if using a shared route, given that it may need to
+ // change in Route.Resolve() call below.
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != stateConnected {
+ return 0, nil, tcpip.ErrInvalidEndpointState
+ }
+ }
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicid := to.NIC
+ if e.bindNICID != 0 {
+ if nicid != 0 && nicid != e.bindNICID {
+ return 0, nil, tcpip.ErrNoRoute
+ }
+
+ nicid = e.bindNICID
+ }
+
+ if to.Addr == header.IPv4Broadcast && !e.broadcast {
+ return 0, nil, tcpip.ErrBroadcastDisabled
+ }
+
+ r, _, _, err := e.connectRoute(nicid, *to)
+ if err != nil {
+ return 0, nil, err
+ }
+ defer r.Release()
+
+ route = &r
+ dstPort = to.Port
+ }
+
+ if route.IsResolutionRequired() {
+ if ch, err := route.Resolve(nil); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ return 0, ch, tcpip.ErrNoLinkAddress
+ }
+ return 0, nil, err
+ }
+ }
+
+ v, err := p.Get(p.Size())
+ if err != nil {
+ return 0, nil, err
+ }
+
+ ttl := route.DefaultTTL()
+ if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) {
+ ttl = e.multicastTTL
+ }
+
+ if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(len(v)), nil, nil
+}
+
+// Peek only returns data from a single datagram, so do nothing here.
+func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) {
+ return 0, tcpip.ControlMessages{}, nil
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // We only allow this to be set when we're in the initial state.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.v6only = v != 0
+
+ case tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ e.multicastTTL = uint8(v)
+ e.mu.Unlock()
+
+ case tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ fa := tcpip.FullAddress{Addr: v.InterfaceAddr}
+ netProto, err := e.checkV4Mapped(&fa, false)
+ if err != nil {
+ return err
+ }
+ nic := v.NIC
+ addr := fa.Addr
+
+ if nic == 0 && addr == "" {
+ e.multicastAddr = ""
+ e.multicastNICID = 0
+ break
+ }
+
+ if nic != 0 {
+ if !e.stack.CheckNIC(nic) {
+ return tcpip.ErrBadLocalAddress
+ }
+ } else {
+ nic = e.stack.CheckLocalAddress(0, netProto, addr)
+ if nic == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ if e.bindNICID != 0 && e.bindNICID != nic {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.multicastNICID = nic
+ e.multicastAddr = addr
+
+ case tcpip.AddMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ nicID := v.NIC
+ if v.InterfaceAddr == header.IPv4Any {
+ if nicID == 0 {
+ r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrUnknownDevice
+ }
+
+ memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ for _, mem := range e.multicastMemberships {
+ if mem == memToInsert {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships = append(e.multicastMemberships, memToInsert)
+
+ case tcpip.RemoveMembershipOption:
+ if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) {
+ return tcpip.ErrInvalidOptionValue
+ }
+
+ nicID := v.NIC
+ if v.InterfaceAddr == header.IPv4Any {
+ if nicID == 0 {
+ r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */)
+ if err == nil {
+ nicID = r.NICID()
+ r.Release()
+ }
+ }
+ } else {
+ nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
+ }
+ if nicID == 0 {
+ return tcpip.ErrUnknownDevice
+ }
+
+ memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr}
+ memToRemoveIndex := -1
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ for i, mem := range e.multicastMemberships {
+ if mem == memToRemove {
+ memToRemoveIndex = i
+ break
+ }
+ }
+ if memToRemoveIndex == -1 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
+ return err
+ }
+
+ e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1]
+ e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1]
+
+ case tcpip.MulticastLoopOption:
+ e.mu.Lock()
+ e.multicastLoop = bool(v)
+ e.mu.Unlock()
+
+ case tcpip.ReusePortOption:
+ e.mu.Lock()
+ e.reusePort = v != 0
+ e.mu.Unlock()
+
+ case tcpip.BroadcastOption:
+ e.mu.Lock()
+ e.broadcast = v != 0
+ e.mu.Unlock()
+
+ return nil
+ }
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ *o = tcpip.SendBufferSizeOption(e.sndBufSize)
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ e.rcvMu.Lock()
+ if e.rcvList.Empty() {
+ *o = 0
+ } else {
+ p := e.rcvList.Front()
+ *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
+ }
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.MulticastTTLOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastTTLOption(e.multicastTTL)
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.MulticastInterfaceOption:
+ e.mu.Lock()
+ *o = tcpip.MulticastInterfaceOption{
+ e.multicastNICID,
+ e.multicastAddr,
+ }
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.MulticastLoopOption:
+ e.mu.RLock()
+ v := e.multicastLoop
+ e.mu.RUnlock()
+
+ *o = tcpip.MulticastLoopOption(v)
+ return nil
+
+ case *tcpip.ReusePortOption:
+ e.mu.RLock()
+ v := e.reusePort
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.KeepaliveEnabledOption:
+ *o = 0
+ return nil
+
+ case *tcpip.BroadcastOption:
+ e.mu.RLock()
+ v := e.broadcast
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// sendUDP sends a UDP segment via the provided network endpoint and under the
+// provided identity.
+func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error {
+ // Allocate a buffer for the UDP header.
+ hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
+
+ // Initialize the header.
+ udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+
+ length := uint16(hdr.UsedLength() + data.Size())
+ udp.Encode(&header.UDPFields{
+ SrcPort: localPort,
+ DstPort: remotePort,
+ Length: length,
+ })
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber, length)
+ for _, v := range data.Views() {
+ xsum = header.Checksum(v, xsum)
+ }
+ udp.SetChecksum(^udp.CalculateChecksum(xsum))
+ }
+
+ // Track count of packets sent.
+ r.Stats().UDP.PacketsSent.Increment()
+
+ return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl)
+}
+
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.netProto
+ if header.IsV4MappedAddress(addr.Addr) {
+ // Fail if using a v4 mapped address on a v6only endpoint.
+ if e.v6only {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ netProto = header.IPv4ProtocolNumber
+ addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
+ if addr.Addr == "\x00\x00\x00\x00" {
+ addr.Addr = ""
+ }
+
+ // Fail if we are bound to an IPv6 address.
+ if !allowMismatch && len(e.id.LocalAddress) == 16 {
+ return 0, tcpip.ErrNetworkUnreachable
+ }
+ }
+
+ // Fail if we're bound to an address length different from the one we're
+ // checking.
+ if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return netProto, nil
+}
+
+// Connect connects the endpoint to its peer. Specifying a NIC is optional.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ if addr.Port == 0 {
+ // We don't support connecting to port zero.
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ nicid := addr.NIC
+ var localPort uint16
+ switch e.state {
+ case stateInitial:
+ case stateBound, stateConnected:
+ localPort = e.id.LocalPort
+ if e.bindNICID == 0 {
+ break
+ }
+
+ if nicid != 0 && nicid != e.bindNICID {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nicid = e.bindNICID
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ r, nicid, netProto, err := e.connectRoute(nicid, addr)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ id := stack.TransportEndpointID{
+ LocalAddress: r.LocalAddress,
+ LocalPort: localPort,
+ RemotePort: addr.Port,
+ RemoteAddress: r.RemoteAddress,
+ }
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv4ProtocolNumber,
+ header.IPv6ProtocolNumber,
+ }
+ }
+
+ id, err = e.registerWithStack(nicid, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ // Remove the old registration.
+ if e.id.LocalPort != 0 {
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e)
+ }
+
+ e.id = id
+ e.route = r.Clone()
+ e.dstPort = addr.Port
+ e.regNICID = nicid
+ e.effectiveNetProtos = netProtos
+
+ e.state = stateConnected
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection
+// to its peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // A socket in the bound state can still receive multicast messages,
+ // so we need to notify waiters on shutdown.
+ if e.state != stateBound && e.state != stateConnected {
+ return tcpip.ErrNotConnected
+ }
+
+ e.shutdownFlags |= flags
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.rcvMu.Lock()
+ wasClosed := e.rcvClosed
+ e.rcvClosed = true
+ e.rcvMu.Unlock()
+
+ if !wasClosed {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ }
+
+ return nil
+}
+
+// Listen is not supported by UDP, it just fails.
+func (*endpoint) Listen(int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept is not supported by UDP, it just fails.
+func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+ if e.id.LocalPort == 0 {
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort)
+ if err != nil {
+ return id, err
+ }
+ id.LocalPort = port
+ }
+
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort)
+ if err != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort)
+ }
+ return id, err
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error {
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, true)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
+ }
+
+ nicid := addr.NIC
+ if len(addr.Addr) != 0 {
+ // A local address was specified, verify that it's valid.
+ nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ if nicid == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err = e.registerWithStack(nicid, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ e.id = id
+ e.regNICID = nicid
+ e.effectiveNetProtos = netProtos
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// Bind binds the endpoint to a specific local address and port.
+// Specifying a NIC is optional.
+func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ err := e.bindLocked(addr)
+ if err != nil {
+ return err
+ }
+
+ // Save the effective NICID generated by bindLocked.
+ e.bindNICID = e.regNICID
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.LocalAddress,
+ Port: e.id.LocalPort,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ }, nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) {
+ // Get the header then trim it from the view.
+ hdr := header.UDP(vv.First())
+ if int(hdr.Length()) > vv.Size() {
+ // Malformed packet.
+ e.stack.Stats().UDP.MalformedPacketsReceived.Increment()
+ return
+ }
+
+ vv.TrimFront(header.UDPMinimumSize)
+
+ e.rcvMu.Lock()
+ e.stack.Stats().UDP.PacketsReceived.Increment()
+
+ // Drop the packet if our buffer is currently full.
+ if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ e.stack.Stats().UDP.ReceiveBufferErrors.Increment()
+ e.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := e.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ pkt := &udpPacket{
+ senderAddress: tcpip.FullAddress{
+ NIC: r.NICID(),
+ Addr: id.RemoteAddress,
+ Port: hdr.SourcePort(),
+ },
+ }
+ pkt.data = vv.Clone(pkt.views[:])
+ e.rcvList.PushBack(pkt)
+ e.rcvBufSize += vv.Size()
+
+ pkt.timestamp = e.stack.NowNanoseconds()
+
+ e.rcvMu.Unlock()
+
+ // Notify any waiters that there's data to be read now.
+ if wasEmpty {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) {
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
new file mode 100644
index 000000000..74e8e9fd5
--- /dev/null
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -0,0 +1,112 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves udpPacket.data field.
+func (u *udpPacket) saveData() buffer.VectorisedView {
+ // We cannot save u.data directly as u.data.views may alias to u.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return u.data.Clone(nil)
+}
+
+// loadData loads udpPacket.data field.
+func (u *udpPacket) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing u.views for data.views.
+ u.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after savercvBufSizeMax(), which would have
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ e.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ e.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ e.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ e.stack = stack.StackFromEnv
+
+ for _, m := range e.multicastMemberships {
+ if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
+ panic(err)
+ }
+ }
+
+ if e.state != stateBound && e.state != stateConnected {
+ return
+ }
+
+ netProto := e.effectiveNetProtos[0]
+ // Connect() and bindLocked() both assert
+ //
+ // netProto == header.IPv6ProtocolNumber
+ //
+ // before creating a multi-entry effectiveNetProtos.
+ if len(e.effectiveNetProtos) > 1 {
+ netProto = header.IPv6ProtocolNumber
+ }
+
+ var err *tcpip.Error
+ if e.state == stateConnected {
+ e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop)
+ if err != nil {
+ panic(*err)
+ }
+
+ e.id.LocalAddress = e.route.LocalAddress
+ } else if len(e.id.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ // Our saved state had a port, but we don't actually have a
+ // reservation. We need to remove the port from our state, but still
+ // pass it to the reservation machinery.
+ id := e.id
+ e.id.LocalPort = 0
+ e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id)
+ if err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go
new file mode 100644
index 000000000..25bdd2929
--- /dev/null
+++ b/pkg/tcpip/transport/udp/forwarder.go
@@ -0,0 +1,96 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Forwarder is a session request forwarder, which allows clients to decide
+// what to do with a session request, for example: ignore it, or process it.
+//
+// The canonical way of using it is to pass the Forwarder.HandlePacket function
+// to stack.SetTransportProtocolHandler.
+type Forwarder struct {
+ handler func(*ForwarderRequest)
+
+ stack *stack.Stack
+}
+
+// NewForwarder allocates and initializes a new forwarder.
+func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder {
+ return &Forwarder{
+ stack: s,
+ handler: handler,
+ }
+}
+
+// HandlePacket handles all packets.
+//
+// This function is expected to be passed as an argument to the
+// stack.SetTransportProtocolHandler function.
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool {
+ f.handler(&ForwarderRequest{
+ stack: f.stack,
+ route: r,
+ id: id,
+ vv: vv,
+ })
+
+ return true
+}
+
+// ForwarderRequest represents a session request received by the forwarder and
+// passed to the client. Clients may optionally create an endpoint to represent
+// it via CreateEndpoint.
+type ForwarderRequest struct {
+ stack *stack.Stack
+ route *stack.Route
+ id stack.TransportEndpointID
+ vv buffer.VectorisedView
+}
+
+// ID returns the 4-tuple (src address, src port, dst address, dst port) that
+// represents the session request.
+func (r *ForwarderRequest) ID() stack.TransportEndpointID {
+ return r.id
+}
+
+// CreateEndpoint creates a connected UDP endpoint for the session request.
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ ep := newEndpoint(r.stack, r.route.NetProto, queue)
+ if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ ep.id = r.id
+ ep.route = r.route.Clone()
+ ep.dstPort = r.id.RemotePort
+ ep.regNICID = r.route.NICID()
+
+ ep.state = stateConnected
+
+ ep.rcvMu.Lock()
+ ep.rcvReady = true
+ ep.rcvMu.Unlock()
+
+ ep.HandlePacket(r.route, r.id, r.vv)
+
+ return ep, nil
+}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
new file mode 100644
index 000000000..3d31dfbf1
--- /dev/null
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -0,0 +1,90 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package udp contains the implementation of the UDP transport protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
+// transport protocols when calling stack.New(). Then endpoints can be created
+// by passing udp.ProtocolNumber as the transport protocol number when calling
+// Stack.NewEndpoint().
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolName is the string representation of the udp protocol name.
+ ProtocolName = "udp"
+
+ // ProtocolNumber is the udp protocol number.
+ ProtocolNumber = header.UDPProtocolNumber
+)
+
+type protocol struct{}
+
+// Number returns the udp protocol number.
+func (*protocol) Number() tcpip.TransportProtocolNumber {
+ return ProtocolNumber
+}
+
+// NewEndpoint creates a new udp endpoint.
+func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, waiterQueue), nil
+}
+
+// NewRawEndpoint creates a new raw UDP endpoint. It implements
+// stack.TransportProtocol.NewRawEndpoint.
+func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue)
+}
+
+// MinimumPacketSize returns the minimum valid udp packet size.
+func (*protocol) MinimumPacketSize() int {
+ return header.UDPMinimumSize
+}
+
+// ParsePorts returns the source and destination ports stored in the given udp
+// packet.
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ h := header.UDP(v)
+ return h.SourcePort(), h.DestinationPort(), nil
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool {
+ return true
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go
new file mode 100755
index 000000000..673a9373b
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_packet_list.go
@@ -0,0 +1,173 @@
+package udp
+
+// ElementMapper provides an identity mapping by default.
+//
+// This can be replaced to provide a struct that maps elements to linker
+// objects, if they are not the same. An ElementMapper is not typically
+// required if: Linker is left as is, Element is left as is, or Linker and
+// Element are the same type.
+type udpPacketElementMapper struct{}
+
+// linkerFor maps an Element to a Linker.
+//
+// This default implementation should be inlined.
+//
+//go:nosplit
+func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem }
+
+// List is an intrusive list. Entries can be added to or removed from the list
+// in O(1) time and with no additional memory allocations.
+//
+// The zero value for List is an empty list ready to use.
+//
+// To iterate over a list (where l is a List):
+// for e := l.Front(); e != nil; e = e.Next() {
+// // do something with e.
+// }
+//
+// +stateify savable
+type udpPacketList struct {
+ head *udpPacket
+ tail *udpPacket
+}
+
+// Reset resets list l to the empty state.
+func (l *udpPacketList) Reset() {
+ l.head = nil
+ l.tail = nil
+}
+
+// Empty returns true iff the list is empty.
+func (l *udpPacketList) Empty() bool {
+ return l.head == nil
+}
+
+// Front returns the first element of list l or nil.
+func (l *udpPacketList) Front() *udpPacket {
+ return l.head
+}
+
+// Back returns the last element of list l or nil.
+func (l *udpPacketList) Back() *udpPacket {
+ return l.tail
+}
+
+// PushFront inserts the element e at the front of list l.
+func (l *udpPacketList) PushFront(e *udpPacket) {
+ udpPacketElementMapper{}.linkerFor(e).SetNext(l.head)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(nil)
+
+ if l.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+
+ l.head = e
+}
+
+// PushBack inserts the element e at the back of list l.
+func (l *udpPacketList) PushBack(e *udpPacket) {
+ udpPacketElementMapper{}.linkerFor(e).SetNext(nil)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail)
+
+ if l.tail != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e)
+ } else {
+ l.head = e
+ }
+
+ l.tail = e
+}
+
+// PushBackList inserts list m at the end of list l, emptying m.
+func (l *udpPacketList) PushBackList(m *udpPacketList) {
+ if l.head == nil {
+ l.head = m.head
+ l.tail = m.tail
+ } else if m.head != nil {
+ udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head)
+ udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail)
+
+ l.tail = m.tail
+ }
+
+ m.head = nil
+ m.tail = nil
+}
+
+// InsertAfter inserts e after b.
+func (l *udpPacketList) InsertAfter(b, e *udpPacket) {
+ a := udpPacketElementMapper{}.linkerFor(b).Next()
+ udpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ udpPacketElementMapper{}.linkerFor(b).SetNext(e)
+
+ if a != nil {
+ udpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+ } else {
+ l.tail = e
+ }
+}
+
+// InsertBefore inserts e before a.
+func (l *udpPacketList) InsertBefore(a, e *udpPacket) {
+ b := udpPacketElementMapper{}.linkerFor(a).Prev()
+ udpPacketElementMapper{}.linkerFor(e).SetNext(a)
+ udpPacketElementMapper{}.linkerFor(e).SetPrev(b)
+ udpPacketElementMapper{}.linkerFor(a).SetPrev(e)
+
+ if b != nil {
+ udpPacketElementMapper{}.linkerFor(b).SetNext(e)
+ } else {
+ l.head = e
+ }
+}
+
+// Remove removes e from l.
+func (l *udpPacketList) Remove(e *udpPacket) {
+ prev := udpPacketElementMapper{}.linkerFor(e).Prev()
+ next := udpPacketElementMapper{}.linkerFor(e).Next()
+
+ if prev != nil {
+ udpPacketElementMapper{}.linkerFor(prev).SetNext(next)
+ } else {
+ l.head = next
+ }
+
+ if next != nil {
+ udpPacketElementMapper{}.linkerFor(next).SetPrev(prev)
+ } else {
+ l.tail = prev
+ }
+}
+
+// Entry is a default implementation of Linker. Users can add anonymous fields
+// of this type to their structs to make them automatically implement the
+// methods needed by List.
+//
+// +stateify savable
+type udpPacketEntry struct {
+ next *udpPacket
+ prev *udpPacket
+}
+
+// Next returns the entry that follows e in the list.
+func (e *udpPacketEntry) Next() *udpPacket {
+ return e.next
+}
+
+// Prev returns the entry that precedes e in the list.
+func (e *udpPacketEntry) Prev() *udpPacket {
+ return e.prev
+}
+
+// SetNext assigns 'entry' as the entry that follows e in the list.
+func (e *udpPacketEntry) SetNext(elem *udpPacket) {
+ e.next = elem
+}
+
+// SetPrev assigns 'entry' as the entry that precedes e in the list.
+func (e *udpPacketEntry) SetPrev(elem *udpPacket) {
+ e.prev = elem
+}
diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go
new file mode 100755
index 000000000..711e2feeb
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_state_autogen.go
@@ -0,0 +1,128 @@
+// automatically generated by stateify.
+
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+func (x *udpPacket) beforeSave() {}
+func (x *udpPacket) save(m state.Map) {
+ x.beforeSave()
+ var data buffer.VectorisedView = x.saveData()
+ m.SaveValue("data", data)
+ m.Save("udpPacketEntry", &x.udpPacketEntry)
+ m.Save("senderAddress", &x.senderAddress)
+ m.Save("timestamp", &x.timestamp)
+}
+
+func (x *udpPacket) afterLoad() {}
+func (x *udpPacket) load(m state.Map) {
+ m.Load("udpPacketEntry", &x.udpPacketEntry)
+ m.Load("senderAddress", &x.senderAddress)
+ m.Load("timestamp", &x.timestamp)
+ m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) })
+}
+
+func (x *endpoint) save(m state.Map) {
+ x.beforeSave()
+ var rcvBufSizeMax int = x.saveRcvBufSizeMax()
+ m.SaveValue("rcvBufSizeMax", rcvBufSizeMax)
+ m.Save("netProto", &x.netProto)
+ m.Save("waiterQueue", &x.waiterQueue)
+ m.Save("rcvReady", &x.rcvReady)
+ m.Save("rcvList", &x.rcvList)
+ m.Save("rcvBufSize", &x.rcvBufSize)
+ m.Save("rcvClosed", &x.rcvClosed)
+ m.Save("sndBufSize", &x.sndBufSize)
+ m.Save("id", &x.id)
+ m.Save("state", &x.state)
+ m.Save("bindNICID", &x.bindNICID)
+ m.Save("regNICID", &x.regNICID)
+ m.Save("dstPort", &x.dstPort)
+ m.Save("v6only", &x.v6only)
+ m.Save("multicastTTL", &x.multicastTTL)
+ m.Save("multicastAddr", &x.multicastAddr)
+ m.Save("multicastNICID", &x.multicastNICID)
+ m.Save("multicastLoop", &x.multicastLoop)
+ m.Save("reusePort", &x.reusePort)
+ m.Save("broadcast", &x.broadcast)
+ m.Save("shutdownFlags", &x.shutdownFlags)
+ m.Save("multicastMemberships", &x.multicastMemberships)
+ m.Save("effectiveNetProtos", &x.effectiveNetProtos)
+}
+
+func (x *endpoint) load(m state.Map) {
+ m.Load("netProto", &x.netProto)
+ m.Load("waiterQueue", &x.waiterQueue)
+ m.Load("rcvReady", &x.rcvReady)
+ m.Load("rcvList", &x.rcvList)
+ m.Load("rcvBufSize", &x.rcvBufSize)
+ m.Load("rcvClosed", &x.rcvClosed)
+ m.Load("sndBufSize", &x.sndBufSize)
+ m.Load("id", &x.id)
+ m.Load("state", &x.state)
+ m.Load("bindNICID", &x.bindNICID)
+ m.Load("regNICID", &x.regNICID)
+ m.Load("dstPort", &x.dstPort)
+ m.Load("v6only", &x.v6only)
+ m.Load("multicastTTL", &x.multicastTTL)
+ m.Load("multicastAddr", &x.multicastAddr)
+ m.Load("multicastNICID", &x.multicastNICID)
+ m.Load("multicastLoop", &x.multicastLoop)
+ m.Load("reusePort", &x.reusePort)
+ m.Load("broadcast", &x.broadcast)
+ m.Load("shutdownFlags", &x.shutdownFlags)
+ m.Load("multicastMemberships", &x.multicastMemberships)
+ m.Load("effectiveNetProtos", &x.effectiveNetProtos)
+ m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) })
+ m.AfterLoad(x.afterLoad)
+}
+
+func (x *multicastMembership) beforeSave() {}
+func (x *multicastMembership) save(m state.Map) {
+ x.beforeSave()
+ m.Save("nicID", &x.nicID)
+ m.Save("multicastAddr", &x.multicastAddr)
+}
+
+func (x *multicastMembership) afterLoad() {}
+func (x *multicastMembership) load(m state.Map) {
+ m.Load("nicID", &x.nicID)
+ m.Load("multicastAddr", &x.multicastAddr)
+}
+
+func (x *udpPacketList) beforeSave() {}
+func (x *udpPacketList) save(m state.Map) {
+ x.beforeSave()
+ m.Save("head", &x.head)
+ m.Save("tail", &x.tail)
+}
+
+func (x *udpPacketList) afterLoad() {}
+func (x *udpPacketList) load(m state.Map) {
+ m.Load("head", &x.head)
+ m.Load("tail", &x.tail)
+}
+
+func (x *udpPacketEntry) beforeSave() {}
+func (x *udpPacketEntry) save(m state.Map) {
+ x.beforeSave()
+ m.Save("next", &x.next)
+ m.Save("prev", &x.prev)
+}
+
+func (x *udpPacketEntry) afterLoad() {}
+func (x *udpPacketEntry) load(m state.Map) {
+ m.Load("next", &x.next)
+ m.Load("prev", &x.prev)
+}
+
+func init() {
+ state.Register("udp.udpPacket", (*udpPacket)(nil), state.Fns{Save: (*udpPacket).save, Load: (*udpPacket).load})
+ state.Register("udp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load})
+ state.Register("udp.multicastMembership", (*multicastMembership)(nil), state.Fns{Save: (*multicastMembership).save, Load: (*multicastMembership).load})
+ state.Register("udp.udpPacketList", (*udpPacketList)(nil), state.Fns{Save: (*udpPacketList).save, Load: (*udpPacketList).load})
+ state.Register("udp.udpPacketEntry", (*udpPacketEntry)(nil), state.Fns{Save: (*udpPacketEntry).save, Load: (*udpPacketEntry).load})
+}