diff options
author | gVisor bot <gvisor-bot@google.com> | 2019-06-02 06:44:55 +0000 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2019-06-02 06:44:55 +0000 |
commit | ceb0d792f328d1fc0692197d8856a43c3936a571 (patch) | |
tree | 83155f302eff44a78bcc30a3a08f4efe59a79379 /pkg/tcpip/transport | |
parent | deb7ecf1e46862d54f4b102f2d163cfbcfc37f3b (diff) | |
parent | 216da0b733dbed9aad9b2ab92ac75bcb906fd7ee (diff) |
Merge 216da0b7 (automated)
Diffstat (limited to 'pkg/tcpip/transport')
35 files changed, 11074 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go new file mode 100644 index 000000000..e2b90ef10 --- /dev/null +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -0,0 +1,710 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package icmp + +import ( + "encoding/binary" + "sync" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// +stateify savable +type icmpPacket struct { + icmpPacketEntry + senderAddress tcpip.FullAddress + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + timestamp int64 + // views is used as buffer for data when its length is large + // enough to store a VectorisedView. + views [8]buffer.View `state:"nosave"` +} + +type endpointState int + +const ( + stateInitial endpointState = iota + stateBound + stateConnected + stateClosed +) + +// endpoint represents an ICMP endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. +// +// +stateify savable +type endpoint struct { + // The following fields are initialized at creation time and are + // immutable. + stack *stack.Stack `state:"manual"` + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + waiterQueue *waiter.Queue + + // The following fields are used to manage the receive queue, and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList icmpPacketList + rcvBufSizeMax int `state:".(int)"` + rcvBufSize int + rcvClosed bool + + // The following fields are protected by the mu mutex. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + id stack.TransportEndpointID + state endpointState + // bindNICID and bindAddr are set via calls to Bind(). They are used to + // reject attempts to send data or connect via a different NIC or + // address + bindNICID tcpip.NICID + bindAddr tcpip.Address + // regNICID is the default NIC to be used when callers don't specify a + // NIC. + regNICID tcpip.NICID + route stack.Route `state:"manual"` +} + +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return &endpoint{ + stack: stack, + netProto: netProto, + transProto: transProto, + waiterQueue: waiterQueue, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + }, nil +} + +// Close puts the endpoint in a closed state and frees all resources +// associated with it. +func (e *endpoint) Close() { + e.mu.Lock() + e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite + switch e.state { + case stateBound, stateConnected: + e.stack.UnregisterTransportEndpoint(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.transProto, e.id, e) + } + + // Close the receive list and drain it. + e.rcvMu.Lock() + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } + e.rcvMu.Unlock() + + e.route.Release() + + // Update the state. + e.state = stateClosed + + e.mu.Unlock() + + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// Read reads data from the endpoint. This method does not block if +// there is no data pending. +func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + e.rcvMu.Lock() + + if e.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if e.rcvClosed { + err = tcpip.ErrClosedForReceive + } + e.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + p := e.rcvList.Front() + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() + + e.rcvMu.Unlock() + + if addr != nil { + *addr = p.senderAddress + } + + return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil +} + +// prepareForWrite prepares the endpoint for sending data. In particular, it +// binds it if it's still in the initial state. To do so, it must first +// reacquire the mutex in exclusive mode. +// +// Returns true for retry if preparation should be retried. +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { + switch e.state { + case stateInitial: + case stateConnected: + return false, nil + + case stateBound: + if to == nil { + return false, tcpip.ErrDestinationRequired + } + return false, nil + default: + return false, tcpip.ErrInvalidEndpointState + } + + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // The state changed when we released the shared locked and re-acquired + // it in exclusive mode. Try again. + if e.state != stateInitial { + return true, nil + } + + // The state is still 'initial', so try to bind the endpoint. + if err := e.bindLocked(tcpip.FullAddress{}); err != nil { + return false, err + } + + return true, nil +} + +// Write writes data to the endpoint's peer. This method does not block +// if the data cannot be written. +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) + if opts.More { + return 0, nil, tcpip.ErrInvalidOptionValue + } + + to := opts.To + + e.mu.RLock() + defer e.mu.RUnlock() + + // If we've shutdown with SHUT_WR we are in an invalid state for sending. + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { + return 0, nil, tcpip.ErrClosedForSend + } + + // Prepare for write. + for { + retry, err := e.prepareForWrite(to) + if err != nil { + return 0, nil, err + } + + if !retry { + break + } + } + + var route *stack.Route + if to == nil { + route = &e.route + + if route.IsResolutionRequired() { + // Promote lock to exclusive if using a shared route, + // given that it may need to change in Route.Resolve() + // call below. + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // Recheck state after lock was re-acquired. + if e.state != stateConnected { + return 0, nil, tcpip.ErrInvalidEndpointState + } + } + } else { + // Reject destination address if it goes through a different + // NIC than the endpoint was bound to. + nicid := to.NIC + if e.bindNICID != 0 { + if nicid != 0 && nicid != e.bindNICID { + return 0, nil, tcpip.ErrNoRoute + } + + nicid = e.bindNICID + } + + toCopy := *to + to = &toCopy + netProto, err := e.checkV4Mapped(to, true) + if err != nil { + return 0, nil, err + } + + // Find the enpoint. + r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto, false /* multicastLoop */) + if err != nil { + return 0, nil, err + } + defer r.Release() + + route = &r + } + + if route.IsResolutionRequired() { + if ch, err := route.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + return 0, ch, tcpip.ErrNoLinkAddress + } + return 0, nil, err + } + } + + v, err := p.Get(p.Size()) + if err != nil { + return 0, nil, err + } + + switch e.netProto { + case header.IPv4ProtocolNumber: + err = e.send4(route, v) + + case header.IPv6ProtocolNumber: + err = send6(route, e.id.LocalPort, v) + } + + if err != nil { + return 0, nil, err + } + + return uintptr(len(v)), nil, nil +} + +// Peek only returns data from a single datagram, so do nothing here. +func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// SetSockOpt sets a socket option. Currently not supported. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return nil +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.SendBufferSizeOption: + e.mu.Lock() + *o = tcpip.SendBufferSizeOption(e.sndBufSize) + e.mu.Unlock() + return nil + + case *tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) + e.rcvMu.Unlock() + return nil + + case *tcpip.ReceiveQueueSizeOption: + e.rcvMu.Lock() + if e.rcvList.Empty() { + *o = 0 + } else { + p := e.rcvList.Front() + *o = tcpip.ReceiveQueueSizeOption(p.data.Size()) + } + e.rcvMu.Unlock() + return nil + + case *tcpip.KeepaliveEnabledOption: + *o = 0 + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func (e *endpoint) send4(r *stack.Route, data buffer.View) *tcpip.Error { + if len(data) < header.ICMPv4EchoMinimumSize { + return tcpip.ErrInvalidEndpointState + } + + // Set the ident to the user-specified port. Sequence number should + // already be set by the user. + binary.BigEndian.PutUint16(data[header.ICMPv4MinimumSize:], e.id.LocalPort) + + hdr := buffer.NewPrependable(header.ICMPv4EchoMinimumSize + int(r.MaxHeaderLength())) + + icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4EchoMinimumSize)) + copy(icmpv4, data) + data = data[header.ICMPv4EchoMinimumSize:] + + // Linux performs these basic checks. + if icmpv4.Type() != header.ICMPv4Echo || icmpv4.Code() != 0 { + return tcpip.ErrInvalidEndpointState + } + + icmpv4.SetChecksum(0) + icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0))) + + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv4ProtocolNumber, r.DefaultTTL()) +} + +func send6(r *stack.Route, ident uint16, data buffer.View) *tcpip.Error { + if len(data) < header.ICMPv6EchoMinimumSize { + return tcpip.ErrInvalidEndpointState + } + + // Set the ident. Sequence number is provided by the user. + binary.BigEndian.PutUint16(data[header.ICMPv6MinimumSize:], ident) + + hdr := buffer.NewPrependable(header.ICMPv6EchoMinimumSize + int(r.MaxHeaderLength())) + + icmpv6 := header.ICMPv6(hdr.Prepend(header.ICMPv6EchoMinimumSize)) + copy(icmpv6, data) + data = data[header.ICMPv6EchoMinimumSize:] + + if icmpv6.Type() != header.ICMPv6EchoRequest || icmpv6.Code() != 0 { + return tcpip.ErrInvalidEndpointState + } + + icmpv6.SetChecksum(0) + icmpv6.SetChecksum(^header.Checksum(icmpv6, header.Checksum(data, 0))) + + return r.WritePacket(nil /* gso */, hdr, data.ToVectorisedView(), header.ICMPv6ProtocolNumber, r.DefaultTTL()) +} + +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.netProto + if header.IsV4MappedAddress(addr.Addr) { + return 0, tcpip.ErrNoRoute + } + + // Fail if we're bound to an address length different from the one we're + // checking. + if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { + return 0, tcpip.ErrInvalidEndpointState + } + + return netProto, nil +} + +// Connect connects the endpoint to its peer. Specifying a NIC is optional. +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + nicid := addr.NIC + localPort := uint16(0) + switch e.state { + case stateBound, stateConnected: + localPort = e.id.LocalPort + if e.bindNICID == 0 { + break + } + + if nicid != 0 && nicid != e.bindNICID { + return tcpip.ErrInvalidEndpointState + } + + nicid = e.bindNICID + default: + return tcpip.ErrInvalidEndpointState + } + + netProto, err := e.checkV4Mapped(&addr, false) + if err != nil { + return err + } + + // Find a route to the desired destination. + r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + id := stack.TransportEndpointID{ + LocalAddress: r.LocalAddress, + LocalPort: localPort, + RemoteAddress: r.RemoteAddress, + } + + // Even if we're connected, this endpoint can still be used to send + // packets on a different network protocol, so we register both even if + // v6only is set to false and this is an ipv6 endpoint. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + + id, err = e.registerWithStack(nicid, netProtos, id) + if err != nil { + return err + } + + e.id = id + e.route = r.Clone() + e.regNICID = nicid + + e.state = stateConnected + + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// ConnectEndpoint is not supported. +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + +// Shutdown closes the read and/or write end of the endpoint connection +// to its peer. +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + e.shutdownFlags |= flags + + if e.state != stateConnected { + return tcpip.ErrNotConnected + } + + if flags&tcpip.ShutdownRead != 0 { + e.rcvMu.Lock() + wasClosed := e.rcvClosed + e.rcvClosed = true + e.rcvMu.Unlock() + + if !wasClosed { + e.waiterQueue.Notify(waiter.EventIn) + } + } + + return nil +} + +// Listen is not supported by UDP, it just fails. +func (*endpoint) Listen(int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept is not supported by UDP, it just fails. +func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { + if id.LocalPort != 0 { + // The endpoint already has a local port, just attempt to + // register it. + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + return id, err + } + + // We need to find a port for the endpoint. + _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + id.LocalPort = p + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, e.transProto, id, e, false) + switch err { + case nil: + return true, nil + case tcpip.ErrPortInUse: + return false, nil + default: + return false, err + } + }) + + return id, err +} + +func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { + // Don't allow binding once endpoint is not in the initial state + // anymore. + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + netProto, err := e.checkV4Mapped(&addr, false) + if err != nil { + return err + } + + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + + if len(addr.Addr) != 0 { + // A local address was specified, verify that it's valid. + if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { + return tcpip.ErrBadLocalAddress + } + } + + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: addr.Addr, + } + id, err = e.registerWithStack(addr.NIC, netProtos, id) + if err != nil { + return err + } + + e.id = id + e.regNICID = addr.NIC + + // Mark endpoint as bound. + e.state = stateBound + + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// Bind binds the endpoint to a specific local address and port. +// Specifying a NIC is optional. +func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + err := e.bindLocked(addr) + if err != nil { + return err + } + + e.bindNICID = addr.NIC + e.bindAddr = addr.Addr + + return nil +} + +// GetLocalAddress returns the address to which the endpoint is bound. +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + return tcpip.FullAddress{ + NIC: e.regNICID, + Addr: e.id.LocalAddress, + Port: e.id.LocalPort, + }, nil +} + +// GetRemoteAddress returns the address to which the endpoint is connected. +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state != stateConnected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + NIC: e.regNICID, + Addr: e.id.RemoteAddress, + Port: e.id.RemotePort, + }, nil +} + +// Readiness returns the current readiness of the endpoint. For example, if +// waiter.EventIn is set, the endpoint is immediately readable. +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine if the endpoint is readable if requested. + if (mask & waiter.EventIn) != 0 { + e.rcvMu.Lock() + if !e.rcvList.Empty() || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvMu.Unlock() + } + + return result +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { + // Only accept echo replies. + switch e.netProto { + case header.IPv4ProtocolNumber: + h := header.ICMPv4(vv.First()) + if h.Type() != header.ICMPv4EchoReply { + e.stack.Stats().DroppedPackets.Increment() + return + } + case header.IPv6ProtocolNumber: + h := header.ICMPv6(vv.First()) + if h.Type() != header.ICMPv6EchoReply { + e.stack.Stats().DroppedPackets.Increment() + return + } + } + + e.rcvMu.Lock() + + // Drop the packet if our buffer is currently full. + if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + e.stack.Stats().DroppedPackets.Increment() + e.rcvMu.Unlock() + return + } + + wasEmpty := e.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + pkt := &icmpPacket{ + senderAddress: tcpip.FullAddress{ + NIC: r.NICID(), + Addr: id.RemoteAddress, + }, + } + + pkt.data = vv.Clone(pkt.views[:]) + + e.rcvList.PushBack(pkt) + e.rcvBufSize += pkt.data.Size() + + pkt.timestamp = e.stack.NowNanoseconds() + + e.rcvMu.Unlock() + + // Notify any waiters that there's data to be read now. + if wasEmpty { + e.waiterQueue.Notify(waiter.EventIn) + } +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +} diff --git a/pkg/tcpip/transport/icmp/endpoint_state.go b/pkg/tcpip/transport/icmp/endpoint_state.go new file mode 100644 index 000000000..332b3cd33 --- /dev/null +++ b/pkg/tcpip/transport/icmp/endpoint_state.go @@ -0,0 +1,90 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package icmp + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// saveData saves icmpPacket.data field. +func (p *icmpPacket) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, + // which is not allowed by state framework (in-struct pointer). + return p.data.Clone(nil) +} + +// loadData loads icmpPacket.data field. +func (p *icmpPacket) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing p.views for data.views. + p.data = data +} + +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after savercvBufSizeMax(), which would have + // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + e.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (e *endpoint) saveRcvBufSizeMax() int { + max := e.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + e.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + e.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (e *endpoint) loadRcvBufSizeMax(max int) { + e.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (e *endpoint) afterLoad() { + e.stack = stack.StackFromEnv + + if e.state != stateBound && e.state != stateConnected { + return + } + + var err *tcpip.Error + if e.state == stateConnected { + e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, e.netProto, false /* multicastLoop */) + if err != nil { + panic(*err) + } + + e.id.LocalAddress = e.route.LocalAddress + } else if len(e.id.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.regNICID, e.netProto, e.id.LocalAddress) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + e.id, err = e.registerWithStack(e.regNICID, []tcpip.NetworkProtocolNumber{e.netProto}, e.id) + if err != nil { + panic(*err) + } +} diff --git a/pkg/tcpip/transport/icmp/icmp_packet_list.go b/pkg/tcpip/transport/icmp/icmp_packet_list.go new file mode 100755 index 000000000..1b35e5b4a --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_packet_list.go @@ -0,0 +1,173 @@ +package icmp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type icmpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (icmpPacketElementMapper) linkerFor(elem *icmpPacket) *icmpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type icmpPacketList struct { + head *icmpPacket + tail *icmpPacket +} + +// Reset resets list l to the empty state. +func (l *icmpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *icmpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *icmpPacketList) Front() *icmpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *icmpPacketList) Back() *icmpPacket { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *icmpPacketList) PushFront(e *icmpPacket) { + icmpPacketElementMapper{}.linkerFor(e).SetNext(l.head) + icmpPacketElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + icmpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *icmpPacketList) PushBack(e *icmpPacket) { + icmpPacketElementMapper{}.linkerFor(e).SetNext(nil) + icmpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *icmpPacketList) PushBackList(m *icmpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + icmpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + icmpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *icmpPacketList) InsertAfter(b, e *icmpPacket) { + a := icmpPacketElementMapper{}.linkerFor(b).Next() + icmpPacketElementMapper{}.linkerFor(e).SetNext(a) + icmpPacketElementMapper{}.linkerFor(e).SetPrev(b) + icmpPacketElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + icmpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *icmpPacketList) InsertBefore(a, e *icmpPacket) { + b := icmpPacketElementMapper{}.linkerFor(a).Prev() + icmpPacketElementMapper{}.linkerFor(e).SetNext(a) + icmpPacketElementMapper{}.linkerFor(e).SetPrev(b) + icmpPacketElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + icmpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *icmpPacketList) Remove(e *icmpPacket) { + prev := icmpPacketElementMapper{}.linkerFor(e).Prev() + next := icmpPacketElementMapper{}.linkerFor(e).Next() + + if prev != nil { + icmpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + icmpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type icmpPacketEntry struct { + next *icmpPacket + prev *icmpPacket +} + +// Next returns the entry that follows e in the list. +func (e *icmpPacketEntry) Next() *icmpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *icmpPacketEntry) Prev() *icmpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *icmpPacketEntry) SetNext(elem *icmpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *icmpPacketEntry) SetPrev(elem *icmpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/icmp/icmp_state_autogen.go b/pkg/tcpip/transport/icmp/icmp_state_autogen.go new file mode 100755 index 000000000..b66857348 --- /dev/null +++ b/pkg/tcpip/transport/icmp/icmp_state_autogen.go @@ -0,0 +1,98 @@ +// automatically generated by stateify. + +package icmp + +import ( + "gvisor.googlesource.com/gvisor/pkg/state" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" +) + +func (x *icmpPacket) beforeSave() {} +func (x *icmpPacket) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("icmpPacketEntry", &x.icmpPacketEntry) + m.Save("senderAddress", &x.senderAddress) + m.Save("timestamp", &x.timestamp) +} + +func (x *icmpPacket) afterLoad() {} +func (x *icmpPacket) load(m state.Map) { + m.Load("icmpPacketEntry", &x.icmpPacketEntry) + m.Load("senderAddress", &x.senderAddress) + m.Load("timestamp", &x.timestamp) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("netProto", &x.netProto) + m.Save("transProto", &x.transProto) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("rcvReady", &x.rcvReady) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("shutdownFlags", &x.shutdownFlags) + m.Save("id", &x.id) + m.Save("state", &x.state) + m.Save("bindNICID", &x.bindNICID) + m.Save("bindAddr", &x.bindAddr) + m.Save("regNICID", &x.regNICID) +} + +func (x *endpoint) load(m state.Map) { + m.Load("netProto", &x.netProto) + m.Load("transProto", &x.transProto) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("rcvReady", &x.rcvReady) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("shutdownFlags", &x.shutdownFlags) + m.Load("id", &x.id) + m.Load("state", &x.state) + m.Load("bindNICID", &x.bindNICID) + m.Load("bindAddr", &x.bindAddr) + m.Load("regNICID", &x.regNICID) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *icmpPacketList) beforeSave() {} +func (x *icmpPacketList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *icmpPacketList) afterLoad() {} +func (x *icmpPacketList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *icmpPacketEntry) beforeSave() {} +func (x *icmpPacketEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *icmpPacketEntry) afterLoad() {} +func (x *icmpPacketEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("icmp.icmpPacket", (*icmpPacket)(nil), state.Fns{Save: (*icmpPacket).save, Load: (*icmpPacket).load}) + state.Register("icmp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("icmp.icmpPacketList", (*icmpPacketList)(nil), state.Fns{Save: (*icmpPacketList).save, Load: (*icmpPacketList).load}) + state.Register("icmp.icmpPacketEntry", (*icmpPacketEntry)(nil), state.Fns{Save: (*icmpPacketEntry).save, Load: (*icmpPacketEntry).load}) +} diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go new file mode 100644 index 000000000..954fde9d8 --- /dev/null +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -0,0 +1,136 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package icmp contains the implementation of the ICMP and IPv6-ICMP transport +// protocols for use in ping. To use it in the networking stack, this package +// must be added to the project, and +// activated on the stack by passing icmp.ProtocolName (or "icmp") and/or +// icmp.ProtocolName6 (or "icmp6") as one of the transport protocols when +// calling stack.New(). Then endpoints can be created by passing +// icmp.ProtocolNumber or icmp.ProtocolNumber6 as the transport protocol number +// when calling Stack.NewEndpoint(). +package icmp + +import ( + "encoding/binary" + "fmt" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + // ProtocolName4 is the string representation of the icmp protocol name. + ProtocolName4 = "icmp4" + + // ProtocolNumber4 is the ICMP protocol number. + ProtocolNumber4 = header.ICMPv4ProtocolNumber + + // ProtocolName6 is the string representation of the icmp protocol name. + ProtocolName6 = "icmp6" + + // ProtocolNumber6 is the IPv6-ICMP protocol number. + ProtocolNumber6 = header.ICMPv6ProtocolNumber +) + +// protocol implements stack.TransportProtocol. +type protocol struct { + number tcpip.TransportProtocolNumber +} + +// Number returns the ICMP protocol number. +func (p *protocol) Number() tcpip.TransportProtocolNumber { + return p.number +} + +func (p *protocol) netProto() tcpip.NetworkProtocolNumber { + switch p.number { + case ProtocolNumber4: + return header.IPv4ProtocolNumber + case ProtocolNumber6: + return header.IPv6ProtocolNumber + } + panic(fmt.Sprint("unknown protocol number: ", p.number)) +} + +// NewEndpoint creates a new icmp endpoint. It implements +// stack.TransportProtocol.NewEndpoint. +func (p *protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if netProto != p.netProto() { + return nil, tcpip.ErrUnknownProtocol + } + return newEndpoint(stack, netProto, p.number, waiterQueue) +} + +// NewRawEndpoint creates a new raw icmp endpoint. It implements +// stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if netProto != p.netProto() { + return nil, tcpip.ErrUnknownProtocol + } + return raw.NewEndpoint(stack, netProto, p.number, waiterQueue) +} + +// MinimumPacketSize returns the minimum valid icmp packet size. +func (p *protocol) MinimumPacketSize() int { + switch p.number { + case ProtocolNumber4: + return header.ICMPv4EchoMinimumSize + case ProtocolNumber6: + return header.ICMPv6EchoMinimumSize + } + panic(fmt.Sprint("unknown protocol number: ", p.number)) +} + +// ParsePorts returns the source and destination ports stored in the given icmp +// packet. +func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + switch p.number { + case ProtocolNumber4: + return 0, binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize:]), nil + case ProtocolNumber6: + return 0, binary.BigEndian.Uint16(v[header.ICMPv6MinimumSize:]), nil + } + panic(fmt.Sprint("unknown protocol number: ", p.number)) +} + +// HandleUnknownDestinationPacket handles packets targeted at this protocol but +// that don't match any existing endpoint. +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { + return true +} + +// SetOption implements TransportProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements TransportProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func init() { + stack.RegisterTransportProtocolFactory(ProtocolName4, func() stack.TransportProtocol { + return &protocol{ProtocolNumber4} + }) + + stack.RegisterTransportProtocolFactory(ProtocolName6, func() stack.TransportProtocol { + return &protocol{ProtocolNumber6} + }) +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go new file mode 100644 index 000000000..1daf5823f --- /dev/null +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -0,0 +1,521 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package raw provides the implementation of raw sockets (see raw(7)). Raw +// sockets allow applications to: +// +// * manually write and inspect transport layer headers and payloads +// * receive all traffic of a given transport protcol (e.g. ICMP or UDP) +// * optionally write and inspect network layer and link layer headers for +// packets +// +// Raw sockets don't have any notion of ports, and incoming packets are +// demultiplexed solely by protocol number. Thus, a raw UDP endpoint will +// receive every UDP packet received by netstack. bind(2) and connect(2) can be +// used to filter incoming packets by source and destination. +package raw + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// +stateify savable +type packet struct { + packetEntry + // data holds the actual packet data, including any headers and + // payload. + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + // views is pre-allocated space to back data. As long as the packet is + // made up of fewer than 8 buffer.Views, no extra allocation is + // necessary to store packet data. + views [8]buffer.View `state:"nosave"` + // timestampNS is the unix time at which the packet was received. + timestampNS int64 + // senderAddr is the network address of the sender. + senderAddr tcpip.FullAddress +} + +// endpoint is the raw socket implementation of tcpip.Endpoint. It is legal to +// have goroutines make concurrent calls into the endpoint. +// +// Lock order: +// endpoint.mu +// endpoint.rcvMu +// +// +stateify savable +type endpoint struct { + // The following fields are initialized at creation time and are + // immutable. + stack *stack.Stack `state:"manual"` + netProto tcpip.NetworkProtocolNumber + transProto tcpip.TransportProtocolNumber + waiterQueue *waiter.Queue + + // The following fields are used to manage the receive queue and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvList packetList + rcvBufSizeMax int `state:".(int)"` + rcvBufSize int + rcvClosed bool + + // The following fields are protected by mu. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + closed bool + connected bool + bound bool + // registeredNIC is the NIC to which th endpoint is explicitly + // registered. Is set when Connect or Bind are used to specify a NIC. + registeredNIC tcpip.NICID + // boundNIC and boundAddr are set on calls to Bind(). When callers + // attempt actions that would invalidate the binding data (e.g. sending + // data via a NIC other than boundNIC), the endpoint will return an + // error. + boundNIC tcpip.NICID + boundAddr tcpip.Address + // route is the route to a remote network endpoint. It is set via + // Connect(), and is valid only when conneted is true. + route stack.Route `state:"manual"` +} + +// NewEndpoint returns a raw endpoint for the given protocols. +// TODO(b/129292371): IP_HDRINCL, IPPROTO_RAW, and AF_PACKET. +func NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + if netProto != header.IPv4ProtocolNumber { + return nil, tcpip.ErrUnknownProtocol + } + + ep := &endpoint{ + stack: stack, + netProto: netProto, + transProto: transProto, + waiterQueue: waiterQueue, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + } + + if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + return nil, err + } + + return ep, nil +} + +// Close implements tcpip.Endpoint.Close. +func (ep *endpoint) Close() { + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.closed { + return + } + + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + + ep.rcvMu.Lock() + defer ep.rcvMu.Unlock() + + // Clear the receive list. + ep.rcvClosed = true + ep.rcvBufSize = 0 + for !ep.rcvList.Empty() { + ep.rcvList.Remove(ep.rcvList.Front()) + } + + if ep.connected { + ep.route.Release() + } + + ep.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + ep.rcvMu.Lock() + + // If there's no data to read, return that read would block or that the + // endpoint is closed. + if ep.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if ep.rcvClosed { + err = tcpip.ErrClosedForReceive + } + ep.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + packet := ep.rcvList.Front() + ep.rcvList.Remove(packet) + ep.rcvBufSize -= packet.data.Size() + + ep.rcvMu.Unlock() + + if addr != nil { + *addr = packet.senderAddr + } + + return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil +} + +// Write implements tcpip.Endpoint.Write. +func (ep *endpoint) Write(payload tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { + // MSG_MORE is unimplemented. This also means that MSG_EOR is a no-op. + if opts.More { + return 0, nil, tcpip.ErrInvalidOptionValue + } + + ep.mu.RLock() + + if ep.closed { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrInvalidEndpointState + } + + // Did the user caller provide a destination? If not, use the connected + // destination. + if opts.To == nil { + // If the user doesn't specify a destination, they should have + // connected to another address. + if !ep.connected { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrDestinationRequired + } + + if ep.route.IsResolutionRequired() { + savedRoute := &ep.route + // Promote lock to exclusive if using a shared route, + // given that it may need to change in finishWrite. + ep.mu.RUnlock() + ep.mu.Lock() + + // Make sure that the route didn't change during the + // time we didn't hold the lock. + if !ep.connected || savedRoute != &ep.route { + ep.mu.Unlock() + return 0, nil, tcpip.ErrInvalidEndpointState + } + + n, ch, err := ep.finishWrite(payload, savedRoute) + ep.mu.Unlock() + return n, ch, err + } + + n, ch, err := ep.finishWrite(payload, &ep.route) + ep.mu.RUnlock() + return n, ch, err + } + + // The caller provided a destination. Reject destination address if it + // goes through a different NIC than the endpoint was bound to. + nic := opts.To.NIC + if ep.bound && nic != 0 && nic != ep.boundNIC { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrNoRoute + } + + // We don't support IPv6 yet, so this has to be an IPv4 address. + if len(opts.To.Addr) != header.IPv4AddressSize { + ep.mu.RUnlock() + return 0, nil, tcpip.ErrInvalidEndpointState + } + + // Find the route to the destination. If boundAddress is 0, + // FindRoute will choose an appropriate source address. + route, err := ep.stack.FindRoute(nic, ep.boundAddr, opts.To.Addr, ep.netProto, false) + if err != nil { + ep.mu.RUnlock() + return 0, nil, err + } + + n, ch, err := ep.finishWrite(payload, &route) + route.Release() + ep.mu.RUnlock() + return n, ch, err +} + +// finishWrite writes the payload to a route. It resolves the route if +// necessary. It's really just a helper to make defer unnecessary in Write. +func (ep *endpoint) finishWrite(payload tcpip.Payload, route *stack.Route) (uintptr, <-chan struct{}, *tcpip.Error) { + // We may need to resolve the route (match a link layer address to the + // network address). If that requires blocking (e.g. to use ARP), + // return a channel on which the caller can wait. + if route.IsResolutionRequired() { + if ch, err := route.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + return 0, ch, tcpip.ErrNoLinkAddress + } + return 0, nil, err + } + } + + payloadBytes, err := payload.Get(payload.Size()) + if err != nil { + return 0, nil, err + } + + switch ep.netProto { + case header.IPv4ProtocolNumber: + hdr := buffer.NewPrependable(len(payloadBytes) + int(route.MaxHeaderLength())) + if err := route.WritePacket(nil /* gso */, hdr, buffer.View(payloadBytes).ToVectorisedView(), ep.transProto, route.DefaultTTL()); err != nil { + return 0, nil, err + } + + default: + return 0, nil, tcpip.ErrUnknownProtocol + } + + return uintptr(len(payloadBytes)), nil, nil +} + +// Peek implements tcpip.Endpoint.Peek. +func (ep *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// Connect implements tcpip.Endpoint.Connect. +func (ep *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + ep.mu.Lock() + defer ep.mu.Unlock() + + if ep.closed { + return tcpip.ErrInvalidEndpointState + } + + // We don't support IPv6 yet. + if len(addr.Addr) != header.IPv4AddressSize { + return tcpip.ErrInvalidEndpointState + } + + nic := addr.NIC + if ep.bound { + if ep.boundNIC == 0 { + // If we're bound, but not to a specific NIC, the NIC + // in addr will be used. Nothing to do here. + } else if addr.NIC == 0 { + // If we're bound to a specific NIC, but addr doesn't + // specify a NIC, use the bound NIC. + nic = ep.boundNIC + } else if addr.NIC != ep.boundNIC { + // We're bound and addr specifies a NIC. They must be + // the same. + return tcpip.ErrInvalidEndpointState + } + } + + // Find a route to the destination. + route, err := ep.stack.FindRoute(nic, tcpip.Address(""), addr.Addr, ep.netProto, false) + if err != nil { + return err + } + defer route.Release() + + // Re-register the endpoint with the appropriate NIC. + if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + return err + } + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + + // Save the route and NIC we've connected via. + ep.route = route.Clone() + ep.registeredNIC = nic + ep.connected = true + + return nil +} + +// Shutdown implements tcpip.Endpoint.Shutdown. It's a noop for raw sockets. +func (ep *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + ep.mu.Lock() + defer ep.mu.Unlock() + + if !ep.connected { + return tcpip.ErrNotConnected + } + return nil +} + +// Listen implements tcpip.Endpoint.Listen. +func (ep *endpoint) Listen(backlog int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept implements tcpip.Endpoint.Accept. +func (ep *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +// Bind implements tcpip.Endpoint.Bind. +func (ep *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + ep.mu.Lock() + defer ep.mu.Unlock() + + // Callers must provide an IPv4 address or no network address (for + // binding to a NIC, but not an address). + if len(addr.Addr) != 0 && len(addr.Addr) != 4 { + return tcpip.ErrInvalidEndpointState + } + + // If a local address was specified, verify that it's valid. + if len(addr.Addr) == header.IPv4AddressSize && ep.stack.CheckLocalAddress(addr.NIC, ep.netProto, addr.Addr) == 0 { + return tcpip.ErrBadLocalAddress + } + + // Re-register the endpoint with the appropriate NIC. + if err := ep.stack.RegisterRawTransportEndpoint(addr.NIC, ep.netProto, ep.transProto, ep); err != nil { + return err + } + ep.stack.UnregisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep) + + ep.registeredNIC = addr.NIC + ep.boundNIC = addr.NIC + ep.boundAddr = addr.Addr + ep.bound = true + + return nil +} + +// GetLocalAddress implements tcpip.Endpoint.GetLocalAddress. +func (ep *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{}, tcpip.ErrNotSupported +} + +// GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress. +func (ep *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + // Even a connected socket doesn't return a remote address. + return tcpip.FullAddress{}, tcpip.ErrNotConnected +} + +// Readiness implements tcpip.Endpoint.Readiness. +func (ep *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine whether the endpoint is readable. + if (mask & waiter.EventIn) != 0 { + ep.rcvMu.Lock() + if !ep.rcvList.Empty() || ep.rcvClosed { + result |= waiter.EventIn + } + ep.rcvMu.Unlock() + } + + return result +} + +// SetSockOpt implements tcpip.Endpoint.SetSockOpt. +func (ep *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + return nil +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.SendBufferSizeOption: + ep.mu.Lock() + *o = tcpip.SendBufferSizeOption(ep.sndBufSize) + ep.mu.Unlock() + return nil + + case *tcpip.ReceiveBufferSizeOption: + ep.rcvMu.Lock() + *o = tcpip.ReceiveBufferSizeOption(ep.rcvBufSizeMax) + ep.rcvMu.Unlock() + return nil + + case *tcpip.ReceiveQueueSizeOption: + ep.rcvMu.Lock() + if ep.rcvList.Empty() { + *o = 0 + } else { + p := ep.rcvList.Front() + *o = tcpip.ReceiveQueueSizeOption(p.data.Size()) + } + ep.rcvMu.Unlock() + return nil + + case *tcpip.KeepaliveEnabledOption: + *o = 0 + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// HandlePacket implements stack.RawTransportEndpoint.HandlePacket. +func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv buffer.VectorisedView) { + ep.rcvMu.Lock() + + // Drop the packet if our buffer is currently full. + if ep.rcvClosed || ep.rcvBufSize >= ep.rcvBufSizeMax { + ep.stack.Stats().DroppedPackets.Increment() + ep.rcvMu.Unlock() + return + } + + if ep.bound { + // If bound to a NIC, only accept data for that NIC. + if ep.boundNIC != 0 && ep.boundNIC != route.NICID() { + ep.rcvMu.Unlock() + return + } + // If bound to an address, only accept data for that address. + if ep.boundAddr != "" && ep.boundAddr != route.RemoteAddress { + ep.rcvMu.Unlock() + return + } + } + + // If connected, only accept packets from the remote address we + // connected to. + if ep.connected && ep.route.RemoteAddress != route.RemoteAddress { + ep.rcvMu.Unlock() + return + } + + wasEmpty := ep.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + packet := &packet{ + senderAddr: tcpip.FullAddress{ + NIC: route.NICID(), + Addr: route.RemoteAddress, + }, + } + + combinedVV := netHeader.ToVectorisedView() + combinedVV.Append(vv) + packet.data = combinedVV.Clone(packet.views[:]) + packet.timestampNS = ep.stack.NowNanoseconds() + + ep.rcvList.PushBack(packet) + ep.rcvBufSize += packet.data.Size() + + ep.rcvMu.Unlock() + + // Notify waiters that there's data to be read. + if wasEmpty { + ep.waiterQueue.Notify(waiter.EventIn) + } +} diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go new file mode 100644 index 000000000..e8907ebb1 --- /dev/null +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -0,0 +1,88 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package raw + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// saveData saves packet.data field. +func (p *packet) saveData() buffer.VectorisedView { + // We cannot save p.data directly as p.data.views may alias to p.views, + // which is not allowed by state framework (in-struct pointer). + return p.data.Clone(nil) +} + +// loadData loads packet.data field. +func (p *packet) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the p.data = data.Clone(p.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing p.views for data.views. + p.data = data +} + +// beforeSave is invoked by stateify. +func (ep *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after saveRcvBufSizeMax(), which would have + // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + ep.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) saveRcvBufSizeMax() int { + max := ep.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + ep.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + ep.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (ep *endpoint) loadRcvBufSizeMax(max int) { + ep.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (ep *endpoint) afterLoad() { + // StackFromEnv is a stack used specifically for save/restore. + ep.stack = stack.StackFromEnv + + // If the endpoint is connected, re-connect via the save/restore stack. + if ep.connected { + var err *tcpip.Error + ep.route, err = ep.stack.FindRoute(ep.registeredNIC, ep.boundAddr, ep.route.RemoteAddress, ep.netProto, false) + if err != nil { + panic(*err) + } + } + + // If the endpoint is bound, re-bind via the save/restore stack. + if ep.bound { + if ep.stack.CheckLocalAddress(ep.registeredNIC, ep.netProto, ep.boundAddr) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + if err := ep.stack.RegisterRawTransportEndpoint(ep.registeredNIC, ep.netProto, ep.transProto, ep); err != nil { + panic(*err) + } +} diff --git a/pkg/tcpip/transport/raw/packet_list.go b/pkg/tcpip/transport/raw/packet_list.go new file mode 100755 index 000000000..2e9074934 --- /dev/null +++ b/pkg/tcpip/transport/raw/packet_list.go @@ -0,0 +1,173 @@ +package raw + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type packetElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (packetElementMapper) linkerFor(elem *packet) *packet { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type packetList struct { + head *packet + tail *packet +} + +// Reset resets list l to the empty state. +func (l *packetList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *packetList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *packetList) Front() *packet { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *packetList) Back() *packet { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *packetList) PushFront(e *packet) { + packetElementMapper{}.linkerFor(e).SetNext(l.head) + packetElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + packetElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *packetList) PushBack(e *packet) { + packetElementMapper{}.linkerFor(e).SetNext(nil) + packetElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *packetList) PushBackList(m *packetList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + packetElementMapper{}.linkerFor(l.tail).SetNext(m.head) + packetElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *packetList) InsertAfter(b, e *packet) { + a := packetElementMapper{}.linkerFor(b).Next() + packetElementMapper{}.linkerFor(e).SetNext(a) + packetElementMapper{}.linkerFor(e).SetPrev(b) + packetElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + packetElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *packetList) InsertBefore(a, e *packet) { + b := packetElementMapper{}.linkerFor(a).Prev() + packetElementMapper{}.linkerFor(e).SetNext(a) + packetElementMapper{}.linkerFor(e).SetPrev(b) + packetElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + packetElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *packetList) Remove(e *packet) { + prev := packetElementMapper{}.linkerFor(e).Prev() + next := packetElementMapper{}.linkerFor(e).Next() + + if prev != nil { + packetElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + packetElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type packetEntry struct { + next *packet + prev *packet +} + +// Next returns the entry that follows e in the list. +func (e *packetEntry) Next() *packet { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *packetEntry) Prev() *packet { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *packetEntry) SetNext(elem *packet) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *packetEntry) SetPrev(elem *packet) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/raw/raw_state_autogen.go b/pkg/tcpip/transport/raw/raw_state_autogen.go new file mode 100755 index 000000000..3327811b4 --- /dev/null +++ b/pkg/tcpip/transport/raw/raw_state_autogen.go @@ -0,0 +1,96 @@ +// automatically generated by stateify. + +package raw + +import ( + "gvisor.googlesource.com/gvisor/pkg/state" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" +) + +func (x *packet) beforeSave() {} +func (x *packet) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("packetEntry", &x.packetEntry) + m.Save("timestampNS", &x.timestampNS) + m.Save("senderAddr", &x.senderAddr) +} + +func (x *packet) afterLoad() {} +func (x *packet) load(m state.Map) { + m.Load("packetEntry", &x.packetEntry) + m.Load("timestampNS", &x.timestampNS) + m.Load("senderAddr", &x.senderAddr) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("netProto", &x.netProto) + m.Save("transProto", &x.transProto) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("closed", &x.closed) + m.Save("connected", &x.connected) + m.Save("bound", &x.bound) + m.Save("registeredNIC", &x.registeredNIC) + m.Save("boundNIC", &x.boundNIC) + m.Save("boundAddr", &x.boundAddr) +} + +func (x *endpoint) load(m state.Map) { + m.Load("netProto", &x.netProto) + m.Load("transProto", &x.transProto) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("closed", &x.closed) + m.Load("connected", &x.connected) + m.Load("bound", &x.bound) + m.Load("registeredNIC", &x.registeredNIC) + m.Load("boundNIC", &x.boundNIC) + m.Load("boundAddr", &x.boundAddr) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *packetList) beforeSave() {} +func (x *packetList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *packetList) afterLoad() {} +func (x *packetList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *packetEntry) beforeSave() {} +func (x *packetEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *packetEntry) afterLoad() {} +func (x *packetEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("raw.packet", (*packet)(nil), state.Fns{Save: (*packet).save, Load: (*packet).load}) + state.Register("raw.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("raw.packetList", (*packetList)(nil), state.Fns{Save: (*packetList).save, Load: (*packetList).load}) + state.Register("raw.packetEntry", (*packetEntry)(nil), state.Fns{Save: (*packetEntry).save, Load: (*packetEntry).load}) +} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go new file mode 100644 index 000000000..d4b860975 --- /dev/null +++ b/pkg/tcpip/transport/tcp/accept.go @@ -0,0 +1,499 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "crypto/sha1" + "encoding/binary" + "hash" + "io" + "log" + "sync" + "time" + + "gvisor.googlesource.com/gvisor/pkg/rand" + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + // tsLen is the length, in bits, of the timestamp in the SYN cookie. + tsLen = 8 + + // tsMask is a mask for timestamp values (i.e., tsLen bits). + tsMask = (1 << tsLen) - 1 + + // tsOffset is the offset, in bits, of the timestamp in the SYN cookie. + tsOffset = 24 + + // hashMask is the mask for hash values (i.e., tsOffset bits). + hashMask = (1 << tsOffset) - 1 + + // maxTSDiff is the maximum allowed difference between a received cookie + // timestamp and the current timestamp. If the difference is greater + // than maxTSDiff, the cookie is expired. + maxTSDiff = 2 +) + +var ( + // SynRcvdCountThreshold is the global maximum number of connections + // that are allowed to be in SYN-RCVD state before TCP starts using SYN + // cookies to accept connections. + // + // It is an exported variable only for testing, and should not otherwise + // be used by importers of this package. + SynRcvdCountThreshold uint64 = 1000 + + // mssTable is a slice containing the possible MSS values that we + // encode in the SYN cookie with two bits. + mssTable = []uint16{536, 1300, 1440, 1460} +) + +func encodeMSS(mss uint16) uint32 { + for i := len(mssTable) - 1; i > 0; i-- { + if mss >= mssTable[i] { + return uint32(i) + } + } + return 0 +} + +// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is +// protected by a mutex so that we can increment only when it's guaranteed not +// to go above a threshold. +var synRcvdCount struct { + sync.Mutex + value uint64 + pending sync.WaitGroup +} + +// listenContext is used by a listening endpoint to store state used while +// listening for connections. This struct is allocated by the listen goroutine +// and must not be accessed or have its methods called concurrently as they +// may mutate the stored objects. +type listenContext struct { + stack *stack.Stack + rcvWnd seqnum.Size + nonce [2][sha1.BlockSize]byte + listenEP *endpoint + + hasherMu sync.Mutex + hasher hash.Hash + v6only bool + netProto tcpip.NetworkProtocolNumber +} + +// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. +func timeStamp() uint32 { + return uint32(time.Now().Unix()>>6) & tsMask +} + +// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD +// state. It succeeds if the increment doesn't make the count go beyond the +// threshold, and fails otherwise. +func incSynRcvdCount() bool { + synRcvdCount.Lock() + + if synRcvdCount.value >= SynRcvdCountThreshold { + synRcvdCount.Unlock() + return false + } + + synRcvdCount.pending.Add(1) + synRcvdCount.value++ + + synRcvdCount.Unlock() + return true +} + +// decSynRcvdCount atomically decrements the global number of endpoints in +// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount +// succeeded. +func decSynRcvdCount() { + synRcvdCount.Lock() + + synRcvdCount.value-- + synRcvdCount.pending.Done() + synRcvdCount.Unlock() +} + +// newListenContext creates a new listen context. +func newListenContext(stack *stack.Stack, listenEP *endpoint, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { + l := &listenContext{ + stack: stack, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6only: v6only, + netProto: netProto, + listenEP: listenEP, + } + + rand.Read(l.nonce[0][:]) + rand.Read(l.nonce[1][:]) + + return l +} + +// cookieHash calculates the cookieHash for the given id, timestamp and nonce +// index. The hash is used to create and validate cookies. +func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 { + + // Initialize block with fixed-size data: local ports and v. + var payload [8]byte + binary.BigEndian.PutUint16(payload[0:], id.LocalPort) + binary.BigEndian.PutUint16(payload[2:], id.RemotePort) + binary.BigEndian.PutUint32(payload[4:], ts) + + // Feed everything to the hasher. + l.hasherMu.Lock() + l.hasher.Reset() + l.hasher.Write(payload[:]) + l.hasher.Write(l.nonce[nonceIndex][:]) + io.WriteString(l.hasher, string(id.LocalAddress)) + io.WriteString(l.hasher, string(id.RemoteAddress)) + + // Finalize the calculation of the hash and return the first 4 bytes. + h := make([]byte, 0, sha1.Size) + h = l.hasher.Sum(h) + l.hasherMu.Unlock() + + return binary.BigEndian.Uint32(h[:]) +} + +// createCookie creates a SYN cookie for the given id and incoming sequence +// number. +func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value { + ts := timeStamp() + v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset) + v += (l.cookieHash(id, ts, 1) + data) & hashMask + return seqnum.Value(v) +} + +// isCookieValid checks if the supplied cookie is valid for the given id and +// sequence number. If it is, it also returns the data originally encoded in the +// cookie when createCookie was called. +func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) { + ts := timeStamp() + v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq) + cookieTS := v >> tsOffset + if ((ts - cookieTS) & tsMask) > maxTSDiff { + return 0, false + } + + return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true +} + +// createConnectingEndpoint creates a new endpoint in a connecting state, with +// the connection parameters given by the arguments. +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { + // Create a new endpoint. + netProto := l.netProto + if netProto == 0 { + netProto = s.route.NetProto + } + n := newEndpoint(l.stack, netProto, nil) + n.v6only = l.v6only + n.id = s.id + n.boundNICID = s.route.NICID() + n.route = s.route.Clone() + n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} + n.rcvBufSize = int(l.rcvWnd) + + n.maybeEnableTimestamp(rcvdSynOpts) + n.maybeEnableSACKPermitted(rcvdSynOpts) + + n.initGSO() + + // Register new endpoint so that packets are routed to it. + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n, n.reusePort); err != nil { + n.Close() + return nil, err + } + + n.isRegistered = true + n.state = stateConnecting + + // Create sender and receiver. + // + // The receiver at least temporarily has a zero receive window scale, + // but the caller may change it (before starting the protocol loop). + n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS) + n.rcv = newReceiver(n, irs, l.rcvWnd, 0) + + return n, nil +} + +// createEndpoint creates a new endpoint in connected state and then performs +// the TCP 3-way handshake. +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { + // Create new endpoint. + irs := s.sequenceNumber + cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) + ep, err := l.createConnectingEndpoint(s, cookie, irs, opts) + if err != nil { + return nil, err + } + + // Perform the 3-way handshake. + h := newHandshake(ep, l.rcvWnd) + + h.resetToSynRcvd(cookie, irs, opts, l.listenEP) + if err := h.execute(); err != nil { + ep.stack.Stats().TCP.FailedConnectionAttempts.Increment() + ep.Close() + return nil, err + } + + ep.state = stateConnected + + // Update the receive window scaling. We can't do it before the + // handshake because it's possible that the peer doesn't support window + // scaling. + ep.rcv.rcvWndScale = h.effectiveRcvWndScale() + + return ep, nil +} + +// deliverAccepted delivers the newly-accepted endpoint to the listener. If the +// endpoint has transitioned out of the listen state, the new endpoint is closed +// instead. +func (e *endpoint) deliverAccepted(n *endpoint) { + e.mu.RLock() + state := e.state + e.mu.RUnlock() + if state == stateListen { + e.acceptedChan <- n + e.waiterQueue.Notify(waiter.EventIn) + } else { + n.Close() + } +} + +// handleSynSegment is called in its own goroutine once the listening endpoint +// receives a SYN segment. It is responsible for completing the handshake and +// queueing the new endpoint for acceptance. +// +// A limited number of these goroutines are allowed before TCP starts using SYN +// cookies to accept connections. +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { + defer decSynRcvdCount() + defer e.decSynRcvdCount() + defer s.decRef() + + n, err := ctx.createEndpointAndPerformHandshake(s, opts) + if err != nil { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + return + } + + e.deliverAccepted(n) +} + +func (e *endpoint) incSynRcvdCount() bool { + e.mu.Lock() + log.Printf("l: %d, c: %d, e.synRcvdCount: %d", len(e.acceptedChan), cap(e.acceptedChan), e.synRcvdCount) + if l, c := len(e.acceptedChan), cap(e.acceptedChan); l == c && e.synRcvdCount >= c { + e.mu.Unlock() + return false + } + e.synRcvdCount++ + e.mu.Unlock() + return true +} + +func (e *endpoint) decSynRcvdCount() { + e.mu.Lock() + e.synRcvdCount-- + e.mu.Unlock() +} + +// handleListenSegment is called when a listening endpoint receives a segment +// and needs to handle it. +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { + switch s.flags { + case header.TCPFlagSyn: + opts := parseSynSegmentOptions(s) + if incSynRcvdCount() { + // Drop the SYN if the listen endpoint's accept queue is + // overflowing. + if e.incSynRcvdCount() { + log.Printf("processing syn packet") + s.incRef() + go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. + return + } + log.Printf("dropping syn packet") + e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() + e.stack.Stats().DroppedPackets.Increment() + return + } else { + // TODO(bhaskerh): Increment syncookie sent stat. + cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + // Send SYN with window scaling because we currently + // dont't encode this information in the cookie. + // + // Enable Timestamp option if the original syn did have + // the timestamp option specified. + synOpts := header.TCPSynOptions{ + WS: -1, + TS: opts.TS, + TSVal: tcpTimeStamp(timeStampOffset()), + TSEcr: opts.TSVal, + } + sendSynTCP(&s.route, s.id, header.TCPFlagSyn|header.TCPFlagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() + } + + case header.TCPFlagAck: + if len(e.acceptedChan) == cap(e.acceptedChan) { + // Silently drop the ack as the application can't accept + // the connection at this point. The ack will be + // retransmitted by the sender anyway and we can + // complete the connection at the time of retransmit if + // the backlog has space. + e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + e.stack.Stats().DroppedPackets.Increment() + return + } + + // Validate the cookie. + data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1) + if !ok || int(data) >= len(mssTable) { + e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() + e.stack.Stats().DroppedPackets.Increment() + return + } + e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() + // Create newly accepted endpoint and deliver it. + rcvdSynOptions := &header.TCPSynOptions{ + MSS: mssTable[data], + // Disable Window scaling as original SYN is + // lost. + WS: -1, + } + + // When syn cookies are in use we enable timestamp only + // if the ack specifies the timestamp option assuming + // that the other end did in fact negotiate the + // timestamp option in the original SYN. + if s.parsedOptions.TS { + rcvdSynOptions.TS = true + rcvdSynOptions.TSVal = s.parsedOptions.TSVal + rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr + } + + n, err := ctx.createConnectingEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) + if err != nil { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + return + } + + // clear the tsOffset for the newly created + // endpoint as the Timestamp was already + // randomly offset when the original SYN-ACK was + // sent above. + n.tsOffset = 0 + + // Switch state to connected. + n.state = stateConnected + + // Do the delivery in a separate goroutine so + // that we don't block the listen loop in case + // the application is slow to accept or stops + // accepting. + // + // NOTE: This won't result in an unbounded + // number of goroutines as we do check before + // entering here that there was at least some + // space available in the backlog. + go e.deliverAccepted(n) + } +} + +// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in +// its own goroutine and is responsible for handling connection requests. +func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { + defer func() { + // Mark endpoint as closed. This will prevent goroutines running + // handleSynSegment() from attempting to queue new connections + // to the endpoint. + e.mu.Lock() + e.state = stateClosed + + // Do cleanup if needed. + e.completeWorkerLocked() + + if e.drainDone != nil { + close(e.drainDone) + } + e.mu.Unlock() + + // Notify waiters that the endpoint is shutdown. + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) + }() + + e.mu.Lock() + v6only := e.v6only + e.mu.Unlock() + + ctx := newListenContext(e.stack, e, rcvWnd, v6only, e.netProto) + + s := sleep.Sleeper{} + s.AddWaker(&e.notificationWaker, wakerForNotification) + s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) + for { + switch index, _ := s.Fetch(true); index { + case wakerForNotification: + n := e.fetchNotifications() + if n¬ifyClose != 0 { + return nil + } + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + s := e.segmentQueue.dequeue() + e.handleListenSegment(ctx, s) + s.decRef() + } + synRcvdCount.pending.Wait() + close(e.drainDone) + <-e.undrain + } + + case wakerForNewSegment: + // Process at most maxSegmentsPerWake segments. + mayRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + mayRequeue = false + break + } + + e.handleListenSegment(ctx, s) + s.decRef() + } + + // If the queue is not empty, make sure we'll wake up + // in the next iteration. + if mayRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + } + } +} diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go new file mode 100644 index 000000000..2aed6f286 --- /dev/null +++ b/pkg/tcpip/transport/tcp/connect.go @@ -0,0 +1,1066 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "sync" + "time" + + "gvisor.googlesource.com/gvisor/pkg/rand" + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// maxSegmentsPerWake is the maximum number of segments to process in the main +// protocol goroutine per wake-up. Yielding [after this number of segments are +// processed] allows other events to be processed as well (e.g., timeouts, +// resets, etc.). +const maxSegmentsPerWake = 100 + +type handshakeState int + +// The following are the possible states of the TCP connection during a 3-way +// handshake. A depiction of the states and transitions can be found in RFC 793, +// page 23. +const ( + handshakeSynSent handshakeState = iota + handshakeSynRcvd + handshakeCompleted +) + +// The following are used to set up sleepers. +const ( + wakerForNotification = iota + wakerForNewSegment + wakerForResend + wakerForResolution +) + +const ( + // Maximum space available for options. + maxOptionSize = 40 +) + +// handshake holds the state used during a TCP 3-way handshake. +type handshake struct { + ep *endpoint + listenEP *endpoint // only non nil when doing passive connects. + state handshakeState + active bool + flags uint8 + ackNum seqnum.Value + + // iss is the initial send sequence number, as defined in RFC 793. + iss seqnum.Value + + // rcvWnd is the receive window, as defined in RFC 793. + rcvWnd seqnum.Size + + // sndWnd is the send window, as defined in RFC 793. + sndWnd seqnum.Size + + // mss is the maximum segment size received from the peer. + mss uint16 + + // sndWndScale is the send window scale, as defined in RFC 1323. A + // negative value means no scaling is supported by the peer. + sndWndScale int + + // rcvWndScale is the receive window scale, as defined in RFC 1323. + rcvWndScale int +} + +func newHandshake(ep *endpoint, rcvWnd seqnum.Size) handshake { + h := handshake{ + ep: ep, + active: true, + rcvWnd: rcvWnd, + rcvWndScale: FindWndScale(rcvWnd), + } + h.resetState() + return h +} + +// FindWndScale determines the window scale to use for the given maximum window +// size. +func FindWndScale(wnd seqnum.Size) int { + if wnd < 0x10000 { + return 0 + } + + max := seqnum.Size(0xffff) + s := 0 + for wnd > max && s < header.MaxWndScale { + s++ + max <<= 1 + } + + return s +} + +// resetState resets the state of the handshake object such that it becomes +// ready for a new 3-way handshake. +func (h *handshake) resetState() { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + + h.state = handshakeSynSent + h.flags = header.TCPFlagSyn + h.ackNum = 0 + h.mss = 0 + h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) +} + +// effectiveRcvWndScale returns the effective receive window scale to be used. +// If the peer doesn't support window scaling, the effective rcv wnd scale is +// zero; otherwise it's the value calculated based on the initial rcv wnd. +func (h *handshake) effectiveRcvWndScale() uint8 { + if h.sndWndScale < 0 { + return 0 + } + return uint8(h.rcvWndScale) +} + +// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD +// state. +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions, listenEP *endpoint) { + h.active = false + h.state = handshakeSynRcvd + h.flags = header.TCPFlagSyn | header.TCPFlagAck + h.iss = iss + h.ackNum = irs + 1 + h.mss = opts.MSS + h.sndWndScale = opts.WS + h.listenEP = listenEP +} + +// checkAck checks if the ACK number, if present, of a segment received during +// a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in +// response. +func (h *handshake) checkAck(s *segment) bool { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber != h.iss+1 { + // RFC 793, page 36, states that a reset must be generated when + // the connection is in any non-synchronized state and an + // incoming segment acknowledges something not yet sent. The + // connection remains in the same state. + ack := s.sequenceNumber.Add(s.logicalLen()) + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, s.ackNumber, ack, 0) + return false + } + + return true +} + +// synSentState handles a segment received when the TCP 3-way handshake is in +// the SYN-SENT state. +func (h *handshake) synSentState(s *segment) *tcpip.Error { + // RFC 793, page 37, states that in the SYN-SENT state, a reset is + // acceptable if the ack field acknowledges the SYN. + if s.flagIsSet(header.TCPFlagRst) { + if s.flagIsSet(header.TCPFlagAck) && s.ackNumber == h.iss+1 { + return tcpip.ErrConnectionRefused + } + return nil + } + + if !h.checkAck(s) { + return nil + } + + // We are in the SYN-SENT state. We only care about segments that have + // the SYN flag. + if !s.flagIsSet(header.TCPFlagSyn) { + return nil + } + + // Parse the SYN options. + rcvSynOpts := parseSynSegmentOptions(s) + + // Remember if the Timestamp option was negotiated. + h.ep.maybeEnableTimestamp(&rcvSynOpts) + + // Remember if the SACKPermitted option was negotiated. + h.ep.maybeEnableSACKPermitted(&rcvSynOpts) + + // Remember the sequence we'll ack from now on. + h.ackNum = s.sequenceNumber + 1 + h.flags |= header.TCPFlagAck + h.mss = rcvSynOpts.MSS + h.sndWndScale = rcvSynOpts.WS + + // If this is a SYN ACK response, we only need to acknowledge the SYN + // and the handshake is completed. + if s.flagIsSet(header.TCPFlagAck) { + h.state = handshakeCompleted + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) + return nil + } + + // A SYN segment was received, but no ACK in it. We acknowledge the SYN + // but resend our own SYN and wait for it to be acknowledged in the + // SYN-RCVD state. + h.state = handshakeSynRcvd + synOpts := header.TCPSynOptions{ + WS: h.rcvWndScale, + TS: rcvSynOpts.TS, + TSVal: h.ep.timestamp(), + TSEcr: h.ep.recentTS, + + // We only send SACKPermitted if the other side indicated it + // permits SACK. This is not explicitly defined in the RFC but + // this is the behaviour implemented by Linux. + SACKPermitted: rcvSynOpts.SACKPermitted, + } + sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + + return nil +} + +// synRcvdState handles a segment received when the TCP 3-way handshake is in +// the SYN-RCVD state. +func (h *handshake) synRcvdState(s *segment) *tcpip.Error { + if s.flagIsSet(header.TCPFlagRst) { + // RFC 793, page 37, states that in the SYN-RCVD state, a reset + // is acceptable if the sequence number is in the window. + if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { + return tcpip.ErrConnectionRefused + } + return nil + } + + if !h.checkAck(s) { + return nil + } + + if s.flagIsSet(header.TCPFlagSyn) && s.sequenceNumber != h.ackNum-1 { + // We received two SYN segments with different sequence + // numbers, so we reset this and restart the whole + // process, except that we don't reset the timer. + ack := s.sequenceNumber.Add(s.logicalLen()) + seq := seqnum.Value(0) + if s.flagIsSet(header.TCPFlagAck) { + seq = s.ackNumber + } + h.ep.sendRaw(buffer.VectorisedView{}, header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0) + + if !h.active { + return tcpip.ErrInvalidEndpointState + } + + h.resetState() + synOpts := header.TCPSynOptions{ + WS: h.rcvWndScale, + TS: h.ep.sendTSOk, + TSVal: h.ep.timestamp(), + TSEcr: h.ep.recentTS, + SACKPermitted: h.ep.sackPermitted, + } + sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + return nil + } + + // We have previously received (and acknowledged) the peer's SYN. If the + // peer acknowledges our SYN, the handshake is completed. + if s.flagIsSet(header.TCPFlagAck) { + // listenContext is also used by a tcp.Forwarder and in that + // context we do not have a listening endpoint to check the + // backlog. So skip this check if listenEP is nil. + if h.listenEP != nil && len(h.listenEP.acceptedChan) == cap(h.listenEP.acceptedChan) { + // If there is no space in the accept queue to accept + // this endpoint then silently drop this ACK. The peer + // will anyway resend the ack and we can complete the + // connection the next time it's retransmitted. + h.ep.stack.Stats().TCP.ListenOverflowAckDrop.Increment() + h.ep.stack.Stats().DroppedPackets.Increment() + return nil + } + // If the timestamp option is negotiated and the segment does + // not carry a timestamp option then the segment must be dropped + // as per https://tools.ietf.org/html/rfc7323#section-3.2. + if h.ep.sendTSOk && !s.parsedOptions.TS { + h.ep.stack.Stats().DroppedPackets.Increment() + return nil + } + + // Update timestamp if required. See RFC7323, section-4.3. + if h.ep.sendTSOk && s.parsedOptions.TS { + h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) + } + h.state = handshakeCompleted + return nil + } + + return nil +} + +func (h *handshake) handleSegment(s *segment) *tcpip.Error { + h.sndWnd = s.window + if !s.flagIsSet(header.TCPFlagSyn) && h.sndWndScale > 0 { + h.sndWnd <<= uint8(h.sndWndScale) + } + + switch h.state { + case handshakeSynRcvd: + return h.synRcvdState(s) + case handshakeSynSent: + return h.synSentState(s) + } + return nil +} + +// processSegments goes through the segment queue and processes up to +// maxSegmentsPerWake (if they're available). +func (h *handshake) processSegments() *tcpip.Error { + for i := 0; i < maxSegmentsPerWake; i++ { + s := h.ep.segmentQueue.dequeue() + if s == nil { + return nil + } + + err := h.handleSegment(s) + s.decRef() + if err != nil { + return err + } + + // We stop processing packets once the handshake is completed, + // otherwise we may process packets meant to be processed by + // the main protocol goroutine. + if h.state == handshakeCompleted { + break + } + } + + // If the queue is not empty, make sure we'll wake up in the next + // iteration. + if !h.ep.segmentQueue.empty() { + h.ep.newSegmentWaker.Assert() + } + + return nil +} + +func (h *handshake) resolveRoute() *tcpip.Error { + // Set up the wakers. + s := sleep.Sleeper{} + resolutionWaker := &sleep.Waker{} + s.AddWaker(resolutionWaker, wakerForResolution) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + defer s.Done() + + // Initial action is to resolve route. + index := wakerForResolution + for { + switch index { + case wakerForResolution: + if _, err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { + // Either success (err == nil) or failure. + return err + } + // Resolution not completed. Keep trying... + + case wakerForNotification: + n := h.ep.fetchNotifications() + if n¬ifyClose != 0 { + h.ep.route.RemoveWaker(resolutionWaker) + return tcpip.ErrAborted + } + if n¬ifyDrain != 0 { + close(h.ep.drainDone) + <-h.ep.undrain + } + } + + // Wait for notification. + index, _ = s.Fetch(true) + } +} + +// execute executes the TCP 3-way handshake. +func (h *handshake) execute() *tcpip.Error { + if h.ep.route.IsResolutionRequired() { + if err := h.resolveRoute(); err != nil { + return err + } + } + + // Initialize the resend timer. + resendWaker := sleep.Waker{} + timeOut := time.Duration(time.Second) + rt := time.AfterFunc(timeOut, func() { + resendWaker.Assert() + }) + defer rt.Stop() + + // Set up the wakers. + s := sleep.Sleeper{} + s.AddWaker(&resendWaker, wakerForResend) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + defer s.Done() + + var sackEnabled SACKEnabled + if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { + // If stack returned an error when checking for SACKEnabled + // status then just default to switching off SACK negotiation. + sackEnabled = false + } + + // Send the initial SYN segment and loop until the handshake is + // completed. + synOpts := header.TCPSynOptions{ + WS: h.rcvWndScale, + TS: true, + TSVal: h.ep.timestamp(), + TSEcr: h.ep.recentTS, + SACKPermitted: bool(sackEnabled), + } + + // Execute is also called in a listen context so we want to make sure we + // only send the TS/SACK option when we received the TS/SACK in the + // initial SYN. + if h.state == handshakeSynRcvd { + synOpts.TS = h.ep.sendTSOk + synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) + } + sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + for h.state != handshakeCompleted { + switch index, _ := s.Fetch(true); index { + case wakerForResend: + timeOut *= 2 + if timeOut > 60*time.Second { + return tcpip.ErrTimeout + } + rt.Reset(timeOut) + sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + + case wakerForNotification: + n := h.ep.fetchNotifications() + if n¬ifyClose != 0 { + return tcpip.ErrAborted + } + if n¬ifyDrain != 0 { + for !h.ep.segmentQueue.empty() { + s := h.ep.segmentQueue.dequeue() + err := h.handleSegment(s) + s.decRef() + if err != nil { + return err + } + if h.state == handshakeCompleted { + return nil + } + } + close(h.ep.drainDone) + <-h.ep.undrain + } + + case wakerForNewSegment: + if err := h.processSegments(); err != nil { + return err + } + } + } + + return nil +} + +func parseSynSegmentOptions(s *segment) header.TCPSynOptions { + synOpts := header.ParseSynOptions(s.options, s.flagIsSet(header.TCPFlagAck)) + if synOpts.TS { + s.parsedOptions.TSVal = synOpts.TSVal + s.parsedOptions.TSEcr = synOpts.TSEcr + } + return synOpts +} + +var optionPool = sync.Pool{ + New: func() interface{} { + return make([]byte, maxOptionSize) + }, +} + +func getOptions() []byte { + return optionPool.Get().([]byte) +} + +func putOptions(options []byte) { + // Reslice to full capacity. + optionPool.Put(options[0:cap(options)]) +} + +func makeSynOptions(opts header.TCPSynOptions) []byte { + // Emulate linux option order. This is as follows: + // + // if md5: NOP NOP MD5SIG 18 md5sig(16) + // if mss: MSS 4 mss(2) + // if ts and sack_advertise: + // SACK 2 TIMESTAMP 2 timestamp(8) + // elif ts: NOP NOP TIMESTAMP 10 timestamp(8) + // elif sack: NOP NOP SACK 2 + // if wscale: NOP WINDOW 3 ws(1) + // if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8)) + // [for each block] start_seq(4) end_seq(4) + // if fastopen_cookie: + // if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2) + // else: FASTOPEN (2 + len(cookie)) + // cookie(variable) [padding to four bytes] + // + options := getOptions() + + // Always encode the mss. + offset := header.EncodeMSSOption(uint32(opts.MSS), options) + + // Special ordering is required here. If both TS and SACK are enabled, + // then the SACK option precedes TS, with no padding. If they are + // enabled individually, then we see padding before the option. + if opts.TS && opts.SACKPermitted { + offset += header.EncodeSACKPermittedOption(options[offset:]) + offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) + } else if opts.TS { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) + } else if opts.SACKPermitted { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeSACKPermittedOption(options[offset:]) + } + + // Initialize the WS option. + if opts.WS >= 0 { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeWSOption(opts.WS, options[offset:]) + } + + // Padding to the end; note that this never apply unless we add a + // fastopen option, we always expect the offset to remain the same. + if delta := header.AddTCPOptionPadding(options, offset); delta != 0 { + panic("unexpected option encoding") + } + + return options[:offset] +} + +func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { + // The MSS in opts is automatically calculated as this function is + // called from many places and we don't want every call point being + // embedded with the MSS calculation. + if opts.MSS == 0 { + opts.MSS = uint16(r.MTU() - header.TCPMinimumSize) + } + + options := makeSynOptions(opts) + err := sendTCP(r, id, buffer.VectorisedView{}, r.DefaultTTL(), flags, seq, ack, rcvWnd, options, nil) + putOptions(options) + return err +} + +// sendTCP sends a TCP segment with the provided options via the provided +// network endpoint and under the provided identity. +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.VectorisedView, ttl uint8, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte, gso *stack.GSO) *tcpip.Error { + optLen := len(opts) + // Allocate a buffer for the TCP header. + hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) + + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + // Initialize the header. + tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) + tcp.Encode(&header.TCPFields{ + SrcPort: id.LocalPort, + DstPort: id.RemotePort, + SeqNum: uint32(seq), + AckNum: uint32(ack), + DataOffset: uint8(header.TCPMinimumSize + optLen), + Flags: flags, + WindowSize: uint16(rcvWnd), + }) + copy(tcp[header.TCPMinimumSize:], opts) + + length := uint16(hdr.UsedLength() + data.Size()) + xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) + // Only calculate the checksum if offloading isn't supported. + if gso != nil && gso.NeedsCsum { + // This is called CHECKSUM_PARTIAL in the Linux kernel. We + // calculate a checksum of the pseudo-header and save it in the + // TCP header, then the kernel calculate a checksum of the + // header and data and get the right sum of the TCP packet. + tcp.SetChecksum(xsum) + } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { + xsum = header.ChecksumVV(data, xsum) + tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) + } + + r.Stats().TCP.SegmentsSent.Increment() + if (flags & header.TCPFlagRst) != 0 { + r.Stats().TCP.ResetsSent.Increment() + } + + return r.WritePacket(gso, hdr, data, ProtocolNumber, ttl) +} + +// makeOptions makes an options slice. +func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { + options := getOptions() + offset := 0 + + // N.B. the ordering here matches the ordering used by Linux internally + // and described in the raw makeOptions function. We don't include + // unnecessary cases here (post connection.) + if e.sendTSOk { + // Embed the timestamp if timestamp has been enabled. + // + // We only use the lower 32 bits of the unix time in + // milliseconds. This is similar to what Linux does where it + // uses the lower 32 bits of the jiffies value in the tsVal + // field of the timestamp option. + // + // Further, RFC7323 section-5.4 recommends millisecond + // resolution as the lowest recommended resolution for the + // timestamp clock. + // + // Ref: https://tools.ietf.org/html/rfc7323#section-5.4. + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:]) + } + if e.sackPermitted && len(sackBlocks) > 0 { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) + } + + // We expect the above to produce an aligned offset. + if delta := header.AddTCPOptionPadding(options, offset); delta != 0 { + panic("unexpected option encoding") + } + + return options[:offset] +} + +// sendRaw sends a TCP segment to the endpoint's peer. +func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { + var sackBlocks []header.SACKBlock + if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] + } + options := e.makeOptions(sackBlocks) + err := sendTCP(&e.route, e.id, data, e.route.DefaultTTL(), flags, seq, ack, rcvWnd, options, e.gso) + putOptions(options) + return err +} + +func (e *endpoint) handleWrite() *tcpip.Error { + // Move packets from send queue to send list. The queue is accessible + // from other goroutines and protected by the send mutex, while the send + // list is only accessible from the handler goroutine, so it needs no + // mutexes. + e.sndBufMu.Lock() + + first := e.sndQueue.Front() + if first != nil { + e.snd.writeList.PushBackList(&e.sndQueue) + e.snd.sndNxtList.UpdateForward(e.sndBufInQueue) + e.sndBufInQueue = 0 + } + + e.sndBufMu.Unlock() + + // Initialize the next segment to write if it's currently nil. + if e.snd.writeNext == nil { + e.snd.writeNext = first + } + + // Push out any new packets. + e.snd.sendData() + + return nil +} + +func (e *endpoint) handleClose() *tcpip.Error { + // Drain the send queue. + e.handleWrite() + + // Mark send side as closed. + e.snd.closed = true + + return nil +} + +// resetConnectionLocked sends a RST segment and puts the endpoint in an error +// state with the given error code. This method must only be called from the +// protocol goroutine. +func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { + e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + + e.state = stateError + e.hardError = err +} + +// completeWorkerLocked is called by the worker goroutine when it's about to +// exit. It marks the worker as completed and performs cleanup work if requested +// by Close(). +func (e *endpoint) completeWorkerLocked() { + e.workerRunning = false + if e.workerCleanup { + e.cleanupLocked() + } +} + +// handleSegments pulls segments from the queue and processes them. It returns +// no error if the protocol loop should continue, an error otherwise. +func (e *endpoint) handleSegments() *tcpip.Error { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + + // Invoke the tcp probe if installed. + if e.probe != nil { + e.probe(e.completeState()) + } + + if s.flagIsSet(header.TCPFlagRst) { + if e.rcv.acceptable(s.sequenceNumber, 0) { + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + s.decRef() + return tcpip.ErrConnectionReset + } + } else if s.flagIsSet(header.TCPFlagAck) { + // Patch the window size in the segment according to the + // send window scale. + s.window <<= e.snd.sndWndScale + + // RFC 793, page 41 states that "once in the ESTABLISHED + // state all segments must carry current acknowledgment + // information." + e.rcv.handleRcvdSegment(s) + e.snd.handleRcvdSegment(s) + } + s.decRef() + } + + // If the queue is not empty, make sure we'll wake up in the next + // iteration. + if checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + + // Send an ACK for all processed packets if needed. + if e.rcv.rcvNxt != e.snd.maxSentAck { + e.snd.sendAck() + } + + e.resetKeepaliveTimer(true) + + return nil +} + +// keepaliveTimerExpired is called when the keepaliveTimer fires. We send TCP +// keepalive packets periodically when the connection is idle. If we don't hear +// from the other side after a number of tries, we terminate the connection. +func (e *endpoint) keepaliveTimerExpired() *tcpip.Error { + e.keepalive.Lock() + if !e.keepalive.enabled || !e.keepalive.timer.checkExpiration() { + e.keepalive.Unlock() + return nil + } + + if e.keepalive.unacked >= e.keepalive.count { + e.keepalive.Unlock() + return tcpip.ErrConnectionReset + } + + // RFC1122 4.2.3.6: TCP keepalive is a dataless ACK with + // seg.seq = snd.nxt-1. + e.keepalive.unacked++ + e.keepalive.Unlock() + e.snd.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, e.snd.sndNxt-1) + e.resetKeepaliveTimer(false) + return nil +} + +// resetKeepaliveTimer restarts or stops the keepalive timer, depending on +// whether it is enabled for this endpoint. +func (e *endpoint) resetKeepaliveTimer(receivedData bool) { + e.keepalive.Lock() + defer e.keepalive.Unlock() + if receivedData { + e.keepalive.unacked = 0 + } + // Start the keepalive timer IFF it's enabled and there is no pending + // data to send. + if !e.keepalive.enabled || e.snd == nil || e.snd.sndUna != e.snd.sndNxt { + e.keepalive.timer.disable() + return + } + if e.keepalive.unacked > 0 { + e.keepalive.timer.enable(e.keepalive.interval) + } else { + e.keepalive.timer.enable(e.keepalive.idle) + } +} + +// disableKeepaliveTimer stops the keepalive timer. +func (e *endpoint) disableKeepaliveTimer() { + e.keepalive.Lock() + e.keepalive.timer.disable() + e.keepalive.Unlock() +} + +// protocolMainLoop is the main loop of the TCP protocol. It runs in its own +// goroutine and is responsible for sending segments and handling received +// segments. +func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { + var closeTimer *time.Timer + var closeWaker sleep.Waker + + epilogue := func() { + // e.mu is expected to be hold upon entering this section. + + if e.snd != nil { + e.snd.resendTimer.cleanup() + } + + if closeTimer != nil { + closeTimer.Stop() + } + + e.completeWorkerLocked() + + if e.drainDone != nil { + close(e.drainDone) + } + + e.mu.Unlock() + + // When the protocol loop exits we should wake up our waiters. + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) + } + + if handshake { + // This is an active connection, so we must initiate the 3-way + // handshake, and then inform potential waiters about its + // completion. + h := newHandshake(e, seqnum.Size(e.receiveBufferAvailable())) + if err := h.execute(); err != nil { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + e.mu.Lock() + e.state = stateError + e.hardError = err + // Lock released below. + epilogue() + + return err + } + + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + + e.rcvListMu.Lock() + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) + e.rcvListMu.Unlock() + } + + e.keepalive.timer.init(&e.keepalive.waker) + defer e.keepalive.timer.cleanup() + + // Tell waiters that the endpoint is connected and writable. + e.mu.Lock() + e.state = stateConnected + drained := e.drainDone != nil + e.mu.Unlock() + if drained { + close(e.drainDone) + <-e.undrain + } + + e.waiterQueue.Notify(waiter.EventOut) + + // Set up the functions that will be called when the main protocol loop + // wakes up. + funcs := []struct { + w *sleep.Waker + f func() *tcpip.Error + }{ + { + w: &e.sndWaker, + f: e.handleWrite, + }, + { + w: &e.sndCloseWaker, + f: e.handleClose, + }, + { + w: &e.newSegmentWaker, + f: e.handleSegments, + }, + { + w: &closeWaker, + f: func() *tcpip.Error { + return tcpip.ErrConnectionAborted + }, + }, + { + w: &e.snd.resendWaker, + f: func() *tcpip.Error { + if !e.snd.retransmitTimerExpired() { + return tcpip.ErrTimeout + } + return nil + }, + }, + { + w: &e.keepalive.waker, + f: e.keepaliveTimerExpired, + }, + { + w: &e.notificationWaker, + f: func() *tcpip.Error { + n := e.fetchNotifications() + if n¬ifyNonZeroReceiveWindow != 0 { + e.rcv.nonZeroWindow() + } + + if n¬ifyReceiveWindowChanged != 0 { + e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize()) + } + + if n¬ifyMTUChanged != 0 { + e.sndBufMu.Lock() + count := e.packetTooBigCount + e.packetTooBigCount = 0 + mtu := e.sndMTU + e.sndBufMu.Unlock() + + e.snd.updateMaxPayloadSize(mtu, count) + } + + if n¬ifyReset != 0 { + e.mu.Lock() + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.mu.Unlock() + } + if n¬ifyClose != 0 && closeTimer == nil { + // Reset the connection 3 seconds after + // the endpoint has been closed. + // + // The timer could fire in background + // when the endpoint is drained. That's + // OK as the loop here will not honor + // the firing until the undrain arrives. + closeTimer = time.AfterFunc(3*time.Second, func() { + closeWaker.Assert() + }) + } + + if n¬ifyKeepaliveChanged != 0 { + // The timer could fire in background + // when the endpoint is drained. That's + // OK. See above. + e.resetKeepaliveTimer(true) + } + + if n¬ifyDrain != 0 { + for !e.segmentQueue.empty() { + if err := e.handleSegments(); err != nil { + return err + } + } + if e.state != stateError { + close(e.drainDone) + <-e.undrain + } + } + + return nil + }, + }, + } + + // Initialize the sleeper based on the wakers in funcs. + s := sleep.Sleeper{} + for i := range funcs { + s.AddWaker(funcs[i].w, i) + } + + // The following assertions and notifications are needed for restored + // endpoints. Fresh newly created endpoints have empty states and should + // not invoke any. + e.segmentQueue.mu.Lock() + if !e.segmentQueue.list.Empty() { + e.newSegmentWaker.Assert() + } + e.segmentQueue.mu.Unlock() + + e.rcvListMu.Lock() + if !e.rcvList.Empty() { + e.waiterQueue.Notify(waiter.EventIn) + } + e.rcvListMu.Unlock() + + e.mu.RLock() + if e.workerCleanup { + e.notifyProtocolGoroutine(notifyClose) + } + e.mu.RUnlock() + + // Main loop. Handle segments until both send and receive ends of the + // connection have completed. + for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { + e.workMu.Unlock() + v, _ := s.Fetch(true) + e.workMu.Lock() + if err := funcs[v].f(); err != nil { + e.mu.Lock() + e.resetConnectionLocked(err) + // Lock released below. + epilogue() + + return nil + } + } + + // Mark endpoint as closed. + e.mu.Lock() + if e.state != stateError { + e.state = stateClosed + } + // Lock released below. + epilogue() + + return nil +} diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go new file mode 100644 index 000000000..e618cd2b9 --- /dev/null +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -0,0 +1,233 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "math" + "time" +) + +// cubicState stores the variables related to TCP CUBIC congestion +// control algorithm state. +// +// See: https://tools.ietf.org/html/rfc8312. +type cubicState struct { + // wLastMax is the previous wMax value. + wLastMax float64 + + // wMax is the value of the congestion window at the + // time of last congestion event. + wMax float64 + + // t denotes the time when the current congestion avoidance + // was entered. + t time.Time + + // numCongestionEvents tracks the number of congestion events since last + // RTO. + numCongestionEvents int + + // c is the cubic constant as specified in RFC8312. It's fixed at 0.4 as + // per RFC. + c float64 + + // k is the time period that the above function takes to increase the + // current window size to W_max if there are no further congestion + // events and is calculated using the following equation: + // + // K = cubic_root(W_max*(1-beta_cubic)/C) (Eq. 2) + k float64 + + // beta is the CUBIC multiplication decrease factor. that is, when a + // congestion event is detected, CUBIC reduces its cwnd to + // W_cubic(0)=W_max*beta_cubic. + beta float64 + + // wC is window computed by CUBIC at time t. It's calculated using the + // formula: + // + // W_cubic(t) = C*(t-K)^3 + W_max (Eq. 1) + wC float64 + + // wEst is the window computed by CUBIC at time t+RTT i.e + // W_cubic(t+RTT). + wEst float64 + + s *sender +} + +// newCubicCC returns a partially initialized cubic state with the constants +// beta and c set and t set to current time. +func newCubicCC(s *sender) *cubicState { + return &cubicState{ + t: time.Now(), + beta: 0.7, + c: 0.4, + s: s, + } +} + +// enterCongestionAvoidance is used to initialize cubic in cases where we exit +// SlowStart without a real congestion event taking place. This can happen when +// a connection goes back to slow start due to a retransmit and we exceed the +// previously lowered ssThresh without experiencing packet loss. +// +// Refer: https://tools.ietf.org/html/rfc8312#section-4.8 +func (c *cubicState) enterCongestionAvoidance() { + // See: https://tools.ietf.org/html/rfc8312#section-4.7 & + // https://tools.ietf.org/html/rfc8312#section-4.8 + if c.numCongestionEvents == 0 { + c.k = 0 + c.t = time.Now() + c.wLastMax = c.wMax + c.wMax = float64(c.s.sndCwnd) + } +} + +// updateSlowStart will update the congestion window as per the slow-start +// algorithm used by NewReno. If after adjusting the congestion window we cross +// the ssThresh then it will return the number of packets that must be consumed +// in congestion avoidance mode. +func (c *cubicState) updateSlowStart(packetsAcked int) int { + // Don't let the congestion window cross into the congestion + // avoidance range. + newcwnd := c.s.sndCwnd + packetsAcked + enterCA := false + if newcwnd >= c.s.sndSsthresh { + newcwnd = c.s.sndSsthresh + c.s.sndCAAckCount = 0 + enterCA = true + } + + packetsAcked -= newcwnd - c.s.sndCwnd + c.s.sndCwnd = newcwnd + if enterCA { + c.enterCongestionAvoidance() + } + return packetsAcked +} + +// Update updates cubic's internal state variables. It must be called on every +// ACK received. +// Refer: https://tools.ietf.org/html/rfc8312#section-4 +func (c *cubicState) Update(packetsAcked int) { + if c.s.sndCwnd < c.s.sndSsthresh { + packetsAcked = c.updateSlowStart(packetsAcked) + if packetsAcked == 0 { + return + } + } else { + c.s.rtt.Lock() + srtt := c.s.rtt.srtt + c.s.rtt.Unlock() + c.s.sndCwnd = c.getCwnd(packetsAcked, c.s.sndCwnd, srtt) + } +} + +// cubicCwnd computes the CUBIC congestion window after t seconds from last +// congestion event. +func (c *cubicState) cubicCwnd(t float64) float64 { + return c.c*math.Pow(t, 3.0) + c.wMax +} + +// getCwnd returns the current congestion window as computed by CUBIC. +// Refer: https://tools.ietf.org/html/rfc8312#section-4 +func (c *cubicState) getCwnd(packetsAcked, sndCwnd int, srtt time.Duration) int { + elapsed := time.Since(c.t).Seconds() + + // Compute the window as per Cubic after 'elapsed' time + // since last congestion event. + c.wC = c.cubicCwnd(elapsed - c.k) + + // Compute the TCP friendly estimate of the congestion window. + c.wEst = c.wMax*c.beta + (3.0*((1.0-c.beta)/(1.0+c.beta)))*(elapsed/srtt.Seconds()) + + // Make sure in the TCP friendly region CUBIC performs at least + // as well as Reno. + if c.wC < c.wEst && float64(sndCwnd) < c.wEst { + // TCP Friendly region of cubic. + return int(c.wEst) + } + + // In Concave/Convex region of CUBIC, calculate what CUBIC window + // will be after 1 RTT and use that to grow congestion window + // for every ack. + tEst := (time.Since(c.t) + srtt).Seconds() + wtRtt := c.cubicCwnd(tEst - c.k) + // As per 4.3 for each received ACK cwnd must be incremented + // by (w_cubic(t+RTT) - cwnd/cwnd. + cwnd := float64(sndCwnd) + for i := 0; i < packetsAcked; i++ { + // Concave/Convex regions of cubic have the same formulas. + // See: https://tools.ietf.org/html/rfc8312#section-4.3 + cwnd += (wtRtt - cwnd) / cwnd + } + return int(cwnd) +} + +// HandleNDupAcks implements congestionControl.HandleNDupAcks. +func (c *cubicState) HandleNDupAcks() { + // See: https://tools.ietf.org/html/rfc8312#section-4.5 + c.numCongestionEvents++ + c.t = time.Now() + c.wLastMax = c.wMax + c.wMax = float64(c.s.sndCwnd) + + c.fastConvergence() + c.reduceSlowStartThreshold() +} + +// HandleRTOExpired implements congestionContrl.HandleRTOExpired. +func (c *cubicState) HandleRTOExpired() { + // See: https://tools.ietf.org/html/rfc8312#section-4.6 + c.t = time.Now() + c.numCongestionEvents = 0 + c.wLastMax = c.wMax + c.wMax = float64(c.s.sndCwnd) + + c.fastConvergence() + + // We lost a packet, so reduce ssthresh. + c.reduceSlowStartThreshold() + + // Reduce the congestion window to 1, i.e., enter slow-start. Per + // RFC 5681, page 7, we must use 1 regardless of the value of the + // initial congestion window. + c.s.sndCwnd = 1 +} + +// fastConvergence implements the logic for Fast Convergence algorithm as +// described in https://tools.ietf.org/html/rfc8312#section-4.6. +func (c *cubicState) fastConvergence() { + if c.wMax < c.wLastMax { + c.wLastMax = c.wMax + c.wMax = c.wMax * (1.0 + c.beta) / 2.0 + } else { + c.wLastMax = c.wMax + } + // Recompute k as wMax may have changed. + c.k = math.Cbrt(c.wMax * (1 - c.beta) / c.c) +} + +// PostRecovery implemements congestionControl.PostRecovery. +func (c *cubicState) PostRecovery() { + c.t = time.Now() +} + +// reduceSlowStartThreshold returns new SsThresh as described in +// https://tools.ietf.org/html/rfc8312#section-4.7. +func (c *cubicState) reduceSlowStartThreshold() { + c.s.sndSsthresh = int(math.Max(float64(c.s.sndCwnd)*c.beta, 2.0)) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go new file mode 100644 index 000000000..fd697402e --- /dev/null +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -0,0 +1,1741 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "fmt" + "math" + "sync" + "sync/atomic" + "time" + + "gvisor.googlesource.com/gvisor/pkg/rand" + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tmutex" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +type endpointState int + +const ( + stateInitial endpointState = iota + stateBound + stateListen + stateConnecting + stateConnected + stateClosed + stateError +) + +// Reasons for notifying the protocol goroutine. +const ( + notifyNonZeroReceiveWindow = 1 << iota + notifyReceiveWindowChanged + notifyClose + notifyMTUChanged + notifyDrain + notifyReset + notifyKeepaliveChanged +) + +// SACKInfo holds TCP SACK related information for a given endpoint. +// +// +stateify savable +type SACKInfo struct { + // Blocks is the maximum number of SACK blocks we track + // per endpoint. + Blocks [MaxSACKBlocks]header.SACKBlock + + // NumBlocks is the number of valid SACK blocks stored in the + // blocks array above. + NumBlocks int +} + +// endpoint represents a TCP endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. The protocol implementation, however, runs in a single +// goroutine. +// +// +stateify savable +type endpoint struct { + // workMu is used to arbitrate which goroutine may perform protocol + // work. Only the main protocol goroutine is expected to call Lock() on + // it, but other goroutines (e.g., send) may call TryLock() to eagerly + // perform work without having to wait for the main one to wake up. + workMu tmutex.Mutex `state:"nosave"` + + // The following fields are initialized at creation time and do not + // change throughout the lifetime of the endpoint. + stack *stack.Stack `state:"manual"` + netProto tcpip.NetworkProtocolNumber + waiterQueue *waiter.Queue `state:"wait"` + + // lastError represents the last error that the endpoint reported; + // access to it is protected by the following mutex. + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error `state:".(string)"` + + // The following fields are used to manage the receive queue. The + // protocol goroutine adds ready-for-delivery segments to rcvList, + // which are returned by Read() calls to users. + // + // Once the peer has closed its send side, rcvClosed is set to true + // to indicate to users that no more data is coming. + // + // rcvListMu can be taken after the endpoint mu below. + rcvListMu sync.Mutex `state:"nosave"` + rcvList segmentList `state:"wait"` + rcvClosed bool + rcvBufSize int + rcvBufUsed int + + // The following fields are protected by the mutex. + mu sync.RWMutex `state:"nosave"` + id stack.TransportEndpointID + state endpointState `state:".(endpointState)"` + isPortReserved bool `state:"manual"` + isRegistered bool + boundNICID tcpip.NICID `state:"manual"` + route stack.Route `state:"manual"` + v6only bool + isConnectNotified bool + // TCP should never broadcast but Linux nevertheless supports enabling/ + // disabling SO_BROADCAST, albeit as a NOOP. + broadcast bool + + // effectiveNetProtos contains the network protocols actually in use. In + // most cases it will only contain "netProto", but in cases like IPv6 + // endpoints with v6only set to false, this could include multiple + // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., + // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped + // address). + effectiveNetProtos []tcpip.NetworkProtocolNumber `state:"manual"` + + // hardError is meaningful only when state is stateError, it stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. hardError is protected by mu. + hardError *tcpip.Error `state:".(string)"` + + // workerRunning specifies if a worker goroutine is running. + workerRunning bool + + // workerCleanup specifies if the worker goroutine must perform cleanup + // before exitting. This can only be set to true when workerRunning is + // also true, and they're both protected by the mutex. + workerCleanup bool + + // sendTSOk is used to indicate when the TS Option has been negotiated. + // When sendTSOk is true every non-RST segment should carry a TS as per + // RFC7323#section-1.1 + sendTSOk bool + + // recentTS is the timestamp that should be sent in the TSEcr field of + // the timestamp for future segments sent by the endpoint. This field is + // updated if required when a new segment is received by this endpoint. + recentTS uint32 + + // tsOffset is a randomized offset added to the value of the + // TSVal field in the timestamp option. + tsOffset uint32 + + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + + // sackPermitted is set to true if the peer sends the TCPSACKPermitted + // option in the SYN/SYN-ACK. + sackPermitted bool + + // sack holds TCP SACK related information for this endpoint. + sack SACKInfo + + // reusePort is set to true if SO_REUSEPORT is enabled. + reusePort bool + + // delay enables Nagle's algorithm. + // + // delay is a boolean (0 is false) and must be accessed atomically. + delay uint32 + + // cork holds back segments until full. + // + // cork is a boolean (0 is false) and must be accessed atomically. + cork uint32 + + // scoreboard holds TCP SACK Scoreboard information for this endpoint. + scoreboard *SACKScoreboard + + // The options below aren't implemented, but we remember the user + // settings because applications expect to be able to set/query these + // options. + reuseAddr bool + + // slowAck holds the negated state of quick ack. It is stubbed out and + // does nothing. + // + // slowAck is a boolean (0 is false) and must be accessed atomically. + slowAck uint32 + + // segmentQueue is used to hand received segments to the protocol + // goroutine. Segments are queued as long as the queue is not full, + // and dropped when it is. + segmentQueue segmentQueue `state:"wait"` + + // synRcvdCount is the number of connections for this endpoint that are + // in SYN-RCVD state. + synRcvdCount int + + // The following fields are used to manage the send buffer. When + // segments are ready to be sent, they are added to sndQueue and the + // protocol goroutine is signaled via sndWaker. + // + // When the send side is closed, the protocol goroutine is notified via + // sndCloseWaker, and sndClosed is set to true. + sndBufMu sync.Mutex `state:"nosave"` + sndBufSize int + sndBufUsed int + sndClosed bool + sndBufInQueue seqnum.Size + sndQueue segmentList `state:"wait"` + sndWaker sleep.Waker `state:"manual"` + sndCloseWaker sleep.Waker `state:"manual"` + + // cc stores the name of the Congestion Control algorithm to use for + // this endpoint. + cc CongestionControlOption + + // The following are used when a "packet too big" control packet is + // received. They are protected by sndBufMu. They are used to + // communicate to the main protocol goroutine how many such control + // messages have been received since the last notification was processed + // and what was the smallest MTU seen. + packetTooBigCount int + sndMTU int + + // newSegmentWaker is used to indicate to the protocol goroutine that + // it needs to wake up and handle new segments queued to it. + newSegmentWaker sleep.Waker `state:"manual"` + + // notificationWaker is used to indicate to the protocol goroutine that + // it needs to wake up and check for notifications. + notificationWaker sleep.Waker `state:"manual"` + + // notifyFlags is a bitmask of flags used to indicate to the protocol + // goroutine what it was notified; this is only accessed atomically. + notifyFlags uint32 `state:"nosave"` + + // keepalive manages TCP keepalive state. When the connection is idle + // (no data sent or received) for keepaliveIdle, we start sending + // keepalives every keepalive.interval. If we send keepalive.count + // without hearing a response, the connection is closed. + keepalive keepalive + + // acceptedChan is used by a listening endpoint protocol goroutine to + // send newly accepted connections to the endpoint so that they can be + // read by Accept() calls. + acceptedChan chan *endpoint `state:".([]*endpoint)"` + + // The following are only used from the protocol goroutine, and + // therefore don't need locks to protect them. + rcv *receiver `state:"wait"` + snd *sender `state:"wait"` + + // The goroutine drain completion notification channel. + drainDone chan struct{} `state:"nosave"` + + // The goroutine undrain notification channel. + undrain chan struct{} `state:"nosave"` + + // probe if not nil is invoked on every received segment. It is passed + // a copy of the current state of the endpoint. + probe stack.TCPProbeFunc `state:"nosave"` + + // The following are only used to assist the restore run to re-connect. + bindAddress tcpip.Address + connectingAddress tcpip.Address + + gso *stack.GSO +} + +// StopWork halts packet processing. Only to be used in tests. +func (e *endpoint) StopWork() { + e.workMu.Lock() +} + +// ResumeWork resumes packet processing. Only to be used in tests. +func (e *endpoint) ResumeWork() { + e.workMu.Unlock() +} + +// keepalive is a synchronization wrapper used to appease stateify. See the +// comment in endpoint, where it is used. +// +// +stateify savable +type keepalive struct { + sync.Mutex `state:"nosave"` + enabled bool + idle time.Duration + interval time.Duration + count int + unacked int + timer timer `state:"nosave"` + waker sleep.Waker `state:"nosave"` +} + +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { + e := &endpoint{ + stack: stack, + netProto: netProto, + waiterQueue: waiterQueue, + rcvBufSize: DefaultBufferSize, + sndBufSize: DefaultBufferSize, + sndMTU: int(math.MaxInt32), + reuseAddr: true, + keepalive: keepalive{ + // Linux defaults. + idle: 2 * time.Hour, + interval: 75 * time.Second, + count: 9, + }, + } + + var ss SendBufferSizeOption + if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + e.sndBufSize = ss.Default + } + + var rs ReceiveBufferSizeOption + if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + e.rcvBufSize = rs.Default + } + + var cs CongestionControlOption + if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil { + e.cc = cs + } + + if p := stack.GetTCPProbe(); p != nil { + e.probe = p + } + + e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.workMu.Init() + e.workMu.Lock() + e.tsOffset = timeStampOffset() + return e +} + +// Readiness returns the current readiness of the endpoint. For example, if +// waiter.EventIn is set, the endpoint is immediately readable. +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + result := waiter.EventMask(0) + + e.mu.RLock() + defer e.mu.RUnlock() + + switch e.state { + case stateInitial, stateBound, stateConnecting: + // Ready for nothing. + + case stateClosed, stateError: + // Ready for anything. + result = mask + + case stateListen: + // Check if there's anything in the accepted channel. + if (mask & waiter.EventIn) != 0 { + if len(e.acceptedChan) > 0 { + result |= waiter.EventIn + } + } + + case stateConnected: + // Determine if the endpoint is writable if requested. + if (mask & waiter.EventOut) != 0 { + e.sndBufMu.Lock() + if e.sndClosed || e.sndBufUsed < e.sndBufSize { + result |= waiter.EventOut + } + e.sndBufMu.Unlock() + } + + // Determine if the endpoint is readable if requested. + if (mask & waiter.EventIn) != 0 { + e.rcvListMu.Lock() + if e.rcvBufUsed > 0 || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvListMu.Unlock() + } + } + + return result +} + +func (e *endpoint) fetchNotifications() uint32 { + return atomic.SwapUint32(&e.notifyFlags, 0) +} + +func (e *endpoint) notifyProtocolGoroutine(n uint32) { + for { + v := atomic.LoadUint32(&e.notifyFlags) + if v&n == n { + // The flags are already set. + return + } + + if atomic.CompareAndSwapUint32(&e.notifyFlags, v, v|n) { + if v == 0 { + // We are causing a transition from no flags to + // at least one flag set, so we must cause the + // protocol goroutine to wake up. + e.notificationWaker.Assert() + } + return + } + } +} + +// Close puts the endpoint in a closed state and frees all resources associated +// with it. It must be called only once and with no other concurrent calls to +// the endpoint. +func (e *endpoint) Close() { + // Issue a shutdown so that the peer knows we won't send any more data + // if we're connected, or stop accepting if we're listening. + e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) + + e.mu.Lock() + + // For listening sockets, we always release ports inline so that they + // are immediately available for reuse after Close() is called. If also + // registered, we unregister as well otherwise the next user would fail + // in Listen() when trying to register. + if e.state == stateListen && e.isPortReserved { + if e.isRegistered { + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.isRegistered = false + } + + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.isPortReserved = false + } + + // Either perform the local cleanup or kick the worker to make sure it + // knows it needs to cleanup. + tcpip.AddDanglingEndpoint(e) + if !e.workerRunning { + e.cleanupLocked() + } else { + e.workerCleanup = true + e.notifyProtocolGoroutine(notifyClose) + } + + e.mu.Unlock() +} + +// cleanupLocked frees all resources associated with the endpoint. It is called +// after Close() is called and the worker goroutine (if any) is done with its +// work. +func (e *endpoint) cleanupLocked() { + // Close all endpoints that might have been accepted by TCP but not by + // the client. + if e.acceptedChan != nil { + close(e.acceptedChan) + for n := range e.acceptedChan { + n.mu.Lock() + n.resetConnectionLocked(tcpip.ErrConnectionAborted) + n.mu.Unlock() + n.Close() + } + e.acceptedChan = nil + } + e.workerCleanup = false + + if e.isRegistered { + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.isRegistered = false + } + + if e.isPortReserved { + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.isPortReserved = false + } + + e.route.Release() + tcpip.DeleteDanglingEndpoint(e) +} + +// Read reads data from the endpoint. +func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + e.mu.RLock() + // The endpoint can be read if it's connected, or if it's already closed + // but has some pending unread data. Also note that a RST being received + // would cause the state to become stateError so we should allow the + // reads to proceed before returning a ECONNRESET. + e.rcvListMu.Lock() + bufUsed := e.rcvBufUsed + if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 { + e.rcvListMu.Unlock() + he := e.hardError + e.mu.RUnlock() + if s == stateError { + return buffer.View{}, tcpip.ControlMessages{}, he + } + return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + } + + v, err := e.readLocked() + e.rcvListMu.Unlock() + + e.mu.RUnlock() + + return v, tcpip.ControlMessages{}, err +} + +func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { + if e.rcvBufUsed == 0 { + if e.rcvClosed || e.state != stateConnected { + return buffer.View{}, tcpip.ErrClosedForReceive + } + return buffer.View{}, tcpip.ErrWouldBlock + } + + s := e.rcvList.Front() + views := s.data.Views() + v := views[s.viewToDeliver] + s.viewToDeliver++ + + if s.viewToDeliver >= len(views) { + e.rcvList.Remove(s) + s.decRef() + } + + scale := e.rcv.rcvWndScale + wasZero := e.zeroReceiveWindow(scale) + e.rcvBufUsed -= len(v) + if wasZero && !e.zeroReceiveWindow(scale) { + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + } + + return v, nil +} + +// Write writes data to the endpoint's peer. +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { + // Linux completely ignores any address passed to sendto(2) for TCP sockets + // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More + // and opts.EndOfRecord are also ignored. + + e.mu.RLock() + defer e.mu.RUnlock() + + // The endpoint cannot be written to if it's not connected. + if e.state != stateConnected { + switch e.state { + case stateError: + return 0, nil, e.hardError + default: + return 0, nil, tcpip.ErrClosedForSend + } + } + + // Nothing to do if the buffer is empty. + if p.Size() == 0 { + return 0, nil, nil + } + + e.sndBufMu.Lock() + + // Check if the connection has already been closed for sends. + if e.sndClosed { + e.sndBufMu.Unlock() + return 0, nil, tcpip.ErrClosedForSend + } + + // Check against the limit. + avail := e.sndBufSize - e.sndBufUsed + if avail <= 0 { + e.sndBufMu.Unlock() + return 0, nil, tcpip.ErrWouldBlock + } + + v, perr := p.Get(avail) + if perr != nil { + e.sndBufMu.Unlock() + return 0, nil, perr + } + + l := len(v) + s := newSegmentFromView(&e.route, e.id, v) + + // Add data to the send queue. + e.sndBufUsed += l + e.sndBufInQueue += seqnum.Size(l) + e.sndQueue.PushBack(s) + + e.sndBufMu.Unlock() + + if e.workMu.TryLock() { + // Do the work inline. + e.handleWrite() + e.workMu.Unlock() + } else { + // Let the protocol goroutine do the work. + e.sndWaker.Assert() + } + return uintptr(l), nil, nil +} + +// Peek reads data without consuming it from the endpoint. +// +// This method does not block if there is no data pending. +func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // The endpoint can be read if it's connected, or if it's already closed + // but has some pending unread data. + if s := e.state; s != stateConnected && s != stateClosed { + if s == stateError { + return 0, tcpip.ControlMessages{}, e.hardError + } + return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState + } + + e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() + + if e.rcvBufUsed == 0 { + if e.rcvClosed || e.state != stateConnected { + return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive + } + return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + } + + // Make a copy of vec so we can modify the slide headers. + vec = append([][]byte(nil), vec...) + + var num uintptr + + for s := e.rcvList.Front(); s != nil; s = s.Next() { + views := s.data.Views() + + for i := s.viewToDeliver; i < len(views); i++ { + v := views[i] + + for len(v) > 0 { + if len(vec) == 0 { + return num, tcpip.ControlMessages{}, nil + } + if len(vec[0]) == 0 { + vec = vec[1:] + continue + } + + n := copy(vec[0], v) + v = v[n:] + vec[0] = vec[0][n:] + num += uintptr(n) + } + } + } + + return num, tcpip.ControlMessages{}, nil +} + +// zeroReceiveWindow checks if the receive window to be announced now would be +// zero, based on the amount of available buffer and the receive window scaling. +// +// It must be called with rcvListMu held. +func (e *endpoint) zeroReceiveWindow(scale uint8) bool { + if e.rcvBufUsed >= e.rcvBufSize { + return true + } + + return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.DelayOption: + if v == 0 { + atomic.StoreUint32(&e.delay, 0) + + // Handle delayed data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.delay, 1) + } + return nil + + case tcpip.CorkOption: + if v == 0 { + atomic.StoreUint32(&e.cork, 0) + + // Handle the corked data. + e.sndWaker.Assert() + } else { + atomic.StoreUint32(&e.cork, 1) + } + return nil + + case tcpip.ReuseAddressOption: + e.mu.Lock() + e.reuseAddr = v != 0 + e.mu.Unlock() + return nil + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v != 0 + e.mu.Unlock() + return nil + + case tcpip.QuickAckOption: + if v == 0 { + atomic.StoreUint32(&e.slowAck, 1) + } else { + atomic.StoreUint32(&e.slowAck, 0) + } + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs ReceiveBufferSizeOption + size := int(v) + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if size < rs.Min { + size = rs.Min + } + if size > rs.Max { + size = rs.Max + } + } + + mask := uint32(notifyReceiveWindowChanged) + + e.rcvListMu.Lock() + + // Make sure the receive buffer size allows us to send a + // non-zero window size. + scale := uint8(0) + if e.rcv != nil { + scale = e.rcv.rcvWndScale + } + if size>>scale == 0 { + size = 1 << scale + } + + // Make sure 2*size doesn't overflow. + if size > math.MaxInt32/2 { + size = math.MaxInt32 / 2 + } + + wasZero := e.zeroReceiveWindow(scale) + e.rcvBufSize = size + if wasZero && !e.zeroReceiveWindow(scale) { + mask |= notifyNonZeroReceiveWindow + } + e.rcvListMu.Unlock() + + e.notifyProtocolGoroutine(mask) + return nil + + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + size := int(v) + var ss SendBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if size < ss.Min { + size = ss.Min + } + if size > ss.Max { + size = ss.Max + } + } + + e.sndBufMu.Lock() + e.sndBufSize = size + e.sndBufMu.Unlock() + return nil + + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + // We only allow this to be set when we're in the initial state. + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.v6only = v != 0 + return nil + + case tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + e.keepalive.enabled = v != 0 + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + return nil + + case tcpip.KeepaliveIdleOption: + e.keepalive.Lock() + e.keepalive.idle = time.Duration(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + return nil + + case tcpip.KeepaliveIntervalOption: + e.keepalive.Lock() + e.keepalive.interval = time.Duration(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + return nil + + case tcpip.KeepaliveCountOption: + e.keepalive.Lock() + e.keepalive.count = int(v) + e.keepalive.Unlock() + e.notifyProtocolGoroutine(notifyKeepaliveChanged) + return nil + + case tcpip.BroadcastOption: + e.mu.Lock() + e.broadcast = v != 0 + e.mu.Unlock() + return nil + + default: + return nil + } +} + +// readyReceiveSize returns the number of bytes ready to be received. +func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // The endpoint cannot be in listen state. + if e.state == stateListen { + return 0, tcpip.ErrInvalidEndpointState + } + + e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() + + return e.rcvBufUsed, nil +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + e.lastErrorMu.Lock() + err := e.lastError + e.lastError = nil + e.lastErrorMu.Unlock() + return err + + case *tcpip.SendBufferSizeOption: + e.sndBufMu.Lock() + *o = tcpip.SendBufferSizeOption(e.sndBufSize) + e.sndBufMu.Unlock() + return nil + + case *tcpip.ReceiveBufferSizeOption: + e.rcvListMu.Lock() + *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize) + e.rcvListMu.Unlock() + return nil + + case *tcpip.ReceiveQueueSizeOption: + v, err := e.readyReceiveSize() + if err != nil { + return err + } + + *o = tcpip.ReceiveQueueSizeOption(v) + return nil + + case *tcpip.DelayOption: + *o = 0 + if v := atomic.LoadUint32(&e.delay); v != 0 { + *o = 1 + } + return nil + + case *tcpip.CorkOption: + *o = 0 + if v := atomic.LoadUint32(&e.cork); v != 0 { + *o = 1 + } + return nil + + case *tcpip.ReuseAddressOption: + e.mu.RLock() + v := e.reuseAddr + e.mu.RUnlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.ReusePortOption: + e.mu.RLock() + v := e.reusePort + e.mu.RUnlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.QuickAckOption: + *o = 1 + if v := atomic.LoadUint32(&e.slowAck); v != 0 { + *o = 0 + } + return nil + + case *tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.TCPInfoOption: + *o = tcpip.TCPInfoOption{} + e.mu.RLock() + snd := e.snd + e.mu.RUnlock() + if snd != nil { + snd.rtt.Lock() + o.RTT = snd.rtt.srtt + o.RTTVar = snd.rtt.rttvar + snd.rtt.Unlock() + } + return nil + + case *tcpip.KeepaliveEnabledOption: + e.keepalive.Lock() + v := e.keepalive.enabled + e.keepalive.Unlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.KeepaliveIdleOption: + e.keepalive.Lock() + *o = tcpip.KeepaliveIdleOption(e.keepalive.idle) + e.keepalive.Unlock() + return nil + + case *tcpip.KeepaliveIntervalOption: + e.keepalive.Lock() + *o = tcpip.KeepaliveIntervalOption(e.keepalive.interval) + e.keepalive.Unlock() + return nil + + case *tcpip.KeepaliveCountOption: + e.keepalive.Lock() + *o = tcpip.KeepaliveCountOption(e.keepalive.count) + e.keepalive.Unlock() + return nil + + case *tcpip.OutOfBandInlineOption: + // We don't currently support disabling this option. + *o = 1 + return nil + + case *tcpip.BroadcastOption: + e.mu.Lock() + v := e.broadcast + e.mu.Unlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.netProto + if header.IsV4MappedAddress(addr.Addr) { + // Fail if using a v4 mapped address on a v6only endpoint. + if e.v6only { + return 0, tcpip.ErrNoRoute + } + + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == "\x00\x00\x00\x00" { + addr.Addr = "" + } + } + + // Fail if we're bound to an address length different from the one we're + // checking. + if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { + return 0, tcpip.ErrInvalidEndpointState + } + + return netProto, nil +} + +// Connect connects the endpoint to its peer. +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + return e.connect(addr, true, true) +} + +// connect connects the endpoint to its peer. In the normal non-S/R case, the +// new connection is expected to run the main goroutine and perform handshake. +// In restore of previously connected endpoints, both ends will be passively +// created (so no new handshaking is done); for stack-accepted connections not +// yet accepted by the app, they are restored without running the main goroutine +// here. +func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (err *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + defer func() { + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + } + }() + + connectingAddr := addr.Addr + + netProto, err := e.checkV4Mapped(&addr) + if err != nil { + return err + } + + nicid := addr.NIC + switch e.state { + case stateBound: + // If we're already bound to a NIC but the caller is requesting + // that we use a different one now, we cannot proceed. + if e.boundNICID == 0 { + break + } + + if nicid != 0 && nicid != e.boundNICID { + return tcpip.ErrNoRoute + } + + nicid = e.boundNICID + + case stateInitial: + // Nothing to do. We'll eventually fill-in the gaps in the ID + // (if any) when we find a route. + + case stateConnecting: + // A connection request has already been issued but hasn't + // completed yet. + return tcpip.ErrAlreadyConnecting + + case stateConnected: + // The endpoint is already connected. If caller hasn't been notified yet, return success. + if !e.isConnectNotified { + e.isConnectNotified = true + return nil + } + // Otherwise return that it's already connected. + return tcpip.ErrAlreadyConnected + + case stateError: + return e.hardError + + default: + return tcpip.ErrInvalidEndpointState + } + + // Find a route to the desired destination. + r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer r.Release() + + origID := e.id + + netProtos := []tcpip.NetworkProtocolNumber{netProto} + e.id.LocalAddress = r.LocalAddress + e.id.RemoteAddress = r.RemoteAddress + e.id.RemotePort = addr.Port + + if e.id.LocalPort != 0 { + // The endpoint is bound to a port, attempt to register it. + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e, e.reusePort) + if err != nil { + return err + } + } else { + // The endpoint doesn't have a local port yet, so try to get + // one. Make sure that it isn't one that will result in the same + // address/port for both local and remote (otherwise this + // endpoint would be trying to connect to itself). + sameAddr := e.id.LocalAddress == e.id.RemoteAddress + if _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + if sameAddr && p == e.id.RemotePort { + return false, nil + } + if !e.stack.IsPortAvailable(netProtos, ProtocolNumber, e.id.LocalAddress, p, false) { + return false, nil + } + + id := e.id + id.LocalPort = p + switch e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) { + case nil: + e.id = id + return true, nil + case tcpip.ErrPortInUse: + return false, nil + default: + return false, err + } + }); err != nil { + return err + } + } + + // Remove the port reservation. This can happen when Bind is called + // before Connect: in such a case we don't want to hold on to + // reservations anymore. + if e.isPortReserved { + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort) + e.isPortReserved = false + } + + e.isRegistered = true + e.state = stateConnecting + e.route = r.Clone() + e.boundNICID = nicid + e.effectiveNetProtos = netProtos + e.connectingAddress = connectingAddr + + e.initGSO() + + // Connect in the restore phase does not perform handshake. Restore its + // connection setting here. + if !handshake { + e.segmentQueue.mu.Lock() + for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { + for s := l.Front(); s != nil; s = s.Next() { + s.id = e.id + s.route = r.Clone() + e.sndWaker.Assert() + } + } + e.segmentQueue.mu.Unlock() + e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) + e.state = stateConnected + } + + if run { + e.workerRunning = true + e.stack.Stats().TCP.ActiveConnectionOpenings.Increment() + go e.protocolMainLoop(handshake) // S/R-SAFE: will be drained before save. + } + + return tcpip.ErrConnectStarted +} + +// ConnectEndpoint is not supported. +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + +// Shutdown closes the read and/or write end of the endpoint connection to its +// peer. +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + e.shutdownFlags |= flags + + switch e.state { + case stateConnected: + // Close for read. + if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { + // Mark read side as closed. + e.rcvListMu.Lock() + e.rcvClosed = true + rcvBufUsed := e.rcvBufUsed + e.rcvListMu.Unlock() + + // If we're fully closed and we have unread data we need to abort + // the connection with a RST. + if (e.shutdownFlags&tcpip.ShutdownWrite) != 0 && rcvBufUsed > 0 { + e.notifyProtocolGoroutine(notifyReset) + return nil + } + } + + // Close for write. + if (e.shutdownFlags & tcpip.ShutdownWrite) != 0 { + e.sndBufMu.Lock() + + if e.sndClosed { + // Already closed. + e.sndBufMu.Unlock() + break + } + + // Queue fin segment. + s := newSegmentFromView(&e.route, e.id, nil) + e.sndQueue.PushBack(s) + e.sndBufInQueue++ + + // Mark endpoint as closed. + e.sndClosed = true + + e.sndBufMu.Unlock() + + // Tell protocol goroutine to close. + e.sndCloseWaker.Assert() + } + + case stateListen: + // Tell protocolListenLoop to stop. + if flags&tcpip.ShutdownRead != 0 { + e.notifyProtocolGoroutine(notifyClose) + } + + default: + return tcpip.ErrNotConnected + } + + return nil +} + +// Listen puts the endpoint in "listen" mode, which allows it to accept +// new connections. +func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + defer func() { + if err != nil && !err.IgnoreStats() { + e.stack.Stats().TCP.FailedConnectionAttempts.Increment() + } + }() + + // Allow the backlog to be adjusted if the endpoint is not shutting down. + // When the endpoint shuts down, it sets workerCleanup to true, and from + // that point onward, acceptedChan is the responsibility of the cleanup() + // method (and should not be touched anywhere else, including here). + if e.state == stateListen && !e.workerCleanup { + // Adjust the size of the channel iff we can fix existing + // pending connections into the new one. + if len(e.acceptedChan) > backlog { + return tcpip.ErrInvalidEndpointState + } + if cap(e.acceptedChan) == backlog { + return nil + } + origChan := e.acceptedChan + e.acceptedChan = make(chan *endpoint, backlog) + close(origChan) + for ep := range origChan { + e.acceptedChan <- ep + } + return nil + } + + // Endpoint must be bound before it can transition to listen mode. + if e.state != stateBound { + return tcpip.ErrInvalidEndpointState + } + + // Register the endpoint. + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e, e.reusePort); err != nil { + return err + } + + e.isRegistered = true + e.state = stateListen + if e.acceptedChan == nil { + e.acceptedChan = make(chan *endpoint, backlog) + } + e.workerRunning = true + + go e.protocolListenLoop( // S/R-SAFE: drained on save. + seqnum.Size(e.receiveBufferAvailable())) + + return nil +} + +// startAcceptedLoop sets up required state and starts a goroutine with the +// main loop for accepted connections. +func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { + e.waiterQueue = waiterQueue + e.workerRunning = true + go e.protocolMainLoop(false) // S/R-SAFE: drained on save. +} + +// Accept returns a new endpoint if a peer has established a connection +// to an endpoint previously set to listen mode. +func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // Endpoint must be in listen state before it can accept connections. + if e.state != stateListen { + return nil, nil, tcpip.ErrInvalidEndpointState + } + + // Get the new accepted endpoint. + var n *endpoint + select { + case n = <-e.acceptedChan: + default: + return nil, nil, tcpip.ErrWouldBlock + } + + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() + + return n, wq, nil +} + +// Bind binds the endpoint to a specific local port and optionally address. +func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + + // Don't allow binding once endpoint is not in the initial state + // anymore. This is because once the endpoint goes into a connected or + // listen state, it is already bound. + if e.state != stateInitial { + return tcpip.ErrAlreadyBound + } + + e.bindAddress = addr.Addr + netProto, err := e.checkV4Mapped(&addr) + if err != nil { + return err + } + + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } + } + + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port, e.reusePort) + if err != nil { + return err + } + + e.isPortReserved = true + e.effectiveNetProtos = netProtos + e.id.LocalPort = port + + // Any failures beyond this point must remove the port registration. + defer func() { + if err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) + e.isPortReserved = false + e.effectiveNetProtos = nil + e.id.LocalPort = 0 + e.id.LocalAddress = "" + e.boundNICID = 0 + } + }() + + // If an address is specified, we must ensure that it's one of our + // local addresses. + if len(addr.Addr) != 0 { + nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nic == 0 { + return tcpip.ErrBadLocalAddress + } + + e.boundNICID = nic + e.id.LocalAddress = addr.Addr + } + + // Mark endpoint as bound. + e.state = stateBound + + return nil +} + +// GetLocalAddress returns the address to which the endpoint is bound. +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + return tcpip.FullAddress{ + Addr: e.id.LocalAddress, + Port: e.id.LocalPort, + NIC: e.boundNICID, + }, nil +} + +// GetRemoteAddress returns the address to which the endpoint is connected. +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state != stateConnected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + Addr: e.id.RemoteAddress, + Port: e.id.RemotePort, + NIC: e.boundNICID, + }, nil +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { + s := newSegment(r, id, vv) + if !s.parse() { + e.stack.Stats().MalformedRcvdPackets.Increment() + e.stack.Stats().TCP.InvalidSegmentsReceived.Increment() + s.decRef() + return + } + + if !s.csumValid { + e.stack.Stats().MalformedRcvdPackets.Increment() + e.stack.Stats().TCP.ChecksumErrors.Increment() + s.decRef() + return + } + + e.stack.Stats().TCP.ValidSegmentsReceived.Increment() + if (s.flags & header.TCPFlagRst) != 0 { + e.stack.Stats().TCP.ResetsReceived.Increment() + } + + // Send packet to worker goroutine. + if e.segmentQueue.enqueue(s) { + e.newSegmentWaker.Assert() + } else { + // The queue is full, so we drop the segment. + e.stack.Stats().DroppedPackets.Increment() + s.decRef() + } +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { + switch typ { + case stack.ControlPacketTooBig: + e.sndBufMu.Lock() + e.packetTooBigCount++ + if v := int(extra); v < e.sndMTU { + e.sndMTU = v + } + e.sndBufMu.Unlock() + + e.notifyProtocolGoroutine(notifyMTUChanged) + } +} + +// updateSndBufferUsage is called by the protocol goroutine when room opens up +// in the send buffer. The number of newly available bytes is v. +func (e *endpoint) updateSndBufferUsage(v int) { + e.sndBufMu.Lock() + notify := e.sndBufUsed >= e.sndBufSize>>1 + e.sndBufUsed -= v + // We only notify when there is half the sndBufSize available after + // a full buffer event occurs. This ensures that we don't wake up + // writers to queue just 1-2 segments and go back to sleep. + notify = notify && e.sndBufUsed < e.sndBufSize>>1 + e.sndBufMu.Unlock() + + if notify { + e.waiterQueue.Notify(waiter.EventOut) + } +} + +// readyToRead is called by the protocol goroutine when a new segment is ready +// to be read, or when the connection is closed for receiving (in which case +// s will be nil). +func (e *endpoint) readyToRead(s *segment) { + e.rcvListMu.Lock() + if s != nil { + s.incRef() + e.rcvBufUsed += s.data.Size() + e.rcvList.PushBack(s) + } else { + e.rcvClosed = true + } + e.rcvListMu.Unlock() + + e.waiterQueue.Notify(waiter.EventIn) +} + +// receiveBufferAvailable calculates how many bytes are still available in the +// receive buffer. +func (e *endpoint) receiveBufferAvailable() int { + e.rcvListMu.Lock() + size := e.rcvBufSize + used := e.rcvBufUsed + e.rcvListMu.Unlock() + + // We may use more bytes than the buffer size when the receive buffer + // shrinks. + if used >= size { + return 0 + } + + return size - used +} + +func (e *endpoint) receiveBufferSize() int { + e.rcvListMu.Lock() + size := e.rcvBufSize + e.rcvListMu.Unlock() + + return size +} + +// updateRecentTimestamp updates the recent timestamp using the algorithm +// described in https://tools.ietf.org/html/rfc7323#section-4.3 +func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { + if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + e.recentTS = tsVal + } +} + +// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if +// the SYN options indicate that timestamp option was negotiated. It also +// initializes the recentTS with the value provided in synOpts.TSval. +func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { + if synOpts.TS { + e.sendTSOk = true + e.recentTS = synOpts.TSVal + } +} + +// timestamp returns the timestamp value to be used in the TSVal field of the +// timestamp option for outgoing TCP segments for a given endpoint. +func (e *endpoint) timestamp() uint32 { + return tcpTimeStamp(e.tsOffset) +} + +// tcpTimeStamp returns a timestamp offset by the provided offset. This is +// not inlined above as it's used when SYN cookies are in use and endpoint +// is not created at the time when the SYN cookie is sent. +func tcpTimeStamp(offset uint32) uint32 { + now := time.Now() + return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset +} + +// timeStampOffset returns a randomized timestamp offset to be used when sending +// timestamp values in a timestamp option for a TCP segment. +func timeStampOffset() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + // Initialize a random tsOffset that will be added to the recentTS + // everytime the timestamp is sent when the Timestamp option is enabled. + // + // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on + // why this is required. + // + // NOTE: This is not completely to spec as normally this should be + // initialized in a manner analogous to how sequence numbers are + // randomized per connection basis. But for now this is sufficient. + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint +// if the SYN options indicate that the SACK option was negotiated and the TCP +// stack is configured to enable TCP SACK option. +func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { + var v SACKEnabled + if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { + // Stack doesn't support SACK. So just return. + return + } + if bool(v) && synOpts.SACKPermitted { + e.sackPermitted = true + } +} + +// maxOptionSize return the maximum size of TCP options. +func (e *endpoint) maxOptionSize() (size int) { + var maxSackBlocks [header.TCPMaxSACKBlocks]header.SACKBlock + options := e.makeOptions(maxSackBlocks[:]) + size = len(options) + putOptions(options) + + return size +} + +// completeState makes a full copy of the endpoint and returns it. This is used +// before invoking the probe. The state returned may not be fully consistent if +// there are intervening syscalls when the state is being copied. +func (e *endpoint) completeState() stack.TCPEndpointState { + var s stack.TCPEndpointState + s.SegTime = time.Now() + + // Copy EndpointID. + e.mu.Lock() + s.ID = stack.TCPEndpointID(e.id) + e.mu.Unlock() + + // Copy endpoint rcv state. + e.rcvListMu.Lock() + s.RcvBufSize = e.rcvBufSize + s.RcvBufUsed = e.rcvBufUsed + s.RcvClosed = e.rcvClosed + e.rcvListMu.Unlock() + + // Endpoint TCP Option state. + s.SendTSOk = e.sendTSOk + s.RecentTS = e.recentTS + s.TSOffset = e.tsOffset + s.SACKPermitted = e.sackPermitted + s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) + copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks]) + s.SACK.ReceivedBlocks, s.SACK.MaxSACKED = e.scoreboard.Copy() + + // Copy endpoint send state. + e.sndBufMu.Lock() + s.SndBufSize = e.sndBufSize + s.SndBufUsed = e.sndBufUsed + s.SndClosed = e.sndClosed + s.SndBufInQueue = e.sndBufInQueue + s.PacketTooBigCount = e.packetTooBigCount + s.SndMTU = e.sndMTU + e.sndBufMu.Unlock() + + // Copy receiver state. + s.Receiver = stack.TCPReceiverState{ + RcvNxt: e.rcv.rcvNxt, + RcvAcc: e.rcv.rcvAcc, + RcvWndScale: e.rcv.rcvWndScale, + PendingBufUsed: e.rcv.pendingBufUsed, + PendingBufSize: e.rcv.pendingBufSize, + } + + // Copy sender state. + s.Sender = stack.TCPSenderState{ + LastSendTime: e.snd.lastSendTime, + DupAckCount: e.snd.dupAckCount, + FastRecovery: stack.TCPFastRecoveryState{ + Active: e.snd.fr.active, + First: e.snd.fr.first, + Last: e.snd.fr.last, + MaxCwnd: e.snd.fr.maxCwnd, + HighRxt: e.snd.fr.highRxt, + RescueRxt: e.snd.fr.rescueRxt, + }, + SndCwnd: e.snd.sndCwnd, + Ssthresh: e.snd.sndSsthresh, + SndCAAckCount: e.snd.sndCAAckCount, + Outstanding: e.snd.outstanding, + SndWnd: e.snd.sndWnd, + SndUna: e.snd.sndUna, + SndNxt: e.snd.sndNxt, + RTTMeasureSeqNum: e.snd.rttMeasureSeqNum, + RTTMeasureTime: e.snd.rttMeasureTime, + Closed: e.snd.closed, + RTO: e.snd.rto, + SRTTInited: e.snd.srttInited, + MaxPayloadSize: e.snd.maxPayloadSize, + SndWndScale: e.snd.sndWndScale, + MaxSentAck: e.snd.maxSentAck, + } + e.snd.rtt.Lock() + s.Sender.SRTT = e.snd.rtt.srtt + e.snd.rtt.Unlock() + + if cubic, ok := e.snd.cc.(*cubicState); ok { + s.Sender.Cubic = stack.TCPCubicState{ + WMax: cubic.wMax, + WLastMax: cubic.wLastMax, + T: cubic.t, + TimeSinceLastCongestion: time.Since(cubic.t), + C: cubic.c, + K: cubic.k, + Beta: cubic.beta, + WC: cubic.wC, + WEst: cubic.wEst, + } + } + return s +} + +func (e *endpoint) initGSO() { + if e.route.Capabilities()&stack.CapabilityGSO == 0 { + return + } + + gso := &stack.GSO{} + switch e.route.NetProto { + case header.IPv4ProtocolNumber: + gso.Type = stack.GSOTCPv4 + gso.L3HdrLen = header.IPv4MinimumSize + case header.IPv6ProtocolNumber: + gso.Type = stack.GSOTCPv6 + gso.L3HdrLen = header.IPv6MinimumSize + default: + panic(fmt.Sprintf("Unknown netProto: %v", e.netProto)) + } + gso.NeedsCsum = true + gso.CsumOffset = header.TCPChecksumOffset + gso.MaxSize = e.route.GSOMaxSize() + e.gso = gso +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go new file mode 100644 index 000000000..e8aed2875 --- /dev/null +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -0,0 +1,362 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "fmt" + "sync" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +func (e *endpoint) drainSegmentLocked() { + // Drain only up to once. + if e.drainDone != nil { + return + } + + e.drainDone = make(chan struct{}) + e.undrain = make(chan struct{}) + e.mu.Unlock() + + e.notifyProtocolGoroutine(notifyDrain) + <-e.drainDone + + e.mu.Lock() +} + +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + // Stop incoming packets. + e.segmentQueue.setLimit(0) + + e.mu.Lock() + defer e.mu.Unlock() + + switch e.state { + case stateInitial, stateBound: + case stateConnected: + if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { + if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { + panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) + } + e.resetConnectionLocked(tcpip.ErrConnectionAborted) + e.mu.Unlock() + e.Close() + e.mu.Lock() + } + if !e.workerRunning { + // The endpoint must be in acceptedChan or has been just + // disconnected and closed. + break + } + fallthrough + case stateListen, stateConnecting: + e.drainSegmentLocked() + if e.state != stateClosed && e.state != stateError { + if !e.workerRunning { + panic("endpoint has no worker running in listen, connecting, or connected state") + } + break + } + fallthrough + case stateError, stateClosed: + for e.state == stateError && e.workerRunning { + e.mu.Unlock() + time.Sleep(100 * time.Millisecond) + e.mu.Lock() + } + if e.workerRunning { + panic("endpoint still has worker running in closed or error state") + } + default: + panic(fmt.Sprintf("endpoint in unknown state %v", e.state)) + } + + if e.waiterQueue != nil && !e.waiterQueue.IsEmpty() { + panic("endpoint still has waiters upon save") + } + + if e.state != stateClosed && !((e.state == stateBound || e.state == stateListen) == e.isPortReserved) { + panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") + } +} + +// saveAcceptedChan is invoked by stateify. +func (e *endpoint) saveAcceptedChan() []*endpoint { + if e.acceptedChan == nil { + return nil + } + acceptedEndpoints := make([]*endpoint, len(e.acceptedChan), cap(e.acceptedChan)) + for i := 0; i < len(acceptedEndpoints); i++ { + select { + case ep := <-e.acceptedChan: + acceptedEndpoints[i] = ep + default: + panic("endpoint acceptedChan buffer got consumed by background context") + } + } + for i := 0; i < len(acceptedEndpoints); i++ { + select { + case e.acceptedChan <- acceptedEndpoints[i]: + default: + panic("endpoint acceptedChan buffer got populated by background context") + } + } + return acceptedEndpoints +} + +// loadAcceptedChan is invoked by stateify. +func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { + if cap(acceptedEndpoints) > 0 { + e.acceptedChan = make(chan *endpoint, cap(acceptedEndpoints)) + for _, ep := range acceptedEndpoints { + e.acceptedChan <- ep + } + } +} + +// saveState is invoked by stateify. +func (e *endpoint) saveState() endpointState { + return e.state +} + +// Endpoint loading must be done in the following ordering by their state, to +// avoid dangling connecting w/o listening peer, and to avoid conflicts in port +// reservation. +var connectedLoading sync.WaitGroup +var listenLoading sync.WaitGroup +var connectingLoading sync.WaitGroup + +// Bound endpoint loading happens last. + +// loadState is invoked by stateify. +func (e *endpoint) loadState(state endpointState) { + // This is to ensure that the loading wait groups include all applicable + // endpoints before any asynchronous calls to the Wait() methods. + switch state { + case stateConnected: + connectedLoading.Add(1) + case stateListen: + listenLoading.Add(1) + case stateConnecting: + connectingLoading.Add(1) + } + e.state = state +} + +// afterLoad is invoked by stateify. +func (e *endpoint) afterLoad() { + e.stack = stack.StackFromEnv + e.segmentQueue.setLimit(MaxUnprocessedSegments) + e.workMu.Init() + + state := e.state + switch state { + case stateInitial, stateBound, stateListen, stateConnecting, stateConnected: + var ss SendBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { + panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) + } + if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { + panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) + } + } + } + + bind := func() { + e.state = stateInitial + if len(e.bindAddress) == 0 { + e.bindAddress = e.id.LocalAddress + } + if err := e.Bind(tcpip.FullAddress{Addr: e.bindAddress, Port: e.id.LocalPort}); err != nil { + panic("endpoint binding failed: " + err.String()) + } + } + + switch state { + case stateConnected: + bind() + if len(e.connectingAddress) == 0 { + // This endpoint is accepted by netstack but not yet by + // the app. If the endpoint is IPv6 but the remote + // address is IPv4, we need to connect as IPv6 so that + // dual-stack mode can be properly activated. + if e.netProto == header.IPv6ProtocolNumber && len(e.id.RemoteAddress) != header.IPv6AddressSize { + e.connectingAddress = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + e.id.RemoteAddress + } else { + e.connectingAddress = e.id.RemoteAddress + } + } + // Reset the scoreboard to reinitialize the sack information as + // we do not restore SACK information. + e.scoreboard.Reset() + if err := e.connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}, false, e.workerRunning); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + connectedLoading.Done() + case stateListen: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + bind() + backlog := cap(e.acceptedChan) + if err := e.Listen(backlog); err != nil { + panic("endpoint listening failed: " + err.String()) + } + listenLoading.Done() + tcpip.AsyncLoading.Done() + }() + case stateConnecting: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + bind() + if err := e.Connect(tcpip.FullAddress{NIC: e.boundNICID, Addr: e.connectingAddress, Port: e.id.RemotePort}); err != tcpip.ErrConnectStarted { + panic("endpoint connecting failed: " + err.String()) + } + connectingLoading.Done() + tcpip.AsyncLoading.Done() + }() + case stateBound: + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + connectingLoading.Wait() + bind() + tcpip.AsyncLoading.Done() + }() + case stateClosed: + if e.isPortReserved { + tcpip.AsyncLoading.Add(1) + go func() { + connectedLoading.Wait() + listenLoading.Wait() + connectingLoading.Wait() + bind() + e.state = stateClosed + tcpip.AsyncLoading.Done() + }() + } + fallthrough + case stateError: + tcpip.DeleteDanglingEndpoint(e) + } +} + +// saveLastError is invoked by stateify. +func (e *endpoint) saveLastError() string { + if e.lastError == nil { + return "" + } + + return e.lastError.String() +} + +// loadLastError is invoked by stateify. +func (e *endpoint) loadLastError(s string) { + if s == "" { + return + } + + e.lastError = loadError(s) +} + +// saveHardError is invoked by stateify. +func (e *endpoint) saveHardError() string { + if e.hardError == nil { + return "" + } + + return e.hardError.String() +} + +// loadHardError is invoked by stateify. +func (e *endpoint) loadHardError(s string) { + if s == "" { + return + } + + e.hardError = loadError(s) +} + +var messageToError map[string]*tcpip.Error + +var populate sync.Once + +func loadError(s string) *tcpip.Error { + populate.Do(func() { + var errors = []*tcpip.Error{ + tcpip.ErrUnknownProtocol, + tcpip.ErrUnknownNICID, + tcpip.ErrUnknownDevice, + tcpip.ErrUnknownProtocolOption, + tcpip.ErrDuplicateNICID, + tcpip.ErrDuplicateAddress, + tcpip.ErrNoRoute, + tcpip.ErrBadLinkEndpoint, + tcpip.ErrAlreadyBound, + tcpip.ErrInvalidEndpointState, + tcpip.ErrAlreadyConnecting, + tcpip.ErrAlreadyConnected, + tcpip.ErrNoPortAvailable, + tcpip.ErrPortInUse, + tcpip.ErrBadLocalAddress, + tcpip.ErrClosedForSend, + tcpip.ErrClosedForReceive, + tcpip.ErrWouldBlock, + tcpip.ErrConnectionRefused, + tcpip.ErrTimeout, + tcpip.ErrAborted, + tcpip.ErrConnectStarted, + tcpip.ErrDestinationRequired, + tcpip.ErrNotSupported, + tcpip.ErrQueueSizeNotSupported, + tcpip.ErrNotConnected, + tcpip.ErrConnectionReset, + tcpip.ErrConnectionAborted, + tcpip.ErrNoSuchFile, + tcpip.ErrInvalidOptionValue, + tcpip.ErrNoLinkAddress, + tcpip.ErrBadAddress, + tcpip.ErrNetworkUnreachable, + tcpip.ErrMessageTooLong, + tcpip.ErrNoBufferSpace, + tcpip.ErrBroadcastDisabled, + tcpip.ErrNotPermitted, + } + + messageToError = make(map[string]*tcpip.Error) + for _, e := range errors { + if messageToError[e.String()] != nil { + panic("tcpip errors with duplicated message: " + e.String()) + } + messageToError[e.String()] = e + } + }) + + e, ok := messageToError[s] + if !ok { + panic("unknown error message: " + s) + } + + return e +} diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go new file mode 100644 index 000000000..c30b45c2c --- /dev/null +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -0,0 +1,171 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// Forwarder is a connection request forwarder, which allows clients to decide +// what to do with a connection request, for example: ignore it, send a RST, or +// attempt to complete the 3-way handshake. +// +// The canonical way of using it is to pass the Forwarder.HandlePacket function +// to stack.SetTransportProtocolHandler. +type Forwarder struct { + maxInFlight int + handler func(*ForwarderRequest) + + mu sync.Mutex + inFlight map[stack.TransportEndpointID]struct{} + listen *listenContext +} + +// NewForwarder allocates and initializes a new forwarder with the given +// maximum number of in-flight connection attempts. Once the maximum is reached +// new incoming connection requests will be ignored. +// +// If rcvWnd is set to zero, the default buffer size is used instead. +func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder { + if rcvWnd == 0 { + rcvWnd = DefaultBufferSize + } + return &Forwarder{ + maxInFlight: maxInFlight, + handler: handler, + inFlight: make(map[stack.TransportEndpointID]struct{}), + listen: newListenContext(s, nil /* listenEP */, seqnum.Size(rcvWnd), true, 0), + } +} + +// HandlePacket handles a packet if it is of interest to the forwarder (i.e., if +// it's a SYN packet), returning true if it's the case. Otherwise the packet +// is not handled and false is returned. +// +// This function is expected to be passed as an argument to the +// stack.SetTransportProtocolHandler function. +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { + s := newSegment(r, id, vv) + defer s.decRef() + + // We only care about well-formed SYN packets. + if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn { + return false + } + + opts := parseSynSegmentOptions(s) + + f.mu.Lock() + defer f.mu.Unlock() + + // We have an inflight request for this id, ignore this one for now. + if _, ok := f.inFlight[id]; ok { + return true + } + + // Ignore the segment if we're beyond the limit. + if len(f.inFlight) >= f.maxInFlight { + return true + } + + // Launch a new goroutine to handle the request. + f.inFlight[id] = struct{}{} + s.incRef() + go f.handler(&ForwarderRequest{ // S/R-SAFE: not used by Sentry. + forwarder: f, + segment: s, + synOptions: opts, + }) + + return true +} + +// ForwarderRequest represents a connection request received by the forwarder +// and passed to the client. Clients must eventually call Complete() on it, and +// may optionally create an endpoint to represent it via CreateEndpoint. +type ForwarderRequest struct { + mu sync.Mutex + forwarder *Forwarder + segment *segment + synOptions header.TCPSynOptions +} + +// ID returns the 4-tuple (src address, src port, dst address, dst port) that +// represents the connection request. +func (r *ForwarderRequest) ID() stack.TransportEndpointID { + return r.segment.id +} + +// Complete completes the request, and optionally sends a RST segment back to the +// sender. +func (r *ForwarderRequest) Complete(sendReset bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.segment == nil { + panic("Completing already completed forwarder request") + } + + // Remove request from the forwarder. + r.forwarder.mu.Lock() + delete(r.forwarder.inFlight, r.segment.id) + r.forwarder.mu.Unlock() + + // If the caller requested, send a reset. + if sendReset { + replyWithReset(r.segment) + } + + // Release all resources. + r.segment.decRef() + r.segment = nil + r.forwarder = nil +} + +// CreateEndpoint creates a TCP endpoint for the connection request, performing +// the 3-way handshake in the process. +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.segment == nil { + return nil, tcpip.ErrInvalidEndpointState + } + + f := r.forwarder + ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{ + MSS: r.synOptions.MSS, + WS: r.synOptions.WS, + TS: r.synOptions.TS, + TSVal: r.synOptions.TSVal, + TSEcr: r.synOptions.TSEcr, + SACKPermitted: r.synOptions.SACKPermitted, + }) + if err != nil { + return nil, err + } + + // Start the protocol goroutine. + ep.startAcceptedLoop(queue) + + return ep, nil +} diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go new file mode 100644 index 000000000..b31bcccfa --- /dev/null +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -0,0 +1,250 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tcp contains the implementation of the TCP transport protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the +// transport protocols when calling stack.New(). Then endpoints can be created +// by passing tcp.ProtocolNumber as the transport protocol number when calling +// Stack.NewEndpoint(). +package tcp + +import ( + "strings" + "sync" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + // ProtocolName is the string representation of the tcp protocol name. + ProtocolName = "tcp" + + // ProtocolNumber is the tcp protocol number. + ProtocolNumber = header.TCPProtocolNumber + + // MinBufferSize is the smallest size of a receive or send buffer. + minBufferSize = 4 << 10 // 4096 bytes. + + // DefaultBufferSize is the default size of the receive and send buffers. + DefaultBufferSize = 1 << 20 // 1MB + + // MaxBufferSize is the largest size a receive and send buffer can grow to. + maxBufferSize = 4 << 20 // 4MB + + // MaxUnprocessedSegments is the maximum number of unprocessed segments + // that can be queued for a given endpoint. + MaxUnprocessedSegments = 300 +) + +// SACKEnabled option can be used to enable SACK support in the TCP +// protocol. See: https://tools.ietf.org/html/rfc2018. +type SACKEnabled bool + +// SendBufferSizeOption allows the default, min and max send buffer sizes for +// TCP endpoints to be queried or configured. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption allows the default, min and max receive buffer size +// for TCP endpoints to be queried or configured. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + +const ( + ccReno = "reno" + ccCubic = "cubic" +) + +// CongestionControlOption sets the current congestion control algorithm. +type CongestionControlOption string + +// AvailableCongestionControlOption returns the supported congestion control +// algorithms. +type AvailableCongestionControlOption string + +type protocol struct { + mu sync.Mutex + sackEnabled bool + sendBufferSize SendBufferSizeOption + recvBufferSize ReceiveBufferSizeOption + congestionControl string + availableCongestionControl []string + allowedCongestionControl []string +} + +// Number returns the tcp protocol number. +func (*protocol) Number() tcpip.TransportProtocolNumber { + return ProtocolNumber +} + +// NewEndpoint creates a new tcp endpoint. +func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, waiterQueue), nil +} + +// NewRawEndpoint creates a new raw TCP endpoint. Raw TCP sockets are currently +// unsupported. It implements stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(stack, netProto, header.TCPProtocolNumber, waiterQueue) +} + +// MinimumPacketSize returns the minimum valid tcp packet size. +func (*protocol) MinimumPacketSize() int { + return header.TCPMinimumSize +} + +// ParsePorts returns the source and destination ports stored in the given tcp +// packet. +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + h := header.TCP(v) + return h.SourcePort(), h.DestinationPort(), nil +} + +// HandleUnknownDestinationPacket handles packets targeted at this protocol but +// that don't match any existing endpoint. +// +// RFC 793, page 36, states that "If the connection does not exist (CLOSED) then +// a reset is sent in response to any incoming segment except another reset. In +// particular, SYNs addressed to a non-existent connection are rejected by this +// means." +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) bool { + s := newSegment(r, id, vv) + defer s.decRef() + + if !s.parse() || !s.csumValid { + return false + } + + // There's nothing to do if this is already a reset packet. + if s.flagIsSet(header.TCPFlagRst) { + return true + } + + replyWithReset(s) + return true +} + +// replyWithReset replies to the given segment with a reset segment. +func replyWithReset(s *segment) { + // Get the seqnum from the packet if the ack flag is set. + seq := seqnum.Value(0) + if s.flagIsSet(header.TCPFlagAck) { + seq = s.ackNumber + } + + ack := s.sequenceNumber.Add(s.logicalLen()) + + sendTCP(&s.route, s.id, buffer.VectorisedView{}, s.route.DefaultTTL(), header.TCPFlagRst|header.TCPFlagAck, seq, ack, 0, nil /* options */, nil /* gso */) +} + +// SetOption implements TransportProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case SACKEnabled: + p.mu.Lock() + p.sackEnabled = bool(v) + p.mu.Unlock() + return nil + + case SendBufferSizeOption: + if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.sendBufferSize = v + p.mu.Unlock() + return nil + + case ReceiveBufferSizeOption: + if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.recvBufferSize = v + p.mu.Unlock() + return nil + + case CongestionControlOption: + for _, c := range p.availableCongestionControl { + if string(v) == c { + p.mu.Lock() + p.congestionControl = string(v) + p.mu.Unlock() + return nil + } + } + return tcpip.ErrInvalidOptionValue + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// Option implements TransportProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *SACKEnabled: + p.mu.Lock() + *v = SACKEnabled(p.sackEnabled) + p.mu.Unlock() + return nil + + case *SendBufferSizeOption: + p.mu.Lock() + *v = p.sendBufferSize + p.mu.Unlock() + return nil + + case *ReceiveBufferSizeOption: + p.mu.Lock() + *v = p.recvBufferSize + p.mu.Unlock() + return nil + case *CongestionControlOption: + p.mu.Lock() + *v = CongestionControlOption(p.congestionControl) + p.mu.Unlock() + return nil + case *AvailableCongestionControlOption: + p.mu.Lock() + *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + p.mu.Unlock() + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func init() { + stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { + return &protocol{ + sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, + recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, + congestionControl: ccReno, + availableCongestionControl: []string{ccReno, ccCubic}, + } + }) +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go new file mode 100644 index 000000000..b08a0e356 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -0,0 +1,221 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "container/heap" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" +) + +// receiver holds the state necessary to receive TCP segments and turn them +// into a stream of bytes. +// +// +stateify savable +type receiver struct { + ep *endpoint + + rcvNxt seqnum.Value + + // rcvAcc is one beyond the last acceptable sequence number. That is, + // the "largest" sequence value that the receiver has announced to the + // its peer that it's willing to accept. This may be different than + // rcvNxt + rcvWnd if the receive window is reduced; in that case we + // have to reduce the window as we receive more data instead of + // shrinking it. + rcvAcc seqnum.Value + + rcvWndScale uint8 + + closed bool + + pendingRcvdSegments segmentHeap + pendingBufUsed seqnum.Size + pendingBufSize seqnum.Size +} + +func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { + return &receiver{ + ep: ep, + rcvNxt: irs + 1, + rcvAcc: irs.Add(rcvWnd + 1), + rcvWndScale: rcvWndScale, + pendingBufSize: rcvWnd, + } +} + +// acceptable checks if the segment sequence number range is acceptable +// according to the table on page 26 of RFC 793. +func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { + rcvWnd := r.rcvNxt.Size(r.rcvAcc) + if rcvWnd == 0 { + return segLen == 0 && segSeq == r.rcvNxt + } + + return segSeq.InWindow(r.rcvNxt, rcvWnd) || + seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen) +} + +// getSendParams returns the parameters needed by the sender when building +// segments to send. +func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { + // Calculate the window size based on the current buffer size. + n := r.ep.receiveBufferAvailable() + acc := r.rcvNxt.Add(seqnum.Size(n)) + if r.rcvAcc.LessThan(acc) { + r.rcvAcc = acc + } + + return r.rcvNxt, r.rcvNxt.Size(r.rcvAcc) >> r.rcvWndScale +} + +// nonZeroWindow is called when the receive window grows from zero to nonzero; +// in such cases we may need to send an ack to indicate to our peer that it can +// resume sending data. +func (r *receiver) nonZeroWindow() { + if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 { + // We never got around to announcing a zero window size, so we + // don't need to immediately announce a nonzero one. + return + } + + // Immediately send an ack. + r.ep.snd.sendAck() +} + +// consumeSegment attempts to consume a segment that was received by r. The +// segment may have just been received or may have been received earlier but +// wasn't ready to be consumed then. +// +// Returns true if the segment was consumed, false if it cannot be consumed +// yet because of a missing segment. +func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum.Size) bool { + if segLen > 0 { + // If the segment doesn't include the seqnum we're expecting to + // consume now, we're missing a segment. We cannot proceed until + // we receive that segment though. + if !r.rcvNxt.InWindow(segSeq, segLen) { + return false + } + + // Trim segment to eliminate already acknowledged data. + if segSeq.LessThan(r.rcvNxt) { + diff := segSeq.Size(r.rcvNxt) + segLen -= diff + segSeq.UpdateForward(diff) + s.sequenceNumber.UpdateForward(diff) + s.data.TrimFront(int(diff)) + } + + // Move segment to ready-to-deliver list. Wakeup any waiters. + r.ep.readyToRead(s) + + } else if segSeq != r.rcvNxt { + return false + } + + // Update the segment that we're expecting to consume. + r.rcvNxt = segSeq.Add(segLen) + + // Trim SACK Blocks to remove any SACK information that covers + // sequence numbers that have been consumed. + TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + + if s.flagIsSet(header.TCPFlagFin) { + r.rcvNxt++ + + // Send ACK immediately. + r.ep.snd.sendAck() + + // Tell any readers that no more data will come. + r.closed = true + r.ep.readyToRead(nil) + + // Flush out any pending segments, except the very first one if + // it happens to be the one we're handling now because the + // caller is using it. + first := 0 + if len(r.pendingRcvdSegments) != 0 && r.pendingRcvdSegments[0] == s { + first = 1 + } + + for i := first; i < len(r.pendingRcvdSegments); i++ { + r.pendingRcvdSegments[i].decRef() + } + r.pendingRcvdSegments = r.pendingRcvdSegments[:first] + } + + return true +} + +// handleRcvdSegment handles TCP segments directed at the connection managed by +// r as they arrive. It is called by the protocol main loop. +func (r *receiver) handleRcvdSegment(s *segment) { + // We don't care about receive processing anymore if the receive side + // is closed. + if r.closed { + return + } + + segLen := seqnum.Size(s.data.Size()) + segSeq := s.sequenceNumber + + // If the sequence number range is outside the acceptable range, just + // send an ACK. This is according to RFC 793, page 37. + if !r.acceptable(segSeq, segLen) { + r.ep.snd.sendAck() + return + } + + // Defer segment processing if it can't be consumed now. + if !r.consumeSegment(s, segSeq, segLen) { + if segLen > 0 || s.flagIsSet(header.TCPFlagFin) { + // We only store the segment if it's within our buffer + // size limit. + if r.pendingBufUsed < r.pendingBufSize { + r.pendingBufUsed += s.logicalLen() + s.incRef() + heap.Push(&r.pendingRcvdSegments, s) + } + + UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) + + // Immediately send an ack so that the peer knows it may + // have to retransmit. + r.ep.snd.sendAck() + } + return + } + + // By consuming the current segment, we may have filled a gap in the + // sequence number domain that allows pending segments to be consumed + // now. So try to do it. + for !r.closed && r.pendingRcvdSegments.Len() > 0 { + s := r.pendingRcvdSegments[0] + segLen := seqnum.Size(s.data.Size()) + segSeq := s.sequenceNumber + + // Skip segment altogether if it has already been acknowledged. + if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) && + !r.consumeSegment(s, segSeq, segLen) { + break + } + + heap.Pop(&r.pendingRcvdSegments) + r.pendingBufUsed -= s.logicalLen() + s.decRef() + } +} diff --git a/pkg/tcpip/transport/tcp/reno.go b/pkg/tcpip/transport/tcp/reno.go new file mode 100644 index 000000000..f83ebc717 --- /dev/null +++ b/pkg/tcpip/transport/tcp/reno.go @@ -0,0 +1,103 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +// renoState stores the variables related to TCP New Reno congestion +// control algorithm. +// +// +stateify savable +type renoState struct { + s *sender +} + +// newRenoCC initializes the state for the NewReno congestion control algorithm. +func newRenoCC(s *sender) *renoState { + return &renoState{s: s} +} + +// updateSlowStart will update the congestion window as per the slow-start +// algorithm used by NewReno. If after adjusting the congestion window +// we cross the SSthreshold then it will return the number of packets that +// must be consumed in congestion avoidance mode. +func (r *renoState) updateSlowStart(packetsAcked int) int { + // Don't let the congestion window cross into the congestion + // avoidance range. + newcwnd := r.s.sndCwnd + packetsAcked + if newcwnd >= r.s.sndSsthresh { + newcwnd = r.s.sndSsthresh + r.s.sndCAAckCount = 0 + } + + packetsAcked -= newcwnd - r.s.sndCwnd + r.s.sndCwnd = newcwnd + return packetsAcked +} + +// updateCongestionAvoidance will update congestion window in congestion +// avoidance mode as described in RFC5681 section 3.1 +func (r *renoState) updateCongestionAvoidance(packetsAcked int) { + // Consume the packets in congestion avoidance mode. + r.s.sndCAAckCount += packetsAcked + if r.s.sndCAAckCount >= r.s.sndCwnd { + r.s.sndCwnd += r.s.sndCAAckCount / r.s.sndCwnd + r.s.sndCAAckCount = r.s.sndCAAckCount % r.s.sndCwnd + } +} + +// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681, +// page 6, eq. 4. It is called when we detect congestion in the network. +func (r *renoState) reduceSlowStartThreshold() { + r.s.sndSsthresh = r.s.outstanding / 2 + if r.s.sndSsthresh < 2 { + r.s.sndSsthresh = 2 + } + +} + +// Update updates the congestion state based on the number of packets that +// were acknowledged. +// Update implements congestionControl.Update. +func (r *renoState) Update(packetsAcked int) { + if r.s.sndCwnd < r.s.sndSsthresh { + packetsAcked = r.updateSlowStart(packetsAcked) + if packetsAcked == 0 { + return + } + } + r.updateCongestionAvoidance(packetsAcked) +} + +// HandleNDupAcks implements congestionControl.HandleNDupAcks. +func (r *renoState) HandleNDupAcks() { + // A retransmit was triggered due to nDupAckThreshold + // being hit. Reduce our slow start threshold. + r.reduceSlowStartThreshold() +} + +// HandleRTOExpired implements congestionControl.HandleRTOExpired. +func (r *renoState) HandleRTOExpired() { + // We lost a packet, so reduce ssthresh. + r.reduceSlowStartThreshold() + + // Reduce the congestion window to 1, i.e., enter slow-start. Per + // RFC 5681, page 7, we must use 1 regardless of the value of the + // initial congestion window. + r.s.sndCwnd = 1 +} + +// PostRecovery implements congestionControl.PostRecovery. +func (r *renoState) PostRecovery() { + // noop. +} diff --git a/pkg/tcpip/transport/tcp/sack.go b/pkg/tcpip/transport/tcp/sack.go new file mode 100644 index 000000000..6a013d99b --- /dev/null +++ b/pkg/tcpip/transport/tcp/sack.go @@ -0,0 +1,99 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" +) + +const ( + // MaxSACKBlocks is the maximum number of SACK blocks stored + // at receiver side. + MaxSACKBlocks = 6 +) + +// UpdateSACKBlocks updates the list of SACK blocks to include the segment +// specified by segStart->segEnd. If the segment happens to be an out of order +// delivery then the first block in the sack.blocks always includes the +// segment identified by segStart->segEnd. +func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) { + newSB := header.SACKBlock{Start: segStart, End: segEnd} + if sack.NumBlocks == 0 { + sack.Blocks[0] = newSB + sack.NumBlocks = 1 + return + } + var n = 0 + for i := 0; i < sack.NumBlocks; i++ { + start, end := sack.Blocks[i].Start, sack.Blocks[i].End + if end.LessThanEq(start) || start.LessThanEq(rcvNxt) { + // Discard any invalid blocks where end is before start + // and discard any sack blocks that are before rcvNxt as + // those have already been acked. + continue + } + if newSB.Start.LessThanEq(end) && start.LessThanEq(newSB.End) { + // Merge this SACK block into newSB and discard this SACK + // block. + if start.LessThan(newSB.Start) { + newSB.Start = start + } + if newSB.End.LessThan(end) { + newSB.End = end + } + } else { + // Save this block. + sack.Blocks[n] = sack.Blocks[i] + n++ + } + } + if rcvNxt.LessThan(newSB.Start) { + // If this was an out of order segment then make sure that the + // first SACK block is the one that includes the segment. + // + // See the first bullet point in + // https://tools.ietf.org/html/rfc2018#section-4 + if n == MaxSACKBlocks { + // If the number of SACK blocks is equal to + // MaxSACKBlocks then discard the last SACK block. + n-- + } + for i := n - 1; i >= 0; i-- { + sack.Blocks[i+1] = sack.Blocks[i] + } + sack.Blocks[0] = newSB + n++ + } + sack.NumBlocks = n +} + +// TrimSACKBlockList updates the sack block list by removing/modifying any block +// where start is < rcvNxt. +func TrimSACKBlockList(sack *SACKInfo, rcvNxt seqnum.Value) { + n := 0 + for i := 0; i < sack.NumBlocks; i++ { + if sack.Blocks[i].End.LessThanEq(rcvNxt) { + continue + } + if sack.Blocks[i].Start.LessThan(rcvNxt) { + // Shrink this SACK block. + sack.Blocks[i].Start = rcvNxt + } + sack.Blocks[n] = sack.Blocks[i] + n++ + } + sack.NumBlocks = n +} diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go new file mode 100644 index 000000000..1c5766a42 --- /dev/null +++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go @@ -0,0 +1,306 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "fmt" + "strings" + + "github.com/google/btree" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" +) + +const ( + // maxSACKBlocks is the maximum number of distinct SACKBlocks the + // scoreboard will track. Once there are 100 distinct blocks, new + // insertions will fail. + maxSACKBlocks = 100 + + // defaultBtreeDegree is set to 2 as btree.New(2) results in a 2-3-4 + // tree. + defaultBtreeDegree = 2 +) + +// SACKScoreboard stores a set of disjoint SACK ranges. +// +// +stateify savable +type SACKScoreboard struct { + // smss is defined in RFC5681 as following: + // + // The SMSS is the size of the largest segment that the sender can + // transmit. This value can be based on the maximum transmission unit + // of the network, the path MTU discovery [RFC1191, RFC4821] algorithm, + // RMSS (see next item), or other factors. The size does not include + // the TCP/IP headers and options. + smss uint16 + maxSACKED seqnum.Value + sacked seqnum.Size `state:"nosave"` + ranges *btree.BTree `state:"nosave"` +} + +// NewSACKScoreboard returns a new SACK Scoreboard. +func NewSACKScoreboard(smss uint16, iss seqnum.Value) *SACKScoreboard { + return &SACKScoreboard{ + smss: smss, + ranges: btree.New(defaultBtreeDegree), + maxSACKED: iss, + } +} + +// Reset erases all known range information from the SACK scoreboard. +func (s *SACKScoreboard) Reset() { + s.ranges = btree.New(defaultBtreeDegree) + s.sacked = 0 +} + +// Insert inserts/merges the provided SACKBlock into the scoreboard. +func (s *SACKScoreboard) Insert(r header.SACKBlock) { + if s.ranges.Len() >= maxSACKBlocks { + return + } + + // Check if we can merge the new range with a range before or after it. + var toDelete []btree.Item + if s.maxSACKED.LessThan(r.End - 1) { + s.maxSACKED = r.End - 1 + } + s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool { + if i == r { + return true + } + sacked := i.(header.SACKBlock) + // There is a hole between these two SACK blocks, so we can't + // merge anymore. + if r.End.LessThan(sacked.Start) { + return false + } + // There is some overlap at this point, merge the blocks and + // delete the other one. + // + // ----sS--------sE + // r.S---------------rE + // -------sE + if sacked.End.LessThan(r.End) { + // sacked is contained in the newly inserted range. + // Delete this block. + toDelete = append(toDelete, i) + return true + } + // sacked covers a range past end of the newly inserted + // block. + r.End = sacked.End + toDelete = append(toDelete, i) + return true + }) + + s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool { + if i == r { + return true + } + sacked := i.(header.SACKBlock) + // sA------sE + // rA----rE + if sacked.End.LessThan(r.Start) { + return false + } + // The previous range extends into the current block. Merge it + // into the newly inserted range and delete the other one. + // + // <-rA---rE----<---rE---> + // sA--------------sE + r.Start = sacked.Start + // Extend r to cover sacked if sacked extends past r. + if r.End.LessThan(sacked.End) { + r.End = sacked.End + } + toDelete = append(toDelete, i) + return true + }) + for _, i := range toDelete { + if sb := s.ranges.Delete(i); sb != nil { + sb := i.(header.SACKBlock) + s.sacked -= sb.Start.Size(sb.End) + } + } + + replaced := s.ranges.ReplaceOrInsert(r) + if replaced == nil { + s.sacked += r.Start.Size(r.End) + } +} + +// IsSACKED returns true if the a given range of sequence numbers denoted by r +// are already covered by SACK information in the scoreboard. +func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool { + if s.Empty() { + return false + } + + found := false + s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool { + sacked := i.(header.SACKBlock) + if sacked.End.LessThan(r.Start) { + return false + } + if sacked.Contains(r) { + found = true + return false + } + return true + }) + return found +} + +// Dump prints the state of the scoreboard structure. +func (s *SACKScoreboard) String() string { + var str strings.Builder + str.WriteString("SACKScoreboard: {") + s.ranges.Ascend(func(i btree.Item) bool { + str.WriteString(fmt.Sprintf("%v,", i)) + return true + }) + str.WriteString("}\n") + return str.String() +} + +// Delete removes all SACK information prior to seq. +func (s *SACKScoreboard) Delete(seq seqnum.Value) { + if s.Empty() { + return + } + toDelete := []btree.Item{} + toInsert := []btree.Item{} + r := header.SACKBlock{seq, seq.Add(1)} + s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool { + if i == r { + return true + } + sb := i.(header.SACKBlock) + toDelete = append(toDelete, i) + if sb.End.LessThanEq(seq) { + s.sacked -= sb.Start.Size(sb.End) + } else { + newSB := header.SACKBlock{seq, sb.End} + toInsert = append(toInsert, newSB) + s.sacked -= sb.Start.Size(seq) + } + return true + }) + for _, sb := range toDelete { + s.ranges.Delete(sb) + } + for _, sb := range toInsert { + s.ranges.ReplaceOrInsert(sb) + } +} + +// Copy provides a copy of the SACK scoreboard. +func (s *SACKScoreboard) Copy() (sackBlocks []header.SACKBlock, maxSACKED seqnum.Value) { + s.ranges.Ascend(func(i btree.Item) bool { + sackBlocks = append(sackBlocks, i.(header.SACKBlock)) + return true + }) + return sackBlocks, s.maxSACKED +} + +// IsRangeLost implements the IsLost(SeqNum) operation defined in RFC 6675 +// section 4 but operates on a range of sequence numbers and returns true if +// there are at least nDupAckThreshold SACK blocks greater than the range being +// checked or if at least (nDupAckThreshold-1)*s.smss bytes have been SACKED +// with sequence numbers greater than the block being checked. +func (s *SACKScoreboard) IsRangeLost(r header.SACKBlock) bool { + if s.Empty() { + return false + } + nDupSACK := 0 + nDupSACKBytes := seqnum.Size(0) + isLost := false + + // We need to check if the immediate lower (if any) sacked + // range contains or partially overlaps with r. + searchMore := true + s.ranges.DescendLessOrEqual(r, func(i btree.Item) bool { + sacked := i.(header.SACKBlock) + if sacked.Contains(r) { + searchMore = false + return false + } + if sacked.End.LessThanEq(r.Start) { + // all sequence numbers covered by sacked are below + // r so we continue searching. + return false + } + // There is a partial overlap. In this case we r.Start is + // between sacked.Start & sacked.End and r.End extends beyond + // sacked.End. + // Move r.Start to sacked.End and continuing searching blocks + // above r.Start. + r.Start = sacked.End + return false + }) + + if !searchMore { + return isLost + } + + s.ranges.AscendGreaterOrEqual(r, func(i btree.Item) bool { + sacked := i.(header.SACKBlock) + if sacked.Contains(r) { + return false + } + nDupSACKBytes += sacked.Start.Size(sacked.End) + nDupSACK++ + if nDupSACK >= nDupAckThreshold || nDupSACKBytes >= seqnum.Size((nDupAckThreshold-1)*s.smss) { + isLost = true + return false + } + return true + }) + return isLost +} + +// IsLost implements the IsLost(SeqNum) operation defined in RFC3517 section +// 4. +// +// This routine returns whether the given sequence number is considered to be +// lost. The routine returns true when either nDupAckThreshold discontiguous +// SACKed sequences have arrived above 'SeqNum' or (nDupAckThreshold * SMSS) +// bytes with sequence numbers greater than 'SeqNum' have been SACKed. +// Otherwise, the routine returns false. +func (s *SACKScoreboard) IsLost(seq seqnum.Value) bool { + return s.IsRangeLost(header.SACKBlock{seq, seq.Add(1)}) +} + +// Empty returns true if the SACK scoreboard has no entries, false otherwise. +func (s *SACKScoreboard) Empty() bool { + return s.ranges.Len() == 0 +} + +// Sacked returns the current number of bytes held in the SACK scoreboard. +func (s *SACKScoreboard) Sacked() seqnum.Size { + return s.sacked +} + +// MaxSACKED returns the highest sequence number ever inserted in the SACK +// scoreboard. +func (s *SACKScoreboard) MaxSACKED() seqnum.Value { + return s.maxSACKED +} + +// SMSS returns the sender's MSS as held by the SACK scoreboard. +func (s *SACKScoreboard) SMSS() uint16 { + return s.smss +} diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go new file mode 100644 index 000000000..450d9fbc1 --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment.go @@ -0,0 +1,186 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "sync/atomic" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// segment represents a TCP segment. It holds the payload and parsed TCP segment +// information, and can be added to intrusive lists. +// segment is mostly immutable, the only field allowed to change is viewToDeliver. +// +// +stateify savable +type segment struct { + segmentEntry + refCnt int32 + id stack.TransportEndpointID `state:"manual"` + route stack.Route `state:"manual"` + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + // views is used as buffer for data when its length is large + // enough to store a VectorisedView. + views [8]buffer.View `state:"nosave"` + // viewToDeliver keeps track of the next View that should be + // delivered by the Read endpoint. + viewToDeliver int + sequenceNumber seqnum.Value + ackNumber seqnum.Value + flags uint8 + window seqnum.Size + // csum is only populated for received segments. + csum uint16 + // csumValid is true if the csum in the received segment is valid. + csumValid bool + + // parsedOptions stores the parsed values from the options in the segment. + parsedOptions header.TCPOptions + options []byte `state:".([]byte)"` + hasNewSACKInfo bool + rcvdTime time.Time `state:".(unixTime)"` + // xmitTime is the last transmit time of this segment. A zero value + // indicates that the segment has yet to be transmitted. + xmitTime time.Time `state:".(unixTime)"` +} + +func newSegment(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) *segment { + s := &segment{ + refCnt: 1, + id: id, + route: r.Clone(), + } + s.data = vv.Clone(s.views[:]) + s.rcvdTime = time.Now() + return s +} + +func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment { + s := &segment{ + refCnt: 1, + id: id, + route: r.Clone(), + } + s.views[0] = v + s.data = buffer.NewVectorisedView(len(v), s.views[:1]) + s.rcvdTime = time.Now() + return s +} + +func (s *segment) clone() *segment { + t := &segment{ + refCnt: 1, + id: s.id, + sequenceNumber: s.sequenceNumber, + ackNumber: s.ackNumber, + flags: s.flags, + window: s.window, + route: s.route.Clone(), + viewToDeliver: s.viewToDeliver, + rcvdTime: s.rcvdTime, + } + t.data = s.data.Clone(t.views[:]) + return t +} + +func (s *segment) flagIsSet(flag uint8) bool { + return (s.flags & flag) != 0 +} + +func (s *segment) decRef() { + if atomic.AddInt32(&s.refCnt, -1) == 0 { + s.route.Release() + } +} + +func (s *segment) incRef() { + atomic.AddInt32(&s.refCnt, 1) +} + +// logicalLen is the segment length in the sequence number space. It's defined +// as the data length plus one for each of the SYN and FIN bits set. +func (s *segment) logicalLen() seqnum.Size { + l := seqnum.Size(s.data.Size()) + if s.flagIsSet(header.TCPFlagSyn) { + l++ + } + if s.flagIsSet(header.TCPFlagFin) { + l++ + } + return l +} + +// parse populates the sequence & ack numbers, flags, and window fields of the +// segment from the TCP header stored in the data. It then updates the view to +// skip the header. +// +// Returns boolean indicating if the parsing was successful. +// +// If checksum verification is not offloaded then parse also verifies the +// TCP checksum and stores the checksum and result of checksum verification in +// the csum and csumValid fields of the segment. +func (s *segment) parse() bool { + h := header.TCP(s.data.First()) + + // h is the header followed by the payload. We check that the offset to + // the data respects the following constraints: + // 1. That it's at least the minimum header size; if we don't do this + // then part of the header would be delivered to user. + // 2. That the header fits within the buffer; if we don't do this, we + // would panic when we tried to access data beyond the buffer. + // + // N.B. The segment has already been validated as having at least the + // minimum TCP size before reaching here, so it's safe to read the + // fields. + offset := int(h.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(h) { + return false + } + + s.options = []byte(h[header.TCPMinimumSize:offset]) + s.parsedOptions = header.ParseTCPOptions(s.options) + + // Query the link capabilities to decide if checksum validation is + // required. + verifyChecksum := true + if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 { + s.csumValid = true + verifyChecksum = false + s.data.TrimFront(offset) + } + if verifyChecksum { + s.csum = h.Checksum() + xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size())) + xsum = h.CalculateChecksum(xsum) + s.data.TrimFront(offset) + xsum = header.ChecksumVV(s.data, xsum) + s.csumValid = xsum == 0xffff + } + + s.sequenceNumber = seqnum.Value(h.SequenceNumber()) + s.ackNumber = seqnum.Value(h.AckNumber()) + s.flags = h.Flags() + s.window = seqnum.Size(h.WindowSize()) + return true +} + +// sackBlock returns a header.SACKBlock that represents this segment. +func (s *segment) sackBlock() header.SACKBlock { + return header.SACKBlock{s.sequenceNumber, s.sequenceNumber.Add(s.logicalLen())} +} diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go new file mode 100644 index 000000000..9fd061d7d --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_heap.go @@ -0,0 +1,46 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +type segmentHeap []*segment + +// Len returns the length of h. +func (h segmentHeap) Len() int { + return len(h) +} + +// Less determines whether the i-th element of h is less than the j-th element. +func (h segmentHeap) Less(i, j int) bool { + return h[i].sequenceNumber.LessThan(h[j].sequenceNumber) +} + +// Swap swaps the i-th and j-th elements of h. +func (h segmentHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +// Push adds x as the last element of h. +func (h *segmentHeap) Push(x interface{}) { + *h = append(*h, x.(*segment)) +} + +// Pop removes the last element of h and returns it. +func (h *segmentHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[:n-1] + return x +} diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go new file mode 100644 index 000000000..e0759225e --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -0,0 +1,79 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "sync" +) + +// segmentQueue is a bounded, thread-safe queue of TCP segments. +// +// +stateify savable +type segmentQueue struct { + mu sync.Mutex `state:"nosave"` + list segmentList `state:"wait"` + limit int + used int +} + +// empty determines if the queue is empty. +func (q *segmentQueue) empty() bool { + q.mu.Lock() + r := q.used == 0 + q.mu.Unlock() + + return r +} + +// setLimit updates the limit. No segments are immediately dropped in case the +// queue becomes full due to the new limit. +func (q *segmentQueue) setLimit(limit int) { + q.mu.Lock() + q.limit = limit + q.mu.Unlock() +} + +// enqueue adds the given segment to the queue. +// +// Returns true when the segment is successfully added to the queue, in which +// case ownership of the reference is transferred to the queue. And returns +// false if the queue is full, in which case ownership is retained by the +// caller. +func (q *segmentQueue) enqueue(s *segment) bool { + q.mu.Lock() + r := q.used < q.limit + if r { + q.list.PushBack(s) + q.used++ + } + q.mu.Unlock() + + return r +} + +// dequeue removes and returns the next segment from queue, if one exists. +// Ownership is transferred to the caller, who is responsible for decrementing +// the ref count when done. +func (q *segmentQueue) dequeue() *segment { + q.mu.Lock() + s := q.list.Front() + if s != nil { + q.list.Remove(s) + q.used-- + } + q.mu.Unlock() + + return s +} diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go new file mode 100644 index 000000000..dd7e14aa6 --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_state.go @@ -0,0 +1,82 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" +) + +// saveData is invoked by stateify. +func (s *segment) saveData() buffer.VectorisedView { + // We cannot save s.data directly as s.data.views may alias to s.views, + // which is not allowed by state framework (in-struct pointer). + v := make([]buffer.View, len(s.data.Views())) + // For views already delivered, we cannot save them directly as they may + // have already been sliced and saved elsewhere (e.g., readViews). + for i := 0; i < s.viewToDeliver; i++ { + v[i] = append([]byte(nil), s.data.Views()[i]...) + } + for i := s.viewToDeliver; i < len(v); i++ { + v[i] = s.data.Views()[i] + } + return buffer.NewVectorisedView(s.data.Size(), v) +} + +// loadData is invoked by stateify. +func (s *segment) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the s.data = data.Clone(s.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing s.views for data.views. + s.data = data +} + +// saveOptions is invoked by stateify. +func (s *segment) saveOptions() []byte { + // We cannot save s.options directly as it may point to s.data's trimmed + // tail, which is not allowed by state framework (in-struct pointer). + b := make([]byte, 0, cap(s.options)) + return append(b, s.options...) +} + +// loadOptions is invoked by stateify. +func (s *segment) loadOptions(options []byte) { + // NOTE: We cannot point s.options back into s.data's trimmed tail. But + // it is OK as they do not need to aliased. Plus, options is already + // allocated so there is no cost here. + s.options = options +} + +// saveRcvdTime is invoked by stateify. +func (s *segment) saveRcvdTime() unixTime { + return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()} +} + +// loadRcvdTime is invoked by stateify. +func (s *segment) loadRcvdTime(unix unixTime) { + s.rcvdTime = time.Unix(unix.second, unix.nano) +} + +// saveXmitTime is invoked by stateify. +func (s *segment) saveXmitTime() unixTime { + return unixTime{s.rcvdTime.Unix(), s.rcvdTime.UnixNano()} +} + +// loadXmitTime is invoked by stateify. +func (s *segment) loadXmitTime(unix unixTime) { + s.rcvdTime = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go new file mode 100644 index 000000000..afc1d0a55 --- /dev/null +++ b/pkg/tcpip/transport/tcp/snd.go @@ -0,0 +1,1180 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "math" + "sync" + "sync/atomic" + "time" + + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" +) + +const ( + // minRTO is the minimum allowed value for the retransmit timeout. + minRTO = 200 * time.Millisecond + + // InitialCwnd is the initial congestion window. + InitialCwnd = 10 + + // nDupAckThreshold is the number of duplicate ACK's required + // before fast-retransmit is entered. + nDupAckThreshold = 3 +) + +// congestionControl is an interface that must be implemented by any supported +// congestion control algorithm. +type congestionControl interface { + // HandleNDupAcks is invoked when sender.dupAckCount >= nDupAckThreshold + // just before entering fast retransmit. + HandleNDupAcks() + + // HandleRTOExpired is invoked when the retransmit timer expires. + HandleRTOExpired() + + // Update is invoked when processing inbound acks. It's passed the + // number of packet's that were acked by the most recent cumulative + // acknowledgement. + Update(packetsAcked int) + + // PostRecovery is invoked when the sender is exiting a fast retransmit/ + // recovery phase. This provides congestion control algorithms a way + // to adjust their state when exiting recovery. + PostRecovery() +} + +// sender holds the state necessary to send TCP segments. +// +// +stateify savable +type sender struct { + ep *endpoint + + // lastSendTime is the timestamp when the last packet was sent. + lastSendTime time.Time `state:".(unixTime)"` + + // dupAckCount is the number of duplicated acks received. It is used for + // fast retransmit. + dupAckCount int + + // fr holds state related to fast recovery. + fr fastRecovery + + // sndCwnd is the congestion window, in packets. + sndCwnd int + + // sndSsthresh is the threshold between slow start and congestion + // avoidance. + sndSsthresh int + + // sndCAAckCount is the number of packets acknowledged during congestion + // avoidance. When enough packets have been ack'd (typically cwnd + // packets), the congestion window is incremented by one. + sndCAAckCount int + + // outstanding is the number of outstanding packets, that is, packets + // that have been sent but not yet acknowledged. + outstanding int + + // sndWnd is the send window size. + sndWnd seqnum.Size + + // sndUna is the next unacknowledged sequence number. + sndUna seqnum.Value + + // sndNxt is the sequence number of the next segment to be sent. + sndNxt seqnum.Value + + // sndNxtList is the sequence number of the next segment to be added to + // the send list. + sndNxtList seqnum.Value + + // rttMeasureSeqNum is the sequence number being used for the latest RTT + // measurement. + rttMeasureSeqNum seqnum.Value + + // rttMeasureTime is the time when the rttMeasureSeqNum was sent. + rttMeasureTime time.Time `state:".(unixTime)"` + + closed bool + writeNext *segment + writeList segmentList + resendTimer timer `state:"nosave"` + resendWaker sleep.Waker `state:"nosave"` + + // rtt.srtt, rtt.rttvar, and rto are the "smoothed round-trip time", + // "round-trip time variation" and "retransmit timeout", as defined in + // section 2 of RFC 6298. + rtt rtt + rto time.Duration + srttInited bool + + // maxPayloadSize is the maximum size of the payload of a given segment. + // It is initialized on demand. + maxPayloadSize int + + // gso is set if generic segmentation offload is enabled. + gso bool + + // sndWndScale is the number of bits to shift left when reading the send + // window size from a segment. + sndWndScale uint8 + + // maxSentAck is the maxium acknowledgement actually sent. + maxSentAck seqnum.Value + + // cc is the congestion control algorithm in use for this sender. + cc congestionControl +} + +// rtt is a synchronization wrapper used to appease stateify. See the comment +// in sender, where it is used. +// +// +stateify savable +type rtt struct { + sync.Mutex `state:"nosave"` + + srtt time.Duration + rttvar time.Duration +} + +// fastRecovery holds information related to fast recovery from a packet loss. +// +// +stateify savable +type fastRecovery struct { + // active whether the endpoint is in fast recovery. The following fields + // are only meaningful when active is true. + active bool + + // first and last represent the inclusive sequence number range being + // recovered. + first seqnum.Value + last seqnum.Value + + // maxCwnd is the maximum value the congestion window may be inflated to + // due to duplicate acks. This exists to avoid attacks where the + // receiver intentionally sends duplicate acks to artificially inflate + // the sender's cwnd. + maxCwnd int + + // highRxt is the highest sequence number which has been retransmitted + // during the current loss recovery phase. + // See: RFC 6675 Section 2 for details. + highRxt seqnum.Value + + // rescueRxt is the highest sequence number which has been + // optimistically retransmitted to prevent stalling of the ACK clock + // when there is loss at the end of the window and no new data is + // available for transmission. + // See: RFC 6675 Section 2 for details. + rescueRxt seqnum.Value +} + +func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender { + // The sender MUST reduce the TCP data length to account for any IP or + // TCP options that it is including in the packets that it sends. + // See: https://tools.ietf.org/html/rfc6691#section-2 + maxPayloadSize := int(mss) - ep.maxOptionSize() + + s := &sender{ + ep: ep, + sndCwnd: InitialCwnd, + sndSsthresh: math.MaxInt64, + sndWnd: sndWnd, + sndUna: iss + 1, + sndNxt: iss + 1, + sndNxtList: iss + 1, + rto: 1 * time.Second, + rttMeasureSeqNum: iss + 1, + lastSendTime: time.Now(), + maxPayloadSize: maxPayloadSize, + maxSentAck: irs + 1, + fr: fastRecovery{ + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. + last: iss, + highRxt: iss, + rescueRxt: iss, + }, + gso: ep.gso != nil, + } + + if s.gso { + s.ep.gso.MSS = uint16(maxPayloadSize) + } + + s.cc = s.initCongestionControl(ep.cc) + + // A negative sndWndScale means that no scaling is in use, otherwise we + // store the scaling value. + if sndWndScale > 0 { + s.sndWndScale = uint8(sndWndScale) + } + + s.resendTimer.init(&s.resendWaker) + + s.updateMaxPayloadSize(int(ep.route.MTU()), 0) + + // Initialize SACK Scoreboard after updating max payload size as we use + // the maxPayloadSize as the smss when determining if a segment is lost + // etc. + s.ep.scoreboard = NewSACKScoreboard(uint16(s.maxPayloadSize), iss) + + return s +} + +func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl { + switch congestionControlName { + case ccCubic: + return newCubicCC(s) + case ccReno: + fallthrough + default: + return newRenoCC(s) + } +} + +// updateMaxPayloadSize updates the maximum payload size based on the given +// MTU. If this is in response to "packet too big" control packets (indicated +// by the count argument), it also reduces the number of outstanding packets and +// attempts to retransmit the first packet above the MTU size. +func (s *sender) updateMaxPayloadSize(mtu, count int) { + m := mtu - header.TCPMinimumSize + + m -= s.ep.maxOptionSize() + + // We don't adjust up for now. + if m >= s.maxPayloadSize { + return + } + + // Make sure we can transmit at least one byte. + if m <= 0 { + m = 1 + } + + s.maxPayloadSize = m + if s.gso { + s.ep.gso.MSS = uint16(m) + } + + if count == 0 { + // updateMaxPayloadSize is also called when the sender is created. + // and there is no data to send in such cases. Return immediately. + return + } + + // Update the scoreboard's smss to reflect the new lowered + // maxPayloadSize. + s.ep.scoreboard.smss = uint16(m) + + s.outstanding -= count + if s.outstanding < 0 { + s.outstanding = 0 + } + + // Rewind writeNext to the first segment exceeding the MTU. Do nothing + // if it is already before such a packet. + for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { + if seg == s.writeNext { + // We got to writeNext before we could find a segment + // exceeding the MTU. + break + } + + if seg.data.Size() > m { + // We found a segment exceeding the MTU. Rewind + // writeNext and try to retransmit it. + s.writeNext = seg + break + } + } + + // Since we likely reduced the number of outstanding packets, we may be + // ready to send some more. + s.sendData() +} + +// sendAck sends an ACK segment. +func (s *sender) sendAck() { + s.sendSegmentFromView(buffer.VectorisedView{}, header.TCPFlagAck, s.sndNxt) +} + +// updateRTO updates the retransmit timeout when a new roud-trip time is +// available. This is done in accordance with section 2 of RFC 6298. +func (s *sender) updateRTO(rtt time.Duration) { + s.rtt.Lock() + if !s.srttInited { + s.rtt.rttvar = rtt / 2 + s.rtt.srtt = rtt + s.srttInited = true + } else { + diff := s.rtt.srtt - rtt + if diff < 0 { + diff = -diff + } + // Use RFC6298 standard algorithm to update rttvar and srtt when + // no timestamps are available. + if !s.ep.sendTSOk { + s.rtt.rttvar = (3*s.rtt.rttvar + diff) / 4 + s.rtt.srtt = (7*s.rtt.srtt + rtt) / 8 + } else { + // When we are taking RTT measurements of every ACK then + // we need to use a modified method as specified in + // https://tools.ietf.org/html/rfc7323#appendix-G + if s.outstanding == 0 { + s.rtt.Unlock() + return + } + // Netstack measures congestion window/inflight all in + // terms of packets and not bytes. This is similar to + // how linux also does cwnd and inflight. In practice + // this approximation works as expected. + expectedSamples := math.Ceil(float64(s.outstanding) / 2) + + // alpha & beta values are the original values as recommended in + // https://tools.ietf.org/html/rfc6298#section-2.3. + const alpha = 0.125 + const beta = 0.25 + + alphaPrime := alpha / expectedSamples + betaPrime := beta / expectedSamples + rttVar := (1-betaPrime)*s.rtt.rttvar.Seconds() + betaPrime*diff.Seconds() + srtt := (1-alphaPrime)*s.rtt.srtt.Seconds() + alphaPrime*rtt.Seconds() + s.rtt.rttvar = time.Duration(rttVar * float64(time.Second)) + s.rtt.srtt = time.Duration(srtt * float64(time.Second)) + } + } + + s.rto = s.rtt.srtt + 4*s.rtt.rttvar + s.rtt.Unlock() + if s.rto < minRTO { + s.rto = minRTO + } +} + +// resendSegment resends the first unacknowledged segment. +func (s *sender) resendSegment() { + // Don't use any segments we already sent to measure RTT as they may + // have been affected by packets being lost. + s.rttMeasureSeqNum = s.sndNxt + + // Resend the segment. + if seg := s.writeList.Front(); seg != nil { + if seg.data.Size() > s.maxPayloadSize { + s.splitSeg(seg, s.maxPayloadSize) + } + + // See: RFC 6675 section 5 Step 4.3 + // + // To prevent retransmission, set both the HighRXT and RescueRXT + // to the highest sequence number in the retransmitted segment. + s.fr.highRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.fr.rescueRxt = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) - 1 + s.sendSegment(seg) + s.ep.stack.Stats().TCP.FastRetransmit.Increment() + + // Run SetPipe() as per RFC 6675 section 5 Step 4.4 + s.SetPipe() + } +} + +// retransmitTimerExpired is called when the retransmit timer expires, and +// unacknowledged segments are assumed lost, and thus need to be resent. +// Returns true if the connection is still usable, or false if the connection +// is deemed lost. +func (s *sender) retransmitTimerExpired() bool { + // Check if the timer actually expired or if it's a spurious wake due + // to a previously orphaned runtime timer. + if !s.resendTimer.checkExpiration() { + return true + } + + s.ep.stack.Stats().TCP.Timeouts.Increment() + + // Give up if we've waited more than a minute since the last resend. + if s.rto >= 60*time.Second { + return false + } + + // Set new timeout. The timer will be restarted by the call to sendData + // below. + s.rto *= 2 + + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. + // + // Retransmit timeouts: + // After a retransmit timeout, record the highest sequence number + // transmitted in the variable recover, and exit the fast recovery + // procedure if applicable. + s.fr.last = s.sndNxt - 1 + + if s.fr.active { + // We were attempting fast recovery but were not successful. + // Leave the state. We don't need to update ssthresh because it + // has already been updated when entered fast-recovery. + s.leaveFastRecovery() + } + + s.cc.HandleRTOExpired() + + // Mark the next segment to be sent as the first unacknowledged one and + // start sending again. Set the number of outstanding packets to 0 so + // that we'll be able to retransmit. + // + // We'll keep on transmitting (or retransmitting) as we get acks for + // the data we transmit. + s.outstanding = 0 + + // Expunge all SACK information as per https://tools.ietf.org/html/rfc6675#section-5.1 + // + // In order to avoid memory deadlocks, the TCP receiver is allowed to + // discard data that has already been selectively acknowledged. As a + // result, [RFC2018] suggests that a TCP sender SHOULD expunge the SACK + // information gathered from a receiver upon a retransmission timeout + // (RTO) "since the timeout might indicate that the data receiver has + // reneged." Additionally, a TCP sender MUST "ignore prior SACK + // information in determining which data to retransmit." + // + // NOTE: We take the stricter interpretation and just expunge all + // information as we lack more rigorous checks to validate if the SACK + // information is usable after an RTO. + s.ep.scoreboard.Reset() + s.writeNext = s.writeList.Front() + s.sendData() + + return true +} + +// pCount returns the number of packets in the segment. Due to GSO, a segment +// can be composed of multiple packets. +func (s *sender) pCount(seg *segment) int { + size := seg.data.Size() + if size == 0 { + return 1 + } + + return (size-1)/s.maxPayloadSize + 1 +} + +// splitSeg splits a given segment at the size specified and inserts the +// remainder as a new segment after the current one in the write list. +func (s *sender) splitSeg(seg *segment, size int) { + if seg.data.Size() <= size { + return + } + // Split this segment up. + nSeg := seg.clone() + nSeg.data.TrimFront(size) + nSeg.sequenceNumber.UpdateForward(seqnum.Size(size)) + s.writeList.InsertAfter(seg, nSeg) + seg.data.CapLength(size) +} + +// NextSeg implements the RFC6675 NextSeg() operation. It returns segments that +// match rule 1, 3 and 4 of the NextSeg() operation defined in RFC6675. Rule 2 +// is handled by the normal send logic. +func (s *sender) NextSeg() (nextSeg1, nextSeg3, nextSeg4 *segment) { + var s3 *segment + var s4 *segment + smss := s.ep.scoreboard.SMSS() + // Step 1. + for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { + if !s.isAssignedSequenceNumber(seg) { + break + } + segSeq := seg.sequenceNumber + if seg.data.Size() > int(smss) { + s.splitSeg(seg, int(smss)) + } + // See RFC 6675 Section 4 + // + // 1. If there exists a smallest unSACKED sequence number + // 'S2' that meets the following 3 criteria for determinig + // loss, the sequence range of one segment of up to SMSS + // octects starting with S2 MUST be returned. + if !s.ep.scoreboard.IsSACKED(header.SACKBlock{segSeq, segSeq.Add(1)}) { + // NextSeg(): + // + // (1.a) S2 is greater than HighRxt + // (1.b) S2 is less than highest octect covered by + // any received SACK. + if s.fr.highRxt.LessThan(segSeq) && segSeq.LessThan(s.ep.scoreboard.maxSACKED) { + // NextSeg(): + // (1.c) IsLost(S2) returns true. + if s.ep.scoreboard.IsLost(segSeq) { + return seg, s3, s4 + } + // NextSeg(): + // + // (3): If the conditions for rules (1) and (2) + // fail, but there exists an unSACKed sequence + // number S3 that meets the criteria for + // detecting loss given in steps 1.a and 1.b + // above (specifically excluding (1.c)) then one + // segment of upto SMSS octets starting with S3 + // SHOULD be returned. + if s3 == nil { + s3 = seg + } + } + // NextSeg(): + // + // (4) If the conditions for (1), (2) and (3) fail, + // but there exists outstanding unSACKED data, we + // provide the opportunity for a single "rescue" + // retransmission per entry into loss recovery. If + // HighACK is greater than RescueRxt, the one + // segment of upto SMSS octects that MUST include + // the highest outstanding unSACKed sequence number + // SHOULD be returned. + if s.fr.rescueRxt.LessThan(s.sndUna - 1) { + if s4 != nil { + if s4.sequenceNumber.LessThan(segSeq) { + s4 = seg + } + } else { + s4 = seg + } + s.fr.rescueRxt = s.fr.last + } + } + } + + return nil, s3, s4 +} + +// maybeSendSegment tries to send the specified segment and either coalesces +// other segments into this one or splits the specified segment based on the +// lower of the specified limit value or the receivers window size specified by +// end. +func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (sent bool) { + // We abuse the flags field to determine if we have already + // assigned a sequence number to this segment. + if !s.isAssignedSequenceNumber(seg) { + // Merge segments if allowed. + if seg.data.Size() != 0 { + available := int(seg.sequenceNumber.Size(end)) + if available > limit { + available = limit + } + + // nextTooBig indicates that the next segment was too + // large to entirely fit in the current segment. It + // would be possible to split the next segment and merge + // the portion that fits, but unexpectedly splitting + // segments can have user visible side-effects which can + // break applications. For example, RFC 7766 section 8 + // says that the length and data of a DNS response + // should be sent in the same TCP segment to avoid + // triggering bugs in poorly written DNS + // implementations. + var nextTooBig bool + for seg.Next() != nil && seg.Next().data.Size() != 0 { + if seg.data.Size()+seg.Next().data.Size() > available { + nextTooBig = true + break + } + seg.data.Append(seg.Next().data) + + // Consume the segment that we just merged in. + s.writeList.Remove(seg.Next()) + } + if !nextTooBig && seg.data.Size() < available { + // Segment is not full. + if s.outstanding > 0 && atomic.LoadUint32(&s.ep.delay) != 0 { + // Nagle's algorithm. From Wikipedia: + // Nagle's algorithm works by + // combining a number of small + // outgoing messages and sending them + // all at once. Specifically, as long + // as there is a sent packet for which + // the sender has received no + // acknowledgment, the sender should + // keep buffering its output until it + // has a full packet's worth of + // output, thus allowing output to be + // sent all at once. + return false + } + if atomic.LoadUint32(&s.ep.cork) != 0 { + // Hold back the segment until full. + return false + } + } + } + + // Assign flags. We don't do it above so that we can merge + // additional data if Nagle holds the segment. + seg.sequenceNumber = s.sndNxt + seg.flags = header.TCPFlagAck | header.TCPFlagPsh + } + + var segEnd seqnum.Value + if seg.data.Size() == 0 { + if s.writeList.Back() != seg { + panic("FIN segments must be the final segment in the write list.") + } + seg.flags = header.TCPFlagAck | header.TCPFlagFin + segEnd = seg.sequenceNumber.Add(1) + } else { + // We're sending a non-FIN segment. + if seg.flags&header.TCPFlagFin != 0 { + panic("Netstack queues FIN segments without data.") + } + + if !seg.sequenceNumber.LessThan(end) { + return false + } + + available := int(seg.sequenceNumber.Size(end)) + if available == 0 { + return false + } + if available > limit { + available = limit + } + + if seg.data.Size() > available { + s.splitSeg(seg, available) + } + + segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + } + + s.sendSegment(seg) + + // Update sndNxt if we actually sent new data (as opposed to + // retransmitting some previously sent data). + if s.sndNxt.LessThan(segEnd) { + s.sndNxt = segEnd + } + + return true +} + +// handleSACKRecovery implements the loss recovery phase as described in RFC6675 +// section 5, step C. +func (s *sender) handleSACKRecovery(limit int, end seqnum.Value) (dataSent bool) { + s.SetPipe() + for s.outstanding < s.sndCwnd { + nextSeg, s3, s4 := s.NextSeg() + if nextSeg == nil { + // NextSeg(): + // + // Step (2): "If no sequence number 'S2' per rule (1) + // exists but there exists available unsent data and the + // receiver's advertised window allows, the sequence + // range of one segment of up to SMSS octets of + // previously unsent data starting with sequence number + // HighData+1 MUST be returned." + for seg := s.writeNext; seg != nil; seg = seg.Next() { + if s.isAssignedSequenceNumber(seg) && seg.sequenceNumber.LessThan(s.sndNxt) { + continue + } + // Step C.3 described below is handled by + // maybeSendSegment which increments sndNxt when + // a segment is transmitted. + // + // Step C.3 "If any of the data octets sent in + // (C.1) are above HighData, HighData must be + // updated to reflect the transmission of + // previously unsent data." + if sent := s.maybeSendSegment(seg, limit, end); !sent { + break + } + dataSent = true + s.outstanding++ + s.writeNext = seg.Next() + nextSeg = seg + break + } + if nextSeg != nil { + continue + } + } + rescueRtx := false + if nextSeg == nil && s3 != nil { + nextSeg = s3 + } + if nextSeg == nil && s4 != nil { + nextSeg = s4 + rescueRtx = true + } + if nextSeg == nil { + break + } + segEnd := nextSeg.sequenceNumber.Add(nextSeg.logicalLen()) + if !rescueRtx && nextSeg.sequenceNumber.LessThan(s.sndNxt) { + // RFC 6675, Step C.2 + // + // "If any of the data octets sent in (C.1) are below + // HighData, HighRxt MUST be set to the highest sequence + // number of the retransmitted segment unless NextSeg () + // rule (4) was invoked for this retransmission." + s.fr.highRxt = segEnd - 1 + } + + // RFC 6675, Step C.4. + // + // "The estimate of the amount of data outstanding in the network + // must be updated by incrementing pipe by the number of octets + // transmitted in (C.1)." + s.outstanding++ + dataSent = true + s.sendSegment(nextSeg) + } + return dataSent +} + +// sendData sends new data segments. It is called when data becomes available or +// when the send window opens up. +func (s *sender) sendData() { + limit := s.maxPayloadSize + if s.gso { + limit = int(s.ep.gso.MaxSize - header.TCPHeaderMaximumSize) + } + end := s.sndUna.Add(s.sndWnd) + + // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10. + // "A TCP SHOULD set cwnd to no more than RW before beginning + // transmission if the TCP has not sent data in the interval exceeding + // the retrasmission timeout." + if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto { + if s.sndCwnd > InitialCwnd { + s.sndCwnd = InitialCwnd + } + } + + var dataSent bool + + // RFC 6675 recovery algorithm step C 1-5. + if s.fr.active && s.ep.sackPermitted { + dataSent = s.handleSACKRecovery(s.maxPayloadSize, end) + } else { + for seg := s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { + cwndLimit := (s.sndCwnd - s.outstanding) * s.maxPayloadSize + if cwndLimit < limit { + limit = cwndLimit + } + if s.isAssignedSequenceNumber(seg) && s.ep.sackPermitted && s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + continue + } + if sent := s.maybeSendSegment(seg, limit, end); !sent { + break + } + dataSent = true + s.outstanding++ + s.writeNext = seg.Next() + } + } + + if dataSent { + // We sent data, so we should stop the keepalive timer to ensure + // that no keepalives are sent while there is pending data. + s.ep.disableKeepaliveTimer() + } + + // Enable the timer if we have pending data and it's not enabled yet. + if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { + s.resendTimer.enable(s.rto) + } + // If we have no more pending data, start the keepalive timer. + if s.sndUna == s.sndNxt { + s.ep.resetKeepaliveTimer(false) + } +} + +func (s *sender) enterFastRecovery() { + s.fr.active = true + // Save state to reflect we're now in fast recovery. + // + // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3. + // We inflate the cwnd by 3 to account for the 3 packets which triggered + // the 3 duplicate ACKs and are now not in flight. + s.sndCwnd = s.sndSsthresh + 3 + s.fr.first = s.sndUna + s.fr.last = s.sndNxt - 1 + s.fr.maxCwnd = s.sndCwnd + s.outstanding + if s.ep.sackPermitted { + s.ep.stack.Stats().TCP.SACKRecovery.Increment() + return + } + s.ep.stack.Stats().TCP.FastRecovery.Increment() +} + +func (s *sender) leaveFastRecovery() { + s.fr.active = false + s.fr.maxCwnd = 0 + s.dupAckCount = 0 + + // Deflate cwnd. It had been artificially inflated when new dups arrived. + s.sndCwnd = s.sndSsthresh + + s.cc.PostRecovery() +} + +func (s *sender) handleFastRecovery(seg *segment) (rtx bool) { + ack := seg.ackNumber + // We are in fast recovery mode. Ignore the ack if it's out of + // range. + if !ack.InRange(s.sndUna, s.sndNxt+1) { + return false + } + + // Leave fast recovery if it acknowledges all the data covered by + // this fast recovery session. + if s.fr.last.LessThan(ack) { + s.leaveFastRecovery() + return false + } + + if s.ep.sackPermitted { + // When SACK is enabled we let retransmission be governed by + // the SACK logic. + return false + } + + // Don't count this as a duplicate if it is carrying data or + // updating the window. + if seg.logicalLen() != 0 || s.sndWnd != seg.window { + return false + } + + // Inflate the congestion window if we're getting duplicate acks + // for the packet we retransmitted. + if ack == s.fr.first { + // We received a dup, inflate the congestion window by 1 packet + // if we're not at the max yet. Only inflate the window if + // regular FastRecovery is in use, RFC6675 does not require + // inflating cwnd on duplicate ACKs. + if s.sndCwnd < s.fr.maxCwnd { + s.sndCwnd++ + } + return false + } + + // A partial ack was received. Retransmit this packet and + // remember it so that we don't retransmit it again. We don't + // inflate the window because we're putting the same packet back + // onto the wire. + // + // N.B. The retransmit timer will be reset by the caller. + s.fr.first = ack + s.dupAckCount = 0 + return true +} + +// isAssignedSequenceNumber relies on the fact that we only set flags once a +// sequencenumber is assigned and that is only done right before we send the +// segment. As a result any segment that has a non-zero flag has a valid +// sequence number assigned to it. +func (s *sender) isAssignedSequenceNumber(seg *segment) bool { + return seg.flags != 0 +} + +// SetPipe implements the SetPipe() function described in RFC6675. Netstack +// maintains the congestion window in number of packets and not bytes, so +// SetPipe() here measures number of outstanding packets rather than actual +// outstanding bytes in the network. +func (s *sender) SetPipe() { + // If SACK isn't permitted or it is permitted but recovery is not active + // then ignore pipe calculations. + if !s.ep.sackPermitted || !s.fr.active { + return + } + pipe := 0 + smss := seqnum.Size(s.ep.scoreboard.SMSS()) + for s1 := s.writeList.Front(); s1 != nil && s1.data.Size() != 0 && s.isAssignedSequenceNumber(s1); s1 = s1.Next() { + // With GSO each segment can be much larger than SMSS. So check the segment + // in SMSS sized ranges. + segEnd := s1.sequenceNumber.Add(seqnum.Size(s1.data.Size())) + for startSeq := s1.sequenceNumber; startSeq.LessThan(segEnd); startSeq = startSeq.Add(smss) { + endSeq := startSeq.Add(smss) + if segEnd.LessThan(endSeq) { + endSeq = segEnd + } + sb := header.SACKBlock{startSeq, endSeq} + // SetPipe(): + // + // After initializing pipe to zero, the following steps are + // taken for each octet 'S1' in the sequence space between + // HighACK and HighData that has not been SACKed: + if !s1.sequenceNumber.LessThan(s.sndNxt) { + break + } + if s.ep.scoreboard.IsSACKED(sb) { + continue + } + + // SetPipe(): + // + // (a) If IsLost(S1) returns false, Pipe is incremened by 1. + // + // NOTE: here we mark the whole segment as lost. We do not try + // and test every byte in our write buffer as we maintain our + // pipe in terms of oustanding packets and not bytes. + if !s.ep.scoreboard.IsRangeLost(sb) { + pipe++ + } + // SetPipe(): + // (b) If S1 <= HighRxt, Pipe is incremented by 1. + if s1.sequenceNumber.LessThanEq(s.fr.highRxt) { + pipe++ + } + } + } + s.outstanding = pipe +} + +// checkDuplicateAck is called when an ack is received. It manages the state +// related to duplicate acks and determines if a retransmit is needed according +// to the rules in RFC 6582 (NewReno). +func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { + ack := seg.ackNumber + if s.fr.active { + return s.handleFastRecovery(seg) + } + + // We're not in fast recovery yet. A segment is considered a duplicate + // only if it doesn't carry any data and doesn't update the send window, + // because if it does, it wasn't sent in response to an out-of-order + // segment. If SACK is enabled then we have an additional check to see + // if the segment carries new SACK information. If it does then it is + // considered a duplicate ACK as per RFC6675. + if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt { + if !s.ep.sackPermitted || !seg.hasNewSACKInfo { + s.dupAckCount = 0 + return false + } + } + + s.dupAckCount++ + + // Do not enter fast recovery until we reach nDupAckThreshold or the + // first unacknowledged byte is considered lost as per SACK scoreboard. + if s.dupAckCount < nDupAckThreshold || (s.ep.sackPermitted && !s.ep.scoreboard.IsLost(s.sndUna)) { + // RFC 6675 Step 3. + s.fr.highRxt = s.sndUna - 1 + // Do run SetPipe() to calculate the outstanding segments. + s.SetPipe() + return false + } + + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2 + // + // We only do the check here, the incrementing of last to the highest + // sequence number transmitted till now is done when enterFastRecovery + // is invoked. + if !s.fr.last.LessThan(seg.ackNumber) { + s.dupAckCount = 0 + return false + } + s.cc.HandleNDupAcks() + s.enterFastRecovery() + s.dupAckCount = 0 + return true +} + +// handleRcvdSegment is called when a segment is received; it is responsible for +// updating the send-related state. +func (s *sender) handleRcvdSegment(seg *segment) { + // Check if we can extract an RTT measurement from this ack. + if !seg.parsedOptions.TS && s.rttMeasureSeqNum.LessThan(seg.ackNumber) { + s.updateRTO(time.Now().Sub(s.rttMeasureTime)) + s.rttMeasureSeqNum = s.sndNxt + } + + // Update Timestamp if required. See RFC7323, section-4.3. + if s.ep.sendTSOk && seg.parsedOptions.TS { + s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber) + } + + // Insert SACKBlock information into our scoreboard. + if s.ep.sackPermitted { + for _, sb := range seg.parsedOptions.SACKBlocks { + // Only insert the SACK block if the following holds + // true: + // * SACK block acks data after the ack number in the + // current segment. + // * SACK block represents a sequence + // between sndUna and sndNxt (i.e. data that is + // currently unacked and in-flight). + // * SACK block that has not been SACKed already. + // + // NOTE: This check specifically excludes DSACK blocks + // which have start/end before sndUna and are used to + // indicate spurious retransmissions. + if seg.ackNumber.LessThan(sb.Start) && s.sndUna.LessThan(sb.Start) && sb.End.LessThanEq(s.sndNxt) && !s.ep.scoreboard.IsSACKED(sb) { + s.ep.scoreboard.Insert(sb) + seg.hasNewSACKInfo = true + } + } + s.SetPipe() + } + + // Count the duplicates and do the fast retransmit if needed. + rtx := s.checkDuplicateAck(seg) + + // Stash away the current window size. + s.sndWnd = seg.window + + // Ignore ack if it doesn't acknowledge any new data. + ack := seg.ackNumber + if (ack - 1).InRange(s.sndUna, s.sndNxt) { + s.dupAckCount = 0 + + // See : https://tools.ietf.org/html/rfc1323#section-3.3. + // Specifically we should only update the RTO using TSEcr if the + // following condition holds: + // + // A TSecr value received in a segment is used to update the + // averaged RTT measurement only if the segment acknowledges + // some new data, i.e., only if it advances the left edge of + // the send window. + if s.ep.sendTSOk && seg.parsedOptions.TSEcr != 0 { + // TSVal/Ecr values sent by Netstack are at a millisecond + // granularity. + elapsed := time.Duration(s.ep.timestamp()-seg.parsedOptions.TSEcr) * time.Millisecond + s.updateRTO(elapsed) + } + + // When an ack is received we must rearm the timer. + // RFC 6298 5.2 + s.resendTimer.enable(s.rto) + + // Remove all acknowledged data from the write list. + acked := s.sndUna.Size(ack) + s.sndUna = ack + + ackLeft := acked + originalOutstanding := s.outstanding + for ackLeft > 0 { + // We use logicalLen here because we can have FIN + // segments (which are always at the end of list) that + // have no data, but do consume a sequence number. + seg := s.writeList.Front() + datalen := seg.logicalLen() + + if datalen > ackLeft { + prevCount := s.pCount(seg) + seg.data.TrimFront(int(ackLeft)) + seg.sequenceNumber.UpdateForward(ackLeft) + s.outstanding -= prevCount - s.pCount(seg) + break + } + + if s.writeNext == seg { + s.writeNext = seg.Next() + } + s.writeList.Remove(seg) + + // if SACK is enabled then Only reduce outstanding if + // the segment was not previously SACKED as these have + // already been accounted for in SetPipe(). + if !s.ep.sackPermitted || !s.ep.scoreboard.IsSACKED(seg.sackBlock()) { + s.outstanding -= s.pCount(seg) + } + seg.decRef() + ackLeft -= datalen + } + + // Update the send buffer usage and notify potential waiters. + s.ep.updateSndBufferUsage(int(acked)) + + // Clear SACK information for all acked data. + s.ep.scoreboard.Delete(s.sndUna) + + // If we are not in fast recovery then update the congestion + // window based on the number of acknowledged packets. + if !s.fr.active { + s.cc.Update(originalOutstanding - s.outstanding) + } + + // It is possible for s.outstanding to drop below zero if we get + // a retransmit timeout, reset outstanding to zero but later + // get an ack that cover previously sent data. + if s.outstanding < 0 { + s.outstanding = 0 + } + + s.SetPipe() + + // If all outstanding data was acknowledged the disable the timer. + // RFC 6298 Rule 5.3 + if s.sndUna == s.sndNxt { + s.outstanding = 0 + s.resendTimer.disable() + } + } + // Now that we've popped all acknowledged data from the retransmit + // queue, retransmit if needed. + if rtx { + s.resendSegment() + } + + // Send more data now that some of the pending data has been ack'd, or + // that the window opened up, or the congestion window was inflated due + // to a duplicate ack during fast recovery. This will also re-enable + // the retransmit timer if needed. + if !s.ep.sackPermitted || s.fr.active || s.dupAckCount == 0 || seg.hasNewSACKInfo { + s.sendData() + } +} + +// sendSegment sends the specified segment. +func (s *sender) sendSegment(seg *segment) *tcpip.Error { + if !seg.xmitTime.IsZero() { + s.ep.stack.Stats().TCP.Retransmits.Increment() + if s.sndCwnd < s.sndSsthresh { + s.ep.stack.Stats().TCP.SlowStartRetransmits.Increment() + } + } + seg.xmitTime = time.Now() + return s.sendSegmentFromView(seg.data, seg.flags, seg.sequenceNumber) +} + +// sendSegmentFromView sends a new segment containing the given payload, flags +// and sequence number. +func (s *sender) sendSegmentFromView(data buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error { + s.lastSendTime = time.Now() + if seq == s.rttMeasureSeqNum { + s.rttMeasureTime = s.lastSendTime + } + + rcvNxt, rcvWnd := s.ep.rcv.getSendParams() + + // Remember the max sent ack. + s.maxSentAck = rcvNxt + + // Every time a packet containing data is sent (including a + // retransmission), if SACK is enabled then use the conservative timer + // described in RFC6675 Section 4.0, otherwise follow the standard time + // described in RFC6298 Section 5.2. + if data.Size() != 0 { + if s.ep.sackPermitted { + s.resendTimer.enable(s.rto) + } else { + if !s.resendTimer.enabled() { + s.resendTimer.enable(s.rto) + } + } + } + + return s.ep.sendRaw(data, flags, seq, rcvNxt, rcvWnd) +} diff --git a/pkg/tcpip/transport/tcp/snd_state.go b/pkg/tcpip/transport/tcp/snd_state.go new file mode 100644 index 000000000..12eff8afc --- /dev/null +++ b/pkg/tcpip/transport/tcp/snd_state.go @@ -0,0 +1,50 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// +stateify savable +type unixTime struct { + second int64 + nano int64 +} + +// saveLastSendTime is invoked by stateify. +func (s *sender) saveLastSendTime() unixTime { + return unixTime{s.lastSendTime.Unix(), s.lastSendTime.UnixNano()} +} + +// loadLastSendTime is invoked by stateify. +func (s *sender) loadLastSendTime(unix unixTime) { + s.lastSendTime = time.Unix(unix.second, unix.nano) +} + +// saveRttMeasureTime is invoked by stateify. +func (s *sender) saveRttMeasureTime() unixTime { + return unixTime{s.rttMeasureTime.Unix(), s.rttMeasureTime.UnixNano()} +} + +// loadRttMeasureTime is invoked by stateify. +func (s *sender) loadRttMeasureTime(unix unixTime) { + s.rttMeasureTime = time.Unix(unix.second, unix.nano) +} + +// afterLoad is invoked by stateify. +func (s *sender) afterLoad() { + s.resendTimer.init(&s.resendWaker) +} diff --git a/pkg/tcpip/transport/tcp/tcp_segment_list.go b/pkg/tcpip/transport/tcp/tcp_segment_list.go new file mode 100755 index 000000000..029f98a11 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_segment_list.go @@ -0,0 +1,173 @@ +package tcp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type segmentElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (segmentElementMapper) linkerFor(elem *segment) *segment { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type segmentList struct { + head *segment + tail *segment +} + +// Reset resets list l to the empty state. +func (l *segmentList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *segmentList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *segmentList) Front() *segment { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *segmentList) Back() *segment { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *segmentList) PushFront(e *segment) { + segmentElementMapper{}.linkerFor(e).SetNext(l.head) + segmentElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + segmentElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *segmentList) PushBack(e *segment) { + segmentElementMapper{}.linkerFor(e).SetNext(nil) + segmentElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *segmentList) PushBackList(m *segmentList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + segmentElementMapper{}.linkerFor(l.tail).SetNext(m.head) + segmentElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *segmentList) InsertAfter(b, e *segment) { + a := segmentElementMapper{}.linkerFor(b).Next() + segmentElementMapper{}.linkerFor(e).SetNext(a) + segmentElementMapper{}.linkerFor(e).SetPrev(b) + segmentElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + segmentElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *segmentList) InsertBefore(a, e *segment) { + b := segmentElementMapper{}.linkerFor(a).Prev() + segmentElementMapper{}.linkerFor(e).SetNext(a) + segmentElementMapper{}.linkerFor(e).SetPrev(b) + segmentElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + segmentElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *segmentList) Remove(e *segment) { + prev := segmentElementMapper{}.linkerFor(e).Prev() + next := segmentElementMapper{}.linkerFor(e).Next() + + if prev != nil { + segmentElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + segmentElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type segmentEntry struct { + next *segment + prev *segment +} + +// Next returns the entry that follows e in the list. +func (e *segmentEntry) Next() *segment { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *segmentEntry) Prev() *segment { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *segmentEntry) SetNext(elem *segment) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *segmentEntry) SetPrev(elem *segment) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go new file mode 100755 index 000000000..9049a99b2 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -0,0 +1,400 @@ +// automatically generated by stateify. + +package tcp + +import ( + "gvisor.googlesource.com/gvisor/pkg/state" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" +) + +func (x *SACKInfo) beforeSave() {} +func (x *SACKInfo) save(m state.Map) { + x.beforeSave() + m.Save("Blocks", &x.Blocks) + m.Save("NumBlocks", &x.NumBlocks) +} + +func (x *SACKInfo) afterLoad() {} +func (x *SACKInfo) load(m state.Map) { + m.Load("Blocks", &x.Blocks) + m.Load("NumBlocks", &x.NumBlocks) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var lastError string = x.saveLastError() + m.SaveValue("lastError", lastError) + var state endpointState = x.saveState() + m.SaveValue("state", state) + var hardError string = x.saveHardError() + m.SaveValue("hardError", hardError) + var acceptedChan []*endpoint = x.saveAcceptedChan() + m.SaveValue("acceptedChan", acceptedChan) + m.Save("netProto", &x.netProto) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("rcvList", &x.rcvList) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvBufUsed", &x.rcvBufUsed) + m.Save("id", &x.id) + m.Save("isRegistered", &x.isRegistered) + m.Save("v6only", &x.v6only) + m.Save("isConnectNotified", &x.isConnectNotified) + m.Save("broadcast", &x.broadcast) + m.Save("workerRunning", &x.workerRunning) + m.Save("workerCleanup", &x.workerCleanup) + m.Save("sendTSOk", &x.sendTSOk) + m.Save("recentTS", &x.recentTS) + m.Save("tsOffset", &x.tsOffset) + m.Save("shutdownFlags", &x.shutdownFlags) + m.Save("sackPermitted", &x.sackPermitted) + m.Save("sack", &x.sack) + m.Save("reusePort", &x.reusePort) + m.Save("delay", &x.delay) + m.Save("cork", &x.cork) + m.Save("scoreboard", &x.scoreboard) + m.Save("reuseAddr", &x.reuseAddr) + m.Save("slowAck", &x.slowAck) + m.Save("segmentQueue", &x.segmentQueue) + m.Save("synRcvdCount", &x.synRcvdCount) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("sndBufUsed", &x.sndBufUsed) + m.Save("sndClosed", &x.sndClosed) + m.Save("sndBufInQueue", &x.sndBufInQueue) + m.Save("sndQueue", &x.sndQueue) + m.Save("cc", &x.cc) + m.Save("packetTooBigCount", &x.packetTooBigCount) + m.Save("sndMTU", &x.sndMTU) + m.Save("keepalive", &x.keepalive) + m.Save("rcv", &x.rcv) + m.Save("snd", &x.snd) + m.Save("bindAddress", &x.bindAddress) + m.Save("connectingAddress", &x.connectingAddress) + m.Save("gso", &x.gso) +} + +func (x *endpoint) load(m state.Map) { + m.Load("netProto", &x.netProto) + m.LoadWait("waiterQueue", &x.waiterQueue) + m.LoadWait("rcvList", &x.rcvList) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvBufUsed", &x.rcvBufUsed) + m.Load("id", &x.id) + m.Load("isRegistered", &x.isRegistered) + m.Load("v6only", &x.v6only) + m.Load("isConnectNotified", &x.isConnectNotified) + m.Load("broadcast", &x.broadcast) + m.Load("workerRunning", &x.workerRunning) + m.Load("workerCleanup", &x.workerCleanup) + m.Load("sendTSOk", &x.sendTSOk) + m.Load("recentTS", &x.recentTS) + m.Load("tsOffset", &x.tsOffset) + m.Load("shutdownFlags", &x.shutdownFlags) + m.Load("sackPermitted", &x.sackPermitted) + m.Load("sack", &x.sack) + m.Load("reusePort", &x.reusePort) + m.Load("delay", &x.delay) + m.Load("cork", &x.cork) + m.Load("scoreboard", &x.scoreboard) + m.Load("reuseAddr", &x.reuseAddr) + m.Load("slowAck", &x.slowAck) + m.LoadWait("segmentQueue", &x.segmentQueue) + m.Load("synRcvdCount", &x.synRcvdCount) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("sndBufUsed", &x.sndBufUsed) + m.Load("sndClosed", &x.sndClosed) + m.Load("sndBufInQueue", &x.sndBufInQueue) + m.LoadWait("sndQueue", &x.sndQueue) + m.Load("cc", &x.cc) + m.Load("packetTooBigCount", &x.packetTooBigCount) + m.Load("sndMTU", &x.sndMTU) + m.Load("keepalive", &x.keepalive) + m.LoadWait("rcv", &x.rcv) + m.LoadWait("snd", &x.snd) + m.Load("bindAddress", &x.bindAddress) + m.Load("connectingAddress", &x.connectingAddress) + m.Load("gso", &x.gso) + m.LoadValue("lastError", new(string), func(y interface{}) { x.loadLastError(y.(string)) }) + m.LoadValue("state", new(endpointState), func(y interface{}) { x.loadState(y.(endpointState)) }) + m.LoadValue("hardError", new(string), func(y interface{}) { x.loadHardError(y.(string)) }) + m.LoadValue("acceptedChan", new([]*endpoint), func(y interface{}) { x.loadAcceptedChan(y.([]*endpoint)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *keepalive) beforeSave() {} +func (x *keepalive) save(m state.Map) { + x.beforeSave() + m.Save("enabled", &x.enabled) + m.Save("idle", &x.idle) + m.Save("interval", &x.interval) + m.Save("count", &x.count) + m.Save("unacked", &x.unacked) +} + +func (x *keepalive) afterLoad() {} +func (x *keepalive) load(m state.Map) { + m.Load("enabled", &x.enabled) + m.Load("idle", &x.idle) + m.Load("interval", &x.interval) + m.Load("count", &x.count) + m.Load("unacked", &x.unacked) +} + +func (x *receiver) beforeSave() {} +func (x *receiver) save(m state.Map) { + x.beforeSave() + m.Save("ep", &x.ep) + m.Save("rcvNxt", &x.rcvNxt) + m.Save("rcvAcc", &x.rcvAcc) + m.Save("rcvWndScale", &x.rcvWndScale) + m.Save("closed", &x.closed) + m.Save("pendingRcvdSegments", &x.pendingRcvdSegments) + m.Save("pendingBufUsed", &x.pendingBufUsed) + m.Save("pendingBufSize", &x.pendingBufSize) +} + +func (x *receiver) afterLoad() {} +func (x *receiver) load(m state.Map) { + m.Load("ep", &x.ep) + m.Load("rcvNxt", &x.rcvNxt) + m.Load("rcvAcc", &x.rcvAcc) + m.Load("rcvWndScale", &x.rcvWndScale) + m.Load("closed", &x.closed) + m.Load("pendingRcvdSegments", &x.pendingRcvdSegments) + m.Load("pendingBufUsed", &x.pendingBufUsed) + m.Load("pendingBufSize", &x.pendingBufSize) +} + +func (x *renoState) beforeSave() {} +func (x *renoState) save(m state.Map) { + x.beforeSave() + m.Save("s", &x.s) +} + +func (x *renoState) afterLoad() {} +func (x *renoState) load(m state.Map) { + m.Load("s", &x.s) +} + +func (x *SACKScoreboard) beforeSave() {} +func (x *SACKScoreboard) save(m state.Map) { + x.beforeSave() + m.Save("smss", &x.smss) + m.Save("maxSACKED", &x.maxSACKED) +} + +func (x *SACKScoreboard) afterLoad() {} +func (x *SACKScoreboard) load(m state.Map) { + m.Load("smss", &x.smss) + m.Load("maxSACKED", &x.maxSACKED) +} + +func (x *segment) beforeSave() {} +func (x *segment) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + var options []byte = x.saveOptions() + m.SaveValue("options", options) + var rcvdTime unixTime = x.saveRcvdTime() + m.SaveValue("rcvdTime", rcvdTime) + var xmitTime unixTime = x.saveXmitTime() + m.SaveValue("xmitTime", xmitTime) + m.Save("segmentEntry", &x.segmentEntry) + m.Save("refCnt", &x.refCnt) + m.Save("viewToDeliver", &x.viewToDeliver) + m.Save("sequenceNumber", &x.sequenceNumber) + m.Save("ackNumber", &x.ackNumber) + m.Save("flags", &x.flags) + m.Save("window", &x.window) + m.Save("csum", &x.csum) + m.Save("csumValid", &x.csumValid) + m.Save("parsedOptions", &x.parsedOptions) + m.Save("hasNewSACKInfo", &x.hasNewSACKInfo) +} + +func (x *segment) afterLoad() {} +func (x *segment) load(m state.Map) { + m.Load("segmentEntry", &x.segmentEntry) + m.Load("refCnt", &x.refCnt) + m.Load("viewToDeliver", &x.viewToDeliver) + m.Load("sequenceNumber", &x.sequenceNumber) + m.Load("ackNumber", &x.ackNumber) + m.Load("flags", &x.flags) + m.Load("window", &x.window) + m.Load("csum", &x.csum) + m.Load("csumValid", &x.csumValid) + m.Load("parsedOptions", &x.parsedOptions) + m.Load("hasNewSACKInfo", &x.hasNewSACKInfo) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) + m.LoadValue("options", new([]byte), func(y interface{}) { x.loadOptions(y.([]byte)) }) + m.LoadValue("rcvdTime", new(unixTime), func(y interface{}) { x.loadRcvdTime(y.(unixTime)) }) + m.LoadValue("xmitTime", new(unixTime), func(y interface{}) { x.loadXmitTime(y.(unixTime)) }) +} + +func (x *segmentQueue) beforeSave() {} +func (x *segmentQueue) save(m state.Map) { + x.beforeSave() + m.Save("list", &x.list) + m.Save("limit", &x.limit) + m.Save("used", &x.used) +} + +func (x *segmentQueue) afterLoad() {} +func (x *segmentQueue) load(m state.Map) { + m.LoadWait("list", &x.list) + m.Load("limit", &x.limit) + m.Load("used", &x.used) +} + +func (x *sender) beforeSave() {} +func (x *sender) save(m state.Map) { + x.beforeSave() + var lastSendTime unixTime = x.saveLastSendTime() + m.SaveValue("lastSendTime", lastSendTime) + var rttMeasureTime unixTime = x.saveRttMeasureTime() + m.SaveValue("rttMeasureTime", rttMeasureTime) + m.Save("ep", &x.ep) + m.Save("dupAckCount", &x.dupAckCount) + m.Save("fr", &x.fr) + m.Save("sndCwnd", &x.sndCwnd) + m.Save("sndSsthresh", &x.sndSsthresh) + m.Save("sndCAAckCount", &x.sndCAAckCount) + m.Save("outstanding", &x.outstanding) + m.Save("sndWnd", &x.sndWnd) + m.Save("sndUna", &x.sndUna) + m.Save("sndNxt", &x.sndNxt) + m.Save("sndNxtList", &x.sndNxtList) + m.Save("rttMeasureSeqNum", &x.rttMeasureSeqNum) + m.Save("closed", &x.closed) + m.Save("writeNext", &x.writeNext) + m.Save("writeList", &x.writeList) + m.Save("rtt", &x.rtt) + m.Save("rto", &x.rto) + m.Save("srttInited", &x.srttInited) + m.Save("maxPayloadSize", &x.maxPayloadSize) + m.Save("gso", &x.gso) + m.Save("sndWndScale", &x.sndWndScale) + m.Save("maxSentAck", &x.maxSentAck) + m.Save("cc", &x.cc) +} + +func (x *sender) load(m state.Map) { + m.Load("ep", &x.ep) + m.Load("dupAckCount", &x.dupAckCount) + m.Load("fr", &x.fr) + m.Load("sndCwnd", &x.sndCwnd) + m.Load("sndSsthresh", &x.sndSsthresh) + m.Load("sndCAAckCount", &x.sndCAAckCount) + m.Load("outstanding", &x.outstanding) + m.Load("sndWnd", &x.sndWnd) + m.Load("sndUna", &x.sndUna) + m.Load("sndNxt", &x.sndNxt) + m.Load("sndNxtList", &x.sndNxtList) + m.Load("rttMeasureSeqNum", &x.rttMeasureSeqNum) + m.Load("closed", &x.closed) + m.Load("writeNext", &x.writeNext) + m.Load("writeList", &x.writeList) + m.Load("rtt", &x.rtt) + m.Load("rto", &x.rto) + m.Load("srttInited", &x.srttInited) + m.Load("maxPayloadSize", &x.maxPayloadSize) + m.Load("gso", &x.gso) + m.Load("sndWndScale", &x.sndWndScale) + m.Load("maxSentAck", &x.maxSentAck) + m.Load("cc", &x.cc) + m.LoadValue("lastSendTime", new(unixTime), func(y interface{}) { x.loadLastSendTime(y.(unixTime)) }) + m.LoadValue("rttMeasureTime", new(unixTime), func(y interface{}) { x.loadRttMeasureTime(y.(unixTime)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *rtt) beforeSave() {} +func (x *rtt) save(m state.Map) { + x.beforeSave() + m.Save("srtt", &x.srtt) + m.Save("rttvar", &x.rttvar) +} + +func (x *rtt) afterLoad() {} +func (x *rtt) load(m state.Map) { + m.Load("srtt", &x.srtt) + m.Load("rttvar", &x.rttvar) +} + +func (x *fastRecovery) beforeSave() {} +func (x *fastRecovery) save(m state.Map) { + x.beforeSave() + m.Save("active", &x.active) + m.Save("first", &x.first) + m.Save("last", &x.last) + m.Save("maxCwnd", &x.maxCwnd) + m.Save("highRxt", &x.highRxt) + m.Save("rescueRxt", &x.rescueRxt) +} + +func (x *fastRecovery) afterLoad() {} +func (x *fastRecovery) load(m state.Map) { + m.Load("active", &x.active) + m.Load("first", &x.first) + m.Load("last", &x.last) + m.Load("maxCwnd", &x.maxCwnd) + m.Load("highRxt", &x.highRxt) + m.Load("rescueRxt", &x.rescueRxt) +} + +func (x *unixTime) beforeSave() {} +func (x *unixTime) save(m state.Map) { + x.beforeSave() + m.Save("second", &x.second) + m.Save("nano", &x.nano) +} + +func (x *unixTime) afterLoad() {} +func (x *unixTime) load(m state.Map) { + m.Load("second", &x.second) + m.Load("nano", &x.nano) +} + +func (x *segmentList) beforeSave() {} +func (x *segmentList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *segmentList) afterLoad() {} +func (x *segmentList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *segmentEntry) beforeSave() {} +func (x *segmentEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *segmentEntry) afterLoad() {} +func (x *segmentEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("tcp.SACKInfo", (*SACKInfo)(nil), state.Fns{Save: (*SACKInfo).save, Load: (*SACKInfo).load}) + state.Register("tcp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("tcp.keepalive", (*keepalive)(nil), state.Fns{Save: (*keepalive).save, Load: (*keepalive).load}) + state.Register("tcp.receiver", (*receiver)(nil), state.Fns{Save: (*receiver).save, Load: (*receiver).load}) + state.Register("tcp.renoState", (*renoState)(nil), state.Fns{Save: (*renoState).save, Load: (*renoState).load}) + state.Register("tcp.SACKScoreboard", (*SACKScoreboard)(nil), state.Fns{Save: (*SACKScoreboard).save, Load: (*SACKScoreboard).load}) + state.Register("tcp.segment", (*segment)(nil), state.Fns{Save: (*segment).save, Load: (*segment).load}) + state.Register("tcp.segmentQueue", (*segmentQueue)(nil), state.Fns{Save: (*segmentQueue).save, Load: (*segmentQueue).load}) + state.Register("tcp.sender", (*sender)(nil), state.Fns{Save: (*sender).save, Load: (*sender).load}) + state.Register("tcp.rtt", (*rtt)(nil), state.Fns{Save: (*rtt).save, Load: (*rtt).load}) + state.Register("tcp.fastRecovery", (*fastRecovery)(nil), state.Fns{Save: (*fastRecovery).save, Load: (*fastRecovery).load}) + state.Register("tcp.unixTime", (*unixTime)(nil), state.Fns{Save: (*unixTime).save, Load: (*unixTime).load}) + state.Register("tcp.segmentList", (*segmentList)(nil), state.Fns{Save: (*segmentList).save, Load: (*segmentList).load}) + state.Register("tcp.segmentEntry", (*segmentEntry)(nil), state.Fns{Save: (*segmentEntry).save, Load: (*segmentEntry).load}) +} diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go new file mode 100644 index 000000000..fc1c7cbd2 --- /dev/null +++ b/pkg/tcpip/transport/tcp/timer.go @@ -0,0 +1,141 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" + + "gvisor.googlesource.com/gvisor/pkg/sleep" +) + +type timerState int + +const ( + timerStateDisabled timerState = iota + timerStateEnabled + timerStateOrphaned +) + +// timer is a timer implementation that reduces the interactions with the +// runtime timer infrastructure by letting timers run (and potentially +// eventually expire) even if they are stopped. It makes it cheaper to +// disable/reenable timers at the expense of spurious wakes. This is useful for +// cases when the same timer is disabled/reenabled repeatedly with relatively +// long timeouts farther into the future. +// +// TCP retransmit timers benefit from this because they the timeouts are long +// (currently at least 200ms), and get disabled when acks are received, and +// reenabled when new pending segments are sent. +// +// It is advantageous to avoid interacting with the runtime because it acquires +// a global mutex and performs O(log n) operations, where n is the global number +// of timers, whenever a timer is enabled or disabled, and may make a syscall. +// +// This struct is thread-compatible. +type timer struct { + // state is the current state of the timer, it can be one of the + // following values: + // disabled - the timer is disabled. + // orphaned - the timer is disabled, but the runtime timer is + // enabled, which means that it will evetually cause a + // spurious wake (unless it gets enabled again before + // then). + // enabled - the timer is enabled, but the runtime timer may be set + // to an earlier expiration time due to a previous + // orphaned state. + state timerState + + // target is the expiration time of the current timer. It is only + // meaningful in the enabled state. + target time.Time + + // runtimeTarget is the expiration time of the runtime timer. It is + // meaningful in the enabled and orphaned states. + runtimeTarget time.Time + + // timer is the runtime timer used to wait on. + timer *time.Timer +} + +// init initializes the timer. Once it expires, it the given waker will be +// asserted. +func (t *timer) init(w *sleep.Waker) { + t.state = timerStateDisabled + + // Initialize a runtime timer that will assert the waker, then + // immediately stop it. + t.timer = time.AfterFunc(time.Hour, func() { + w.Assert() + }) + t.timer.Stop() +} + +// cleanup frees all resources associated with the timer. +func (t *timer) cleanup() { + t.timer.Stop() +} + +// checkExpiration checks if the given timer has actually expired, it should be +// called whenever a sleeper wakes up due to the waker being asserted, and is +// used to check if it's a supurious wake (due to a previously orphaned timer) +// or a legitimate one. +func (t *timer) checkExpiration() bool { + // Transition to fully disabled state if we're just consuming an + // orphaned timer. + if t.state == timerStateOrphaned { + t.state = timerStateDisabled + return false + } + + // The timer is enabled, but it may have expired early. Check if that's + // the case, and if so, reset the runtime timer to the correct time. + now := time.Now() + if now.Before(t.target) { + t.runtimeTarget = t.target + t.timer.Reset(t.target.Sub(now)) + return false + } + + // The timer has actually expired, disable it for now and inform the + // caller. + t.state = timerStateDisabled + return true +} + +// disable disables the timer, leaving it in an orphaned state if it wasn't +// already disabled. +func (t *timer) disable() { + if t.state != timerStateDisabled { + t.state = timerStateOrphaned + } +} + +// enabled returns true if the timer is currently enabled, false otherwise. +func (t *timer) enabled() bool { + return t.state == timerStateEnabled +} + +// enable enables the timer, programming the runtime timer if necessary. +func (t *timer) enable(d time.Duration) { + t.target = time.Now().Add(d) + + // Check if we need to set the runtime timer. + if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) { + t.runtimeTarget = t.target + t.timer.Reset(d) + } + + t.state = timerStateEnabled +} diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go new file mode 100644 index 000000000..3d52a4f31 --- /dev/null +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -0,0 +1,1002 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp + +import ( + "math" + "sync" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// +stateify savable +type udpPacket struct { + udpPacketEntry + senderAddress tcpip.FullAddress + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + timestamp int64 + // views is used as buffer for data when its length is large + // enough to store a VectorisedView. + views [8]buffer.View `state:"nosave"` +} + +type endpointState int + +const ( + stateInitial endpointState = iota + stateBound + stateConnected + stateClosed +) + +// endpoint represents a UDP endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. +// +// It implements tcpip.Endpoint. +// +// +stateify savable +type endpoint struct { + // The following fields are initialized at creation time and do not + // change throughout the lifetime of the endpoint. + stack *stack.Stack `state:"manual"` + netProto tcpip.NetworkProtocolNumber + waiterQueue *waiter.Queue + + // The following fields are used to manage the receive queue, and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList udpPacketList + rcvBufSizeMax int `state:".(int)"` + rcvBufSize int + rcvClosed bool + + // The following fields are protected by the mu mutex. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + id stack.TransportEndpointID + state endpointState + bindNICID tcpip.NICID + regNICID tcpip.NICID + route stack.Route `state:"manual"` + dstPort uint16 + v6only bool + multicastTTL uint8 + multicastAddr tcpip.Address + multicastNICID tcpip.NICID + multicastLoop bool + reusePort bool + broadcast bool + + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + + // multicastMemberships that need to be remvoed when the endpoint is + // closed. Protected by the mu mutex. + multicastMemberships []multicastMembership + + // effectiveNetProtos contains the network protocols actually in use. In + // most cases it will only contain "netProto", but in cases like IPv6 + // endpoints with v6only set to false, this could include multiple + // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., + // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped + // address). + effectiveNetProtos []tcpip.NetworkProtocolNumber +} + +// +stateify savable +type multicastMembership struct { + nicID tcpip.NICID + multicastAddr tcpip.Address +} + +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { + return &endpoint{ + stack: stack, + netProto: netProto, + waiterQueue: waiterQueue, + // RFC 1075 section 5.4 recommends a TTL of 1 for membership + // requests. + // + // RFC 5135 4.2.1 appears to assume that IGMP messages have a + // TTL of 1. + // + // RFC 5135 Appendix A defines TTL=1: A multicast source that + // wants its traffic to not traverse a router (e.g., leave a + // home network) may find it useful to send traffic with IP + // TTL=1. + // + // Linux defaults to TTL=1. + multicastTTL: 1, + multicastLoop: true, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + } +} + +// Close puts the endpoint in a closed state and frees all resources +// associated with it. +func (e *endpoint) Close() { + e.mu.Lock() + e.shutdownFlags = tcpip.ShutdownRead | tcpip.ShutdownWrite + + switch e.state { + case stateBound, stateConnected: + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + } + + for _, mem := range e.multicastMemberships { + e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr) + } + e.multicastMemberships = nil + + // Close the receive list and drain it. + e.rcvMu.Lock() + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } + e.rcvMu.Unlock() + + e.route.Release() + + // Update the state. + e.state = stateClosed + + e.mu.Unlock() + + e.waiterQueue.Notify(waiter.EventHUp | waiter.EventErr | waiter.EventIn | waiter.EventOut) +} + +// Read reads data from the endpoint. This method does not block if +// there is no data pending. +func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { + e.rcvMu.Lock() + + if e.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if e.rcvClosed { + err = tcpip.ErrClosedForReceive + } + e.rcvMu.Unlock() + return buffer.View{}, tcpip.ControlMessages{}, err + } + + p := e.rcvList.Front() + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() + + e.rcvMu.Unlock() + + if addr != nil { + *addr = p.senderAddress + } + + return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil +} + +// prepareForWrite prepares the endpoint for sending data. In particular, it +// binds it if it's still in the initial state. To do so, it must first +// reacquire the mutex in exclusive mode. +// +// Returns true for retry if preparation should be retried. +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { + switch e.state { + case stateInitial: + case stateConnected: + return false, nil + + case stateBound: + if to == nil { + return false, tcpip.ErrDestinationRequired + } + return false, nil + default: + return false, tcpip.ErrInvalidEndpointState + } + + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // The state changed when we released the shared locked and re-acquired + // it in exclusive mode. Try again. + if e.state != stateInitial { + return true, nil + } + + // The state is still 'initial', so try to bind the endpoint. + if err := e.bindLocked(tcpip.FullAddress{}); err != nil { + return false, err + } + + return true, nil +} + +// connectRoute establishes a route to the specified interface or the +// configured multicast interface if no interface is specified and the +// specified address is a multicast address. +func (e *endpoint) connectRoute(nicid tcpip.NICID, addr tcpip.FullAddress) (stack.Route, tcpip.NICID, tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto, err := e.checkV4Mapped(&addr, false) + if err != nil { + return stack.Route{}, 0, 0, err + } + + localAddr := e.id.LocalAddress + if header.IsV4MulticastAddress(addr.Addr) || header.IsV6MulticastAddress(addr.Addr) { + if nicid == 0 { + nicid = e.multicastNICID + } + if localAddr == "" { + localAddr = e.multicastAddr + } + } + + // Find a route to the desired destination. + r, err := e.stack.FindRoute(nicid, localAddr, addr.Addr, netProto, e.multicastLoop) + if err != nil { + return stack.Route{}, 0, 0, err + } + return r, nicid, netProto, nil +} + +// Write writes data to the endpoint's peer. This method does not block +// if the data cannot be written. +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) + if opts.More { + return 0, nil, tcpip.ErrInvalidOptionValue + } + + if p.Size() > math.MaxUint16 { + // Payload can't possibly fit in a packet. + return 0, nil, tcpip.ErrMessageTooLong + } + + to := opts.To + + e.mu.RLock() + defer e.mu.RUnlock() + + // If we've shutdown with SHUT_WR we are in an invalid state for sending. + if e.shutdownFlags&tcpip.ShutdownWrite != 0 { + return 0, nil, tcpip.ErrClosedForSend + } + + // Prepare for write. + for { + retry, err := e.prepareForWrite(to) + if err != nil { + return 0, nil, err + } + + if !retry { + break + } + } + + var route *stack.Route + var dstPort uint16 + if to == nil { + route = &e.route + dstPort = e.dstPort + + if route.IsResolutionRequired() { + // Promote lock to exclusive if using a shared route, given that it may need to + // change in Route.Resolve() call below. + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // Recheck state after lock was re-acquired. + if e.state != stateConnected { + return 0, nil, tcpip.ErrInvalidEndpointState + } + } + } else { + // Reject destination address if it goes through a different + // NIC than the endpoint was bound to. + nicid := to.NIC + if e.bindNICID != 0 { + if nicid != 0 && nicid != e.bindNICID { + return 0, nil, tcpip.ErrNoRoute + } + + nicid = e.bindNICID + } + + if to.Addr == header.IPv4Broadcast && !e.broadcast { + return 0, nil, tcpip.ErrBroadcastDisabled + } + + r, _, _, err := e.connectRoute(nicid, *to) + if err != nil { + return 0, nil, err + } + defer r.Release() + + route = &r + dstPort = to.Port + } + + if route.IsResolutionRequired() { + if ch, err := route.Resolve(nil); err != nil { + if err == tcpip.ErrWouldBlock { + return 0, ch, tcpip.ErrNoLinkAddress + } + return 0, nil, err + } + } + + v, err := p.Get(p.Size()) + if err != nil { + return 0, nil, err + } + + ttl := route.DefaultTTL() + if header.IsV4MulticastAddress(route.RemoteAddress) || header.IsV6MulticastAddress(route.RemoteAddress) { + ttl = e.multicastTTL + } + + if err := sendUDP(route, buffer.View(v).ToVectorisedView(), e.id.LocalPort, dstPort, ttl); err != nil { + return 0, nil, err + } + return uintptr(len(v)), nil, nil +} + +// Peek only returns data from a single datagram, so do nothing here. +func (e *endpoint) Peek([][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Error) { + return 0, tcpip.ControlMessages{}, nil +} + +// SetSockOpt sets a socket option. Currently not supported. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + // We only allow this to be set when we're in the initial state. + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.v6only = v != 0 + + case tcpip.MulticastTTLOption: + e.mu.Lock() + e.multicastTTL = uint8(v) + e.mu.Unlock() + + case tcpip.MulticastInterfaceOption: + e.mu.Lock() + defer e.mu.Unlock() + + fa := tcpip.FullAddress{Addr: v.InterfaceAddr} + netProto, err := e.checkV4Mapped(&fa, false) + if err != nil { + return err + } + nic := v.NIC + addr := fa.Addr + + if nic == 0 && addr == "" { + e.multicastAddr = "" + e.multicastNICID = 0 + break + } + + if nic != 0 { + if !e.stack.CheckNIC(nic) { + return tcpip.ErrBadLocalAddress + } + } else { + nic = e.stack.CheckLocalAddress(0, netProto, addr) + if nic == 0 { + return tcpip.ErrBadLocalAddress + } + } + + if e.bindNICID != 0 && e.bindNICID != nic { + return tcpip.ErrInvalidEndpointState + } + + e.multicastNICID = nic + e.multicastAddr = addr + + case tcpip.AddMembershipOption: + if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { + return tcpip.ErrInvalidOptionValue + } + + nicID := v.NIC + if v.InterfaceAddr == header.IPv4Any { + if nicID == 0 { + r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) + if err == nil { + nicID = r.NICID() + r.Release() + } + } + } else { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return tcpip.ErrUnknownDevice + } + + memToInsert := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} + + e.mu.Lock() + defer e.mu.Unlock() + + for _, mem := range e.multicastMemberships { + if mem == memToInsert { + return tcpip.ErrPortInUse + } + } + + if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.multicastMemberships = append(e.multicastMemberships, memToInsert) + + case tcpip.RemoveMembershipOption: + if !header.IsV4MulticastAddress(v.MulticastAddr) && !header.IsV6MulticastAddress(v.MulticastAddr) { + return tcpip.ErrInvalidOptionValue + } + + nicID := v.NIC + if v.InterfaceAddr == header.IPv4Any { + if nicID == 0 { + r, err := e.stack.FindRoute(0, "", v.MulticastAddr, header.IPv4ProtocolNumber, false /* multicastLoop */) + if err == nil { + nicID = r.NICID() + r.Release() + } + } + } else { + nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr) + } + if nicID == 0 { + return tcpip.ErrUnknownDevice + } + + memToRemove := multicastMembership{nicID: nicID, multicastAddr: v.MulticastAddr} + memToRemoveIndex := -1 + + e.mu.Lock() + defer e.mu.Unlock() + + for i, mem := range e.multicastMemberships { + if mem == memToRemove { + memToRemoveIndex = i + break + } + } + if memToRemoveIndex == -1 { + return tcpip.ErrBadLocalAddress + } + + if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil { + return err + } + + e.multicastMemberships[memToRemoveIndex] = e.multicastMemberships[len(e.multicastMemberships)-1] + e.multicastMemberships = e.multicastMemberships[:len(e.multicastMemberships)-1] + + case tcpip.MulticastLoopOption: + e.mu.Lock() + e.multicastLoop = bool(v) + e.mu.Unlock() + + case tcpip.ReusePortOption: + e.mu.Lock() + e.reusePort = v != 0 + e.mu.Unlock() + + case tcpip.BroadcastOption: + e.mu.Lock() + e.broadcast = v != 0 + e.mu.Unlock() + + return nil + } + return nil +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.SendBufferSizeOption: + e.mu.Lock() + *o = tcpip.SendBufferSizeOption(e.sndBufSize) + e.mu.Unlock() + return nil + + case *tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) + e.rcvMu.Unlock() + return nil + + case *tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.ReceiveQueueSizeOption: + e.rcvMu.Lock() + if e.rcvList.Empty() { + *o = 0 + } else { + p := e.rcvList.Front() + *o = tcpip.ReceiveQueueSizeOption(p.data.Size()) + } + e.rcvMu.Unlock() + return nil + + case *tcpip.MulticastTTLOption: + e.mu.Lock() + *o = tcpip.MulticastTTLOption(e.multicastTTL) + e.mu.Unlock() + return nil + + case *tcpip.MulticastInterfaceOption: + e.mu.Lock() + *o = tcpip.MulticastInterfaceOption{ + e.multicastNICID, + e.multicastAddr, + } + e.mu.Unlock() + return nil + + case *tcpip.MulticastLoopOption: + e.mu.RLock() + v := e.multicastLoop + e.mu.RUnlock() + + *o = tcpip.MulticastLoopOption(v) + return nil + + case *tcpip.ReusePortOption: + e.mu.RLock() + v := e.reusePort + e.mu.RUnlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.KeepaliveEnabledOption: + *o = 0 + return nil + + case *tcpip.BroadcastOption: + e.mu.RLock() + v := e.broadcast + e.mu.RUnlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// sendUDP sends a UDP segment via the provided network endpoint and under the +// provided identity. +func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort uint16, ttl uint8) *tcpip.Error { + // Allocate a buffer for the UDP header. + hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) + + // Initialize the header. + udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + + length := uint16(hdr.UsedLength() + data.Size()) + udp.Encode(&header.UDPFields{ + SrcPort: localPort, + DstPort: remotePort, + Length: length, + }) + + // Only calculate the checksum if offloading isn't supported. + if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { + xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) + for _, v := range data.Views() { + xsum = header.Checksum(v, xsum) + } + udp.SetChecksum(^udp.CalculateChecksum(xsum)) + } + + // Track count of packets sent. + r.Stats().UDP.PacketsSent.Increment() + + return r.WritePacket(nil /* gso */, hdr, data, ProtocolNumber, ttl) +} + +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.netProto + if header.IsV4MappedAddress(addr.Addr) { + // Fail if using a v4 mapped address on a v6only endpoint. + if e.v6only { + return 0, tcpip.ErrNoRoute + } + + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == "\x00\x00\x00\x00" { + addr.Addr = "" + } + + // Fail if we are bound to an IPv6 address. + if !allowMismatch && len(e.id.LocalAddress) == 16 { + return 0, tcpip.ErrNetworkUnreachable + } + } + + // Fail if we're bound to an address length different from the one we're + // checking. + if l := len(e.id.LocalAddress); l != 0 && l != len(addr.Addr) { + return 0, tcpip.ErrInvalidEndpointState + } + + return netProto, nil +} + +// Connect connects the endpoint to its peer. Specifying a NIC is optional. +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + if addr.Port == 0 { + // We don't support connecting to port zero. + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + nicid := addr.NIC + var localPort uint16 + switch e.state { + case stateInitial: + case stateBound, stateConnected: + localPort = e.id.LocalPort + if e.bindNICID == 0 { + break + } + + if nicid != 0 && nicid != e.bindNICID { + return tcpip.ErrInvalidEndpointState + } + + nicid = e.bindNICID + default: + return tcpip.ErrInvalidEndpointState + } + + r, nicid, netProto, err := e.connectRoute(nicid, addr) + if err != nil { + return err + } + defer r.Release() + + id := stack.TransportEndpointID{ + LocalAddress: r.LocalAddress, + LocalPort: localPort, + RemotePort: addr.Port, + RemoteAddress: r.RemoteAddress, + } + + // Even if we're connected, this endpoint can still be used to send + // packets on a different network protocol, so we register both even if + // v6only is set to false and this is an ipv6 endpoint. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv4ProtocolNumber, + header.IPv6ProtocolNumber, + } + } + + id, err = e.registerWithStack(nicid, netProtos, id) + if err != nil { + return err + } + + // Remove the old registration. + if e.id.LocalPort != 0 { + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) + } + + e.id = id + e.route = r.Clone() + e.dstPort = addr.Port + e.regNICID = nicid + e.effectiveNetProtos = netProtos + + e.state = stateConnected + + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// ConnectEndpoint is not supported. +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + +// Shutdown closes the read and/or write end of the endpoint connection +// to its peer. +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // A socket in the bound state can still receive multicast messages, + // so we need to notify waiters on shutdown. + if e.state != stateBound && e.state != stateConnected { + return tcpip.ErrNotConnected + } + + e.shutdownFlags |= flags + + if flags&tcpip.ShutdownRead != 0 { + e.rcvMu.Lock() + wasClosed := e.rcvClosed + e.rcvClosed = true + e.rcvMu.Unlock() + + if !wasClosed { + e.waiterQueue.Notify(waiter.EventIn) + } + } + + return nil +} + +// Listen is not supported by UDP, it just fails. +func (*endpoint) Listen(int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept is not supported by UDP, it just fails. +func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { + if e.id.LocalPort == 0 { + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort, e.reusePort) + if err != nil { + return id, err + } + id.LocalPort = port + } + + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e, e.reusePort) + if err != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, id.LocalAddress, id.LocalPort) + } + return id, err +} + +func (e *endpoint) bindLocked(addr tcpip.FullAddress) *tcpip.Error { + // Don't allow binding once endpoint is not in the initial state + // anymore. + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + netProto, err := e.checkV4Mapped(&addr, true) + if err != nil { + return err + } + + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } + } + + nicid := addr.NIC + if len(addr.Addr) != 0 { + // A local address was specified, verify that it's valid. + nicid = e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nicid == 0 { + return tcpip.ErrBadLocalAddress + } + } + + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: addr.Addr, + } + id, err = e.registerWithStack(nicid, netProtos, id) + if err != nil { + return err + } + + e.id = id + e.regNICID = nicid + e.effectiveNetProtos = netProtos + + // Mark endpoint as bound. + e.state = stateBound + + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// Bind binds the endpoint to a specific local address and port. +// Specifying a NIC is optional. +func (e *endpoint) Bind(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + err := e.bindLocked(addr) + if err != nil { + return err + } + + // Save the effective NICID generated by bindLocked. + e.bindNICID = e.regNICID + + return nil +} + +// GetLocalAddress returns the address to which the endpoint is bound. +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + return tcpip.FullAddress{ + NIC: e.regNICID, + Addr: e.id.LocalAddress, + Port: e.id.LocalPort, + }, nil +} + +// GetRemoteAddress returns the address to which the endpoint is connected. +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state != stateConnected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + NIC: e.regNICID, + Addr: e.id.RemoteAddress, + Port: e.id.RemotePort, + }, nil +} + +// Readiness returns the current readiness of the endpoint. For example, if +// waiter.EventIn is set, the endpoint is immediately readable. +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine if the endpoint is readable if requested. + if (mask & waiter.EventIn) != 0 { + e.rcvMu.Lock() + if !e.rcvList.Empty() || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvMu.Unlock() + } + + return result +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv buffer.VectorisedView) { + // Get the header then trim it from the view. + hdr := header.UDP(vv.First()) + if int(hdr.Length()) > vv.Size() { + // Malformed packet. + e.stack.Stats().UDP.MalformedPacketsReceived.Increment() + return + } + + vv.TrimFront(header.UDPMinimumSize) + + e.rcvMu.Lock() + e.stack.Stats().UDP.PacketsReceived.Increment() + + // Drop the packet if our buffer is currently full. + if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + e.stack.Stats().UDP.ReceiveBufferErrors.Increment() + e.rcvMu.Unlock() + return + } + + wasEmpty := e.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + pkt := &udpPacket{ + senderAddress: tcpip.FullAddress{ + NIC: r.NICID(), + Addr: id.RemoteAddress, + Port: hdr.SourcePort(), + }, + } + pkt.data = vv.Clone(pkt.views[:]) + e.rcvList.PushBack(pkt) + e.rcvBufSize += vv.Size() + + pkt.timestamp = e.stack.NowNanoseconds() + + e.rcvMu.Unlock() + + // Notify any waiters that there's data to be read now. + if wasEmpty { + e.waiterQueue.Notify(waiter.EventIn) + } +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go new file mode 100644 index 000000000..74e8e9fd5 --- /dev/null +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -0,0 +1,112 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// saveData saves udpPacket.data field. +func (u *udpPacket) saveData() buffer.VectorisedView { + // We cannot save u.data directly as u.data.views may alias to u.views, + // which is not allowed by state framework (in-struct pointer). + return u.data.Clone(nil) +} + +// loadData loads udpPacket.data field. +func (u *udpPacket) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing u.views for data.views. + u.data = data +} + +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after savercvBufSizeMax(), which would have + // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + e.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (e *endpoint) saveRcvBufSizeMax() int { + max := e.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + e.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + e.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (e *endpoint) loadRcvBufSizeMax(max int) { + e.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (e *endpoint) afterLoad() { + e.stack = stack.StackFromEnv + + for _, m := range e.multicastMemberships { + if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil { + panic(err) + } + } + + if e.state != stateBound && e.state != stateConnected { + return + } + + netProto := e.effectiveNetProtos[0] + // Connect() and bindLocked() both assert + // + // netProto == header.IPv6ProtocolNumber + // + // before creating a multi-entry effectiveNetProtos. + if len(e.effectiveNetProtos) > 1 { + netProto = header.IPv6ProtocolNumber + } + + var err *tcpip.Error + if e.state == stateConnected { + e.route, err = e.stack.FindRoute(e.regNICID, e.id.LocalAddress, e.id.RemoteAddress, netProto, e.multicastLoop) + if err != nil { + panic(*err) + } + + e.id.LocalAddress = e.route.LocalAddress + } else if len(e.id.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + // Our saved state had a port, but we don't actually have a + // reservation. We need to remove the port from our state, but still + // pass it to the reservation machinery. + id := e.id + e.id.LocalPort = 0 + e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, id) + if err != nil { + panic(*err) + } +} diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go new file mode 100644 index 000000000..25bdd2929 --- /dev/null +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -0,0 +1,96 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package udp + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// Forwarder is a session request forwarder, which allows clients to decide +// what to do with a session request, for example: ignore it, or process it. +// +// The canonical way of using it is to pass the Forwarder.HandlePacket function +// to stack.SetTransportProtocolHandler. +type Forwarder struct { + handler func(*ForwarderRequest) + + stack *stack.Stack +} + +// NewForwarder allocates and initializes a new forwarder. +func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { + return &Forwarder{ + stack: s, + handler: handler, + } +} + +// HandlePacket handles all packets. +// +// This function is expected to be passed as an argument to the +// stack.SetTransportProtocolHandler function. +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool { + f.handler(&ForwarderRequest{ + stack: f.stack, + route: r, + id: id, + vv: vv, + }) + + return true +} + +// ForwarderRequest represents a session request received by the forwarder and +// passed to the client. Clients may optionally create an endpoint to represent +// it via CreateEndpoint. +type ForwarderRequest struct { + stack *stack.Stack + route *stack.Route + id stack.TransportEndpointID + vv buffer.VectorisedView +} + +// ID returns the 4-tuple (src address, src port, dst address, dst port) that +// represents the session request. +func (r *ForwarderRequest) ID() stack.TransportEndpointID { + return r.id +} + +// CreateEndpoint creates a connected UDP endpoint for the session request. +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + ep := newEndpoint(r.stack, r.route.NetProto, queue) + if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.reusePort); err != nil { + ep.Close() + return nil, err + } + + ep.id = r.id + ep.route = r.route.Clone() + ep.dstPort = r.id.RemotePort + ep.regNICID = r.route.NICID() + + ep.state = stateConnected + + ep.rcvMu.Lock() + ep.rcvReady = true + ep.rcvMu.Unlock() + + ep.HandlePacket(r.route, r.id, r.vv) + + return ep, nil +} diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go new file mode 100644 index 000000000..3d31dfbf1 --- /dev/null +++ b/pkg/tcpip/transport/udp/protocol.go @@ -0,0 +1,90 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package udp contains the implementation of the UDP transport protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing udp.ProtocolName (or "udp") as one of the +// transport protocols when calling stack.New(). Then endpoints can be created +// by passing udp.ProtocolNumber as the transport protocol number when calling +// Stack.NewEndpoint(). +package udp + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/raw" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + // ProtocolName is the string representation of the udp protocol name. + ProtocolName = "udp" + + // ProtocolNumber is the udp protocol number. + ProtocolNumber = header.UDPProtocolNumber +) + +type protocol struct{} + +// Number returns the udp protocol number. +func (*protocol) Number() tcpip.TransportProtocolNumber { + return ProtocolNumber +} + +// NewEndpoint creates a new udp endpoint. +func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, waiterQueue), nil +} + +// NewRawEndpoint creates a new raw UDP endpoint. It implements +// stack.TransportProtocol.NewRawEndpoint. +func (p *protocol) NewRawEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return raw.NewEndpoint(stack, netProto, header.UDPProtocolNumber, waiterQueue) +} + +// MinimumPacketSize returns the minimum valid udp packet size. +func (*protocol) MinimumPacketSize() int { + return header.UDPMinimumSize +} + +// ParsePorts returns the source and destination ports stored in the given udp +// packet. +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + h := header.UDP(v) + return h.SourcePort(), h.DestinationPort(), nil +} + +// HandleUnknownDestinationPacket handles packets targeted at this protocol but +// that don't match any existing endpoint. +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { + return true +} + +// SetOption implements TransportProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements TransportProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func init() { + stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { + return &protocol{} + }) +} diff --git a/pkg/tcpip/transport/udp/udp_packet_list.go b/pkg/tcpip/transport/udp/udp_packet_list.go new file mode 100755 index 000000000..673a9373b --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_packet_list.go @@ -0,0 +1,173 @@ +package udp + +// ElementMapper provides an identity mapping by default. +// +// This can be replaced to provide a struct that maps elements to linker +// objects, if they are not the same. An ElementMapper is not typically +// required if: Linker is left as is, Element is left as is, or Linker and +// Element are the same type. +type udpPacketElementMapper struct{} + +// linkerFor maps an Element to a Linker. +// +// This default implementation should be inlined. +// +//go:nosplit +func (udpPacketElementMapper) linkerFor(elem *udpPacket) *udpPacket { return elem } + +// List is an intrusive list. Entries can be added to or removed from the list +// in O(1) time and with no additional memory allocations. +// +// The zero value for List is an empty list ready to use. +// +// To iterate over a list (where l is a List): +// for e := l.Front(); e != nil; e = e.Next() { +// // do something with e. +// } +// +// +stateify savable +type udpPacketList struct { + head *udpPacket + tail *udpPacket +} + +// Reset resets list l to the empty state. +func (l *udpPacketList) Reset() { + l.head = nil + l.tail = nil +} + +// Empty returns true iff the list is empty. +func (l *udpPacketList) Empty() bool { + return l.head == nil +} + +// Front returns the first element of list l or nil. +func (l *udpPacketList) Front() *udpPacket { + return l.head +} + +// Back returns the last element of list l or nil. +func (l *udpPacketList) Back() *udpPacket { + return l.tail +} + +// PushFront inserts the element e at the front of list l. +func (l *udpPacketList) PushFront(e *udpPacket) { + udpPacketElementMapper{}.linkerFor(e).SetNext(l.head) + udpPacketElementMapper{}.linkerFor(e).SetPrev(nil) + + if l.head != nil { + udpPacketElementMapper{}.linkerFor(l.head).SetPrev(e) + } else { + l.tail = e + } + + l.head = e +} + +// PushBack inserts the element e at the back of list l. +func (l *udpPacketList) PushBack(e *udpPacket) { + udpPacketElementMapper{}.linkerFor(e).SetNext(nil) + udpPacketElementMapper{}.linkerFor(e).SetPrev(l.tail) + + if l.tail != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(e) + } else { + l.head = e + } + + l.tail = e +} + +// PushBackList inserts list m at the end of list l, emptying m. +func (l *udpPacketList) PushBackList(m *udpPacketList) { + if l.head == nil { + l.head = m.head + l.tail = m.tail + } else if m.head != nil { + udpPacketElementMapper{}.linkerFor(l.tail).SetNext(m.head) + udpPacketElementMapper{}.linkerFor(m.head).SetPrev(l.tail) + + l.tail = m.tail + } + + m.head = nil + m.tail = nil +} + +// InsertAfter inserts e after b. +func (l *udpPacketList) InsertAfter(b, e *udpPacket) { + a := udpPacketElementMapper{}.linkerFor(b).Next() + udpPacketElementMapper{}.linkerFor(e).SetNext(a) + udpPacketElementMapper{}.linkerFor(e).SetPrev(b) + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + + if a != nil { + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + } else { + l.tail = e + } +} + +// InsertBefore inserts e before a. +func (l *udpPacketList) InsertBefore(a, e *udpPacket) { + b := udpPacketElementMapper{}.linkerFor(a).Prev() + udpPacketElementMapper{}.linkerFor(e).SetNext(a) + udpPacketElementMapper{}.linkerFor(e).SetPrev(b) + udpPacketElementMapper{}.linkerFor(a).SetPrev(e) + + if b != nil { + udpPacketElementMapper{}.linkerFor(b).SetNext(e) + } else { + l.head = e + } +} + +// Remove removes e from l. +func (l *udpPacketList) Remove(e *udpPacket) { + prev := udpPacketElementMapper{}.linkerFor(e).Prev() + next := udpPacketElementMapper{}.linkerFor(e).Next() + + if prev != nil { + udpPacketElementMapper{}.linkerFor(prev).SetNext(next) + } else { + l.head = next + } + + if next != nil { + udpPacketElementMapper{}.linkerFor(next).SetPrev(prev) + } else { + l.tail = prev + } +} + +// Entry is a default implementation of Linker. Users can add anonymous fields +// of this type to their structs to make them automatically implement the +// methods needed by List. +// +// +stateify savable +type udpPacketEntry struct { + next *udpPacket + prev *udpPacket +} + +// Next returns the entry that follows e in the list. +func (e *udpPacketEntry) Next() *udpPacket { + return e.next +} + +// Prev returns the entry that precedes e in the list. +func (e *udpPacketEntry) Prev() *udpPacket { + return e.prev +} + +// SetNext assigns 'entry' as the entry that follows e in the list. +func (e *udpPacketEntry) SetNext(elem *udpPacket) { + e.next = elem +} + +// SetPrev assigns 'entry' as the entry that precedes e in the list. +func (e *udpPacketEntry) SetPrev(elem *udpPacket) { + e.prev = elem +} diff --git a/pkg/tcpip/transport/udp/udp_state_autogen.go b/pkg/tcpip/transport/udp/udp_state_autogen.go new file mode 100755 index 000000000..711e2feeb --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_state_autogen.go @@ -0,0 +1,128 @@ +// automatically generated by stateify. + +package udp + +import ( + "gvisor.googlesource.com/gvisor/pkg/state" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" +) + +func (x *udpPacket) beforeSave() {} +func (x *udpPacket) save(m state.Map) { + x.beforeSave() + var data buffer.VectorisedView = x.saveData() + m.SaveValue("data", data) + m.Save("udpPacketEntry", &x.udpPacketEntry) + m.Save("senderAddress", &x.senderAddress) + m.Save("timestamp", &x.timestamp) +} + +func (x *udpPacket) afterLoad() {} +func (x *udpPacket) load(m state.Map) { + m.Load("udpPacketEntry", &x.udpPacketEntry) + m.Load("senderAddress", &x.senderAddress) + m.Load("timestamp", &x.timestamp) + m.LoadValue("data", new(buffer.VectorisedView), func(y interface{}) { x.loadData(y.(buffer.VectorisedView)) }) +} + +func (x *endpoint) save(m state.Map) { + x.beforeSave() + var rcvBufSizeMax int = x.saveRcvBufSizeMax() + m.SaveValue("rcvBufSizeMax", rcvBufSizeMax) + m.Save("netProto", &x.netProto) + m.Save("waiterQueue", &x.waiterQueue) + m.Save("rcvReady", &x.rcvReady) + m.Save("rcvList", &x.rcvList) + m.Save("rcvBufSize", &x.rcvBufSize) + m.Save("rcvClosed", &x.rcvClosed) + m.Save("sndBufSize", &x.sndBufSize) + m.Save("id", &x.id) + m.Save("state", &x.state) + m.Save("bindNICID", &x.bindNICID) + m.Save("regNICID", &x.regNICID) + m.Save("dstPort", &x.dstPort) + m.Save("v6only", &x.v6only) + m.Save("multicastTTL", &x.multicastTTL) + m.Save("multicastAddr", &x.multicastAddr) + m.Save("multicastNICID", &x.multicastNICID) + m.Save("multicastLoop", &x.multicastLoop) + m.Save("reusePort", &x.reusePort) + m.Save("broadcast", &x.broadcast) + m.Save("shutdownFlags", &x.shutdownFlags) + m.Save("multicastMemberships", &x.multicastMemberships) + m.Save("effectiveNetProtos", &x.effectiveNetProtos) +} + +func (x *endpoint) load(m state.Map) { + m.Load("netProto", &x.netProto) + m.Load("waiterQueue", &x.waiterQueue) + m.Load("rcvReady", &x.rcvReady) + m.Load("rcvList", &x.rcvList) + m.Load("rcvBufSize", &x.rcvBufSize) + m.Load("rcvClosed", &x.rcvClosed) + m.Load("sndBufSize", &x.sndBufSize) + m.Load("id", &x.id) + m.Load("state", &x.state) + m.Load("bindNICID", &x.bindNICID) + m.Load("regNICID", &x.regNICID) + m.Load("dstPort", &x.dstPort) + m.Load("v6only", &x.v6only) + m.Load("multicastTTL", &x.multicastTTL) + m.Load("multicastAddr", &x.multicastAddr) + m.Load("multicastNICID", &x.multicastNICID) + m.Load("multicastLoop", &x.multicastLoop) + m.Load("reusePort", &x.reusePort) + m.Load("broadcast", &x.broadcast) + m.Load("shutdownFlags", &x.shutdownFlags) + m.Load("multicastMemberships", &x.multicastMemberships) + m.Load("effectiveNetProtos", &x.effectiveNetProtos) + m.LoadValue("rcvBufSizeMax", new(int), func(y interface{}) { x.loadRcvBufSizeMax(y.(int)) }) + m.AfterLoad(x.afterLoad) +} + +func (x *multicastMembership) beforeSave() {} +func (x *multicastMembership) save(m state.Map) { + x.beforeSave() + m.Save("nicID", &x.nicID) + m.Save("multicastAddr", &x.multicastAddr) +} + +func (x *multicastMembership) afterLoad() {} +func (x *multicastMembership) load(m state.Map) { + m.Load("nicID", &x.nicID) + m.Load("multicastAddr", &x.multicastAddr) +} + +func (x *udpPacketList) beforeSave() {} +func (x *udpPacketList) save(m state.Map) { + x.beforeSave() + m.Save("head", &x.head) + m.Save("tail", &x.tail) +} + +func (x *udpPacketList) afterLoad() {} +func (x *udpPacketList) load(m state.Map) { + m.Load("head", &x.head) + m.Load("tail", &x.tail) +} + +func (x *udpPacketEntry) beforeSave() {} +func (x *udpPacketEntry) save(m state.Map) { + x.beforeSave() + m.Save("next", &x.next) + m.Save("prev", &x.prev) +} + +func (x *udpPacketEntry) afterLoad() {} +func (x *udpPacketEntry) load(m state.Map) { + m.Load("next", &x.next) + m.Load("prev", &x.prev) +} + +func init() { + state.Register("udp.udpPacket", (*udpPacket)(nil), state.Fns{Save: (*udpPacket).save, Load: (*udpPacket).load}) + state.Register("udp.endpoint", (*endpoint)(nil), state.Fns{Save: (*endpoint).save, Load: (*endpoint).load}) + state.Register("udp.multicastMembership", (*multicastMembership)(nil), state.Fns{Save: (*multicastMembership).save, Load: (*multicastMembership).load}) + state.Register("udp.udpPacketList", (*udpPacketList)(nil), state.Fns{Save: (*udpPacketList).save, Load: (*udpPacketList).load}) + state.Register("udp.udpPacketEntry", (*udpPacketEntry)(nil), state.Fns{Save: (*udpPacketEntry).save, Load: (*udpPacketEntry).load}) +} |