diff options
Diffstat (limited to 'pkg/tcpip/transport')
35 files changed, 13739 insertions, 0 deletions
diff --git a/pkg/tcpip/transport/queue/BUILD b/pkg/tcpip/transport/queue/BUILD new file mode 100644 index 000000000..162af574c --- /dev/null +++ b/pkg/tcpip/transport/queue/BUILD @@ -0,0 +1,29 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools/go_stateify:defs.bzl", "go_stateify") + +go_stateify( + name = "queue_state", + srcs = [ + "queue.go", + ], + out = "queue_state.go", + package = "queue", +) + +go_library( + name = "queue", + srcs = [ + "queue.go", + "queue_state.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/queue", + visibility = ["//:sandbox"], + deps = [ + "//pkg/ilist", + "//pkg/state", + "//pkg/tcpip", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/queue/queue.go b/pkg/tcpip/transport/queue/queue.go new file mode 100644 index 000000000..2d2918504 --- /dev/null +++ b/pkg/tcpip/transport/queue/queue.go @@ -0,0 +1,166 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package queue provides the implementation of buffer queue +// and interface of queue entry with Length method. +package queue + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/ilist" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// Entry implements Linker interface and has both Length and Release methods. +type Entry interface { + ilist.Linker + Length() int64 + Release() + Peek() Entry +} + +// Queue is a buffer queue. +type Queue struct { + ReaderQueue *waiter.Queue + WriterQueue *waiter.Queue + + mu sync.Mutex `state:"nosave"` + closed bool + used int64 + limit int64 + dataList ilist.List +} + +// New allocates and initializes a new queue. +func New(ReaderQueue *waiter.Queue, WriterQueue *waiter.Queue, limit int64) *Queue { + return &Queue{ReaderQueue: ReaderQueue, WriterQueue: WriterQueue, limit: limit} +} + +// Close closes q for reading and writing. It is immediately not writable and +// will become unreadble will no more data is pending. +// +// Both the read and write queues must be notified after closing: +// q.ReaderQueue.Notify(waiter.EventIn) +// q.WriterQueue.Notify(waiter.EventOut) +func (q *Queue) Close() { + q.mu.Lock() + q.closed = true + q.mu.Unlock() +} + +// Reset empties the queue and Releases all of the Entries. +// +// Both the read and write queues must be notified after resetting: +// q.ReaderQueue.Notify(waiter.EventIn) +// q.WriterQueue.Notify(waiter.EventOut) +func (q *Queue) Reset() { + q.mu.Lock() + for cur := q.dataList.Front(); cur != nil; cur = cur.Next() { + cur.(Entry).Release() + } + q.dataList.Reset() + q.used = 0 + q.mu.Unlock() +} + +// IsReadable determines if q is currently readable. +func (q *Queue) IsReadable() bool { + q.mu.Lock() + defer q.mu.Unlock() + + return q.closed || q.dataList.Front() != nil +} + +// IsWritable determines if q is currently writable. +func (q *Queue) IsWritable() bool { + q.mu.Lock() + defer q.mu.Unlock() + + return q.closed || q.used < q.limit +} + +// Enqueue adds an entry to the data queue if room is available. +// +// If notify is true, ReaderQueue.Notify must be called: +// q.ReaderQueue.Notify(waiter.EventIn) +func (q *Queue) Enqueue(e Entry) (notify bool, err *tcpip.Error) { + q.mu.Lock() + + if q.closed { + q.mu.Unlock() + return false, tcpip.ErrClosedForSend + } + + if q.used >= q.limit { + q.mu.Unlock() + return false, tcpip.ErrWouldBlock + } + + notify = q.dataList.Front() == nil + q.used += e.Length() + q.dataList.PushBack(e) + + q.mu.Unlock() + + return notify, nil +} + +// Dequeue removes the first entry in the data queue, if one exists. +// +// If notify is true, WriterQueue.Notify must be called: +// q.WriterQueue.Notify(waiter.EventOut) +func (q *Queue) Dequeue() (e Entry, notify bool, err *tcpip.Error) { + q.mu.Lock() + + if q.dataList.Front() == nil { + err := tcpip.ErrWouldBlock + if q.closed { + err = tcpip.ErrClosedForReceive + } + q.mu.Unlock() + + return nil, false, err + } + + notify = q.used >= q.limit + + e = q.dataList.Front().(Entry) + q.dataList.Remove(e) + q.used -= e.Length() + + notify = notify && q.used < q.limit + + q.mu.Unlock() + + return e, notify, nil +} + +// Peek returns the first entry in the data queue, if one exists. +func (q *Queue) Peek() (Entry, *tcpip.Error) { + q.mu.Lock() + defer q.mu.Unlock() + + if q.dataList.Front() == nil { + err := tcpip.ErrWouldBlock + if q.closed { + err = tcpip.ErrClosedForReceive + } + return nil, err + } + + return q.dataList.Front().(Entry).Peek(), nil +} + +// QueuedSize returns the number of bytes currently in the queue, that is, the +// number of readable bytes. +func (q *Queue) QueuedSize() int64 { + return q.used +} + +// MaxQueueSize returns the maximum number of bytes storable in the queue. +func (q *Queue) MaxQueueSize() int64 { + return q.limit +} diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD new file mode 100644 index 000000000..d0eb8b8bd --- /dev/null +++ b/pkg/tcpip/transport/tcp/BUILD @@ -0,0 +1,97 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_stateify") + +go_stateify( + name = "tcp_state", + srcs = [ + "endpoint.go", + "endpoint_state.go", + "rcv.go", + "segment_heap.go", + "snd.go", + "tcp_segment_list.go", + ], + out = "tcp_state.go", + package = "tcp", +) + +go_template_instance( + name = "tcp_segment_list", + out = "tcp_segment_list.go", + package = "tcp", + prefix = "segment", + template = "//pkg/ilist:generic_list", + types = { + "Linker": "*segment", + }, +) + +go_library( + name = "tcp", + srcs = [ + "accept.go", + "connect.go", + "endpoint.go", + "endpoint_state.go", + "forwarder.go", + "protocol.go", + "rcv.go", + "sack.go", + "segment.go", + "segment_heap.go", + "segment_queue.go", + "snd.go", + "tcp_segment_list.go", + "tcp_state.go", + "timer.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp", + visibility = ["//visibility:public"], + deps = [ + "//pkg/sleep", + "//pkg/state", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + "//pkg/tcpip/stack", + "//pkg/tmutex", + "//pkg/waiter", + ], +) + +go_test( + name = "tcp_test", + size = "small", + srcs = [ + "dual_stack_test.go", + "tcp_sack_test.go", + "tcp_test.go", + "tcp_timestamp_test.go", + ], + deps = [ + ":tcp", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/loopback", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/seqnum", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp/testing/context", + "//pkg/waiter", + ], +) + +filegroup( + name = "autogen", + srcs = [ + "tcp_segment_list.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go new file mode 100644 index 000000000..9a5b13066 --- /dev/null +++ b/pkg/tcpip/transport/tcp/accept.go @@ -0,0 +1,407 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp + +import ( + "crypto/rand" + "crypto/sha1" + "encoding/binary" + "hash" + "io" + "sync" + "time" + + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + // tsLen is the length, in bits, of the timestamp in the SYN cookie. + tsLen = 8 + + // tsMask is a mask for timestamp values (i.e., tsLen bits). + tsMask = (1 << tsLen) - 1 + + // tsOffset is the offset, in bits, of the timestamp in the SYN cookie. + tsOffset = 24 + + // hashMask is the mask for hash values (i.e., tsOffset bits). + hashMask = (1 << tsOffset) - 1 + + // maxTSDiff is the maximum allowed difference between a received cookie + // timestamp and the current timestamp. If the difference is greater + // than maxTSDiff, the cookie is expired. + maxTSDiff = 2 +) + +var ( + // SynRcvdCountThreshold is the global maximum number of connections + // that are allowed to be in SYN-RCVD state before TCP starts using SYN + // cookies to accept connections. + // + // It is an exported variable only for testing, and should not otherwise + // be used by importers of this package. + SynRcvdCountThreshold uint64 = 1000 + + // mssTable is a slice containing the possible MSS values that we + // encode in the SYN cookie with two bits. + mssTable = []uint16{536, 1300, 1440, 1460} +) + +func encodeMSS(mss uint16) uint32 { + for i := len(mssTable) - 1; i > 0; i-- { + if mss >= mssTable[i] { + return uint32(i) + } + } + return 0 +} + +// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is +// protected by a mutex so that we can increment only when it's guaranteed not +// to go above a threshold. +var synRcvdCount struct { + sync.Mutex + value uint64 +} + +// listenContext is used by a listening endpoint to store state used while +// listening for connections. This struct is allocated by the listen goroutine +// and must not be accessed or have its methods called concurrently as they +// may mutate the stored objects. +type listenContext struct { + stack *stack.Stack + rcvWnd seqnum.Size + nonce [2][sha1.BlockSize]byte + + hasherMu sync.Mutex + hasher hash.Hash + v6only bool + netProto tcpip.NetworkProtocolNumber +} + +// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds. +func timeStamp() uint32 { + return uint32(time.Now().Unix()>>6) & tsMask +} + +// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD +// state. It succeeds if the increment doesn't make the count go beyond the +// threshold, and fails otherwise. +func incSynRcvdCount() bool { + synRcvdCount.Lock() + defer synRcvdCount.Unlock() + + if synRcvdCount.value >= SynRcvdCountThreshold { + return false + } + + synRcvdCount.value++ + + return true +} + +// decSynRcvdCount atomically decrements the global number of endpoints in +// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount +// succeeded. +func decSynRcvdCount() { + synRcvdCount.Lock() + defer synRcvdCount.Unlock() + + synRcvdCount.value-- +} + +// newListenContext creates a new listen context. +func newListenContext(stack *stack.Stack, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext { + l := &listenContext{ + stack: stack, + rcvWnd: rcvWnd, + hasher: sha1.New(), + v6only: v6only, + netProto: netProto, + } + + rand.Read(l.nonce[0][:]) + rand.Read(l.nonce[1][:]) + + return l +} + +// cookieHash calculates the cookieHash for the given id, timestamp and nonce +// index. The hash is used to create and validate cookies. +func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 { + + // Initialize block with fixed-size data: local ports and v. + var payload [8]byte + binary.BigEndian.PutUint16(payload[0:], id.LocalPort) + binary.BigEndian.PutUint16(payload[2:], id.RemotePort) + binary.BigEndian.PutUint32(payload[4:], ts) + + // Feed everything to the hasher. + l.hasherMu.Lock() + l.hasher.Reset() + l.hasher.Write(payload[:]) + l.hasher.Write(l.nonce[nonceIndex][:]) + io.WriteString(l.hasher, string(id.LocalAddress)) + io.WriteString(l.hasher, string(id.RemoteAddress)) + + // Finalize the calculation of the hash and return the first 4 bytes. + h := make([]byte, 0, sha1.Size) + h = l.hasher.Sum(h) + l.hasherMu.Unlock() + + return binary.BigEndian.Uint32(h[:]) +} + +// createCookie creates a SYN cookie for the given id and incoming sequence +// number. +func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value { + ts := timeStamp() + v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset) + v += (l.cookieHash(id, ts, 1) + data) & hashMask + return seqnum.Value(v) +} + +// isCookieValid checks if the supplied cookie is valid for the given id and +// sequence number. If it is, it also returns the data originally encoded in the +// cookie when createCookie was called. +func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) { + ts := timeStamp() + v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq) + cookieTS := v >> tsOffset + if ((ts - cookieTS) & tsMask) > maxTSDiff { + return 0, false + } + + return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true +} + +// createConnectedEndpoint creates a new connected endpoint, with the connection +// parameters given by the arguments. +func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { + // Create a new endpoint. + netProto := l.netProto + if netProto == 0 { + netProto = s.route.NetProto + } + n := newEndpoint(l.stack, netProto, nil) + n.v6only = l.v6only + n.id = s.id + n.boundNICID = s.route.NICID() + n.route = s.route.Clone() + n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} + n.rcvBufSize = int(l.rcvWnd) + + n.maybeEnableTimestamp(rcvdSynOpts) + n.maybeEnableSACKPermitted(rcvdSynOpts) + + // Register new endpoint so that packets are routed to it. + if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil { + n.Close() + return nil, err + } + + n.isRegistered = true + n.state = stateConnected + + // Create sender and receiver. + // + // The receiver at least temporarily has a zero receive window scale, + // but the caller may change it (before starting the protocol loop). + n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS) + n.rcv = newReceiver(n, irs, l.rcvWnd, 0) + + return n, nil +} + +// createEndpoint creates a new endpoint in connected state and then performs +// the TCP 3-way handshake. +func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) { + // Create new endpoint. + irs := s.sequenceNumber + cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS)) + ep, err := l.createConnectedEndpoint(s, cookie, irs, opts) + if err != nil { + return nil, err + } + + // Perform the 3-way handshake. + h, err := newHandshake(ep, l.rcvWnd) + if err != nil { + ep.Close() + return nil, err + } + + h.resetToSynRcvd(cookie, irs, opts) + if err := h.execute(); err != nil { + ep.Close() + return nil, err + } + + // Update the receive window scaling. We can't do it before the + // handshake because it's possible that the peer doesn't support window + // scaling. + ep.rcv.rcvWndScale = h.effectiveRcvWndScale() + + return ep, nil +} + +// deliverAccepted delivers the newly-accepted endpoint to the listener. If the +// endpoint has transitioned out of the listen state, the new endpoint is closed +// instead. +func (e *endpoint) deliverAccepted(n *endpoint) { + e.mu.RLock() + if e.state == stateListen { + e.acceptedChan <- n + e.waiterQueue.Notify(waiter.EventIn) + } else { + n.Close() + } + e.mu.RUnlock() +} + +// handleSynSegment is called in its own goroutine once the listening endpoint +// receives a SYN segment. It is responsible for completing the handshake and +// queueing the new endpoint for acceptance. +// +// A limited number of these goroutines are allowed before TCP starts using SYN +// cookies to accept connections. +func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { + defer decSynRcvdCount() + defer s.decRef() + + n, err := ctx.createEndpointAndPerformHandshake(s, opts) + if err != nil { + return + } + + e.deliverAccepted(n) +} + +// handleListenSegment is called when a listening endpoint receives a segment +// and needs to handle it. +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { + switch s.flags { + case flagSyn: + opts := parseSynSegmentOptions(s) + if incSynRcvdCount() { + s.incRef() + go e.handleSynSegment(ctx, s, &opts) // S/R-FIXME + } else { + cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + // Send SYN with window scaling because we currently + // dont't encode this information in the cookie. + // + // Enable Timestamp option if the original syn did have + // the timestamp option specified. + synOpts := header.TCPSynOptions{ + WS: -1, + TS: opts.TS, + TSVal: tcpTimeStamp(timeStampOffset()), + TSEcr: opts.TSVal, + } + sendSynTCP(&s.route, s.id, flagSyn|flagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts) + } + + case flagAck: + if data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1); ok && int(data) < len(mssTable) { + // Create newly accepted endpoint and deliver it. + rcvdSynOptions := &header.TCPSynOptions{ + MSS: mssTable[data], + // Disable Window scaling as original SYN is + // lost. + WS: -1, + } + // When syn cookies are in use we enable timestamp only + // if the ack specifies the timestamp option assuming + // that the other end did in fact negotiate the + // timestamp option in the original SYN. + if s.parsedOptions.TS { + rcvdSynOptions.TS = true + rcvdSynOptions.TSVal = s.parsedOptions.TSVal + rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr + } + n, err := ctx.createConnectedEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions) + if err == nil { + // clear the tsOffset for the newly created + // endpoint as the Timestamp was already + // randomly offset when the original SYN-ACK was + // sent above. + n.tsOffset = 0 + e.deliverAccepted(n) + } + } + } +} + +// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in +// its own goroutine and is responsible for handling connection requests. +func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { + defer func() { + // Mark endpoint as closed. This will prevent goroutines running + // handleSynSegment() from attempting to queue new connections + // to the endpoint. + e.mu.Lock() + e.state = stateClosed + e.mu.Unlock() + + // Notify waiters that the endpoint is shutdown. + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) + + // Do cleanup if needed. + e.completeWorker() + }() + + e.mu.Lock() + v6only := e.v6only + e.mu.Unlock() + + ctx := newListenContext(e.stack, rcvWnd, v6only, e.netProto) + + s := sleep.Sleeper{} + s.AddWaker(&e.notificationWaker, wakerForNotification) + s.AddWaker(&e.newSegmentWaker, wakerForNewSegment) + for { + switch index, _ := s.Fetch(true); index { + case wakerForNotification: + n := e.fetchNotifications() + if n¬ifyClose != 0 { + return nil + } + if n¬ifyDrain != 0 { + for s := e.segmentQueue.dequeue(); s != nil; s = e.segmentQueue.dequeue() { + e.handleListenSegment(ctx, s) + s.decRef() + } + e.drainDone <- struct{}{} + return nil + } + + case wakerForNewSegment: + // Process at most maxSegmentsPerWake segments. + mayRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + mayRequeue = false + break + } + + e.handleListenSegment(ctx, s) + s.decRef() + } + + // If the queue is not empty, make sure we'll wake up + // in the next iteration. + if mayRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + } + } +} diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go new file mode 100644 index 000000000..4d20f4d3f --- /dev/null +++ b/pkg/tcpip/transport/tcp/connect.go @@ -0,0 +1,953 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp + +import ( + "crypto/rand" + "sync" + "sync/atomic" + "time" + + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// maxSegmentsPerWake is the maximum number of segments to process in the main +// protocol goroutine per wake-up. Yielding [after this number of segments are +// processed] allows other events to be processed as well (e.g., timeouts, +// resets, etc.). +const maxSegmentsPerWake = 100 + +type handshakeState int + +// The following are the possible states of the TCP connection during a 3-way +// handshake. A depiction of the states and transitions can be found in RFC 793, +// page 23. +const ( + handshakeSynSent handshakeState = iota + handshakeSynRcvd + handshakeCompleted +) + +// The following are used to set up sleepers. +const ( + wakerForNotification = iota + wakerForNewSegment + wakerForResend + wakerForResolution +) + +const ( + // Maximum space available for options. + maxOptionSize = 40 +) + +// handshake holds the state used during a TCP 3-way handshake. +type handshake struct { + ep *endpoint + state handshakeState + active bool + flags uint8 + ackNum seqnum.Value + + // iss is the initial send sequence number, as defined in RFC 793. + iss seqnum.Value + + // rcvWnd is the receive window, as defined in RFC 793. + rcvWnd seqnum.Size + + // sndWnd is the send window, as defined in RFC 793. + sndWnd seqnum.Size + + // mss is the maximum segment size received from the peer. + mss uint16 + + // sndWndScale is the send window scale, as defined in RFC 1323. A + // negative value means no scaling is supported by the peer. + sndWndScale int + + // rcvWndScale is the receive window scale, as defined in RFC 1323. + rcvWndScale int +} + +func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, *tcpip.Error) { + h := handshake{ + ep: ep, + active: true, + rcvWnd: rcvWnd, + rcvWndScale: FindWndScale(rcvWnd), + } + if err := h.resetState(); err != nil { + return handshake{}, err + } + + return h, nil +} + +// FindWndScale determines the window scale to use for the given maximum window +// size. +func FindWndScale(wnd seqnum.Size) int { + if wnd < 0x10000 { + return 0 + } + + max := seqnum.Size(0xffff) + s := 0 + for wnd > max && s < header.MaxWndScale { + s++ + max <<= 1 + } + + return s +} + +// resetState resets the state of the handshake object such that it becomes +// ready for a new 3-way handshake. +func (h *handshake) resetState() *tcpip.Error { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + + h.state = handshakeSynSent + h.flags = flagSyn + h.ackNum = 0 + h.mss = 0 + h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24) + + return nil +} + +// effectiveRcvWndScale returns the effective receive window scale to be used. +// If the peer doesn't support window scaling, the effective rcv wnd scale is +// zero; otherwise it's the value calculated based on the initial rcv wnd. +func (h *handshake) effectiveRcvWndScale() uint8 { + if h.sndWndScale < 0 { + return 0 + } + return uint8(h.rcvWndScale) +} + +// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD +// state. +func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) { + h.active = false + h.state = handshakeSynRcvd + h.flags = flagSyn | flagAck + h.iss = iss + h.ackNum = irs + 1 + h.mss = opts.MSS + h.sndWndScale = opts.WS +} + +// checkAck checks if the ACK number, if present, of a segment received during +// a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in +// response. +func (h *handshake) checkAck(s *segment) bool { + if s.flagIsSet(flagAck) && s.ackNumber != h.iss+1 { + // RFC 793, page 36, states that a reset must be generated when + // the connection is in any non-synchronized state and an + // incoming segment acknowledges something not yet sent. The + // connection remains in the same state. + ack := s.sequenceNumber.Add(s.logicalLen()) + h.ep.sendRaw(nil, flagRst|flagAck, s.ackNumber, ack, 0) + return false + } + + return true +} + +// synSentState handles a segment received when the TCP 3-way handshake is in +// the SYN-SENT state. +func (h *handshake) synSentState(s *segment) *tcpip.Error { + // RFC 793, page 37, states that in the SYN-SENT state, a reset is + // acceptable if the ack field acknowledges the SYN. + if s.flagIsSet(flagRst) { + if s.flagIsSet(flagAck) && s.ackNumber == h.iss+1 { + return tcpip.ErrConnectionRefused + } + return nil + } + + if !h.checkAck(s) { + return nil + } + + // We are in the SYN-SENT state. We only care about segments that have + // the SYN flag. + if !s.flagIsSet(flagSyn) { + return nil + } + + // Parse the SYN options. + rcvSynOpts := parseSynSegmentOptions(s) + + // Remember if the Timestamp option was negotiated. + h.ep.maybeEnableTimestamp(&rcvSynOpts) + + // Remember if the SACKPermitted option was negotiated. + h.ep.maybeEnableSACKPermitted(&rcvSynOpts) + + // Remember the sequence we'll ack from now on. + h.ackNum = s.sequenceNumber + 1 + h.flags |= flagAck + h.mss = rcvSynOpts.MSS + h.sndWndScale = rcvSynOpts.WS + + // If this is a SYN ACK response, we only need to acknowledge the SYN + // and the handshake is completed. + if s.flagIsSet(flagAck) { + h.state = handshakeCompleted + h.ep.sendRaw(nil, flagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale()) + return nil + } + + // A SYN segment was received, but no ACK in it. We acknowledge the SYN + // but resend our own SYN and wait for it to be acknowledged in the + // SYN-RCVD state. + h.state = handshakeSynRcvd + synOpts := header.TCPSynOptions{ + WS: h.rcvWndScale, + TS: rcvSynOpts.TS, + TSVal: h.ep.timestamp(), + TSEcr: h.ep.recentTS, + + // We only send SACKPermitted if the other side indicated it + // permits SACK. This is not explicitly defined in the RFC but + // this is the behaviour implemented by Linux. + SACKPermitted: rcvSynOpts.SACKPermitted, + } + sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + + return nil +} + +// synRcvdState handles a segment received when the TCP 3-way handshake is in +// the SYN-RCVD state. +func (h *handshake) synRcvdState(s *segment) *tcpip.Error { + if s.flagIsSet(flagRst) { + // RFC 793, page 37, states that in the SYN-RCVD state, a reset + // is acceptable if the sequence number is in the window. + if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) { + return tcpip.ErrConnectionRefused + } + return nil + } + + if !h.checkAck(s) { + return nil + } + + if s.flagIsSet(flagSyn) && s.sequenceNumber != h.ackNum-1 { + // We received two SYN segments with different sequence + // numbers, so we reset this and restart the whole + // process, except that we don't reset the timer. + ack := s.sequenceNumber.Add(s.logicalLen()) + seq := seqnum.Value(0) + if s.flagIsSet(flagAck) { + seq = s.ackNumber + } + h.ep.sendRaw(nil, flagRst|flagAck, seq, ack, 0) + + if !h.active { + return tcpip.ErrInvalidEndpointState + } + + if err := h.resetState(); err != nil { + return err + } + synOpts := header.TCPSynOptions{ + WS: h.rcvWndScale, + TS: h.ep.sendTSOk, + TSVal: h.ep.timestamp(), + TSEcr: h.ep.recentTS, + SACKPermitted: h.ep.sackPermitted, + } + sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + return nil + } + + // We have previously received (and acknowledged) the peer's SYN. If the + // peer acknowledges our SYN, the handshake is completed. + if s.flagIsSet(flagAck) { + + // If the timestamp option is negotiated and the segment does + // not carry a timestamp option then the segment must be dropped + // as per https://tools.ietf.org/html/rfc7323#section-3.2. + if h.ep.sendTSOk && !s.parsedOptions.TS { + atomic.AddUint64(&h.ep.stack.MutableStats().DroppedPackets, 1) + return nil + } + + // Update timestamp if required. See RFC7323, section-4.3. + h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber) + + h.state = handshakeCompleted + return nil + } + + return nil +} + +// processSegments goes through the segment queue and processes up to +// maxSegmentsPerWake (if they're available). +func (h *handshake) processSegments() *tcpip.Error { + for i := 0; i < maxSegmentsPerWake; i++ { + s := h.ep.segmentQueue.dequeue() + if s == nil { + return nil + } + + h.sndWnd = s.window + if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 { + h.sndWnd <<= uint8(h.sndWndScale) + } + + var err *tcpip.Error + switch h.state { + case handshakeSynRcvd: + err = h.synRcvdState(s) + case handshakeSynSent: + err = h.synSentState(s) + } + s.decRef() + if err != nil { + return err + } + + // We stop processing packets once the handshake is completed, + // otherwise we may process packets meant to be processed by + // the main protocol goroutine. + if h.state == handshakeCompleted { + break + } + } + + // If the queue is not empty, make sure we'll wake up in the next + // iteration. + if !h.ep.segmentQueue.empty() { + h.ep.newSegmentWaker.Assert() + } + + return nil +} + +func (h *handshake) resolveRoute() *tcpip.Error { + // Set up the wakers. + s := sleep.Sleeper{} + resolutionWaker := &sleep.Waker{} + s.AddWaker(resolutionWaker, wakerForResolution) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + defer s.Done() + + // Initial action is to resolve route. + index := wakerForResolution + for { + switch index { + case wakerForResolution: + if err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock { + // Either success (err == nil) or failure. + return err + } + // Resolution not completed. Keep trying... + + case wakerForNotification: + n := h.ep.fetchNotifications() + if n¬ifyClose != 0 { + h.ep.route.RemoveWaker(resolutionWaker) + return tcpip.ErrAborted + } + } + + // Wait for notification. + index, _ = s.Fetch(true) + } +} + +// execute executes the TCP 3-way handshake. +func (h *handshake) execute() *tcpip.Error { + if h.ep.route.IsResolutionRequired() { + if err := h.resolveRoute(); err != nil { + return err + } + } + + // Initialize the resend timer. + resendWaker := sleep.Waker{} + timeOut := time.Duration(time.Second) + rt := time.AfterFunc(timeOut, func() { + resendWaker.Assert() + }) + defer rt.Stop() + + // Set up the wakers. + s := sleep.Sleeper{} + s.AddWaker(&resendWaker, wakerForResend) + s.AddWaker(&h.ep.notificationWaker, wakerForNotification) + s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment) + defer s.Done() + + var sackEnabled SACKEnabled + if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil { + // If stack returned an error when checking for SACKEnabled + // status then just default to switching off SACK negotiation. + sackEnabled = false + } + + // Send the initial SYN segment and loop until the handshake is + // completed. + synOpts := header.TCPSynOptions{ + WS: h.rcvWndScale, + TS: true, + TSVal: h.ep.timestamp(), + TSEcr: h.ep.recentTS, + SACKPermitted: bool(sackEnabled), + } + + // Execute is also called in a listen context so we want to make sure we + // only send the TS/SACK option when we received the TS/SACK in the + // initial SYN. + if h.state == handshakeSynRcvd { + synOpts.TS = h.ep.sendTSOk + synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled) + } + sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + for h.state != handshakeCompleted { + switch index, _ := s.Fetch(true); index { + case wakerForResend: + timeOut *= 2 + if timeOut > 60*time.Second { + return tcpip.ErrTimeout + } + rt.Reset(timeOut) + sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts) + + case wakerForNotification: + n := h.ep.fetchNotifications() + if n¬ifyClose != 0 { + return tcpip.ErrAborted + } + + case wakerForNewSegment: + if err := h.processSegments(); err != nil { + return err + } + } + } + + return nil +} + +func parseSynSegmentOptions(s *segment) header.TCPSynOptions { + synOpts := header.ParseSynOptions(s.options, s.flagIsSet(flagAck)) + if synOpts.TS { + s.parsedOptions.TSVal = synOpts.TSVal + s.parsedOptions.TSEcr = synOpts.TSEcr + } + return synOpts +} + +var optionPool = sync.Pool{ + New: func() interface{} { + return make([]byte, maxOptionSize) + }, +} + +func getOptions() []byte { + return optionPool.Get().([]byte) +} + +func putOptions(options []byte) { + // Reslice to full capacity. + optionPool.Put(options[0:cap(options)]) +} + +func makeSynOptions(opts header.TCPSynOptions) []byte { + // Emulate linux option order. This is as follows: + // + // if md5: NOP NOP MD5SIG 18 md5sig(16) + // if mss: MSS 4 mss(2) + // if ts and sack_advertise: + // SACK 2 TIMESTAMP 2 timestamp(8) + // elif ts: NOP NOP TIMESTAMP 10 timestamp(8) + // elif sack: NOP NOP SACK 2 + // if wscale: NOP WINDOW 3 ws(1) + // if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8)) + // [for each block] start_seq(4) end_seq(4) + // if fastopen_cookie: + // if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2) + // else: FASTOPEN (2 + len(cookie)) + // cookie(variable) [padding to four bytes] + // + options := getOptions() + + // Always encode the mss. + offset := header.EncodeMSSOption(uint32(opts.MSS), options) + + // Special ordering is required here. If both TS and SACK are enabled, + // then the SACK option precedes TS, with no padding. If they are + // enabled individually, then we see padding before the option. + if opts.TS && opts.SACKPermitted { + offset += header.EncodeSACKPermittedOption(options[offset:]) + offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) + } else if opts.TS { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:]) + } else if opts.SACKPermitted { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeSACKPermittedOption(options[offset:]) + } + + // Initialize the WS option. + if opts.WS >= 0 { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeWSOption(opts.WS, options[offset:]) + } + + // Padding to the end; note that this never apply unless we add a + // fastopen option, we always expect the offset to remain the same. + if delta := header.AddTCPOptionPadding(options, offset); delta != 0 { + panic("unexpected option encoding") + } + + return options[:offset] +} + +func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error { + // The MSS in opts is automatically calculated as this function is + // called from many places and we don't want every call point being + // embedded with the MSS calculation. + if opts.MSS == 0 { + opts.MSS = uint16(r.MTU() - header.TCPMinimumSize) + } + + options := makeSynOptions(opts) + err := sendTCPWithOptions(r, id, nil, flags, seq, ack, rcvWnd, options) + putOptions(options) + return err +} + +// sendTCPWithOptions sends a TCP segment with the provided options via the +// provided network endpoint and under the provided identity. +func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error { + optLen := len(opts) + // Allocate a buffer for the TCP header. + hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen) + + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + // Initialize the header. + tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen)) + tcp.Encode(&header.TCPFields{ + SrcPort: id.LocalPort, + DstPort: id.RemotePort, + SeqNum: uint32(seq), + AckNum: uint32(ack), + DataOffset: uint8(header.TCPMinimumSize + optLen), + Flags: flags, + WindowSize: uint16(rcvWnd), + }) + copy(tcp[header.TCPMinimumSize:], opts) + + // Only calculate the checksum if offloading isn't supported. + if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { + length := uint16(hdr.UsedLength()) + xsum := r.PseudoHeaderChecksum(ProtocolNumber) + if data != nil { + length += uint16(len(data)) + xsum = header.Checksum(data, xsum) + } + + tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length)) + } + + return r.WritePacket(&hdr, data, ProtocolNumber) +} + +// sendTCP sends a TCP segment via the provided network endpoint and under the +// provided identity. +func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { + // Allocate a buffer for the TCP header. + hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength())) + + if rcvWnd > 0xffff { + rcvWnd = 0xffff + } + + // Initialize the header. + tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize)) + tcp.Encode(&header.TCPFields{ + SrcPort: id.LocalPort, + DstPort: id.RemotePort, + SeqNum: uint32(seq), + AckNum: uint32(ack), + DataOffset: header.TCPMinimumSize, + Flags: flags, + WindowSize: uint16(rcvWnd), + }) + + // Only calculate the checksum if offloading isn't supported. + if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { + length := uint16(hdr.UsedLength()) + xsum := r.PseudoHeaderChecksum(ProtocolNumber) + if data != nil { + length += uint16(len(data)) + xsum = header.Checksum(data, xsum) + } + + tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length)) + } + + return r.WritePacket(&hdr, data, ProtocolNumber) +} + +// makeOptions makes an options slice. +func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { + options := getOptions() + offset := 0 + + // N.B. the ordering here matches the ordering used by Linux internally + // and described in the raw makeOptions function. We don't include + // unnecessary cases here (post connection.) + if e.sendTSOk { + // Embed the timestamp if timestamp has been enabled. + // + // We only use the lower 32 bits of the unix time in + // milliseconds. This is similar to what Linux does where it + // uses the lower 32 bits of the jiffies value in the tsVal + // field of the timestamp option. + // + // Further, RFC7323 section-5.4 recommends millisecond + // resolution as the lowest recommended resolution for the + // timestamp clock. + // + // Ref: https://tools.ietf.org/html/rfc7323#section-5.4. + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:]) + } + if e.sackPermitted && len(sackBlocks) > 0 { + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeNOP(options[offset:]) + offset += header.EncodeSACKBlocks(sackBlocks, options[offset:]) + } + + // We expect the above to produce an aligned offset. + if delta := header.AddTCPOptionPadding(options, offset); delta != 0 { + panic("unexpected option encoding") + } + + return options[:offset] +} + +// sendRaw sends a TCP segment to the endpoint's peer. +func (e *endpoint) sendRaw(data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { + var sackBlocks []header.SACKBlock + if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&flagAck != 0) { + sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] + } + options := e.makeOptions(sackBlocks) + if len(options) > 0 { + err := sendTCPWithOptions(&e.route, e.id, data, flags, seq, ack, rcvWnd, options) + putOptions(options) + return err + } + err := sendTCP(&e.route, e.id, data, flags, seq, ack, rcvWnd) + putOptions(options) + return err +} + +func (e *endpoint) handleWrite() bool { + // Move packets from send queue to send list. The queue is accessible + // from other goroutines and protected by the send mutex, while the send + // list is only accessible from the handler goroutine, so it needs no + // mutexes. + e.sndBufMu.Lock() + + first := e.sndQueue.Front() + if first != nil { + e.snd.writeList.PushBackList(&e.sndQueue) + e.snd.sndNxtList.UpdateForward(e.sndBufInQueue) + e.sndBufInQueue = 0 + } + + e.sndBufMu.Unlock() + + // Initialize the next segment to write if it's currently nil. + if e.snd.writeNext == nil { + e.snd.writeNext = first + } + + // Push out any new packets. + e.snd.sendData() + + return true +} + +func (e *endpoint) handleClose() bool { + // Drain the send queue. + e.handleWrite() + + // Mark send side as closed. + e.snd.closed = true + + return true +} + +// resetConnection sends a RST segment and puts the endpoint in an error state +// with the given error code. +// This method must only be called from the protocol goroutine. +func (e *endpoint) resetConnection(err *tcpip.Error) { + e.sendRaw(nil, flagAck|flagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) + + e.mu.Lock() + e.state = stateError + e.hardError = err + e.mu.Unlock() +} + +// completeWorker is called by the worker goroutine when it's about to exit. It +// marks the worker as completed and performs cleanup work if requested by +// Close(). +func (e *endpoint) completeWorker() { + e.mu.Lock() + defer e.mu.Unlock() + + e.workerRunning = false + if e.workerCleanup { + e.cleanup() + } +} + +// handleSegments pulls segments from the queue and processes them. It returns +// true if the protocol loop should continue, false otherwise. +func (e *endpoint) handleSegments() bool { + checkRequeue := true + for i := 0; i < maxSegmentsPerWake; i++ { + s := e.segmentQueue.dequeue() + if s == nil { + checkRequeue = false + break + } + + // Invoke the tcp probe if installed. + if e.probe != nil { + e.probe(e.completeState()) + } + + if s.flagIsSet(flagRst) { + if e.rcv.acceptable(s.sequenceNumber, 0) { + // RFC 793, page 37 states that "in all states + // except SYN-SENT, all reset (RST) segments are + // validated by checking their SEQ-fields." So + // we only process it if it's acceptable. + s.decRef() + e.mu.Lock() + e.state = stateError + e.hardError = tcpip.ErrConnectionReset + e.mu.Unlock() + return false + } + } else if s.flagIsSet(flagAck) { + // Patch the window size in the segment according to the + // send window scale. + s.window <<= e.snd.sndWndScale + + // If the timestamp option is negotiated and the segment + // does not carry a timestamp option then the segment + // must be dropped as per + // https://tools.ietf.org/html/rfc7323#section-3.2. + if e.sendTSOk && !s.parsedOptions.TS { + atomic.AddUint64(&e.stack.MutableStats().DroppedPackets, 1) + s.decRef() + continue + } + + // RFC 793, page 41 states that "once in the ESTABLISHED + // state all segments must carry current acknowledgment + // information." + e.rcv.handleRcvdSegment(s) + e.snd.handleRcvdSegment(s) + } + s.decRef() + } + + // If the queue is not empty, make sure we'll wake up in the next + // iteration. + if checkRequeue && !e.segmentQueue.empty() { + e.newSegmentWaker.Assert() + } + + // Send an ACK for all processed packets if needed. + if e.rcv.rcvNxt != e.snd.maxSentAck { + e.snd.sendAck() + } + + return true +} + +// protocolMainLoop is the main loop of the TCP protocol. It runs in its own +// goroutine and is responsible for sending segments and handling received +// segments. +func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error { + var closeTimer *time.Timer + var closeWaker sleep.Waker + + defer func() { + e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut) + e.completeWorker() + + if e.snd != nil { + e.snd.resendTimer.cleanup() + } + + if closeTimer != nil { + closeTimer.Stop() + } + }() + + if !passive { + // This is an active connection, so we must initiate the 3-way + // handshake, and then inform potential waiters about its + // completion. + h, err := newHandshake(e, seqnum.Size(e.receiveBufferAvailable())) + if err == nil { + err = h.execute() + } + if err != nil { + e.lastErrorMu.Lock() + e.lastError = err + e.lastErrorMu.Unlock() + + e.mu.Lock() + e.state = stateError + e.hardError = err + e.mu.Unlock() + + return err + } + + // Transfer handshake state to TCP connection. We disable + // receive window scaling if the peer doesn't support it + // (indicated by a negative send window scale). + e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale) + + e.rcvListMu.Lock() + e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale()) + e.rcvListMu.Unlock() + } + + // Tell waiters that the endpoint is connected and writable. + e.mu.Lock() + e.state = stateConnected + e.mu.Unlock() + + e.waiterQueue.Notify(waiter.EventOut) + + // When the protocol loop exits we should wake up our waiters with EventHUp. + defer e.waiterQueue.Notify(waiter.EventHUp) + + // Set up the functions that will be called when the main protocol loop + // wakes up. + funcs := []struct { + w *sleep.Waker + f func() bool + }{ + { + w: &e.sndWaker, + f: e.handleWrite, + }, + { + w: &e.sndCloseWaker, + f: e.handleClose, + }, + { + w: &e.newSegmentWaker, + f: e.handleSegments, + }, + { + w: &closeWaker, + f: func() bool { + e.resetConnection(tcpip.ErrConnectionAborted) + return false + }, + }, + { + w: &e.snd.resendWaker, + f: func() bool { + if !e.snd.retransmitTimerExpired() { + e.resetConnection(tcpip.ErrTimeout) + return false + } + return true + }, + }, + { + w: &e.notificationWaker, + f: func() bool { + n := e.fetchNotifications() + if n¬ifyNonZeroReceiveWindow != 0 { + e.rcv.nonZeroWindow() + } + + if n¬ifyReceiveWindowChanged != 0 { + e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize()) + } + + if n¬ifyMTUChanged != 0 { + e.sndBufMu.Lock() + count := e.packetTooBigCount + e.packetTooBigCount = 0 + mtu := e.sndMTU + e.sndBufMu.Unlock() + + e.snd.updateMaxPayloadSize(mtu, count) + } + + if n¬ifyClose != 0 && closeTimer == nil { + // Reset the connection 3 seconds after the + // endpoint has been closed. + closeTimer = time.AfterFunc(3*time.Second, func() { + closeWaker.Assert() + }) + } + return true + }, + }, + } + + // Initialize the sleeper based on the wakers in funcs. + s := sleep.Sleeper{} + for i := range funcs { + s.AddWaker(funcs[i].w, i) + } + + // Main loop. Handle segments until both send and receive ends of the + // connection have completed. + for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList { + e.workMu.Unlock() + v, _ := s.Fetch(true) + e.workMu.Lock() + if !funcs[v].f() { + return nil + } + } + + // Mark endpoint as closed. + e.mu.Lock() + e.state = stateClosed + e.mu.Unlock() + + return nil +} diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go new file mode 100644 index 000000000..a89af4559 --- /dev/null +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -0,0 +1,550 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp_test + +import ( + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/checker" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +func TestV4MappedConnectOnV6Only(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Start connection attempt, it must fail. + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) + if err != tcpip.ErrNoRoute { + t.Fatalf("Unexpected return value from Connect: %v", err) + } +} + +func testV4Connect(t *testing.T, c *context.Context) { + // Start connection attempt. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventOut) + defer c.WQ.EventUnregister(&we) + + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort}) + if err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: tcp.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK packet. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + ), + ) + + // Wait for connection to be established. + select { + case <-ch: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } +} + +func TestV4MappedConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Test the connection request. + testV4Connect(t, c) +} + +func TestV4ConnectWhenBoundToWildcard(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV4Connect(t, c) +} + +func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to v4 mapped wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV4Connect(t, c) +} + +func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to v4 mapped address. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV4Connect(t, c) +} + +func testV6Connect(t *testing.T, c *context.Context) { + // Start connection attempt to IPv6 address. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventOut) + defer c.WQ.EventUnregister(&we) + + err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort}) + if err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetV6Packet() + checker.IPv6(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + tcp := header.TCP(header.IPv6(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + iss := seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: tcp.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK packet. + checker.IPv6(t, c.GetV6Packet(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + ), + ) + + // Wait for connection to be established. + select { + case <-ch: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } +} + +func TestV6Connect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Test the connection request. + testV6Connect(t, c) +} + +func TestV6ConnectV6Only(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Test the connection request. + testV6Connect(t, c) +} + +func TestV6ConnectWhenBoundToWildcard(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV6Connect(t, c) +} + +func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to local address. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV6Connect(t, c) +} + +func TestV4RefuseOnV6Only(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(true) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the RST reply. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.AckNum(uint32(irs)+1), + ), + ) +} + +func TestV6RefuseOnBoundToV4Mapped(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind and listen. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + irs := seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the RST reply. + checker.IPv6(t, c.GetV6Packet(), + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck), + checker.AckNum(uint32(irs)+1), + ), + ) +} + +func testV4Accept(t *testing.T, c *context.Context) { + // Start listening. + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + irs := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + checker.IPv4(t, b, + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1), + ), + ) + + // Send ACK. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + nep, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + nep, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Make sure we get the same error when calling the original ep and the + // new one. This validates that v4-mapped endpoints are still able to + // query the V6Only flag, whereas pure v4 endpoints are not. + var v tcpip.V6OnlyOption + expected := c.EP.GetSockOpt(&v) + if err := nep.GetSockOpt(&v); err != expected { + t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected) + } + + // Check the peer address. + addr, err := nep.GetRemoteAddress() + if err != nil { + t.Fatalf("GetRemoteAddress failed failed: %v", err) + } + + if addr.Addr != context.TestAddr { + t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr) + } +} + +func TestV4AcceptOnV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Accept(t, c) +} + +func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind to v4 mapped wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Accept(t, c) +} + +func TestV4AcceptOnBoundToV4Mapped(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind and listen. + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Accept(t, c) +} + +func TestV6AcceptOnV6(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateV6Endpoint(false) + + // Bind and listen. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := c.EP.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Send a SYN request. + irs := seqnum.Value(789) + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + RcvWnd: 30000, + }) + + // Receive the SYN-ACK reply. + b := c.GetV6Packet() + tcp := header.TCP(header.IPv6(b).Payload()) + iss := seqnum.Value(tcp.SequenceNumber()) + checker.IPv6(t, b, + checker.TCP( + checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.AckNum(uint32(irs)+1), + ), + ) + + // Send ACK. + c.SendV6Packet(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagAck, + SeqNum: irs + 1, + AckNum: iss + 1, + RcvWnd: 30000, + }) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + nep, _, err := c.EP.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + nep, _, err = c.EP.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Make sure we can still query the v6 only status of the new endpoint, + // that is, that it is in fact a v6 socket. + var v tcpip.V6OnlyOption + if err := nep.GetSockOpt(&v); err != nil { + t.Fatalf("GetSockOpt failed failed: %v", err) + } + + // Check the peer address. + addr, err := nep.GetRemoteAddress() + if err != nil { + t.Fatalf("GetRemoteAddress failed failed: %v", err) + } + + if addr.Addr != context.TestV6Addr { + t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr) + } +} + +func TestV4AcceptOnV4(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Accept(t, c) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go new file mode 100644 index 000000000..5d62589d8 --- /dev/null +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -0,0 +1,1371 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp + +import ( + "crypto/rand" + "math" + "sync" + "sync/atomic" + "time" + + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tmutex" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +type endpointState int + +const ( + stateInitial endpointState = iota + stateBound + stateListen + stateConnecting + stateConnected + stateClosed + stateError +) + +// Reasons for notifying the protocol goroutine. +const ( + notifyNonZeroReceiveWindow = 1 << iota + notifyReceiveWindowChanged + notifyClose + notifyMTUChanged + notifyDrain +) + +// SACKInfo holds TCP SACK related information for a given endpoint. +type SACKInfo struct { + // Blocks is the maximum number of SACK blocks we track + // per endpoint. + Blocks [MaxSACKBlocks]header.SACKBlock + + // NumBlocks is the number of valid SACK blocks stored in the + // blocks array above. + NumBlocks int +} + +// endpoint represents a TCP endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. The protocol implementation, however, runs in a single +// goroutine. +type endpoint struct { + // workMu is used to arbitrate which goroutine may perform protocol + // work. Only the main protocol goroutine is expected to call Lock() on + // it, but other goroutines (e.g., send) may call TryLock() to eagerly + // perform work without having to wait for the main one to wake up. + workMu tmutex.Mutex `state:"nosave"` + + // The following fields are initialized at creation time and do not + // change throughout the lifetime of the endpoint. + stack *stack.Stack `state:"manual"` + netProto tcpip.NetworkProtocolNumber + waiterQueue *waiter.Queue + + // lastError represents the last error that the endpoint reported; + // access to it is protected by the following mutex. + lastErrorMu sync.Mutex `state:"nosave"` + lastError *tcpip.Error + + // The following fields are used to manage the receive queue. The + // protocol goroutine adds ready-for-delivery segments to rcvList, + // which are returned by Read() calls to users. + // + // Once the peer has closed its send side, rcvClosed is set to true + // to indicate to users that no more data is coming. + rcvListMu sync.Mutex `state:"nosave"` + rcvList segmentList + rcvClosed bool + rcvBufSize int + rcvBufUsed int + + // The following fields are protected by the mutex. + mu sync.RWMutex `state:"nosave"` + id stack.TransportEndpointID + state endpointState + isPortReserved bool + isRegistered bool + boundNICID tcpip.NICID + route stack.Route `state:"manual"` + v6only bool + isConnectNotified bool + + // effectiveNetProtos contains the network protocols actually in use. In + // most cases it will only contain "netProto", but in cases like IPv6 + // endpoints with v6only set to false, this could include multiple + // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., + // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped + // address). + effectiveNetProtos []tcpip.NetworkProtocolNumber + + // hardError is meaningful only when state is stateError, it stores the + // error to be returned when read/write syscalls are called and the + // endpoint is in this state. + hardError *tcpip.Error + + // workerRunning specifies if a worker goroutine is running. + workerRunning bool + + // workerCleanup specifies if the worker goroutine must perform cleanup + // before exitting. This can only be set to true when workerRunning is + // also true, and they're both protected by the mutex. + workerCleanup bool + + // sendTSOk is used to indicate when the TS Option has been negotiated. + // When sendTSOk is true every non-RST segment should carry a TS as per + // RFC7323#section-1.1 + sendTSOk bool + + // recentTS is the timestamp that should be sent in the TSEcr field of + // the timestamp for future segments sent by the endpoint. This field is + // updated if required when a new segment is received by this endpoint. + recentTS uint32 + + // tsOffset is a randomized offset added to the value of the + // TSVal field in the timestamp option. + tsOffset uint32 + + // shutdownFlags represent the current shutdown state of the endpoint. + shutdownFlags tcpip.ShutdownFlags + + // sackPermitted is set to true if the peer sends the TCPSACKPermitted + // option in the SYN/SYN-ACK. + sackPermitted bool + + // sack holds TCP SACK related information for this endpoint. + sack SACKInfo + + // The options below aren't implemented, but we remember the user + // settings because applications expect to be able to set/query these + // options. + noDelay bool + reuseAddr bool + + // segmentQueue is used to hand received segments to the protocol + // goroutine. Segments are queued as long as the queue is not full, + // and dropped when it is. + segmentQueue segmentQueue `state:"zerovalue"` + + // The following fields are used to manage the send buffer. When + // segments are ready to be sent, they are added to sndQueue and the + // protocol goroutine is signaled via sndWaker. + // + // When the send side is closed, the protocol goroutine is notified via + // sndCloseWaker, and sndClosed is set to true. + sndBufMu sync.Mutex `state:"nosave"` + sndBufSize int + sndBufUsed int + sndClosed bool + sndBufInQueue seqnum.Size + sndQueue segmentList + sndWaker sleep.Waker `state:"manual"` + sndCloseWaker sleep.Waker `state:"manual"` + + // The following are used when a "packet too big" control packet is + // received. They are protected by sndBufMu. They are used to + // communicate to the main protocol goroutine how many such control + // messages have been received since the last notification was processed + // and what was the smallest MTU seen. + packetTooBigCount int + sndMTU int + + // newSegmentWaker is used to indicate to the protocol goroutine that + // it needs to wake up and handle new segments queued to it. + newSegmentWaker sleep.Waker `state:"manual"` + + // notificationWaker is used to indicate to the protocol goroutine that + // it needs to wake up and check for notifications. + notificationWaker sleep.Waker `state:"manual"` + + // notifyFlags is a bitmask of flags used to indicate to the protocol + // goroutine what it was notified; this is only accessed atomically. + notifyFlags uint32 `state:"zerovalue"` + + // acceptedChan is used by a listening endpoint protocol goroutine to + // send newly accepted connections to the endpoint so that they can be + // read by Accept() calls. + acceptedChan chan *endpoint `state:".(endpointChan)"` + + // The following are only used from the protocol goroutine, and + // therefore don't need locks to protect them. + rcv *receiver + snd *sender + + // The goroutine drain completion notification channel. + drainDone chan struct{} `state:"nosave"` + + // probe if not nil is invoked on every received segment. It is passed + // a copy of the current state of the endpoint. + probe stack.TCPProbeFunc `state:"nosave"` +} + +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { + e := &endpoint{ + stack: stack, + netProto: netProto, + waiterQueue: waiterQueue, + rcvBufSize: DefaultBufferSize, + sndBufSize: DefaultBufferSize, + sndMTU: int(math.MaxInt32), + noDelay: false, + reuseAddr: true, + } + + var ss SendBufferSizeOption + if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + e.sndBufSize = ss.Default + } + + var rs ReceiveBufferSizeOption + if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + e.rcvBufSize = rs.Default + } + + if p := stack.GetTCPProbe(); p != nil { + e.probe = p + } + + e.segmentQueue.setLimit(2 * e.rcvBufSize) + e.workMu.Init() + e.workMu.Lock() + e.tsOffset = timeStampOffset() + return e +} + +// Readiness returns the current readiness of the endpoint. For example, if +// waiter.EventIn is set, the endpoint is immediately readable. +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + result := waiter.EventMask(0) + + e.mu.RLock() + defer e.mu.RUnlock() + + switch e.state { + case stateInitial, stateBound, stateConnecting: + // Ready for nothing. + + case stateClosed, stateError: + // Ready for anything. + result = mask + + case stateListen: + // Check if there's anything in the accepted channel. + if (mask & waiter.EventIn) != 0 { + if len(e.acceptedChan) > 0 { + result |= waiter.EventIn + } + } + + case stateConnected: + // Determine if the endpoint is writable if requested. + if (mask & waiter.EventOut) != 0 { + e.sndBufMu.Lock() + if e.sndClosed || e.sndBufUsed < e.sndBufSize { + result |= waiter.EventOut + } + e.sndBufMu.Unlock() + } + + // Determine if the endpoint is readable if requested. + if (mask & waiter.EventIn) != 0 { + e.rcvListMu.Lock() + if e.rcvBufUsed > 0 || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvListMu.Unlock() + } + } + + return result +} + +func (e *endpoint) fetchNotifications() uint32 { + return atomic.SwapUint32(&e.notifyFlags, 0) +} + +func (e *endpoint) notifyProtocolGoroutine(n uint32) { + for { + v := atomic.LoadUint32(&e.notifyFlags) + if v&n == n { + // The flags are already set. + return + } + + if atomic.CompareAndSwapUint32(&e.notifyFlags, v, v|n) { + if v == 0 { + // We are causing a transition from no flags to + // at least one flag set, so we must cause the + // protocol goroutine to wake up. + e.notificationWaker.Assert() + } + return + } + } +} + +// Close puts the endpoint in a closed state and frees all resources associated +// with it. It must be called only once and with no other concurrent calls to +// the endpoint. +func (e *endpoint) Close() { + // Issue a shutdown so that the peer knows we won't send any more data + // if we're connected, or stop accepting if we're listening. + e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead) + + // While we hold the lock, determine if the cleanup should happen + // inline or if we should tell the worker (if any) to do the cleanup. + e.mu.Lock() + worker := e.workerRunning + if worker { + e.workerCleanup = true + } + + // We always release ports inline so that they are immediately available + // for reuse after Close() is called. If also registered, it means this + // is a listening socket, so we must unregister as well otherwise the + // next user would fail in Listen() when trying to register. + if e.isPortReserved { + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort) + e.isPortReserved = false + + if e.isRegistered { + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + e.isRegistered = false + } + } + + e.mu.Unlock() + + // Now that we don't hold the lock anymore, either perform the local + // cleanup or kick the worker to make sure it knows it needs to cleanup. + if !worker { + e.cleanup() + } else { + e.notifyProtocolGoroutine(notifyClose) + } +} + +// cleanup frees all resources associated with the endpoint. It is called after +// Close() is called and the worker goroutine (if any) is done with its work. +func (e *endpoint) cleanup() { + // Close all endpoints that might have been accepted by TCP but not by + // the client. + if e.acceptedChan != nil { + close(e.acceptedChan) + for n := range e.acceptedChan { + n.resetConnection(tcpip.ErrConnectionAborted) + n.Close() + } + } + + if e.isRegistered { + e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + } + + e.route.Release() +} + +// Read reads data from the endpoint. +func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, *tcpip.Error) { + e.mu.RLock() + // The endpoint can be read if it's connected, or if it's already closed + // but has some pending unread data. Also note that a RST being received + // would cause the state to become stateError so we should allow the + // reads to proceed before returning a ECONNRESET. + if s := e.state; s != stateConnected && s != stateClosed && e.rcvBufUsed == 0 { + e.mu.RUnlock() + if s == stateError { + return buffer.View{}, e.hardError + } + return buffer.View{}, tcpip.ErrInvalidEndpointState + } + + e.rcvListMu.Lock() + v, err := e.readLocked() + e.rcvListMu.Unlock() + + e.mu.RUnlock() + + return v, err +} + +func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { + if e.rcvBufUsed == 0 { + if e.rcvClosed || e.state != stateConnected { + return buffer.View{}, tcpip.ErrClosedForReceive + } + return buffer.View{}, tcpip.ErrWouldBlock + } + + s := e.rcvList.Front() + views := s.data.Views() + v := views[s.viewToDeliver] + s.viewToDeliver++ + + if s.viewToDeliver >= len(views) { + e.rcvList.Remove(s) + s.decRef() + } + + scale := e.rcv.rcvWndScale + wasZero := e.zeroReceiveWindow(scale) + e.rcvBufUsed -= len(v) + if wasZero && !e.zeroReceiveWindow(scale) { + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + } + + return v, nil +} + +// Write writes data to the endpoint's peer. +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) { + // Linux completely ignores any address passed to sendto(2) for TCP sockets + // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More + // and opts.EndOfRecord are also ignored. + + e.mu.RLock() + defer e.mu.RUnlock() + + // The endpoint cannot be written to if it's not connected. + if e.state != stateConnected { + switch e.state { + case stateError: + return 0, e.hardError + default: + return 0, tcpip.ErrClosedForSend + } + } + + // Nothing to do if the buffer is empty. + if p.Size() == 0 { + return 0, nil + } + + e.sndBufMu.Lock() + + // Check if the connection has already been closed for sends. + if e.sndClosed { + e.sndBufMu.Unlock() + return 0, tcpip.ErrClosedForSend + } + + // Check against the limit. + avail := e.sndBufSize - e.sndBufUsed + if avail <= 0 { + e.sndBufMu.Unlock() + return 0, tcpip.ErrWouldBlock + } + + v, perr := p.Get(avail) + if perr != nil { + e.sndBufMu.Unlock() + return 0, perr + } + + var err *tcpip.Error + if p.Size() > avail { + err = tcpip.ErrWouldBlock + } + l := len(v) + s := newSegmentFromView(&e.route, e.id, v) + + // Add data to the send queue. + e.sndBufUsed += l + e.sndBufInQueue += seqnum.Size(l) + e.sndQueue.PushBack(s) + + e.sndBufMu.Unlock() + + if e.workMu.TryLock() { + // Do the work inline. + e.handleWrite() + e.workMu.Unlock() + } else { + // Let the protocol goroutine do the work. + e.sndWaker.Assert() + } + return uintptr(l), err +} + +// Peek reads data without consuming it from the endpoint. +// +// This method does not block if there is no data pending. +func (e *endpoint) Peek(vec [][]byte) (uintptr, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // The endpoint can be read if it's connected, or if it's already closed + // but has some pending unread data. + if s := e.state; s != stateConnected && s != stateClosed { + if s == stateError { + return 0, e.hardError + } + return 0, tcpip.ErrInvalidEndpointState + } + + e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() + + if e.rcvBufUsed == 0 { + if e.rcvClosed || e.state != stateConnected { + return 0, tcpip.ErrClosedForReceive + } + return 0, tcpip.ErrWouldBlock + } + + // Make a copy of vec so we can modify the slide headers. + vec = append([][]byte(nil), vec...) + + var num uintptr + + for s := e.rcvList.Front(); s != nil; s = s.Next() { + views := s.data.Views() + + for i := s.viewToDeliver; i < len(views); i++ { + v := views[i] + + for len(v) > 0 { + if len(vec) == 0 { + return num, nil + } + if len(vec[0]) == 0 { + vec = vec[1:] + continue + } + + n := copy(vec[0], v) + v = v[n:] + vec[0] = vec[0][n:] + num += uintptr(n) + } + } + } + + return num, nil +} + +// zeroReceiveWindow checks if the receive window to be announced now would be +// zero, based on the amount of available buffer and the receive window scaling. +// +// It must be called with rcvListMu held. +func (e *endpoint) zeroReceiveWindow(scale uint8) bool { + if e.rcvBufUsed >= e.rcvBufSize { + return true + } + + return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0 +} + +// SetSockOpt sets a socket option. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.NoDelayOption: + e.mu.Lock() + e.noDelay = v != 0 + e.mu.Unlock() + return nil + + case tcpip.ReuseAddressOption: + e.mu.Lock() + e.reuseAddr = v != 0 + e.mu.Unlock() + return nil + + case tcpip.ReceiveBufferSizeOption: + // Make sure the receive buffer size is within the min and max + // allowed. + var rs ReceiveBufferSizeOption + size := int(v) + if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil { + if size < rs.Min { + size = rs.Min + } + if size > rs.Max { + size = rs.Max + } + } + + mask := uint32(notifyReceiveWindowChanged) + + e.rcvListMu.Lock() + + // Make sure the receive buffer size allows us to send a + // non-zero window size. + scale := uint8(0) + if e.rcv != nil { + scale = e.rcv.rcvWndScale + } + if size>>scale == 0 { + size = 1 << scale + } + + // Make sure 2*size doesn't overflow. + if size > math.MaxInt32/2 { + size = math.MaxInt32 / 2 + } + + wasZero := e.zeroReceiveWindow(scale) + e.rcvBufSize = size + if wasZero && !e.zeroReceiveWindow(scale) { + mask |= notifyNonZeroReceiveWindow + } + e.rcvListMu.Unlock() + + e.segmentQueue.setLimit(2 * size) + + e.notifyProtocolGoroutine(mask) + return nil + + case tcpip.SendBufferSizeOption: + // Make sure the send buffer size is within the min and max + // allowed. + size := int(v) + var ss SendBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if size < ss.Min { + size = ss.Min + } + if size > ss.Max { + size = ss.Max + } + } + + e.sndBufMu.Lock() + e.sndBufSize = size + e.sndBufMu.Unlock() + + return nil + + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + // We only allow this to be set when we're in the initial state. + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.v6only = v != 0 + } + + return nil +} + +// readyReceiveSize returns the number of bytes ready to be received. +func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // The endpoint cannot be in listen state. + if e.state == stateListen { + return 0, tcpip.ErrInvalidEndpointState + } + + e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() + + return e.rcvBufUsed, nil +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + e.lastErrorMu.Lock() + err := e.lastError + e.lastError = nil + e.lastErrorMu.Unlock() + return err + + case *tcpip.SendBufferSizeOption: + e.sndBufMu.Lock() + *o = tcpip.SendBufferSizeOption(e.sndBufSize) + e.sndBufMu.Unlock() + return nil + + case *tcpip.ReceiveBufferSizeOption: + e.rcvListMu.Lock() + *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize) + e.rcvListMu.Unlock() + return nil + + case *tcpip.ReceiveQueueSizeOption: + v, err := e.readyReceiveSize() + if err != nil { + return err + } + + *o = tcpip.ReceiveQueueSizeOption(v) + return nil + + case *tcpip.NoDelayOption: + e.mu.RLock() + v := e.noDelay + e.mu.RUnlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.ReuseAddressOption: + e.mu.RLock() + v := e.reuseAddr + e.mu.RUnlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.TCPInfoOption: + *o = tcpip.TCPInfoOption{} + return nil + } + + return tcpip.ErrUnknownProtocolOption +} + +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.netProto + if header.IsV4MappedAddress(addr.Addr) { + // Fail if using a v4 mapped address on a v6only endpoint. + if e.v6only { + return 0, tcpip.ErrNoRoute + } + + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == "\x00\x00\x00\x00" { + addr.Addr = "" + } + } + + // Fail if we're bound to an address length different from the one we're + // checking. + if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) { + return 0, tcpip.ErrInvalidEndpointState + } + + return netProto, nil +} + +// Connect connects the endpoint to its peer. +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + netProto, err := e.checkV4Mapped(&addr) + if err != nil { + return err + } + + nicid := addr.NIC + switch e.state { + case stateBound: + // If we're already bound to a NIC but the caller is requesting + // that we use a different one now, we cannot proceed. + if e.boundNICID == 0 { + break + } + + if nicid != 0 && nicid != e.boundNICID { + return tcpip.ErrNoRoute + } + + nicid = e.boundNICID + + case stateInitial: + // Nothing to do. We'll eventually fill-in the gaps in the ID + // (if any) when we find a route. + + case stateConnecting: + // A connection request has already been issued but hasn't + // completed yet. + return tcpip.ErrAlreadyConnecting + + case stateConnected: + // The endpoint is already connected. If caller hasn't been notified yet, return success. + if !e.isConnectNotified { + e.isConnectNotified = true + return nil + } + // Otherwise return that it's already connected. + return tcpip.ErrAlreadyConnected + + case stateError: + return e.hardError + + default: + return tcpip.ErrInvalidEndpointState + } + + // Find a route to the desired destination. + r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto) + if err != nil { + return err + } + defer r.Release() + + origID := e.id + + netProtos := []tcpip.NetworkProtocolNumber{netProto} + e.id.LocalAddress = r.LocalAddress + e.id.RemoteAddress = r.RemoteAddress + e.id.RemotePort = addr.Port + + if e.id.LocalPort != 0 { + // The endpoint is bound to a port, attempt to register it. + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e) + if err != nil { + return err + } + } else { + // The endpoint doesn't have a local port yet, so try to get + // one. Make sure that it isn't one that will result in the same + // address/port for both local and remote (otherwise this + // endpoint would be trying to connect to itself). + sameAddr := e.id.LocalAddress == e.id.RemoteAddress + _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + if sameAddr && p == e.id.RemotePort { + return false, nil + } + + e.id.LocalPort = p + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e) + switch err { + case nil: + return true, nil + case tcpip.ErrPortInUse: + return false, nil + default: + return false, err + } + }) + if err != nil { + return err + } + } + + // Remove the port reservation. This can happen when Bind is called + // before Connect: in such a case we don't want to hold on to + // reservations anymore. + if e.isPortReserved { + e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort) + e.isPortReserved = false + } + + e.isRegistered = true + e.state = stateConnecting + e.route = r.Clone() + e.boundNICID = nicid + e.effectiveNetProtos = netProtos + e.workerRunning = true + + go e.protocolMainLoop(false) // S/R-FIXME + + return tcpip.ErrConnectStarted +} + +// ConnectEndpoint is not supported. +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + +// Shutdown closes the read and/or write end of the endpoint connection to its +// peer. +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + e.shutdownFlags |= flags + + switch e.state { + case stateConnected: + // Close for write. + if (flags & tcpip.ShutdownWrite) != 0 { + e.sndBufMu.Lock() + + if e.sndClosed { + // Already closed. + e.sndBufMu.Unlock() + break + } + + // Queue fin segment. + s := newSegmentFromView(&e.route, e.id, nil) + e.sndQueue.PushBack(s) + e.sndBufInQueue++ + + // Mark endpoint as closed. + e.sndClosed = true + + e.sndBufMu.Unlock() + + // Tell protocol goroutine to close. + e.sndCloseWaker.Assert() + } + + case stateListen: + // Tell protocolListenLoop to stop. + if flags&tcpip.ShutdownRead != 0 { + e.notifyProtocolGoroutine(notifyClose) + } + + default: + return tcpip.ErrInvalidEndpointState + } + + return nil +} + +// Listen puts the endpoint in "listen" mode, which allows it to accept +// new connections. +func (e *endpoint) Listen(backlog int) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + // Allow the backlog to be adjusted if the endpoint is not shutting down. + // When the endpoint shuts down, it sets workerCleanup to true, and from + // that point onward, acceptedChan is the responsibility of the cleanup() + // method (and should not be touched anywhere else, including here). + if e.state == stateListen && !e.workerCleanup { + // Adjust the size of the channel iff we can fix existing + // pending connections into the new one. + if len(e.acceptedChan) > backlog { + return tcpip.ErrInvalidEndpointState + } + origChan := e.acceptedChan + e.acceptedChan = make(chan *endpoint, backlog) + close(origChan) + for ep := range origChan { + e.acceptedChan <- ep + } + return nil + } + + // Endpoint must be bound before it can transition to listen mode. + if e.state != stateBound { + return tcpip.ErrInvalidEndpointState + } + + // Register the endpoint. + if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil { + return err + } + + e.isRegistered = true + e.state = stateListen + if e.acceptedChan == nil { + e.acceptedChan = make(chan *endpoint, backlog) + } + e.workerRunning = true + + go e.protocolListenLoop( // S/R-SAFE: drained on save. + seqnum.Size(e.receiveBufferAvailable())) + + return nil +} + +// startAcceptedLoop sets up required state and starts a goroutine with the +// main loop for accepted connections. +func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) { + e.waiterQueue = waiterQueue + e.workerRunning = true + go e.protocolMainLoop(true) // S/R-FIXME +} + +// Accept returns a new endpoint if a peer has established a connection +// to an endpoint previously set to listen mode. +func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + // Endpoint must be in listen state before it can accept connections. + if e.state != stateListen { + return nil, nil, tcpip.ErrInvalidEndpointState + } + + // Get the new accepted endpoint. + var n *endpoint + select { + case n = <-e.acceptedChan: + default: + return nil, nil, tcpip.ErrWouldBlock + } + + // Start the protocol goroutine. + wq := &waiter.Queue{} + n.startAcceptedLoop(wq) + + return n, wq, nil +} + +// Bind binds the endpoint to a specific local port and optionally address. +func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (retErr *tcpip.Error) { + e.mu.Lock() + defer e.mu.Unlock() + + // Don't allow binding once endpoint is not in the initial state + // anymore. This is because once the endpoint goes into a connected or + // listen state, it is already bound. + if e.state != stateInitial { + return tcpip.ErrAlreadyBound + } + + netProto, err := e.checkV4Mapped(&addr) + if err != nil { + return err + } + + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } + } + + // Reserve the port. + port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port) + if err != nil { + return err + } + + e.isPortReserved = true + e.effectiveNetProtos = netProtos + e.id.LocalPort = port + + // Any failures beyond this point must remove the port registration. + defer func() { + if retErr != nil { + e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port) + e.isPortReserved = false + e.effectiveNetProtos = nil + e.id.LocalPort = 0 + e.id.LocalAddress = "" + e.boundNICID = 0 + } + }() + + // If an address is specified, we must ensure that it's one of our + // local addresses. + if len(addr.Addr) != 0 { + nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) + if nic == 0 { + return tcpip.ErrBadLocalAddress + } + + e.boundNICID = nic + e.id.LocalAddress = addr.Addr + } + + // Check the commit function. + if commit != nil { + if err := commit(); err != nil { + // The defer takes care of unwind. + return err + } + } + + // Mark endpoint as bound. + e.state = stateBound + + return nil +} + +// GetLocalAddress returns the address to which the endpoint is bound. +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + return tcpip.FullAddress{ + Addr: e.id.LocalAddress, + Port: e.id.LocalPort, + NIC: e.boundNICID, + }, nil +} + +// GetRemoteAddress returns the address to which the endpoint is connected. +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state != stateConnected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + Addr: e.id.RemoteAddress, + Port: e.id.RemotePort, + NIC: e.boundNICID, + }, nil +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) { + s := newSegment(r, id, vv) + if !s.parse() { + atomic.AddUint64(&e.stack.MutableStats().MalformedRcvdPackets, 1) + s.decRef() + return + } + + // Send packet to worker goroutine. + if e.segmentQueue.enqueue(s) { + e.newSegmentWaker.Assert() + } else { + // The queue is full, so we drop the segment. + atomic.AddUint64(&e.stack.MutableStats().DroppedPackets, 1) + s.decRef() + } +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { + switch typ { + case stack.ControlPacketTooBig: + e.sndBufMu.Lock() + e.packetTooBigCount++ + if v := int(extra); v < e.sndMTU { + e.sndMTU = v + } + e.sndBufMu.Unlock() + + e.notifyProtocolGoroutine(notifyMTUChanged) + } +} + +// updateSndBufferUsage is called by the protocol goroutine when room opens up +// in the send buffer. The number of newly available bytes is v. +func (e *endpoint) updateSndBufferUsage(v int) { + e.sndBufMu.Lock() + notify := e.sndBufUsed >= e.sndBufSize>>1 + e.sndBufUsed -= v + // We only notify when there is half the sndBufSize available after + // a full buffer event occurs. This ensures that we don't wake up + // writers to queue just 1-2 segments and go back to sleep. + notify = notify && e.sndBufUsed < e.sndBufSize>>1 + e.sndBufMu.Unlock() + + if notify { + e.waiterQueue.Notify(waiter.EventOut) + } +} + +// readyToRead is called by the protocol goroutine when a new segment is ready +// to be read, or when the connection is closed for receiving (in which case +// s will be nil). +func (e *endpoint) readyToRead(s *segment) { + e.rcvListMu.Lock() + if s != nil { + s.incRef() + e.rcvBufUsed += s.data.Size() + e.rcvList.PushBack(s) + } else { + e.rcvClosed = true + } + e.rcvListMu.Unlock() + + e.waiterQueue.Notify(waiter.EventIn) +} + +// receiveBufferAvailable calculates how many bytes are still available in the +// receive buffer. +func (e *endpoint) receiveBufferAvailable() int { + e.rcvListMu.Lock() + size := e.rcvBufSize + used := e.rcvBufUsed + e.rcvListMu.Unlock() + + // We may use more bytes than the buffer size when the receive buffer + // shrinks. + if used >= size { + return 0 + } + + return size - used +} + +func (e *endpoint) receiveBufferSize() int { + e.rcvListMu.Lock() + size := e.rcvBufSize + e.rcvListMu.Unlock() + + return size +} + +// updateRecentTimestamp updates the recent timestamp using the algorithm +// described in https://tools.ietf.org/html/rfc7323#section-4.3 +func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) { + if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) { + e.recentTS = tsVal + } +} + +// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if +// the SYN options indicate that timestamp option was negotiated. It also +// initializes the recentTS with the value provided in synOpts.TSval. +func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) { + if synOpts.TS { + e.sendTSOk = true + e.recentTS = synOpts.TSVal + } +} + +// timestamp returns the timestamp value to be used in the TSVal field of the +// timestamp option for outgoing TCP segments for a given endpoint. +func (e *endpoint) timestamp() uint32 { + return tcpTimeStamp(e.tsOffset) +} + +// tcpTimeStamp returns a timestamp offset by the provided offset. This is +// not inlined above as it's used when SYN cookies are in use and endpoint +// is not created at the time when the SYN cookie is sent. +func tcpTimeStamp(offset uint32) uint32 { + now := time.Now() + return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset +} + +// timeStampOffset returns a randomized timestamp offset to be used when sending +// timestamp values in a timestamp option for a TCP segment. +func timeStampOffset() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + // Initialize a random tsOffset that will be added to the recentTS + // everytime the timestamp is sent when the Timestamp option is enabled. + // + // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on + // why this is required. + // + // NOTE: This is not completely to spec as normally this should be + // initialized in a manner analogous to how sequence numbers are + // randomized per connection basis. But for now this is sufficient. + return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24 +} + +// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint +// if the SYN options indicate that the SACK option was negotiated and the TCP +// stack is configured to enable TCP SACK option. +func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) { + var v SACKEnabled + if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil { + // Stack doesn't support SACK. So just return. + return + } + if bool(v) && synOpts.SACKPermitted { + e.sackPermitted = true + } +} + +// completeState makes a full copy of the endpoint and returns it. This is used +// before invoking the probe. The state returned may not be fully consistent if +// there are intervening syscalls when the state is being copied. +func (e *endpoint) completeState() stack.TCPEndpointState { + var s stack.TCPEndpointState + s.SegTime = time.Now() + + // Copy EndpointID. + e.mu.Lock() + s.ID = stack.TCPEndpointID(e.id) + e.mu.Unlock() + + // Copy endpoint rcv state. + e.rcvListMu.Lock() + s.RcvBufSize = e.rcvBufSize + s.RcvBufUsed = e.rcvBufUsed + s.RcvClosed = e.rcvClosed + e.rcvListMu.Unlock() + + // Endpoint TCP Option state. + s.SendTSOk = e.sendTSOk + s.RecentTS = e.recentTS + s.TSOffset = e.tsOffset + s.SACKPermitted = e.sackPermitted + s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks) + copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks]) + + // Copy endpoint send state. + e.sndBufMu.Lock() + s.SndBufSize = e.sndBufSize + s.SndBufUsed = e.sndBufUsed + s.SndClosed = e.sndClosed + s.SndBufInQueue = e.sndBufInQueue + s.PacketTooBigCount = e.packetTooBigCount + s.SndMTU = e.sndMTU + e.sndBufMu.Unlock() + + // Copy receiver state. + s.Receiver = stack.TCPReceiverState{ + RcvNxt: e.rcv.rcvNxt, + RcvAcc: e.rcv.rcvAcc, + RcvWndScale: e.rcv.rcvWndScale, + PendingBufUsed: e.rcv.pendingBufUsed, + PendingBufSize: e.rcv.pendingBufSize, + } + + // Copy sender state. + s.Sender = stack.TCPSenderState{ + LastSendTime: e.snd.lastSendTime, + DupAckCount: e.snd.dupAckCount, + FastRecovery: stack.TCPFastRecoveryState{ + Active: e.snd.fr.active, + First: e.snd.fr.first, + Last: e.snd.fr.last, + MaxCwnd: e.snd.fr.maxCwnd, + }, + SndCwnd: e.snd.sndCwnd, + Ssthresh: e.snd.sndSsthresh, + SndCAAckCount: e.snd.sndCAAckCount, + Outstanding: e.snd.outstanding, + SndWnd: e.snd.sndWnd, + SndUna: e.snd.sndUna, + SndNxt: e.snd.sndNxt, + RTTMeasureSeqNum: e.snd.rttMeasureSeqNum, + RTTMeasureTime: e.snd.rttMeasureTime, + Closed: e.snd.closed, + SRTT: e.snd.srtt, + RTO: e.snd.rto, + SRTTInited: e.snd.srttInited, + MaxPayloadSize: e.snd.maxPayloadSize, + SndWndScale: e.snd.sndWndScale, + MaxSentAck: e.snd.maxSentAck, + } + return s +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go new file mode 100644 index 000000000..dbb70ff21 --- /dev/null +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -0,0 +1,128 @@ +// Copyright 2017 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp + +import ( + "fmt" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// ErrSaveRejection indicates a failed save due to unsupported tcp endpoint +// state. +type ErrSaveRejection struct { + Err error +} + +// Error returns a sensible description of the save rejection error. +func (e ErrSaveRejection) Error() string { + return "save rejected due to unsupported endpoint state: " + e.Err.Error() +} + +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + // Stop incoming packets. + e.segmentQueue.setLimit(0) + + e.mu.RLock() + defer e.mu.RUnlock() + + switch e.state { + case stateInitial: + case stateBound: + case stateListen: + if !e.segmentQueue.empty() { + e.mu.RUnlock() + e.drainDone = make(chan struct{}, 1) + e.notificationWaker.Assert() + <-e.drainDone + e.mu.RLock() + } + case stateConnecting: + panic(ErrSaveRejection{fmt.Errorf("endpoint in connecting state upon save: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) + case stateConnected: + // FIXME + panic(ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) + case stateClosed: + case stateError: + default: + panic(fmt.Sprintf("endpoint in unknown state %v", e.state)) + } +} + +// afterLoad is invoked by stateify. +func (e *endpoint) afterLoad() { + e.stack = stack.StackFromEnv + + if e.state == stateListen { + e.state = stateBound + backlog := cap(e.acceptedChan) + e.acceptedChan = nil + defer func() { + if err := e.Listen(backlog); err != nil { + panic("endpoint listening failed: " + err.String()) + } + }() + } + + if e.state == stateBound { + e.state = stateInitial + defer func() { + if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil { + panic("endpoint binding failed: " + err.String()) + } + }() + } + + if e.state == stateInitial { + var ss SendBufferSizeOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { + if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { + panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max)) + } + if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max { + panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max)) + } + } + } + + e.segmentQueue.setLimit(2 * e.rcvBufSize) + e.workMu.Init() +} + +// saveAcceptedChan is invoked by stateify. +func (e *endpoint) saveAcceptedChan() endpointChan { + if e.acceptedChan == nil { + return endpointChan{} + } + close(e.acceptedChan) + buffer := make([]*endpoint, 0, len(e.acceptedChan)) + for ep := range e.acceptedChan { + buffer = append(buffer, ep) + } + if len(buffer) != cap(buffer) { + panic("endpoint.acceptedChan buffer got consumed by background context") + } + c := cap(e.acceptedChan) + e.acceptedChan = nil + return endpointChan{buffer: buffer, cap: c} +} + +// loadAcceptedChan is invoked by stateify. +func (e *endpoint) loadAcceptedChan(c endpointChan) { + if c.cap == 0 { + return + } + e.acceptedChan = make(chan *endpoint, c.cap) + for _, ep := range c.buffer { + e.acceptedChan <- ep + } +} + +type endpointChan struct { + buffer []*endpoint + cap int +} diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go new file mode 100644 index 000000000..657ac524f --- /dev/null +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -0,0 +1,161 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// Forwarder is a connection request forwarder, which allows clients to decide +// what to do with a connection request, for example: ignore it, send a RST, or +// attempt to complete the 3-way handshake. +// +// The canonical way of using it is to pass the Forwarder.HandlePacket function +// to stack.SetTransportProtocolHandler. +type Forwarder struct { + maxInFlight int + handler func(*ForwarderRequest) + + mu sync.Mutex + inFlight map[stack.TransportEndpointID]struct{} + listen *listenContext +} + +// NewForwarder allocates and initializes a new forwarder with the given +// maximum number of in-flight connection attempts. Once the maximum is reached +// new incoming connection requests will be ignored. +// +// If rcvWnd is set to zero, the default buffer size is used instead. +func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder { + if rcvWnd == 0 { + rcvWnd = DefaultBufferSize + } + return &Forwarder{ + maxInFlight: maxInFlight, + handler: handler, + inFlight: make(map[stack.TransportEndpointID]struct{}), + listen: newListenContext(s, seqnum.Size(rcvWnd), true, 0), + } +} + +// HandlePacket handles a packet if it is of interest to the forwarder (i.e., if +// it's a SYN packet), returning true if it's the case. Otherwise the packet +// is not handled and false is returned. +// +// This function is expected to be passed as an argument to the +// stack.SetTransportProtocolHandler function. +func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) bool { + s := newSegment(r, id, vv) + defer s.decRef() + + // We only care about well-formed SYN packets. + if !s.parse() || s.flags != flagSyn { + return false + } + + opts := parseSynSegmentOptions(s) + + f.mu.Lock() + defer f.mu.Unlock() + + // We have an inflight request for this id, ignore this one for now. + if _, ok := f.inFlight[id]; ok { + return true + } + + // Ignore the segment if we're beyond the limit. + if len(f.inFlight) >= f.maxInFlight { + return true + } + + // Launch a new goroutine to handle the request. + f.inFlight[id] = struct{}{} + s.incRef() + go f.handler(&ForwarderRequest{ // S/R-FIXME + forwarder: f, + segment: s, + synOptions: opts, + }) + + return true +} + +// ForwarderRequest represents a connection request received by the forwarder +// and passed to the client. Clients must eventually call Complete() on it, and +// may optionally create an endpoint to represent it via CreateEndpoint. +type ForwarderRequest struct { + mu sync.Mutex + forwarder *Forwarder + segment *segment + synOptions header.TCPSynOptions +} + +// ID returns the 4-tuple (src address, src port, dst address, dst port) that +// represents the connection request. +func (r *ForwarderRequest) ID() stack.TransportEndpointID { + return r.segment.id +} + +// Complete completes the request, and optionally sends a RST segment back to the +// sender. +func (r *ForwarderRequest) Complete(sendReset bool) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.segment == nil { + panic("Completing already completed forwarder request") + } + + // Remove request from the forwarder. + r.forwarder.mu.Lock() + delete(r.forwarder.inFlight, r.segment.id) + r.forwarder.mu.Unlock() + + // If the caller requested, send a reset. + if sendReset { + replyWithReset(r.segment) + } + + // Release all resources. + r.segment.decRef() + r.segment = nil + r.forwarder = nil +} + +// CreateEndpoint creates a TCP endpoint for the connection request, performing +// the 3-way handshake in the process. +func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + r.mu.Lock() + defer r.mu.Unlock() + + if r.segment == nil { + return nil, tcpip.ErrInvalidEndpointState + } + + f := r.forwarder + ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{ + MSS: r.synOptions.MSS, + WS: r.synOptions.WS, + TS: r.synOptions.TS, + TSVal: r.synOptions.TSVal, + TSEcr: r.synOptions.TSEcr, + SACKPermitted: r.synOptions.SACKPermitted, + }) + if err != nil { + return nil, err + } + + // Start the protocol goroutine. + ep.startAcceptedLoop(queue) + + return ep, nil +} diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go new file mode 100644 index 000000000..d81a1dd9b --- /dev/null +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -0,0 +1,192 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tcp contains the implementation of the TCP transport protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the +// transport protocols when calling stack.New(). Then endpoints can be created +// by passing tcp.ProtocolNumber as the transport protocol number when calling +// Stack.NewEndpoint(). +package tcp + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + // ProtocolName is the string representation of the tcp protocol name. + ProtocolName = "tcp" + + // ProtocolNumber is the tcp protocol number. + ProtocolNumber = header.TCPProtocolNumber + + // MinBufferSize is the smallest size of a receive or send buffer. + minBufferSize = 4 << 10 // 4096 bytes. + + // DefaultBufferSize is the default size of the receive and send buffers. + DefaultBufferSize = 1 << 20 // 1MB + + // MaxBufferSize is the largest size a receive and send buffer can grow to. + maxBufferSize = 4 << 20 // 4MB +) + +// SACKEnabled option can be used to enable SACK support in the TCP +// protocol. See: https://tools.ietf.org/html/rfc2018. +type SACKEnabled bool + +// SendBufferSizeOption allows the default, min and max send buffer sizes for +// TCP endpoints to be queried or configured. +type SendBufferSizeOption struct { + Min int + Default int + Max int +} + +// ReceiveBufferSizeOption allows the default, min and max receive buffer size +// for TCP endpoints to be queried or configured. +type ReceiveBufferSizeOption struct { + Min int + Default int + Max int +} + +type protocol struct { + mu sync.Mutex + sackEnabled bool + sendBufferSize SendBufferSizeOption + recvBufferSize ReceiveBufferSizeOption +} + +// Number returns the tcp protocol number. +func (*protocol) Number() tcpip.TransportProtocolNumber { + return ProtocolNumber +} + +// NewEndpoint creates a new tcp endpoint. +func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, waiterQueue), nil +} + +// MinimumPacketSize returns the minimum valid tcp packet size. +func (*protocol) MinimumPacketSize() int { + return header.TCPMinimumSize +} + +// ParsePorts returns the source and destination ports stored in the given tcp +// packet. +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + h := header.TCP(v) + return h.SourcePort(), h.DestinationPort(), nil +} + +// HandleUnknownDestinationPacket handles packets targeted at this protocol but +// that don't match any existing endpoint. +// +// RFC 793, page 36, states that "If the connection does not exist (CLOSED) then +// a reset is sent in response to any incoming segment except another reset. In +// particular, SYNs addressed to a non-existent connection are rejected by this +// means." +func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) bool { + s := newSegment(r, id, vv) + defer s.decRef() + + if !s.parse() { + return false + } + + // There's nothing to do if this is already a reset packet. + if s.flagIsSet(flagRst) { + return true + } + + replyWithReset(s) + return true +} + +// replyWithReset replies to the given segment with a reset segment. +func replyWithReset(s *segment) { + // Get the seqnum from the packet if the ack flag is set. + seq := seqnum.Value(0) + if s.flagIsSet(flagAck) { + seq = s.ackNumber + } + + ack := s.sequenceNumber.Add(s.logicalLen()) + + sendTCP(&s.route, s.id, nil, flagRst|flagAck, seq, ack, 0) +} + +// SetOption implements TransportProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case SACKEnabled: + p.mu.Lock() + p.sackEnabled = bool(v) + p.mu.Unlock() + return nil + + case SendBufferSizeOption: + if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.sendBufferSize = v + p.mu.Unlock() + return nil + + case ReceiveBufferSizeOption: + if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max { + return tcpip.ErrInvalidOptionValue + } + p.mu.Lock() + p.recvBufferSize = v + p.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +// Option implements TransportProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *SACKEnabled: + p.mu.Lock() + *v = SACKEnabled(p.sackEnabled) + p.mu.Unlock() + return nil + + case *SendBufferSizeOption: + p.mu.Lock() + *v = p.sendBufferSize + p.mu.Unlock() + return nil + + case *ReceiveBufferSizeOption: + p.mu.Lock() + *v = p.recvBufferSize + p.mu.Unlock() + return nil + + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func init() { + stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { + return &protocol{ + sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, + recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize}, + } + }) +} diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go new file mode 100644 index 000000000..574602105 --- /dev/null +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -0,0 +1,208 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp + +import ( + "container/heap" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" +) + +// receiver holds the state necessary to receive TCP segments and turn them +// into a stream of bytes. +type receiver struct { + ep *endpoint + + rcvNxt seqnum.Value + + // rcvAcc is one beyond the last acceptable sequence number. That is, + // the "largest" sequence value that the receiver has announced to the + // its peer that it's willing to accept. This may be different than + // rcvNxt + rcvWnd if the receive window is reduced; in that case we + // have to reduce the window as we receive more data instead of + // shrinking it. + rcvAcc seqnum.Value + + rcvWndScale uint8 + + closed bool + + pendingRcvdSegments segmentHeap + pendingBufUsed seqnum.Size + pendingBufSize seqnum.Size +} + +func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver { + return &receiver{ + ep: ep, + rcvNxt: irs + 1, + rcvAcc: irs.Add(rcvWnd + 1), + rcvWndScale: rcvWndScale, + pendingBufSize: rcvWnd, + } +} + +// acceptable checks if the segment sequence number range is acceptable +// according to the table on page 26 of RFC 793. +func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { + rcvWnd := r.rcvNxt.Size(r.rcvAcc) + if rcvWnd == 0 { + return segLen == 0 && segSeq == r.rcvNxt + } + + return segSeq.InWindow(r.rcvNxt, rcvWnd) || + seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen) +} + +// getSendParams returns the parameters needed by the sender when building +// segments to send. +func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) { + // Calculate the window size based on the current buffer size. + n := r.ep.receiveBufferAvailable() + acc := r.rcvNxt.Add(seqnum.Size(n)) + if r.rcvAcc.LessThan(acc) { + r.rcvAcc = acc + } + + return r.rcvNxt, r.rcvNxt.Size(r.rcvAcc) >> r.rcvWndScale +} + +// nonZeroWindow is called when the receive window grows from zero to nonzero; +// in such cases we may need to send an ack to indicate to our peer that it can +// resume sending data. +func (r *receiver) nonZeroWindow() { + if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 { + // We never got around to announcing a zero window size, so we + // don't need to immediately announce a nonzero one. + return + } + + // Immediately send an ack. + r.ep.snd.sendAck() +} + +// consumeSegment attempts to consume a segment that was received by r. The +// segment may have just been received or may have been received earlier but +// wasn't ready to be consumed then. +// +// Returns true if the segment was consumed, false if it cannot be consumed +// yet because of a missing segment. +func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum.Size) bool { + if segLen > 0 { + // If the segment doesn't include the seqnum we're expecting to + // consume now, we're missing a segment. We cannot proceed until + // we receive that segment though. + if !r.rcvNxt.InWindow(segSeq, segLen) { + return false + } + + // Trim segment to eliminate already acknowledged data. + if segSeq.LessThan(r.rcvNxt) { + diff := segSeq.Size(r.rcvNxt) + segLen -= diff + segSeq.UpdateForward(diff) + s.sequenceNumber.UpdateForward(diff) + s.data.TrimFront(int(diff)) + } + + // Move segment to ready-to-deliver list. Wakeup any waiters. + r.ep.readyToRead(s) + + } else if segSeq != r.rcvNxt { + return false + } + + // Update the segment that we're expecting to consume. + r.rcvNxt = segSeq.Add(segLen) + + // Trim SACK Blocks to remove any SACK information that covers + // sequence numbers that have been consumed. + TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + + if s.flagIsSet(flagFin) { + r.rcvNxt++ + + // Send ACK immediately. + r.ep.snd.sendAck() + + // Tell any readers that no more data will come. + r.closed = true + r.ep.readyToRead(nil) + + // Flush out any pending segments, except the very first one if + // it happens to be the one we're handling now because the + // caller is using it. + first := 0 + if len(r.pendingRcvdSegments) != 0 && r.pendingRcvdSegments[0] == s { + first = 1 + } + + for i := first; i < len(r.pendingRcvdSegments); i++ { + r.pendingRcvdSegments[i].decRef() + } + r.pendingRcvdSegments = r.pendingRcvdSegments[:first] + } + + return true +} + +// handleRcvdSegment handles TCP segments directed at the connection managed by +// r as they arrive. It is called by the protocol main loop. +func (r *receiver) handleRcvdSegment(s *segment) { + // We don't care about receive processing anymore if the receive side + // is closed. + if r.closed { + return + } + + segLen := seqnum.Size(s.data.Size()) + segSeq := s.sequenceNumber + + // If the sequence number range is outside the acceptable range, just + // send an ACK. This is according to RFC 793, page 37. + if !r.acceptable(segSeq, segLen) { + r.ep.snd.sendAck() + return + } + + // Defer segment processing if it can't be consumed now. + if !r.consumeSegment(s, segSeq, segLen) { + if segLen > 0 || s.flagIsSet(flagFin) { + // We only store the segment if it's within our buffer + // size limit. + if r.pendingBufUsed < r.pendingBufSize { + r.pendingBufUsed += s.logicalLen() + s.incRef() + heap.Push(&r.pendingRcvdSegments, s) + } + + UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt) + + // Immediately send an ack so that the peer knows it may + // have to retransmit. + r.ep.snd.sendAck() + } + return + } + + // By consuming the current segment, we may have filled a gap in the + // sequence number domain that allows pending segments to be consumed + // now. So try to do it. + for !r.closed && r.pendingRcvdSegments.Len() > 0 { + s := r.pendingRcvdSegments[0] + segLen := seqnum.Size(s.data.Size()) + segSeq := s.sequenceNumber + + // Skip segment altogether if it has already been acknowledged. + if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) && + !r.consumeSegment(s, segSeq, segLen) { + break + } + + heap.Pop(&r.pendingRcvdSegments) + r.pendingBufUsed -= s.logicalLen() + s.decRef() + } +} diff --git a/pkg/tcpip/transport/tcp/sack.go b/pkg/tcpip/transport/tcp/sack.go new file mode 100644 index 000000000..0b66305a5 --- /dev/null +++ b/pkg/tcpip/transport/tcp/sack.go @@ -0,0 +1,85 @@ +package tcp + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" +) + +const ( + // MaxSACKBlocks is the maximum number of SACK blocks stored + // at receiver side. + MaxSACKBlocks = 6 +) + +// UpdateSACKBlocks updates the list of SACK blocks to include the segment +// specified by segStart->segEnd. If the segment happens to be an out of order +// delivery then the first block in the sack.blocks always includes the +// segment identified by segStart->segEnd. +func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) { + newSB := header.SACKBlock{Start: segStart, End: segEnd} + if sack.NumBlocks == 0 { + sack.Blocks[0] = newSB + sack.NumBlocks = 1 + return + } + var n = 0 + for i := 0; i < sack.NumBlocks; i++ { + start, end := sack.Blocks[i].Start, sack.Blocks[i].End + if end.LessThanEq(start) || start.LessThanEq(rcvNxt) { + // Discard any invalid blocks where end is before start + // and discard any sack blocks that are before rcvNxt as + // those have already been acked. + continue + } + if newSB.Start.LessThanEq(end) && start.LessThanEq(newSB.End) { + // Merge this SACK block into newSB and discard this SACK + // block. + if start.LessThan(newSB.Start) { + newSB.Start = start + } + if newSB.End.LessThan(end) { + newSB.End = end + } + } else { + // Save this block. + sack.Blocks[n] = sack.Blocks[i] + n++ + } + } + if rcvNxt.LessThan(newSB.Start) { + // If this was an out of order segment then make sure that the + // first SACK block is the one that includes the segment. + // + // See the first bullet point in + // https://tools.ietf.org/html/rfc2018#section-4 + if n == MaxSACKBlocks { + // If the number of SACK blocks is equal to + // MaxSACKBlocks then discard the last SACK block. + n-- + } + for i := n - 1; i >= 0; i-- { + sack.Blocks[i+1] = sack.Blocks[i] + } + sack.Blocks[0] = newSB + n++ + } + sack.NumBlocks = n +} + +// TrimSACKBlockList updates the sack block list by removing/modifying any block +// where start is < rcvNxt. +func TrimSACKBlockList(sack *SACKInfo, rcvNxt seqnum.Value) { + n := 0 + for i := 0; i < sack.NumBlocks; i++ { + if sack.Blocks[i].End.LessThanEq(rcvNxt) { + continue + } + if sack.Blocks[i].Start.LessThan(rcvNxt) { + // Shrink this SACK block. + sack.Blocks[i].Start = rcvNxt + } + sack.Blocks[n] = sack.Blocks[i] + n++ + } + sack.NumBlocks = n +} diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go new file mode 100644 index 000000000..c742fc394 --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment.go @@ -0,0 +1,145 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp + +import ( + "sync/atomic" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// Flags that may be set in a TCP segment. +const ( + flagFin = 1 << iota + flagSyn + flagRst + flagPsh + flagAck + flagUrg +) + +// segment represents a TCP segment. It holds the payload and parsed TCP segment +// information, and can be added to intrusive lists. +// segment is mostly immutable, the only field allowed to change is viewToDeliver. +type segment struct { + segmentEntry + refCnt int32 + id stack.TransportEndpointID + route stack.Route + data buffer.VectorisedView + // views is used as buffer for data when its length is large + // enough to store a VectorisedView. + views [8]buffer.View + // viewToDeliver keeps track of the next View that should be + // delivered by the Read endpoint. + viewToDeliver int + sequenceNumber seqnum.Value + ackNumber seqnum.Value + flags uint8 + window seqnum.Size + + // parsedOptions stores the parsed values from the options in the segment. + parsedOptions header.TCPOptions + options []byte +} + +func newSegment(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) *segment { + s := &segment{ + refCnt: 1, + id: id, + route: r.Clone(), + } + s.data = vv.Clone(s.views[:]) + return s +} + +func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment { + s := &segment{ + refCnt: 1, + id: id, + route: r.Clone(), + } + s.views[0] = v + s.data = buffer.NewVectorisedView(len(v), s.views[:1]) + return s +} + +func (s *segment) clone() *segment { + t := &segment{ + refCnt: 1, + id: s.id, + sequenceNumber: s.sequenceNumber, + ackNumber: s.ackNumber, + flags: s.flags, + window: s.window, + route: s.route.Clone(), + viewToDeliver: s.viewToDeliver, + } + t.data = s.data.Clone(t.views[:]) + return t +} + +func (s *segment) flagIsSet(flag uint8) bool { + return (s.flags & flag) != 0 +} + +func (s *segment) decRef() { + if atomic.AddInt32(&s.refCnt, -1) == 0 { + s.route.Release() + } +} + +func (s *segment) incRef() { + atomic.AddInt32(&s.refCnt, 1) +} + +// logicalLen is the segment length in the sequence number space. It's defined +// as the data length plus one for each of the SYN and FIN bits set. +func (s *segment) logicalLen() seqnum.Size { + l := seqnum.Size(s.data.Size()) + if s.flagIsSet(flagSyn) { + l++ + } + if s.flagIsSet(flagFin) { + l++ + } + return l +} + +// parse populates the sequence & ack numbers, flags, and window fields of the +// segment from the TCP header stored in the data. It then updates the view to +// skip the data. Returns boolean indicating if the parsing was successful. +func (s *segment) parse() bool { + h := header.TCP(s.data.First()) + + // h is the header followed by the payload. We check that the offset to + // the data respects the following constraints: + // 1. That it's at least the minimum header size; if we don't do this + // then part of the header would be delivered to user. + // 2. That the header fits within the buffer; if we don't do this, we + // would panic when we tried to access data beyond the buffer. + // + // N.B. The segment has already been validated as having at least the + // minimum TCP size before reaching here, so it's safe to read the + // fields. + offset := int(h.DataOffset()) + if offset < header.TCPMinimumSize || offset > len(h) { + return false + } + + s.options = []byte(h[header.TCPMinimumSize:offset]) + s.parsedOptions = header.ParseTCPOptions(s.options) + s.data.TrimFront(offset) + + s.sequenceNumber = seqnum.Value(h.SequenceNumber()) + s.ackNumber = seqnum.Value(h.AckNumber()) + s.flags = h.Flags() + s.window = seqnum.Size(h.WindowSize()) + + return true +} diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go new file mode 100644 index 000000000..137ddbdd2 --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_heap.go @@ -0,0 +1,36 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp + +type segmentHeap []*segment + +// Len returns the length of h. +func (h segmentHeap) Len() int { + return len(h) +} + +// Less determines whether the i-th element of h is less than the j-th element. +func (h segmentHeap) Less(i, j int) bool { + return h[i].sequenceNumber.LessThan(h[j].sequenceNumber) +} + +// Swap swaps the i-th and j-th elements of h. +func (h segmentHeap) Swap(i, j int) { + h[i], h[j] = h[j], h[i] +} + +// Push adds x as the last element of h. +func (h *segmentHeap) Push(x interface{}) { + *h = append(*h, x.(*segment)) +} + +// Pop removes the last element of h and returns it. +func (h *segmentHeap) Pop() interface{} { + old := *h + n := len(old) + x := old[n-1] + *h = old[:n-1] + return x +} diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go new file mode 100644 index 000000000..c4a7f7d5b --- /dev/null +++ b/pkg/tcpip/transport/tcp/segment_queue.go @@ -0,0 +1,69 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" +) + +// segmentQueue is a bounded, thread-safe queue of TCP segments. +type segmentQueue struct { + mu sync.Mutex + list segmentList + limit int + used int +} + +// empty determines if the queue is empty. +func (q *segmentQueue) empty() bool { + q.mu.Lock() + r := q.used == 0 + q.mu.Unlock() + + return r +} + +// setLimit updates the limit. No segments are immediately dropped in case the +// queue becomes full due to the new limit. +func (q *segmentQueue) setLimit(limit int) { + q.mu.Lock() + q.limit = limit + q.mu.Unlock() +} + +// enqueue adds the given segment to the queue. +// +// Returns true when the segment is successfully added to the queue, in which +// case ownership of the reference is transferred to the queue. And returns +// false if the queue is full, in which case ownership is retained by the +// caller. +func (q *segmentQueue) enqueue(s *segment) bool { + q.mu.Lock() + r := q.used < q.limit + if r { + q.list.PushBack(s) + q.used += s.data.Size() + header.TCPMinimumSize + } + q.mu.Unlock() + + return r +} + +// dequeue removes and returns the next segment from queue, if one exists. +// Ownership is transferred to the caller, who is responsible for decrementing +// the ref count when done. +func (q *segmentQueue) dequeue() *segment { + q.mu.Lock() + s := q.list.Front() + if s != nil { + q.list.Remove(s) + q.used -= s.data.Size() + header.TCPMinimumSize + } + q.mu.Unlock() + + return s +} diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go new file mode 100644 index 000000000..ad94aecd8 --- /dev/null +++ b/pkg/tcpip/transport/tcp/snd.go @@ -0,0 +1,628 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp + +import ( + "math" + "time" + + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" +) + +const ( + // minRTO is the minimum allowed value for the retransmit timeout. + minRTO = 200 * time.Millisecond + + // InitialCwnd is the initial congestion window. + InitialCwnd = 10 +) + +// sender holds the state necessary to send TCP segments. +type sender struct { + ep *endpoint + + // lastSendTime is the timestamp when the last packet was sent. + lastSendTime time.Time + + // dupAckCount is the number of duplicated acks received. It is used for + // fast retransmit. + dupAckCount int + + // fr holds state related to fast recovery. + fr fastRecovery + + // sndCwnd is the congestion window, in packets. + sndCwnd int + + // sndSsthresh is the threshold between slow start and congestion + // avoidance. + sndSsthresh int + + // sndCAAckCount is the number of packets acknowledged during congestion + // avoidance. When enough packets have been ack'd (typically cwnd + // packets), the congestion window is incremented by one. + sndCAAckCount int + + // outstanding is the number of outstanding packets, that is, packets + // that have been sent but not yet acknowledged. + outstanding int + + // sndWnd is the send window size. + sndWnd seqnum.Size + + // sndUna is the next unacknowledged sequence number. + sndUna seqnum.Value + + // sndNxt is the sequence number of the next segment to be sent. + sndNxt seqnum.Value + + // sndNxtList is the sequence number of the next segment to be added to + // the send list. + sndNxtList seqnum.Value + + // rttMeasureSeqNum is the sequence number being used for the latest RTT + // measurement. + rttMeasureSeqNum seqnum.Value + + // rttMeasureTime is the time when the rttMeasureSeqNum was sent. + rttMeasureTime time.Time + + closed bool + writeNext *segment + writeList segmentList + resendTimer timer `state:"nosave"` + resendWaker sleep.Waker `state:"nosave"` + + // srtt, rttvar & rto are the "smoothed round-trip time", "round-trip + // time variation" and "retransmit timeout", as defined in section 2 of + // RFC 6298. + srtt time.Duration + rttvar time.Duration + rto time.Duration + srttInited bool + + // maxPayloadSize is the maximum size of the payload of a given segment. + // It is initialized on demand. + maxPayloadSize int + + // sndWndScale is the number of bits to shift left when reading the send + // window size from a segment. + sndWndScale uint8 + + // maxSentAck is the maxium acknowledgement actually sent. + maxSentAck seqnum.Value +} + +// fastRecovery holds information related to fast recovery from a packet loss. +type fastRecovery struct { + // active whether the endpoint is in fast recovery. The following fields + // are only meaningful when active is true. + active bool + + // first and last represent the inclusive sequence number range being + // recovered. + first seqnum.Value + last seqnum.Value + + // maxCwnd is the maximum value the congestion window may be inflated to + // due to duplicate acks. This exists to avoid attacks where the + // receiver intentionally sends duplicate acks to artificially inflate + // the sender's cwnd. + maxCwnd int +} + +func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender { + s := &sender{ + ep: ep, + sndCwnd: InitialCwnd, + sndSsthresh: math.MaxInt64, + sndWnd: sndWnd, + sndUna: iss + 1, + sndNxt: iss + 1, + sndNxtList: iss + 1, + rto: 1 * time.Second, + rttMeasureSeqNum: iss + 1, + lastSendTime: time.Now(), + maxPayloadSize: int(mss), + maxSentAck: irs + 1, + fr: fastRecovery{ + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1. + last: iss, + }, + } + + // A negative sndWndScale means that no scaling is in use, otherwise we + // store the scaling value. + if sndWndScale > 0 { + s.sndWndScale = uint8(sndWndScale) + } + + s.updateMaxPayloadSize(int(ep.route.MTU()), 0) + + s.resendTimer.init(&s.resendWaker) + + return s +} + +// updateMaxPayloadSize updates the maximum payload size based on the given +// MTU. If this is in response to "packet too big" control packets (indicated +// by the count argument), it also reduces the number of oustanding packets and +// attempts to retransmit the first packet above the MTU size. +func (s *sender) updateMaxPayloadSize(mtu, count int) { + m := mtu - header.TCPMinimumSize + + // Calculate the maximum option size. + var maxSackBlocks [header.TCPMaxSACKBlocks]header.SACKBlock + options := s.ep.makeOptions(maxSackBlocks[:]) + m -= len(options) + putOptions(options) + + // We don't adjust up for now. + if m >= s.maxPayloadSize { + return + } + + // Make sure we can transmit at least one byte. + if m <= 0 { + m = 1 + } + + s.maxPayloadSize = m + + s.outstanding -= count + if s.outstanding < 0 { + s.outstanding = 0 + } + + // Rewind writeNext to the first segment exceeding the MTU. Do nothing + // if it is already before such a packet. + for seg := s.writeList.Front(); seg != nil; seg = seg.Next() { + if seg == s.writeNext { + // We got to writeNext before we could find a segment + // exceeding the MTU. + break + } + + if seg.data.Size() > m { + // We found a segment exceeding the MTU. Rewind + // writeNext and try to retransmit it. + s.writeNext = seg + break + } + } + + // Since we likely reduced the number of outstanding packets, we may be + // ready to send some more. + s.sendData() +} + +// sendAck sends an ACK segment. +func (s *sender) sendAck() { + s.sendSegment(nil, flagAck, s.sndNxt) +} + +// updateRTO updates the retransmit timeout when a new roud-trip time is +// available. This is done in accordance with section 2 of RFC 6298. +func (s *sender) updateRTO(rtt time.Duration) { + if !s.srttInited { + s.rttvar = rtt / 2 + s.srtt = rtt + s.srttInited = true + } else { + diff := s.srtt - rtt + if diff < 0 { + diff = -diff + } + s.rttvar = (3*s.rttvar + diff) / 4 + s.srtt = (7*s.srtt + rtt) / 8 + } + + s.rto = s.srtt + 4*s.rttvar + if s.rto < minRTO { + s.rto = minRTO + } +} + +// resendSegment resends the first unacknowledged segment. +func (s *sender) resendSegment() { + // Don't use any segments we already sent to measure RTT as they may + // have been affected by packets being lost. + s.rttMeasureSeqNum = s.sndNxt + + // Resend the segment. + if seg := s.writeList.Front(); seg != nil { + s.sendSegment(&seg.data, seg.flags, seg.sequenceNumber) + } +} + +// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681, +// page 6, eq. 4. It is called when we detect congestion in the network. +func (s *sender) reduceSlowStartThreshold() { + s.sndSsthresh = s.outstanding / 2 + if s.sndSsthresh < 2 { + s.sndSsthresh = 2 + } +} + +// retransmitTimerExpired is called when the retransmit timer expires, and +// unacknowledged segments are assumed lost, and thus need to be resent. +// Returns true if the connection is still usable, or false if the connection +// is deemed lost. +func (s *sender) retransmitTimerExpired() bool { + // Check if the timer actually expired or if it's a spurious wake due + // to a previously orphaned runtime timer. + if !s.resendTimer.checkExpiration() { + return true + } + + // Give up if we've waited more than a minute since the last resend. + if s.rto >= 60*time.Second { + return false + } + + // Set new timeout. The timer will be restarted by the call to sendData + // below. + s.rto *= 2 + + if s.fr.active { + // We were attempting fast recovery but were not successful. + // Leave the state. We don't need to update ssthresh because it + // has already been updated when entered fast-recovery. + s.leaveFastRecovery() + } + + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4. + // We store the highest sequence number transmitted in cases where + // we were not in fast recovery. + s.fr.last = s.sndNxt - 1 + + // We lost a packet, so reduce ssthresh. + s.reduceSlowStartThreshold() + + // Reduce the congestion window to 1, i.e., enter slow-start. Per + // RFC 5681, page 7, we must use 1 regardless of the value of the + // initial congestion window. + s.sndCwnd = 1 + + // Mark the next segment to be sent as the first unacknowledged one and + // start sending again. Set the number of outstanding packets to 0 so + // that we'll be able to retransmit. + // + // We'll keep on transmitting (or retransmitting) as we get acks for + // the data we transmit. + s.outstanding = 0 + s.writeNext = s.writeList.Front() + s.sendData() + + return true +} + +// sendData sends new data segments. It is called when data becomes available or +// when the send window opens up. +func (s *sender) sendData() { + limit := s.maxPayloadSize + + // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10. + // "A TCP SHOULD set cwnd to no more than RW before beginning + // transmission if the TCP has not sent data in the interval exceeding + // the retrasmission timeout." + if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto { + if s.sndCwnd > InitialCwnd { + s.sndCwnd = InitialCwnd + } + } + + // TODO: We currently don't merge multiple send buffers + // into one segment if they happen to fit. We should do that + // eventually. + var seg *segment + end := s.sndUna.Add(s.sndWnd) + for seg = s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() { + // We abuse the flags field to determine if we have already + // assigned a sequence number to this segment. + if seg.flags == 0 { + seg.sequenceNumber = s.sndNxt + seg.flags = flagAck | flagPsh + } + + var segEnd seqnum.Value + if seg.data.Size() == 0 { + seg.flags = flagAck + + s.ep.rcvListMu.Lock() + rcvBufUsed := s.ep.rcvBufUsed + s.ep.rcvListMu.Unlock() + + s.ep.mu.Lock() + // We're sending a FIN by default + fl := flagFin + if (s.ep.shutdownFlags&tcpip.ShutdownRead) != 0 && rcvBufUsed > 0 { + // If there is unread data we must send a RST. + // For more information see RFC 2525 section 2.17. + fl = flagRst + } + s.ep.mu.Unlock() + seg.flags |= uint8(fl) + + segEnd = seg.sequenceNumber.Add(1) + } else { + // We're sending a non-FIN segment. + if !seg.sequenceNumber.LessThan(end) { + break + } + + available := int(seg.sequenceNumber.Size(end)) + if available > limit { + available = limit + } + + if seg.data.Size() > available { + // Split this segment up. + nSeg := seg.clone() + nSeg.data.TrimFront(available) + nSeg.sequenceNumber.UpdateForward(seqnum.Size(available)) + s.writeList.InsertAfter(seg, nSeg) + seg.data.CapLength(available) + } + + s.outstanding++ + segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size())) + } + + s.sendSegment(&seg.data, seg.flags, seg.sequenceNumber) + + // Update sndNxt if we actually sent new data (as opposed to + // retransmitting some previously sent data). + if s.sndNxt.LessThan(segEnd) { + s.sndNxt = segEnd + } + } + + // Remember the next segment we'll write. + s.writeNext = seg + + // Enable the timer if we have pending data and it's not enabled yet. + if !s.resendTimer.enabled() && s.sndUna != s.sndNxt { + s.resendTimer.enable(s.rto) + } +} + +func (s *sender) enterFastRecovery() { + // Save state to reflect we're now in fast recovery. + s.reduceSlowStartThreshold() + // Save state to reflect we're now in fast recovery. + // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3. + // We inflat the cwnd by 3 to account for the 3 packets which triggered + // the 3 duplicate ACKs and are now not in flight. + s.sndCwnd = s.sndSsthresh + 3 + s.fr.first = s.sndUna + s.fr.last = s.sndNxt - 1 + s.fr.maxCwnd = s.sndCwnd + s.outstanding + s.fr.active = true +} + +func (s *sender) leaveFastRecovery() { + s.fr.active = false + s.fr.first = 0 + s.fr.last = s.sndNxt - 1 + s.fr.maxCwnd = 0 + s.dupAckCount = 0 + + // Deflate cwnd. It had been artificially inflated when new dups arrived. + s.sndCwnd = s.sndSsthresh +} + +// checkDuplicateAck is called when an ack is received. It manages the state +// related to duplicate acks and determines if a retransmit is needed according +// to the rules in RFC 6582 (NewReno). +func (s *sender) checkDuplicateAck(seg *segment) bool { + ack := seg.ackNumber + if s.fr.active { + // We are in fast recovery mode. Ignore the ack if it's out of + // range. + if !ack.InRange(s.sndUna, s.sndNxt+1) { + return false + } + + // Leave fast recovery if it acknowledges all the data covered by + // this fast recovery session. + if s.fr.last.LessThan(ack) { + s.leaveFastRecovery() + return false + } + + // Don't count this as a duplicate if it is carrying data or + // updating the window. + if seg.logicalLen() != 0 || s.sndWnd != seg.window { + return false + } + + // Inflate the congestion window if we're getting duplicate acks + // for the packet we retransmitted. + if ack == s.fr.first { + // We received a dup, inflate the congestion window by 1 + // packet if we're not at the max yet. + if s.sndCwnd < s.fr.maxCwnd { + s.sndCwnd++ + } + return false + } + + // A partial ack was received. Retransmit this packet and + // remember it so that we don't retransmit it again. We don't + // inflate the window because we're putting the same packet back + // onto the wire. + // + // N.B. The retransmit timer will be reset by the caller. + s.fr.first = ack + return true + } + + // We're not in fast recovery yet. A segment is considered a duplicate + // only if it doesn't carry any data and doesn't update the send window, + // because if it does, it wasn't sent in response to an out-of-order + // segment. + if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt { + s.dupAckCount = 0 + return false + } + + // Enter fast recovery when we reach 3 dups. + s.dupAckCount++ + if s.dupAckCount != 3 { + return false + } + + // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2 + // + // We only do the check here, the incrementing of last to the highest + // sequence number transmitted till now is done when enterFastRecovery + // is invoked. + if !s.fr.last.LessThan(seg.ackNumber) { + s.dupAckCount = 0 + return false + } + s.enterFastRecovery() + s.dupAckCount = 0 + return true +} + +// updateCwnd updates the congestion window based on the number of packets that +// were acknowledged. +func (s *sender) updateCwnd(packetsAcked int) { + if s.sndCwnd < s.sndSsthresh { + // Don't let the congestion window cross into the congestion + // avoidance range. + newcwnd := s.sndCwnd + packetsAcked + if newcwnd >= s.sndSsthresh { + newcwnd = s.sndSsthresh + s.sndCAAckCount = 0 + } + + packetsAcked -= newcwnd - s.sndCwnd + s.sndCwnd = newcwnd + if packetsAcked == 0 { + // We've consumed all ack'd packets. + return + } + } + + // Consume the packets in congestion avoidance mode. + s.sndCAAckCount += packetsAcked + if s.sndCAAckCount >= s.sndCwnd { + s.sndCwnd += s.sndCAAckCount / s.sndCwnd + s.sndCAAckCount = s.sndCAAckCount % s.sndCwnd + } +} + +// handleRcvdSegment is called when a segment is received; it is responsible for +// updating the send-related state. +func (s *sender) handleRcvdSegment(seg *segment) { + // Check if we can extract an RTT measurement from this ack. + if s.rttMeasureSeqNum.LessThan(seg.ackNumber) { + s.updateRTO(time.Now().Sub(s.rttMeasureTime)) + s.rttMeasureSeqNum = s.sndNxt + } + + // Update Timestamp if required. See RFC7323, section-4.3. + s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber) + + // Count the duplicates and do the fast retransmit if needed. + rtx := s.checkDuplicateAck(seg) + + // Stash away the current window size. + s.sndWnd = seg.window + + // Ignore ack if it doesn't acknowledge any new data. + ack := seg.ackNumber + if (ack - 1).InRange(s.sndUna, s.sndNxt) { + // When an ack is received we must reset the timer. We stop it + // here and it will be restarted later if needed. + s.resendTimer.disable() + + // Remove all acknowledged data from the write list. + acked := s.sndUna.Size(ack) + s.sndUna = ack + + ackLeft := acked + originalOutstanding := s.outstanding + for ackLeft > 0 { + // We use logicalLen here because we can have FIN + // segments (which are always at the end of list) that + // have no data, but do consume a sequence number. + seg := s.writeList.Front() + datalen := seg.logicalLen() + + if datalen > ackLeft { + seg.data.TrimFront(int(ackLeft)) + break + } + + if s.writeNext == seg { + s.writeNext = seg.Next() + } + s.writeList.Remove(seg) + s.outstanding-- + seg.decRef() + ackLeft -= datalen + } + + // Update the send buffer usage and notify potential waiters. + s.ep.updateSndBufferUsage(int(acked)) + + // If we are not in fast recovery then update the congestion + // window based on the number of acknowledged packets. + if !s.fr.active { + s.updateCwnd(originalOutstanding - s.outstanding) + } + + // It is possible for s.outstanding to drop below zero if we get + // a retransmit timeout, reset outstanding to zero but later + // get an ack that cover previously sent data. + if s.outstanding < 0 { + s.outstanding = 0 + } + } + + // Now that we've popped all acknowledged data from the retransmit + // queue, retransmit if needed. + if rtx { + s.resendSegment() + } + + // Send more data now that some of the pending data has been ack'd, or + // that the window opened up, or the congestion window was inflated due + // to a duplicate ack during fast recovery. This will also re-enable + // the retransmit timer if needed. + s.sendData() +} + +// sendSegment sends a new segment containing the given payload, flags and +// sequence number. +func (s *sender) sendSegment(data *buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error { + s.lastSendTime = time.Now() + if seq == s.rttMeasureSeqNum { + s.rttMeasureTime = s.lastSendTime + } + + rcvNxt, rcvWnd := s.ep.rcv.getSendParams() + + // Remember the max sent ack. + s.maxSentAck = rcvNxt + + if data == nil { + return s.ep.sendRaw(nil, flags, seq, rcvNxt, rcvWnd) + } + + if len(data.Views()) > 1 { + panic("send path does not support views with multiple buffers") + } + + return s.ep.sendRaw(data.First(), flags, seq, rcvNxt, rcvWnd) +} diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go new file mode 100644 index 000000000..19de96dcb --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go @@ -0,0 +1,336 @@ +package tcp_test + +import ( + "fmt" + "reflect" + "testing" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context" +) + +// createConnectWithSACKPermittedOption creates and connects c.ep with the +// SACKPermitted option enabled if the stack in the context has the SACK support +// enabled. +func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint { + return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()}) +} + +func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) { + t.Helper() + if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil { + t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err) + } +} + +// TestSackPermittedConnect establishes a connection with the SACK option +// enabled. +func TestSackPermittedConnect(t *testing.T) { + for _, sackEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + setStackSACKPermitted(t, c, sackEnabled) + rep := createConnectedWithSACKPermittedOption(c) + data := []byte{1, 2, 3} + + rep.SendPacket(data, nil) + savedSeqNum := rep.NextSeqNum + rep.VerifyACKNoSACK() + + // Make an out of order packet and send it. + rep.NextSeqNum += 3 + sackBlocks := []header.SACKBlock{ + {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, + } + rep.SendPacket(data, nil) + + // Restore the saved sequence number so that the + // VerifyXXX calls use the right sequence number for + // checking ACK numbers. + rep.NextSeqNum = savedSeqNum + if sackEnabled { + rep.VerifyACKHasSACK(sackBlocks) + } else { + rep.VerifyACKNoSACK() + } + + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative ACK for all 9 + // bytes sent and no SACK blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned in the ACK. + rep.VerifyACKNoSACK() + }) + } +} + +// TestSackDisabledConnect establishes a connection with the SACK option +// disabled and verifies that no SACKs are sent for out of order segments. +func TestSackDisabledConnect(t *testing.T) { + for _, sackEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + setStackSACKPermitted(t, c, sackEnabled) + + rep := c.CreateConnectedWithOptions(header.TCPSynOptions{}) + + data := []byte{1, 2, 3} + + rep.SendPacket(data, nil) + savedSeqNum := rep.NextSeqNum + rep.VerifyACKNoSACK() + + // Make an out of order packet and send it. + rep.NextSeqNum += 3 + rep.SendPacket(data, nil) + + // The ACK should contain the older sequence number and + // no SACK blocks. + rep.NextSeqNum = savedSeqNum + rep.VerifyACKNoSACK() + + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative ACK for all 9 + // bytes sent and no SACK blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned in the ACK. + rep.VerifyACKNoSACK() + }) + } +} + +// TestSackPermittedAccept accepts and establishes a connection with the +// SACKPermitted option enabled if the connection request specifies the +// SACKPermitted option. In case of SYN cookies SACK should be disabled as we +// don't encode the SACK information in the cookie. +func TestSackPermittedAccept(t *testing.T) { + type testCase struct { + cookieEnabled bool + sackPermitted bool + wndScale int + wndSize uint16 + } + + testCases := []testCase{ + // When cookie is used window scaling is disabled. + {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled. + {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). + } + savedSynCountThreshold := tcp.SynRcvdCountThreshold + defer func() { + tcp.SynRcvdCountThreshold = savedSynCountThreshold + }() + for _, tc := range testCases { + t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { + if tc.cookieEnabled { + tcp.SynRcvdCountThreshold = 0 + } else { + tcp.SynRcvdCountThreshold = savedSynCountThreshold + } + for _, sackEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + setStackSACKPermitted(t, c, sackEnabled) + + rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted}) + // Now verify no SACK blocks are + // received when sack is disabled. + data := []byte{1, 2, 3} + rep.SendPacket(data, nil) + rep.VerifyACKNoSACK() + + savedSeqNum := rep.NextSeqNum + + // Make an out of order packet and send + // it. + rep.NextSeqNum += 3 + sackBlocks := []header.SACKBlock{ + {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))}, + } + rep.SendPacket(data, nil) + + // The ACK should contain the older + // sequence number. + rep.NextSeqNum = savedSeqNum + if sackEnabled && tc.sackPermitted { + rep.VerifyACKHasSACK(sackBlocks) + } else { + rep.VerifyACKNoSACK() + } + + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative + // ACK for all 9 bytes sent and no SACK + // blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned + // in the ACK. + rep.VerifyACKNoSACK() + }) + } + }) + } +} + +// TestSackDisabledAccept accepts and establishes a connection with +// the SACKPermitted option disabled and verifies that no SACKs are +// sent for out of order packets. +func TestSackDisabledAccept(t *testing.T) { + type testCase struct { + cookieEnabled bool + wndScale int + wndSize uint16 + } + + testCases := []testCase{ + // When cookie is used window scaling is disabled. + {true, -1, 0xffff}, // When cookie is used window scaling is disabled. + {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). + } + savedSynCountThreshold := tcp.SynRcvdCountThreshold + defer func() { + tcp.SynRcvdCountThreshold = savedSynCountThreshold + }() + for _, tc := range testCases { + t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) { + if tc.cookieEnabled { + tcp.SynRcvdCountThreshold = 0 + } else { + tcp.SynRcvdCountThreshold = savedSynCountThreshold + } + for _, sackEnabled := range []bool{false, true} { + t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + setStackSACKPermitted(t, c, sackEnabled) + + rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Now verify no SACK blocks are + // received when sack is disabled. + data := []byte{1, 2, 3} + rep.SendPacket(data, nil) + rep.VerifyACKNoSACK() + savedSeqNum := rep.NextSeqNum + + // Make an out of order packet and send + // it. + rep.NextSeqNum += 3 + rep.SendPacket(data, nil) + + // The ACK should contain the older + // sequence number and no SACK blocks. + rep.NextSeqNum = savedSeqNum + rep.VerifyACKNoSACK() + + // Send the missing segment. + rep.SendPacket(data, nil) + // The ACK should contain the cumulative + // ACK for all 9 bytes sent and no SACK + // blocks. + rep.NextSeqNum += 3 + // Check that no SACK block is returned + // in the ACK. + rep.VerifyACKNoSACK() + }) + } + }) + } +} + +func TestUpdateSACKBlocks(t *testing.T) { + testCases := []struct { + segStart seqnum.Value + segEnd seqnum.Value + rcvNxt seqnum.Value + sackBlocks []header.SACKBlock + updated []header.SACKBlock + }{ + // Trivial cases where current SACK block list is empty and we + // have an out of order delivery. + {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}}, + {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}}, + {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}}, + + // Cases where current SACK block list is not empty and we have + // an out of order delivery. Tests that the updated SACK block + // list has the first block as the one that contains the new + // SACK block representing the segment that was just delivered. + {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}}, + {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}}, + {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}}, + + // Ensure that we only retain header.MaxSACKBlocks and drop the + // oldest one if adding a new block exceeds + // header.MaxSACKBlocks. + {24, 30, 9, + []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}}, + []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}}, + + // Cases where segment extends an existing SACK block. + {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}}, + {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, + {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}}, + {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}}, + {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}}, + {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}}, + {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}}, + {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, + {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}}, + {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}}, + {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}}, + {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}}, + + // Cases where segment contains rcvNxt. + {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}}, + } + + for _, tc := range testCases { + var sack tcp.SACKInfo + copy(sack.Blocks[:], tc.sackBlocks) + sack.NumBlocks = len(tc.sackBlocks) + tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt) + if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) { + t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want) + } + + } +} + +func TestTrimSackBlockList(t *testing.T) { + testCases := []struct { + rcvNxt seqnum.Value + sackBlocks []header.SACKBlock + trimmed []header.SACKBlock + }{ + // Simple cases where we trim whole entries. + {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}}, + {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}}, + {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}}, + {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, + // Cases where we need to update a block. + {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}}, + {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}}, + {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}}, + {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}}, + } + for _, tc := range testCases { + var sack tcp.SACKInfo + copy(sack.Blocks[:], tc.sackBlocks) + sack.NumBlocks = len(tc.sackBlocks) + tcp.TrimSACKBlockList(&sack, tc.rcvNxt) + if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) { + t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want) + } + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go new file mode 100644 index 000000000..118d861ba --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -0,0 +1,2759 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp_test + +import ( + "bytes" + "fmt" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/checker" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + // defaultMTU is the MTU, in bytes, used throughout the tests, except + // where another value is explicitly used. It is chosen to match the MTU + // of loopback interfaces on linux systems. + defaultMTU = 65535 + + // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an + // IPv4 endpoint when the MTU is set to defaultMTU in the test. + defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize +) + +func TestGiveUpConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Register for notification, then start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventOut) + defer wq.EventUnregister(&waitEntry) + + if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Close the connection, wait for completion. + ep.Close() + + // Wait for ep to become writable. + <-notifyCh + if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { + t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted) + } +} + +func TestActiveHandshake(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) +} + +func TestNonBlockingClose(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + ep := c.EP + c.EP = nil + + // Close the endpoint and measure how long it takes. + t0 := time.Now() + ep.Close() + if diff := time.Now().Sub(t0); diff > 3*time.Second { + t.Fatalf("Took too long to close: %v", diff) + } +} + +func TestConnectResetAfterClose(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + ep := c.EP + c.EP = nil + + // Close the endpoint, make sure we get a FIN segment, then acknowledge + // to complete closure of sender, but don't send our own FIN. + ep.Close() + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for the ep to give up waiting for a FIN, and send a RST. + time.Sleep(3 * time.Second) + for { + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + if tcp.Flags() == header.TCPFlagAck|header.TCPFlagFin { + // This is a retransmit of the FIN, ignore it. + continue + } + + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + ), + ) + break + } +} + +func TestSimpleReceive(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("Unexpected error from Read: %v", err) + } + + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Receive data. + v, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Unexpected error from Read: %v", err) + } + + if bytes.Compare(data, v) != 0 { + t.Fatalf("Data is different: expected %v, got %v", data, v) + } + + // Check that ACK is received. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestOutOfOrderReceive(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("Unexpected error from Read: %v", err) + } + + // Send second half of data first, with seqnum 3 ahead of expected. + data := []byte{1, 2, 3, 4, 5, 6} + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 793, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Check that we get an ACK specifying which seqnum is expected. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Wait 200ms and check that no data has been received. + time.Sleep(200 * time.Millisecond) + if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("Unexpected error from Read: %v", err) + } + + // Send the first 3 bytes now. + c.SendPacket(data[:3], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Receive data. + read := make([]byte, 0, 6) + for len(read) < len(data) { + v, err := c.EP.Read(nil) + if err != nil { + if err == tcpip.ErrWouldBlock { + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + continue + } + t.Fatalf("Unexpected error from Read: %v", err) + } + + read = append(read, v...) + } + + // Check that we received the data in proper order. + if bytes.Compare(data, read) != 0 { + t.Fatalf("Data is different: expected %v, got %v", data, read) + } + + // Check that the whole data is acknowledged. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestOutOfOrderFlood(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create a new connection with initial window size of 10. + opt := tcpip.ReceiveBufferSizeOption(10) + c.CreateConnected(789, 30000, &opt) + + if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("Unexpected error from Read: %v", err) + } + + // Send 100 packets before the actual one that is expected. + data := []byte{1, 2, 3, 4, 5, 6} + for i := 0; i < 100; i++ { + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 796, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Send packet with seqnum 793. It must be discarded because the + // out-of-order buffer was filled by the previous packets. + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 793, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Now send the expected packet, seqnum 790. + c.SendPacket(data[:3], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Check that only packet 790 is acknowledged. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(793), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestRstOnCloseWithUnreadData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("Unexpected error from Read: %v", err) + } + + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(3 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that ACK is received, this happens regardless of the read. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Now that we know we have unread data, let's just close the connection + // and verify that netstack sends an RST rather than a FIN. + c.EP.Close() + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst), + )) +} + +func TestFullWindowReceive(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + opt := tcpip.ReceiveBufferSizeOption(10) + c.CreateConnected(789, 30000, &opt) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + _, err := c.EP.Read(nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("Unexpected error from Read: %v", err) + } + + // Fill up the window. + data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that data is acknowledged, and window goes to zero. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(0), + ), + ) + + // Receive data and check it. + v, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Unexpected error from Read: %v", err) + } + + if bytes.Compare(data, v) != 0 { + t.Fatalf("Data is different: expected %v, got %v", data, v) + } + + // Check that we get an ACK for the newly non-zero window. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(10), + ), + ) +} + +func TestNoWindowShrinking(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Start off with a window size of 10, then shrink it to 5. + opt := tcpip.ReceiveBufferSizeOption(10) + c.CreateConnected(789, 30000, &opt) + + opt = 5 + if err := c.EP.SetSockOpt(opt); err != nil { + t.Fatalf("SetSockOpt failed: %v", err) + } + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + _, err := c.EP.Read(nil) + if err != tcpip.ErrWouldBlock { + t.Fatalf("Unexpected error from Read: %v", err) + } + + // Send 3 bytes, check that the peer acknowledges them. + data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} + c.SendPacket(data[:3], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that data is acknowledged, and that window doesn't go to zero + // just yet because it was previously set to 10. It must go to 7 now. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(793), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(7), + ), + ) + + // Send 7 more bytes, check that the window fills up. + c.SendPacket(data[3:], &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 793, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + select { + case <-ch: + case <-time.After(5 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(0), + ), + ) + + // Receive data and check it. + read := make([]byte, 0, 10) + for len(read) < len(data) { + v, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Unexpected error from Read: %v", err) + } + + read = append(read, v...) + } + + if bytes.Compare(data, read) != 0 { + t.Fatalf("Data is different: expected %v, got %v", data, read) + } + + // Check that we get an ACK for the newly non-zero window, which is the + // new size. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+len(data))), + checker.TCPFlags(header.TCPFlagAck), + checker.Window(5), + ), + ) +} + +func TestSimpleSend(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(data, p) != 0 { + t.Fatalf("Data is different: expected %v, got %v", data, p) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), + RcvWnd: 30000, + }) +} + +func TestZeroWindowSend(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 0, nil) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Since the window is currently zero, check that no packet is received. + c.CheckNoPacket("Packet received when window is zero") + + // Open up the window. Data should be received now. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Check that data is received. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(data, p) != 0 { + t.Fatalf("Data is different: expected %v, got %v", data, p) + } + + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + seqnum.Size(len(data))), + RcvWnd: 30000, + }) +} + +func TestScaledWindowConnect(t *testing.T) { + // This test ensures that window scaling is used when the peer + // does advertise it and connection is established with Connect(). + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set the window size greater than the maximum non-scaled window. + opt := tcpip.ReceiveBufferSizeOption(65535 * 3) + c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, + }) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Check that data is received, and that advertised window is 0xbfff, + // that is, that it is scaled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(0xbfff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) +} + +func TestNonScaledWindowConnect(t *testing.T) { + // This test ensures that window scaling is not used when the peer + // doesn't advertise it and connection is established with Connect(). + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set the window size greater than the maximum non-scaled window. + opt := tcpip.ReceiveBufferSizeOption(65535 * 3) + c.CreateConnected(789, 30000, &opt) + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Check that data is received, and that advertised window is 0xffff, + // that is, that it's not scaled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(0xffff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) +} + +func TestScaledWindowAccept(t *testing.T) { + // This test ensures that window scaling is used when the peer + // does advertise it and connection is established with Accept(). + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + // Set the window size greater than the maximum non-scaled window. + if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + t.Fatalf("SetSockOpt failed failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Do 3-way handshake. + c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Check that data is received, and that advertised window is 0xbfff, + // that is, that it is scaled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(0xbfff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) +} + +func TestNonScaledWindowAccept(t *testing.T) { + // This test ensures that window scaling is not used when the peer + // doesn't advertise it and connection is established with Accept(). + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + // Set the window size greater than the maximum non-scaled window. + if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil { + t.Fatalf("SetSockOpt failed failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Do 3-way handshake. + c.PassiveConnect(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Check that data is received, and that advertised window is 0xffff, + // that is, that it's not scaled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(0xffff), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) +} + +func TestZeroScaledWindowReceive(t *testing.T) { + // This test ensures that the endpoint sends a non-zero window size + // advertisement when the scaled window transitions from 0 to non-zero, + // but the actual window (not scaled) hasn't gotten to zero. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Set the window size such that a window scale of 4 will be used. + const wnd = 65535 * 10 + const ws = uint32(4) + opt := tcpip.ReceiveBufferSizeOption(wnd) + c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{ + header.TCPOptionWS, 3, 0, header.TCPOptionNOP, + }) + + // Write chunks of 50000 bytes. + remain := wnd + sent := 0 + data := make([]byte, 50000) + for remain > len(data) { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(remain>>ws)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Make the window non-zero, but the scaled window zero. + if remain >= 16 { + data = data[:remain-15] + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + sent), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + sent += len(data) + remain -= len(data) + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(0), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + } + + // Read some data. An ack should be sent in response to that. + v, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Unexpected error from Read: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(790+sent)), + checker.Window(uint16(len(v)>>ws)), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { + payloadMultiplier := 10 + dataLen := payloadMultiplier * maxPayload + data := make([]byte, dataLen) + for i := range data { + data[i] = byte(i) + } + + view := buffer.NewView(len(data)) + copy(view, data) + + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Check that data is received in chunks. + bytesReceived := 0 + numPackets := 0 + for bytesReceived != dataLen { + b := c.GetPacket() + numPackets++ + tcp := header.TCP(header.IPv4(b).Payload()) + payloadLen := len(tcp.Payload()) + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + pdata := data[bytesReceived : bytesReceived+payloadLen] + if p := tcp.Payload(); bytes.Compare(pdata, p) != 0 { + t.Fatalf("Data is different: expected %v, got %v", pdata, p) + } + bytesReceived += payloadLen + var options []byte + if c.TimeStampEnabled { + // If timestamp option is enabled, echo back the timestamp and increment + // the TSEcr value included in the packet and send that back as the TSVal. + parsedOpts := tcp.ParsedOptions() + tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) + options = tsOpt[:] + } + // Acknowledge the data. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), + RcvWnd: 30000, + TCPOpts: options, + }) + } + if numPackets == 1 { + t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet") + } +} + +func TestSendGreaterThanMTU(t *testing.T) { + const maxPayload = 100 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + testBrokenUpWrite(t, c, maxPayload) +} + +func TestActiveSendMSSLessThanMTU(t *testing.T) { + const maxPayload = 100 + c := context.New(t, 65535) + defer c.Cleanup() + + c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), + }) + testBrokenUpWrite(t, c, maxPayload) +} + +func TestPassiveSendMSSLessThanMTU(t *testing.T) { + const maxPayload = 100 + const mtu = 1200 + c := context.New(t, mtu) + defer c.Cleanup() + + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + // Set the buffer size to a deterministic size so that we can check the + // window scaling option. + const rcvBufferSize = 0x20000 + const wndScale = 2 + if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + t.Fatalf("SetSockOpt failed failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Do 3-way handshake. + c.PassiveConnect(maxPayload, wndScale, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Check that data gets properly segmented. + testBrokenUpWrite(t, c, maxPayload) +} + +func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) { + const maxPayload = 536 + const mtu = 2000 + c := context.New(t, mtu) + defer c.Cleanup() + + // Set the SynRcvd threshold to zero to force a syn cookie based accept + // to happen. + saved := tcp.SynRcvdCountThreshold + defer func() { + tcp.SynRcvdCountThreshold = saved + }() + tcp.SynRcvdCountThreshold = 0 + + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Do 3-way handshake. + c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + + // Check that data gets properly segmented. + testBrokenUpWrite(t, c, maxPayload) +} + +func TestForwarderSendMSSLessThanMTU(t *testing.T) { + const maxPayload = 100 + const mtu = 1200 + c := context.New(t, mtu) + defer c.Cleanup() + + s := c.Stack() + ch := make(chan *tcpip.Error, 1) + f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) { + var err *tcpip.Error + c.EP, err = r.CreateEndpoint(&c.WQ) + ch <- err + }) + s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket) + + // Do 3-way handshake. + c.PassiveConnect(maxPayload, 1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize}) + + // Wait for connection to be available. + select { + case err := <-ch: + if err != nil { + t.Fatalf("Error creating endpoint: %v", err) + } + case <-time.After(2 * time.Second): + t.Fatalf("Timed out waiting for connection") + } + + // Check that data gets properly segmented. + testBrokenUpWrite(t, c, maxPayload) +} + +func TestSynOptionsOnActiveConnect(t *testing.T) { + const mtu = 1400 + c := context.New(t, mtu) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + // Set the buffer size to a deterministic size so that we can check the + // window scaling option. + const rcvBufferSize = 0x20000 + const wndScale = 2 + if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil { + t.Fatalf("SetSockOpt failed failed: %v", err) + } + + // Start connection attempt. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventOut) + defer c.WQ.EventUnregister(&we) + + err = c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, WS: wndScale}), + ), + ) + + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + // Wait for retransmit. + time.Sleep(1 * time.Second) + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.SrcPort(tcp.SourcePort()), + checker.SeqNum(tcp.SequenceNumber()), + checker.TCPSynOptions(header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, WS: wndScale}), + ), + ) + + // Send SYN-ACK. + iss := seqnum.Value(789) + c.SendPacket(nil, &context.Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: tcp.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + // Receive ACK packet. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + ), + ) + + // Wait for connection to be established. + select { + case <-ch: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for connection") + } +} + +func TestCloseListener(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create listener. + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Close the listener and measure how long it takes. + t0 := time.Now() + ep.Close() + if diff := time.Now().Sub(t0); diff > 3*time.Second { + t.Fatalf("Took too long to close: %v", diff) + } +} + +func TestReceiveOnResetConnection(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Send RST segment. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagRst, + SeqNum: 790, + RcvWnd: 30000, + }) + + // Try to read. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + +loop: + for { + switch _, err := c.EP.Read(nil); err { + case nil: + t.Fatalf("Unexpected success.") + case tcpip.ErrWouldBlock: + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for reset to arrive") + } + case tcpip.ErrConnectionReset: + break loop + default: + t.Fatalf("Unexpected error: want %v, got %v", tcpip.ErrConnectionReset, err) + } + } +} + +func TestSendOnResetConnection(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Send RST segment. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagRst, + SeqNum: 790, + RcvWnd: 30000, + }) + + // Wait for the RST to be received. + time.Sleep(1 * time.Second) + + // Try to write. + view := buffer.NewView(10) + _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) + if err != tcpip.ErrConnectionReset { + t.Fatalf("Unexpected error from Write: want %v, got %v", tcpip.ErrConnectionReset, err) + } +} + +func TestFinImmediately(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Shutdown immediately, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Unexpected error from Shutdown: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinRetransmit(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Shutdown immediately, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Unexpected error from Shutdown: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + // Don't acknowledge yet. We should get a retransmit of the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinWithNoPendingData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Write something out, and have it acknowledged. + view := buffer.NewView(10) + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Shutdown, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Unexpected error from Shutdown: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ + + // Ack and send FIN as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Check that the stack acks the FIN. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinWithPendingDataCwndFull(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Write enough segments to fill the congestion window before ACK'ing + // any of them. + view := buffer.NewView(10) + for i := tcp.InitialCwnd; i > 0; i-- { + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + } + + next := uint32(c.IRS) + 1 + for i := tcp.InitialCwnd; i > 0; i-- { + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + } + + // Shutdown the connection, check that the FIN segment isn't sent + // because the congestion window doesn't allow it. Wait until a + // retransmit is received. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Unexpected error from Shutdown: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Send the ACK that will allow the FIN to be sent as well. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ + + // Send a FIN that acknowledges everything. Get an ACK back. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinWithPendingData(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Write something out, and acknowledge it to get cwnd to 2. + view := buffer.NewView(10) + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Write new data, but don't acknowledge it. + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + // Shutdown the connection, check that we do get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Unexpected error from Shutdown: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ + + // Send a FIN that acknowledges everything. Get an ACK back. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck), + ), + ) +} + +func TestFinWithPartialAck(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Write something out, and acknowledge it to get cwnd to 2. Also send + // FIN from the test side. + view := buffer.NewView(10) + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + next := uint32(c.IRS) + 1 + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) + + // Check that we get an ACK for the fin. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Write new data, but don't acknowledge it. + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(len(view)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + next += uint32(len(view)) + + // Shutdown the connection, check that we do get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Unexpected error from Shutdown: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(next), + checker.AckNum(791), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + next++ + + // Send an ACK for the data, but not for the FIN yet. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 791, + AckNum: seqnum.Value(next - 1), + RcvWnd: 30000, + }) + + // Check that we don't get a retransmit of the FIN. + c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond) + + // Ack the FIN. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 791, + AckNum: seqnum.Value(next), + RcvWnd: 30000, + }) +} + +func TestExponentialIncreaseDuringSlowStart(t *testing.T) { + maxPayload := 10 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // Double the number of expected packets for the next iteration. + expected *= 2 + } +} + +func TestCongestionAvoidance(t *testing.T) { + maxPayload := 10 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + } + + // Don't acknowledge the first packet of the last packet train. Let's + // wait for them to time out, which will trigger a restart of slow + // start, and initialization of ssthresh to cwnd/2. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // This part is tricky: when the timeout happened, we had "expected" + // packets pending, cwnd reset to 1, and ssthresh set to expected/2. + // By acknowledging "expected" packets, the slow-start part will + // increase cwnd to expected/2 (which "consumes" expected/2-1 of the + // acknowledgements), then the congestion avoidance part will consume + // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack + // remains in the "ack count" (which will cause cwnd to be incremented + // once it reaches cwnd acks). + // + // So we're straight into congestion avoidance with cwnd set to + // expected/2 + 1. + // + // Check that packets trains of cwnd packets are sent, and that cwnd is + // incremented by 1 after we acknowledge each packet. + expected = expected/2 + 1 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // In cogestion avoidance, the packets trains increase by 1 in + // each iteration. + expected++ + } +} + +func TestFastRecovery(t *testing.T) { + maxPayload := 10 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in one shot. Packets will only be written at the + // MTU size though. + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + } + + // Send 3 duplicate acks. This should force an immediate retransmit of + // the pending packet and put the sender into fast recovery. + rtxOffset := bytesRead - maxPayload*expected + for i := 0; i < 3; i++ { + c.SendAck(790, rtxOffset) + } + + // Receive the retransmitted packet. + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Now send 7 mode duplicate acks. Each of these should cause a window + // inflation by 1 and cause the sender to send an extra packet. + for i := 0; i < 7; i++ { + c.SendAck(790, rtxOffset) + } + + recover := bytesRead + + // Ensure no new packets arrive. + c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.", + 50*time.Millisecond) + + // Acknowledge half of the pending data. + rtxOffset = bytesRead - expected*maxPayload/2 + c.SendAck(790, rtxOffset) + + // Receive the retransmit due to partial ack. + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Receive the 10 extra packets that should have been released due to + // the congestion window inflation in recovery. + for i := 0; i < 10; i++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // A partial ACK during recovery should reduce congestion window by the + // number acked. Since we had "expected" packets outstanding before sending + // partial ack and we acked expected/2 , the cwnd and outstanding should + // be expected/2 + 7. Which means the sender should not send any more packets + // till we ack this one. + c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.", + 50*time.Millisecond) + + // Acknowledge all pending data to recover point. + c.SendAck(790, recover) + + // At this point, the cwnd should reset to expected/2 and there are 10 + // packets outstanding. + // + // NOTE: Technically netstack is incorrect in that we adjust the cwnd on + // the same segment that takes us out of recovery. But because of that + // the actual cwnd at exit of recovery will be expected/2 + 1 as we + // acked a cwnd worth of packets which will increase the cwnd further by + // 1 in congestion avoidance. + // + // Now in the first iteration since there are 10 packets outstanding. + // We would expect to get expected/2 +1 - 10 packets. But subsequent + // iterations will send us expected/2 + 1 + 1 (per iteration). + expected = expected/2 + 1 - 10 + for i := 0; i < iterations; i++ { + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond) + + // Acknowledge all the data received so far. + c.SendAck(790, bytesRead) + + // In cogestion avoidance, the packets trains increase by 1 in + // each iteration. + if i == 0 { + // After the first iteration we expect to get the full + // congestion window worth of packets in every + // iteration. + expected += 10 + } + expected++ + } +} + +func TestRetransmit(t *testing.T) { + maxPayload := 10 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + const iterations = 7 + data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1))) + for i := range data { + data[i] = byte(i) + } + + // Write all the data in two shots. Packets will only be written at the + // MTU size though. + half := data[:len(data)/2] + if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + half = data[len(data)/2:] + if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Do slow start for a few iterations. + expected := tcp.InitialCwnd + bytesRead := 0 + for i := 0; i < iterations; i++ { + expected = tcp.InitialCwnd << uint(i) + if i > 0 { + // Acknowledge all the data received so far if not on + // first iteration. + c.SendAck(790, bytesRead) + } + + // Read all packets expected on this iteration. Don't + // acknowledge any of them just yet, so that we can measure the + // congestion window. + for j := 0; j < expected; j++ { + c.ReceiveAndCheckPacket(data, bytesRead, maxPayload) + bytesRead += maxPayload + } + + // Check we don't receive any more packets on this iteration. + // The timeout can't be too high or we'll trigger a timeout. + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) + } + + // Wait for a timeout and retransmit. + rtxOffset := bytesRead - maxPayload*expected + c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload) + + // Acknowledge half of the pending data. + rtxOffset = bytesRead - expected*maxPayload/2 + c.SendAck(790, rtxOffset) + + // Receive the remaining data, making sure that acknowledged data is not + // retransmitted. + for offset := rtxOffset; offset < len(data); offset += maxPayload { + c.ReceiveAndCheckPacket(data, offset, maxPayload) + c.SendAck(790, offset+maxPayload) + } + + c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond) +} + +func TestUpdateListenBacklog(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // Create listener. + var wq waiter.Queue + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + + // Update the backlog with another Listen() on the same endpoint. + if err := ep.Listen(20); err != nil { + t.Fatalf("Listen failed to update backlog: %v", err) + } + + ep.Close() +} + +func scaledSendWindow(t *testing.T, scale uint8) { + // This test ensures that the endpoint is using the right scaling by + // sending a buffer that is larger than the window size, and ensuring + // that the endpoint doesn't send more than allowed. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize + c.CreateConnectedWithRawOptions(789, 0, nil, []byte{ + header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), + header.TCPOptionWS, 3, scale, header.TCPOptionNOP, + }) + + // Open up the window with a scaled value. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 1, + }) + + // Send some data. Check that it's capped by the window size. + view := buffer.NewView(65535) + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Check that only data that fits in the scaled window is sent. + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen((1<<scale)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + // Reset the connection to free resources. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagRst, + SeqNum: 790, + }) +} + +func TestScaledSendWindow(t *testing.T) { + for scale := uint8(0); scale <= 14; scale++ { + scaledSendWindow(t, scale) + } +} + +func TestReceivedSegmentQueuing(t *testing.T) { + // This test sends 200 segments containing a few bytes each to an + // endpoint and checks that they're all received and acknowledged by + // the endpoint, that is, that none of the segments are dropped by + // internal queues. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + // Send 200 segments. + data := []byte{1, 2, 3} + for i := 0; i < 200; i++ { + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(790 + i*len(data)), + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + } + + // Receive ACKs for all segments. + last := seqnum.Value(790 + 200*len(data)) + for { + b := c.GetPacket() + checker.IPv4(t, b, + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + tcp := header.TCP(header.IPv4(b).Payload()) + ack := seqnum.Value(tcp.AckNumber()) + if ack == last { + break + } + + if last.LessThan(ack) { + t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last) + } + } +} + +func TestReadAfterClosedState(t *testing.T) { + // This test ensures that calling Read() or Peek() after the endpoint + // has transitioned to closedState still works if there is pending + // data. To transition to stateClosed without calling Close(), we must + // shutdown the send path and the peer must send its own FIN. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnected(789, 30000, nil) + + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { + t.Fatalf("Unexpected error from Read: %v", err) + } + + // Shutdown immediately for write, check that we get a FIN. + if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { + t.Fatalf("Unexpected error from Shutdown: %v", err) + } + + checker.IPv4(t, c.GetPacket(), + checker.PayloadLen(header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin), + ), + ) + + // Send some data and acknowledge the FIN. + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck | header.TCPFlagFin, + SeqNum: 790, + AckNum: c.IRS.Add(2), + RcvWnd: 30000, + }) + + // Check that ACK is received. + checker.IPv4(t, c.GetPacket(), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+2), + checker.AckNum(uint32(791+len(data))), + checker.TCPFlags(header.TCPFlagAck), + ), + ) + + // Give the stack the chance to transition to closed state. + time.Sleep(1 * time.Second) + + // Wait for receive to be notified. + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Check that peek works. + peekBuf := make([]byte, 10) + n, err := c.EP.Peek([][]byte{peekBuf}) + if err != nil { + t.Fatalf("Unexpected error from Peek: %v", err) + } + + peekBuf = peekBuf[:n] + if bytes.Compare(data, peekBuf) != 0 { + t.Fatalf("Data is different: expected %v, got %v", data, peekBuf) + } + + // Receive data. + v, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Unexpected error from Read: %v", err) + } + + if bytes.Compare(data, v) != 0 { + t.Fatalf("Data is different: expected %v, got %v", data, v) + } + + // Now that we drained the queue, check that functions fail with the + // right error code. + if _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { + t.Fatalf("Unexpected return from Read: got %v, want %v", err, tcpip.ErrClosedForReceive) + } + + if _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { + t.Fatalf("Unexpected return from Peek: got %v, want %v", err, tcpip.ErrClosedForReceive) + } +} + +func TestReusePort(t *testing.T) { + // This test ensures that ports are immediately available for reuse + // after Close on the endpoints using them returns. + c := context.New(t, defaultMTU) + defer c.Cleanup() + + // First case, just an endpoint that was bound. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + c.EP.Close() + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + c.EP.Close() + + // Second case, an endpoint that was bound and is connecting.. + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + err = c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}) + if err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + c.EP.Close() + + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + c.EP.Close() + + // Third case, an endpoint that was bound and is listening. + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + err = c.EP.Listen(10) + if err != nil { + t.Fatalf("Listen failed: %v", err) + } + c.EP.Close() + + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + err = c.EP.Listen(10) + if err != nil { + t.Fatalf("Listen failed: %v", err) + } +} + +func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { + t.Helper() + + var s tcpip.ReceiveBufferSizeOption + if err := ep.GetSockOpt(&s); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) + } + + if int(s) != v { + t.Fatalf("Bad receive buffer size: want=%v, got=%v", v, s) + } +} + +func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) { + t.Helper() + + var s tcpip.SendBufferSizeOption + if err := ep.GetSockOpt(&s); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) + } + + if int(s) != v { + t.Fatalf("Bad send buffer size: want=%v, got=%v", v, s) + } +} + +func TestDefaultBufferSizes(t *testing.T) { + s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}) + + // Check the default values. + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + defer func() { + if ep != nil { + ep.Close() + } + }() + + checkSendBufferSize(t, ep, tcp.DefaultBufferSize) + checkRecvBufferSize(t, ep, tcp.DefaultBufferSize) + + // Change the default send buffer size. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize * 2, tcp.DefaultBufferSize * 20}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %v", err) + } + + ep.Close() + ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + + checkSendBufferSize(t, ep, tcp.DefaultBufferSize*2) + checkRecvBufferSize(t, ep, tcp.DefaultBufferSize) + + // Change the default receive buffer size. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize * 3, tcp.DefaultBufferSize * 30}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %v", err) + } + + ep.Close() + ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + + checkSendBufferSize(t, ep, tcp.DefaultBufferSize*2) + checkRecvBufferSize(t, ep, tcp.DefaultBufferSize*3) +} + +func TestMinMaxBufferSizes(t *testing.T) { + s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}) + + // Check the default values. + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{}) + if err != nil { + t.Fatalf("NewEndpoint failed; %v", err) + } + defer ep.Close() + + // Change the min/max values for send/receive + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultBufferSize * 2, tcp.DefaultBufferSize * 20}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %v", err) + } + + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultBufferSize * 3, tcp.DefaultBufferSize * 30}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %v", err) + } + + // Set values below the min. + if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) + } + + checkRecvBufferSize(t, ep, 200) + + if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) + } + + checkSendBufferSize(t, ep, 300) + + // Set values above the max. + if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultBufferSize*20)); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) + } + + checkRecvBufferSize(t, ep, tcp.DefaultBufferSize*20) + + if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultBufferSize*30)); err != nil { + t.Fatalf("GetSockOpt failed: %v", err) + } + + checkSendBufferSize(t, ep, tcp.DefaultBufferSize*30) +} + +func TestSelfConnect(t *testing.T) { + // This test ensures that intentional self-connects work. In particular, + // it checks that if an endpoint binds to say 127.0.0.1:1000 then + // connects to 127.0.0.1:1000, then it will be connected to itself, and + // is able to send and receive data through the same endpoint. + s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName}) + + id := loopback.New() + if testing.Verbose() { + id = sniffer.New(id) + } + + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, context.StackAddr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: "\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }, + }) + + var wq waiter.Queue + ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Register for notification, then start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + wq.EventRegister(&waitEntry, waiter.EventOut) + defer wq.EventUnregister(&waitEntry) + + err = ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}) + if err != tcpip.ErrConnectStarted { + t.Fatalf("Unexpected return value from Connect: %v", err) + } + + <-notifyCh + err = ep.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + t.Fatalf("Connect failed: %v", err) + } + + // Write something. + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + if _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + // Read back what was written. + wq.EventUnregister(&waitEntry) + wq.EventRegister(&waitEntry, waiter.EventIn) + rd, err := ep.Read(nil) + if err != nil { + if err != tcpip.ErrWouldBlock { + t.Fatalf("Read failed: %v", err) + } + <-notifyCh + rd, err = ep.Read(nil) + if err != nil { + t.Fatalf("Read failed: %v", err) + } + } + + if bytes.Compare(data, rd) != 0 { + t.Fatalf("Data is different: want=%v, got=%v", data, rd) + } +} + +func TestPathMTUDiscovery(t *testing.T) { + // This test verifies the stack retransmits packets after it receives an + // ICMP packet indicating that the path MTU has been exceeded. + c := context.New(t, 1500) + defer c.Cleanup() + + // Create new connection with MSS of 1460. + const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize + c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{ + header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256), + }) + + // Send 3200 bytes of data. + const writeSize = 3200 + data := buffer.NewView(writeSize) + for i := range data { + data[i] = byte(i) + } + + if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Write failed: %v", err) + } + + receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte { + var ret []byte + for i, size := range sizes { + p := c.GetPacket() + if i == which { + ret = p + } + checker.IPv4(t, p, + checker.PayloadLen(size+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(seqNum), + checker.AckNum(790), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + seqNum += uint32(size) + } + return ret + } + + // Receive three packets. + sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload} + first := receivePackets(c, sizes, 0, uint32(c.IRS)+1) + + // Send "packet too big" messages back to netstack. + const newMTU = 1200 + const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize + mtu := []byte{0, 0, newMTU / 256, newMTU % 256} + c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU) + + // See retransmitted packets. None exceeding the new max. + sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload} + receivePackets(c, sizes, -1, uint32(c.IRS)+1) +} + +func TestTCPEndpointProbe(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + invoked := make(chan struct{}) + c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) { + // Validate that the endpoint ID is what we expect. + // + // We don't do an extensive validation of every field but a + // basic sanity test. + if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want { + t.Fatalf("unexpected LocalAddress got: %d, want: %d", got, want) + } + if got, want := state.ID.LocalPort, c.Port; got != want { + t.Fatalf("unexpected LocalPort got: %d, want: %d", got, want) + } + if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want { + t.Fatalf("unexpected RemoteAddress got: %d, want: %d", got, want) + } + if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want { + t.Fatalf("unexpected RemotePort got: %d, want: %d", got, want) + } + + invoked <- struct{}{} + }) + + c.CreateConnected(789, 30000, nil) + + data := []byte{1, 2, 3} + c.SendPacket(data, &context.Headers{ + SrcPort: context.TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: 790, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + }) + + select { + case <-invoked: + case <-time.After(100 * time.Millisecond): + t.Fatalf("TCP Probe function was not called") + } +} diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go new file mode 100644 index 000000000..ae1b578c5 --- /dev/null +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -0,0 +1,302 @@ +package tcp_test + +import ( + "bytes" + "math/rand" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/checker" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// createConnectedWithTimestampOption creates and connects c.ep with the +// timestamp option enabled. +func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint { + return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1}) +} + +// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on +// an active connect and sets the TS Echo Reply fields correctly when the +// SYN-ACK also indicates support for the TS option and provides a TSVal. +func TestTimeStampEnabledConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + rep := createConnectedWithTimestampOption(c) + + // Register for read and validate that we have data to read. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + // The following tests ensure that TS option once enabled behaves + // correctly as described in + // https://tools.ietf.org/html/rfc7323#section-4.3. + // + // We are not testing delayed ACKs here, but we do test out of order + // packet delivery and filling the sequence number hole created due to + // the out of order packet. + // + // The test also verifies that the sequence numbers and timestamps are + // as expected. + data := []byte{1, 2, 3} + + // First we increment tsVal by a small amount. + tsVal := rep.TSVal + 100 + rep.SendPacketWithTS(data, tsVal) + rep.VerifyACKWithTS(tsVal) + + // Next we send an out of order packet. + rep.NextSeqNum += 3 + tsVal += 200 + rep.SendPacketWithTS(data, tsVal) + + // The ACK should contain the original sequenceNumber and an older TS. + rep.NextSeqNum -= 6 + rep.VerifyACKWithTS(tsVal - 200) + + // Next we fill the hole and the returned ACK should contain the + // cumulative sequence number acking all data sent till now and have the + // latest timestamp sent below in its TSEcr field. + tsVal -= 100 + rep.SendPacketWithTS(data, tsVal) + rep.NextSeqNum += 3 + rep.VerifyACKWithTS(tsVal) + + // Increment tsVal by a large value that doesn't result in a wrap around. + tsVal += 0x7fffffff + rep.SendPacketWithTS(data, tsVal) + rep.VerifyACKWithTS(tsVal) + + // Increment tsVal again by a large value which should cause the + // timestamp value to wrap around. The returned ACK should contain the + // wrapped around timestamp in its tsEcr field and not the tsVal from + // the previous packet sent above. + tsVal += 0x7fffffff + rep.SendPacketWithTS(data, tsVal) + rep.VerifyACKWithTS(tsVal) + + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // There should be 5 views to read and each of them should + // contain the same data. + for i := 0; i < 5; i++ { + got, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Unexpected error from Read: %v", err) + } + if want := data; bytes.Compare(got, want) != 0 { + t.Fatalf("Data is different: got: %v, want: %v", got, want) + } + } +} + +// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an +// active connect but if the SYN-ACK doesn't specify the TS option then +// timestamp option is not enabled and future packets do not contain a +// timestamp. +func TestTimeStampDisabledConnect(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + c.CreateConnectedWithOptions(header.TCPSynOptions{}) +} + +func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { + savedSynCountThreshold := tcp.SynRcvdCountThreshold + defer func() { + tcp.SynRcvdCountThreshold = savedSynCountThreshold + }() + + if cookieEnabled { + tcp.SynRcvdCountThreshold = 0 + } + c := context.New(t, defaultMTU) + defer c.Cleanup() + + t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) + tsVal := rand.Uint32() + c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal}) + + // Now send some data and validate that timestamp is echoed correctly in the ACK. + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Check that data is received and that the timestamp option TSEcr field + // matches the expected value. + b := c.GetPacket() + checker.IPv4(t, b, + // Add 12 bytes for the timestamp option + 2 NOPs to align at 4 + // byte boundary. + checker.PayloadLen(len(data)+header.TCPMinimumSize+12), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(wndSize), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPTimestampChecker(true, 0, tsVal+1), + ), + ) +} + +// TestTimeStampEnabledAccept tests that if the SYN on a passive connect +// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK +// and echoes the tsVal field of the original SYN in the tcEcr field of the +// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify +// that Timestamp option is enabled in both cases if requested in the original +// SYN. +func TestTimeStampEnabledAccept(t *testing.T) { + testCases := []struct { + cookieEnabled bool + wndScale int + wndSize uint16 + }{ + {true, -1, 0xffff}, // When cookie is used window scaling is disabled. + {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). + } + for _, tc := range testCases { + timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) + } +} + +func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) { + savedSynCountThreshold := tcp.SynRcvdCountThreshold + defer func() { + tcp.SynRcvdCountThreshold = savedSynCountThreshold + }() + if cookieEnabled { + tcp.SynRcvdCountThreshold = 0 + } + + c := context.New(t, defaultMTU) + defer c.Cleanup() + + t.Logf("Test w/ CookieEnabled = %v", cookieEnabled) + c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Now send some data with the accepted connection endpoint and validate + // that no timestamp option is sent in the TCP segment. + data := []byte{1, 2, 3} + view := buffer.NewView(len(data)) + copy(view, data) + + if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil { + t.Fatalf("Unexpected error from Write: %v", err) + } + + // Check that data is received and that the timestamp option is disabled + // when SYN cookies are enabled/disabled. + b := c.GetPacket() + checker.IPv4(t, b, + checker.PayloadLen(len(data)+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(context.TestPort), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(790), + checker.Window(wndSize), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + checker.TCPTimestampChecker(false, 0, 0), + ), + ) +} + +// TestTimeStampDisabledAccept tests that Timestamp option is not used when the +// peer doesn't advertise it and connection is established with Accept(). +func TestTimeStampDisabledAccept(t *testing.T) { + testCases := []struct { + cookieEnabled bool + wndScale int + wndSize uint16 + }{ + {true, -1, 0xffff}, // When cookie is used window scaling is disabled. + {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default). + } + for _, tc := range testCases { + timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize) + } +} + +func TestSendGreaterThanMTUWithOptions(t *testing.T) { + const maxPayload = 100 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + createConnectedWithTimestampOption(c) + testBrokenUpWrite(t, c, maxPayload) +} + +func TestSegmentDropWhenTimestampMissing(t *testing.T) { + const maxPayload = 100 + c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload)) + defer c.Cleanup() + + rep := createConnectedWithTimestampOption(c) + + // Register for read. + we, ch := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&we, waiter.EventIn) + defer c.WQ.EventUnregister(&we) + + stk := c.Stack() + droppedPackets := stk.Stats().DroppedPackets + data := []byte{1, 2, 3} + // Save the sequence number as we will reset it later down + // in the test. + savedSeqNum := rep.NextSeqNum + rep.SendPacket(data, nil) + + select { + case <-ch: + t.Fatalf("Got data to read when we expect packet to be dropped") + case <-time.After(1 * time.Second): + // We expect that no data will be available to read. + } + + // Assert that DroppedPackets was incremented by 1. + if got, want := stk.Stats().DroppedPackets, droppedPackets+1; got != want { + t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want) + } + + droppedPackets = stk.Stats().DroppedPackets + // Reset the sequence number so that the other endpoint accepts + // this segment and does not treat it like an out of order delivery. + rep.NextSeqNum = savedSeqNum + // Now send a packet with timestamp and we should get the data. + rep.SendPacketWithTS(data, rep.TSVal+1) + + select { + case <-ch: + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for data to arrive") + } + + // Assert that DroppedPackets was not incremented by 1. + if got, want := stk.Stats().DroppedPackets, droppedPackets; got != want { + t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want) + } + + // Issue a read and we should data. + got, err := c.EP.Read(nil) + if err != nil { + t.Fatalf("Unexpected error from Read: %v", err) + } + if want := data; bytes.Compare(got, want) != 0 { + t.Fatalf("Data is different: got: %v, want: %v", got, want) + } +} diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD new file mode 100644 index 000000000..40850c3e7 --- /dev/null +++ b/pkg/tcpip/transport/tcp/testing/context/BUILD @@ -0,0 +1,27 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "context", + testonly = 1, + srcs = ["context.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context", + visibility = [ + "//:sandbox", + ], + deps = [ + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/seqnum", + "//pkg/tcpip/stack", + "//pkg/tcpip/transport/tcp", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go new file mode 100644 index 000000000..6a402d150 --- /dev/null +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -0,0 +1,900 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package context provides a test context for use in tcp tests. It also +// provides helper methods to assert/check certain behaviours. +package context + +import ( + "bytes" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/checker" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + // StackAddr is the IPv4 address assigned to the stack. + StackAddr = "\x0a\x00\x00\x01" + + // StackPort is used as the listening port in tests for passive + // connects. + StackPort = 1234 + + // TestAddr is the source address for packets sent to the stack via the + // link layer endpoint. + TestAddr = "\x0a\x00\x00\x02" + + // TestPort is the TCP port used for packets sent to the stack + // via the link layer endpoint. + TestPort = 4096 + + // StackV6Addr is the IPv6 address assigned to the stack. + StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + + // TestV6Addr is the source address for packets sent to the stack via + // the link layer endpoint. + TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + // StackV4MappedAddr is StackAddr as a mapped v6 address. + StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr + + // TestV4MappedAddr is TestAddr as a mapped v6 address. + TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr + + // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0. + V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + + // testInitialSequenceNumber is the initial sequence number sent in packets that + // are sent in response to a SYN or in the initial SYN sent to the stack. + testInitialSequenceNumber = 789 +) + +// defaultWindowScale value specified here depends on the tcp.DefaultBufferSize +// constant defined in the tcp/endpoint.go because the tcp.DefaultBufferSize is +// used in tcp.newHandshake to determine the window scale to use when sending a +// SYN/SYN-ACK. +var defaultWindowScale = tcp.FindWndScale(tcp.DefaultBufferSize) + +// Headers is used to represent the TCP header fields when building a +// new packet. +type Headers struct { + // SrcPort holds the src port value to be used in the packet. + SrcPort uint16 + + // DstPort holds the destination port value to be used in the packet. + DstPort uint16 + + // SeqNum is the value of the sequence number field in the TCP header. + SeqNum seqnum.Value + + // AckNum represents the acknowledgement number field in the TCP header. + AckNum seqnum.Value + + // Flags are the TCP flags in the TCP header. + Flags int + + // RcvWnd is the window to be advertised in the ReceiveWindow field of + // the TCP header. + RcvWnd seqnum.Size + + // TCPOpts holds the options to be sent in the option field of the TCP + // header. + TCPOpts []byte +} + +// Context provides an initialized Network stack and a link layer endpoint +// for use in TCP tests. +type Context struct { + t *testing.T + linkEP *channel.Endpoint + s *stack.Stack + + // IRS holds the initial sequence number in the SYN sent by endpoint in + // case of an active connect or the sequence number sent by the endpoint + // in the SYN-ACK sent in response to a SYN when listening in passive + // mode. + IRS seqnum.Value + + // Port holds the port bound by EP below in case of an active connect or + // the listening port number in case of a passive connect. + Port uint16 + + // EP is the test endpoint in the stack owned by this context. This endpoint + // is used in various tests to either initiate an active connect or is used + // as a passive listening endpoint to accept inbound connections. + EP tcpip.Endpoint + + // Wq is the wait queue associated with EP and is used to block for events + // on EP. + WQ waiter.Queue + + // TimeStampEnabled is true if ep is connected with the timestamp option + // enabled. + TimeStampEnabled bool +} + +// New allocates and initializes a test context containing a new +// stack and a link-layer endpoint. +func New(t *testing.T, mtu uint32) *Context { + s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName}) + + // Allow minimum send/receive buffer sizes to be 1 during tests. + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %v", err) + } + + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil { + t.Fatalf("SetTransportProtocolOption failed: %v", err) + } + + // Some of the congestion control tests send up to 640 packets, we so + // set the channel size to 1000. + id, linkEP := channel.New(1000, mtu, "") + if testing.Verbose() { + id = sniffer.New(id) + } + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: "\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }, + { + Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }, + }) + + return &Context{ + t: t, + s: s, + linkEP: linkEP, + } +} + +// Cleanup closes the context endpoint if required. +func (c *Context) Cleanup() { + if c.EP != nil { + c.EP.Close() + } +} + +// Stack returns a reference to the stack in the Context. +func (c *Context) Stack() *stack.Stack { + return c.s +} + +// CheckNoPacketTimeout verifies that no packet is received during the time +// specified by wait. +func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) { + select { + case <-c.linkEP.C: + c.t.Fatalf(errMsg) + + case <-time.After(wait): + } +} + +// CheckNoPacket verifies that no packet is received for 1 second. +func (c *Context) CheckNoPacket(errMsg string) { + c.CheckNoPacketTimeout(errMsg, 1*time.Second) +} + +// GetPacket reads a packet from the link layer endpoint and verifies +// that it is an IPv4 packet with the expected source and destination +// addresses. It will fail with an error if no packet is received for +// 2 seconds. +func (c *Context) GetPacket() []byte { + select { + case p := <-c.linkEP.C: + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } + b := make([]byte, len(p.Header)+len(p.Payload)) + copy(b, p.Header) + copy(b[len(p.Header):], p.Payload) + + checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr)) + return b + + case <-time.After(2 * time.Second): + c.t.Fatalf("Packet wasn't written out") + } + + return nil +} + +// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint. +func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) { + // Allocate a buffer data and headers. + buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(p1) + len(p2)) + if len(buf) > maxTotalSize { + buf = buf[:maxTotalSize] + } + + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(header.ICMPv4ProtocolNumber), + SrcAddr: TestAddr, + DstAddr: StackAddr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + icmp := header.ICMPv4(buf[header.IPv4MinimumSize:]) + icmp.SetType(typ) + icmp.SetCode(code) + + copy(icmp[header.ICMPv4MinimumSize:], p1) + copy(icmp[header.ICMPv4MinimumSize+len(p1):], p2) + + // Inject packet. + var views [1]buffer.View + vv := buf.ToVectorisedView(views) + c.linkEP.Inject(ipv4.ProtocolNumber, &vv) +} + +// SendPacket builds and sends a TCP segment(with the provided payload & TCP +// headers) in an IPv4 packet via the link layer endpoint. +func (c *Context) SendPacket(payload []byte, h *Headers) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts) + + // Initialize the IP header. + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(tcp.ProtocolNumber), + SrcAddr: TestAddr, + DstAddr: StackAddr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Initialize the TCP header. + t := header.TCP(buf[header.IPv4MinimumSize:]) + t.Encode(&header.TCPFields{ + SrcPort: h.SrcPort, + DstPort: h.DstPort, + SeqNum: uint32(h.SeqNum), + AckNum: uint32(h.AckNum), + DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)), + Flags: uint8(h.Flags), + WindowSize: uint16(h.RcvWnd), + }) + + // Calculate the TCP pseudo-header checksum. + xsum := header.Checksum([]byte(TestAddr), 0) + xsum = header.Checksum([]byte(StackAddr), xsum) + xsum = header.Checksum([]byte{0, uint8(tcp.ProtocolNumber)}, xsum) + + // Calculate the TCP checksum and set it. + length := uint16(header.TCPMinimumSize + len(h.TCPOpts) + len(payload)) + xsum = header.Checksum(payload, xsum) + t.SetChecksum(^t.CalculateChecksum(xsum, length)) + + // Inject packet. + var views [1]buffer.View + vv := buf.ToVectorisedView(views) + c.linkEP.Inject(ipv4.ProtocolNumber, &vv) +} + +// SendAck sends an ACK packet. +func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) { + c.SendPacket(nil, &Headers{ + SrcPort: TestPort, + DstPort: c.Port, + Flags: header.TCPFlagAck, + SeqNum: seqnum.Value(testInitialSequenceNumber).Add(1), + AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)), + RcvWnd: 30000, + }) +} + +// ReceiveAndCheckPacket reads a packet from the link layer endpoint and +// verifies that the packet packet payload of packet matches the slice +// of data indicated by offset & size. +func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) { + b := c.GetPacket() + checker.IPv4(c.t, b, + checker.PayloadLen(size+header.TCPMinimumSize), + checker.TCP( + checker.DstPort(TestPort), + checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))), + checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))), + checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)), + ), + ) + + pdata := data[offset:][:size] + if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 { + c.t.Fatalf("Data is different: expected %v, got %v", pdata, p) + } +} + +// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only +// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6 +// only endpoint instead of a default dual stack socket. +func (c *Context) CreateV6Endpoint(v6only bool) { + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + var v tcpip.V6OnlyOption + if v6only { + v = 1 + } + if err := c.EP.SetSockOpt(v); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } +} + +// GetV6Packet reads a single packet from the link layer endpoint of the context +// and asserts that it is an IPv6 Packet with the expected src/dest addresses. +func (c *Context) GetV6Packet() []byte { + select { + case p := <-c.linkEP.C: + if p.Proto != ipv6.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) + } + b := make([]byte, len(p.Header)+len(p.Payload)) + copy(b, p.Header) + copy(b[len(p.Header):], p.Payload) + + checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr)) + return b + + case <-time.After(2 * time.Second): + c.t.Fatalf("Packet wasn't written out") + } + + return nil +} + +// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of +// the context. +func (c *Context) SendV6Packet(payload []byte, h *Headers) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.TCPMinimumSize + len(payload)), + NextHeader: uint8(tcp.ProtocolNumber), + HopLimit: 65, + SrcAddr: TestV6Addr, + DstAddr: StackV6Addr, + }) + + // Initialize the TCP header. + t := header.TCP(buf[header.IPv6MinimumSize:]) + t.Encode(&header.TCPFields{ + SrcPort: h.SrcPort, + DstPort: h.DstPort, + SeqNum: uint32(h.SeqNum), + AckNum: uint32(h.AckNum), + DataOffset: header.TCPMinimumSize, + Flags: uint8(h.Flags), + WindowSize: uint16(h.RcvWnd), + }) + + // Calculate the TCP pseudo-header checksum. + xsum := header.Checksum([]byte(TestV6Addr), 0) + xsum = header.Checksum([]byte(StackV6Addr), xsum) + xsum = header.Checksum([]byte{0, uint8(tcp.ProtocolNumber)}, xsum) + + // Calculate the TCP checksum and set it. + length := uint16(header.TCPMinimumSize + len(payload)) + xsum = header.Checksum(payload, xsum) + t.SetChecksum(^t.CalculateChecksum(xsum, length)) + + // Inject packet. + var views [1]buffer.View + vv := buf.ToVectorisedView(views) + c.linkEP.Inject(ipv6.ProtocolNumber, &vv) +} + +// CreateConnected creates a connected TCP endpoint. +func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) { + c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) +} + +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + if epRcvBuf != nil { + if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } + } + + // Start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) + if err != tcpip.ErrConnectStarted { + c.t.Fatalf("Unexpected return value from Connect: %v", err) + } + + // Receive SYN packet. + b := c.GetPacket() + checker.IPv4(c.t, b, + checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagSyn), + ), + ) + + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + c.SendPacket(nil, &Headers{ + SrcPort: tcp.DestinationPort(), + DstPort: tcp.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: rcvWnd, + TCPOpts: options, + }) + + // Receive ACK packet. + checker.IPv4(c.t, c.GetPacket(), + checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS)+1), + checker.AckNum(uint32(iss)+1), + ), + ) + + // Wait for connection to be established. + select { + case <-notifyCh: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + c.t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for connection") + } + + c.Port = tcp.SourcePort() +} + +// RawEndpoint is just a small wrapper around a TCP endpoint's state to make +// sending data and ACK packets easy while being able to manipulate the sequence +// numbers and timestamp values as needed. +type RawEndpoint struct { + C *Context + SrcPort uint16 + DstPort uint16 + Flags int + NextSeqNum seqnum.Value + AckNum seqnum.Value + WndSize seqnum.Size + RecentTS uint32 // Stores the latest timestamp to echo back. + TSVal uint32 // TSVal stores the last timestamp sent by this endpoint. + + // SackPermitted is true if SACKPermitted option was negotiated for this endpoint. + SACKPermitted bool +} + +// SendPacketWithTS embeds the provided tsVal in the Timestamp option +// for the packet to be sent out. +func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) { + r.TSVal = tsVal + tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:]) + r.SendPacket(payload, tsOpt[:]) +} + +// SendPacket is a small wrapper function to build and send packets. +func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) { + packetHeaders := &Headers{ + SrcPort: r.SrcPort, + DstPort: r.DstPort, + Flags: r.Flags, + SeqNum: r.NextSeqNum, + AckNum: r.AckNum, + RcvWnd: r.WndSize, + TCPOpts: opts, + } + r.C.SendPacket(payload, packetHeaders) + r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload))) +} + +// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided +// tsVal. +func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) { + // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4] + ackPacket := r.C.GetPacket() + checker.IPv4(r.C.t, ackPacket, + checker.TCP( + checker.DstPort(r.SrcPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(r.AckNum)), + checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPTimestampChecker(true, 0, tsVal), + ), + ) + // Store the parsed TSVal from the ack as recentTS. + tcpSeg := header.TCP(header.IPv4(ackPacket).Payload()) + opts := tcpSeg.ParsedOptions() + r.RecentTS = opts.TSVal +} + +// VerifyACKNoSACK verifies that the ACK does not contain a SACK block. +func (r *RawEndpoint) VerifyACKNoSACK() { + r.VerifyACKHasSACK(nil) +} + +// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks. +func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) { + // Read ACK and verify that the TCP options in the segment do + // not contain a SACK block. + ackPacket := r.C.GetPacket() + checker.IPv4(r.C.t, ackPacket, + checker.TCP( + checker.DstPort(r.SrcPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(r.AckNum)), + checker.AckNum(uint32(r.NextSeqNum)), + checker.TCPSACKBlockChecker(sackBlocks), + ), + ) +} + +// CreateConnectedWithOptions creates and connects c.ep with the specified TCP +// options enabled and returns a RawEndpoint which represents the other end of +// the connection. +// +// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK +// does not carry an option that was not requested. +func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint { + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) + } + + // Start connection attempt. + waitEntry, notifyCh := waiter.NewChannelEntry(nil) + c.WQ.EventRegister(&waitEntry, waiter.EventOut) + defer c.WQ.EventUnregister(&waitEntry) + + testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort} + err = c.EP.Connect(testFullAddr) + if err != tcpip.ErrConnectStarted { + c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err) + } + // Receive SYN packet. + b := c.GetPacket() + // Validate that the syn has the timestamp option and a valid + // TS value. + checker.IPv4(c.t, b, + checker.TCP( + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagSyn), + checker.TCPSynOptions(header.TCPSynOptions{ + MSS: uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize), + TS: true, + WS: defaultWindowScale, + SACKPermitted: c.SACKEnabled(), + }), + ), + ) + tcpSeg := header.TCP(header.IPv4(b).Payload()) + synOptions := header.ParseSynOptions(tcpSeg.Options(), false) + + // Build options w/ tsVal to be sent in the SYN-ACK. + synAckOptions := make([]byte, 40) + offset := 0 + if wantOptions.TS { + offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:]) + } + if wantOptions.SACKPermitted { + offset += header.EncodeSACKPermittedOption(synAckOptions[offset:]) + } + + offset += header.AddTCPOptionPadding(synAckOptions, offset) + + // Build SYN-ACK. + c.IRS = seqnum.Value(tcpSeg.SequenceNumber()) + iss := seqnum.Value(testInitialSequenceNumber) + c.SendPacket(nil, &Headers{ + SrcPort: tcpSeg.DestinationPort(), + DstPort: tcpSeg.SourcePort(), + Flags: header.TCPFlagSyn | header.TCPFlagAck, + SeqNum: iss, + AckNum: c.IRS.Add(1), + RcvWnd: 30000, + TCPOpts: synAckOptions[:offset], + }) + + // Read ACK. + ackPacket := c.GetPacket() + + // Verify TCP header fields. + tcpCheckers := []checker.TransportChecker{ + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagAck), + checker.SeqNum(uint32(c.IRS) + 1), + checker.AckNum(uint32(iss) + 1), + } + + // Verify that tsEcr of ACK packet is wantOptions.TSVal if the + // timestamp option was enabled, if not then we verify that + // there is no timestamp in the ACK packet. + if wantOptions.TS { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal)) + } else { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) + } + + checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...)) + + ackSeg := header.TCP(header.IPv4(ackPacket).Payload()) + ackOptions := ackSeg.ParsedOptions() + + // Wait for connection to be established. + select { + case <-notifyCh: + err = c.EP.GetSockOpt(tcpip.ErrorOption{}) + if err != nil { + c.t.Fatalf("Unexpected error when connecting: %v", err) + } + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for connection") + } + + // Store the source port in use by the endpoint. + c.Port = tcpSeg.SourcePort() + + // Mark in context that timestamp option is enabled for this endpoint. + c.TimeStampEnabled = true + + return &RawEndpoint{ + C: c, + SrcPort: tcpSeg.DestinationPort(), + DstPort: tcpSeg.SourcePort(), + Flags: header.TCPFlagAck | header.TCPFlagPsh, + NextSeqNum: iss + 1, + AckNum: c.IRS.Add(1), + WndSize: 30000, + RecentTS: ackOptions.TSVal, + TSVal: wantOptions.TSVal, + SACKPermitted: wantOptions.SACKPermitted, + } +} + +// AcceptWithOptions initializes a listening endpoint and connects to it with the +// provided options enabled. It also verifies that the SYN-ACK has the expected +// values for the provided options. +// +// The function returns a RawEndpoint representing the other end of the accepted +// endpoint. +func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + // Create EP and start listening. + wq := &waiter.Queue{} + ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + defer ep.Close() + + if err := ep.Bind(tcpip.FullAddress{Port: StackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + if err := ep.Listen(10); err != nil { + c.t.Fatalf("Listen failed: %v", err) + } + + rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + c.EP, _, err = ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + c.EP, _, err = ep.Accept() + if err != nil { + c.t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for accept") + } + } + return rep +} + +// PassiveConnect just disables WindowScaling and delegates the call to +// PassiveConnectWithOptions. +func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) { + synOptions.WS = -1 + c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions) +} + +// PassiveConnectWithOptions initiates a new connection (with the specified TCP +// options enabled) to the port on which the Context.ep is listening for new +// connections. It also validates that the SYN-ACK has the expected values for +// the enabled options. +// +// NOTE: MSS is not a negotiated option and it can be asymmetric +// in each direction. This function uses the maxPayload to set the MSS to be +// sent to the peer on a connect and validates that the MSS in the SYN-ACK +// response is equal to the MTU - (tcphdr len + iphdr len). +// +// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the +// value of the window scaling option to be sent in the SYN. If synOptions.WS > +// 0 then we send the WindowScale option. +func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint { + opts := make([]byte, 40) + offset := 0 + offset += header.EncodeMSSOption(uint32(maxPayload), opts) + + if synOptions.WS >= 0 { + offset += header.EncodeWSOption(3, opts[offset:]) + } + if synOptions.TS { + offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:]) + } + + if synOptions.SACKPermitted { + offset += header.EncodeSACKPermittedOption(opts[offset:]) + } + + paddingToAdd := 4 - offset%4 + // Now add any padding bytes that might be required to quad align the + // options. + for i := offset; i < offset+paddingToAdd; i++ { + opts[i] = header.TCPOptionNOP + } + offset += paddingToAdd + + // Send a SYN request. + iss := seqnum.Value(testInitialSequenceNumber) + c.SendPacket(nil, &Headers{ + SrcPort: TestPort, + DstPort: StackPort, + Flags: header.TCPFlagSyn, + SeqNum: iss, + RcvWnd: 30000, + TCPOpts: opts[:offset], + }) + + // Receive the SYN-ACK reply. Make sure MSS and other expected options + // are present. + b := c.GetPacket() + tcp := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcp.SequenceNumber()) + + tcpCheckers := []checker.TransportChecker{ + checker.SrcPort(StackPort), + checker.DstPort(TestPort), + checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn), + checker.AckNum(uint32(iss) + 1), + checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}), + } + + // If TS option was enabled in the original SYN then add a checker to + // validate the Timestamp option in the SYN-ACK. + if synOptions.TS { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal)) + } else { + tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0)) + } + + checker.IPv4(c.t, b, checker.TCP(tcpCheckers...)) + rcvWnd := seqnum.Size(30000) + ackHeaders := &Headers{ + SrcPort: TestPort, + DstPort: StackPort, + Flags: header.TCPFlagAck, + SeqNum: iss + 1, + AckNum: c.IRS + 1, + RcvWnd: rcvWnd, + } + + // If WS was expected to be in effect then scale the advertised window + // correspondingly. + if synOptions.WS > 0 { + ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS) + } + + parsedOpts := tcp.ParsedOptions() + if synOptions.TS { + // Echo the tsVal back to the peer in the tsEcr field of the + // timestamp option. + // Increment TSVal by 1 from the value sent in the SYN and echo + // the TSVal in the SYN-ACK in the TSEcr field. + opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} + header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:]) + ackHeaders.TCPOpts = opts[:] + } + + // Send ACK. + c.SendPacket(nil, ackHeaders) + + c.Port = StackPort + + return &RawEndpoint{ + C: c, + SrcPort: TestPort, + DstPort: StackPort, + Flags: header.TCPFlagPsh | header.TCPFlagAck, + NextSeqNum: iss + 1, + AckNum: c.IRS + 1, + WndSize: rcvWnd, + SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(), + RecentTS: parsedOpts.TSVal, + TSVal: synOptions.TSVal + 1, + } +} + +// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true +// for the Stack in the context. +func (c *Context) SACKEnabled() bool { + var v tcp.SACKEnabled + if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil { + // Stack doesn't support SACK. So just return. + return false + } + return bool(v) +} diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go new file mode 100644 index 000000000..7aa824d8f --- /dev/null +++ b/pkg/tcpip/transport/tcp/timer.go @@ -0,0 +1,131 @@ +// Copyright 2017 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcp + +import ( + "time" + + "gvisor.googlesource.com/gvisor/pkg/sleep" +) + +type timerState int + +const ( + timerStateDisabled timerState = iota + timerStateEnabled + timerStateOrphaned +) + +// timer is a timer implementation that reduces the interactions with the +// runtime timer infrastructure by letting timers run (and potentially +// eventually expire) even if they are stopped. It makes it cheaper to +// disable/reenable timers at the expense of spurious wakes. This is useful for +// cases when the same timer is disabled/reenabled repeatedly with relatively +// long timeouts farther into the future. +// +// TCP retransmit timers benefit from this because they the timeouts are long +// (currently at least 200ms), and get disabled when acks are received, and +// reenabled when new pending segments are sent. +// +// It is advantageous to avoid interacting with the runtime because it acquires +// a global mutex and performs O(log n) operations, where n is the global number +// of timers, whenever a timer is enabled or disabled, and may make a syscall. +// +// This struct is thread-compatible. +type timer struct { + // state is the current state of the timer, it can be one of the + // following values: + // disabled - the timer is disabled. + // orphaned - the timer is disabled, but the runtime timer is + // enabled, which means that it will evetually cause a + // spurious wake (unless it gets enabled again before + // then). + // enabled - the timer is enabled, but the runtime timer may be set + // to an earlier expiration time due to a previous + // orphaned state. + state timerState + + // target is the expiration time of the current timer. It is only + // meaningful in the enabled state. + target time.Time + + // runtimeTarget is the expiration time of the runtime timer. It is + // meaningful in the enabled and orphaned states. + runtimeTarget time.Time + + // timer is the runtime timer used to wait on. + timer *time.Timer +} + +// init initializes the timer. Once it expires, it the given waker will be +// asserted. +func (t *timer) init(w *sleep.Waker) { + t.state = timerStateDisabled + + // Initialize a runtime timer that will assert the waker, then + // immediately stop it. + t.timer = time.AfterFunc(time.Hour, func() { + w.Assert() + }) + t.timer.Stop() +} + +// cleanup frees all resources associated with the timer. +func (t *timer) cleanup() { + t.timer.Stop() +} + +// checkExpiration checks if the given timer has actually expired, it should be +// called whenever a sleeper wakes up due to the waker being asserted, and is +// used to check if it's a supurious wake (due to a previously orphaned timer) +// or a legitimate one. +func (t *timer) checkExpiration() bool { + // Transition to fully disabled state if we're just consuming an + // orphaned timer. + if t.state == timerStateOrphaned { + t.state = timerStateDisabled + return false + } + + // The timer is enabled, but it may have expired early. Check if that's + // the case, and if so, reset the runtime timer to the correct time. + now := time.Now() + if now.Before(t.target) { + t.runtimeTarget = t.target + t.timer.Reset(t.target.Sub(now)) + return false + } + + // The timer has actually expired, disable it for now and inform the + // caller. + t.state = timerStateDisabled + return true +} + +// disable disables the timer, leaving it in an orphaned state if it wasn't +// already disabled. +func (t *timer) disable() { + if t.state != timerStateDisabled { + t.state = timerStateOrphaned + } +} + +// enabled returns true if the timer is currently enabled, false otherwise. +func (t *timer) enabled() bool { + return t.state == timerStateEnabled +} + +// enable enables the timer, programming the runtime timer if necessary. +func (t *timer) enable(d time.Duration) { + t.target = time.Now().Add(d) + + // Check if we need to set the runtime timer. + if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) { + t.runtimeTarget = t.target + t.timer.Reset(d) + } + + t.state = timerStateEnabled +} diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD new file mode 100644 index 000000000..cf83ca134 --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/BUILD @@ -0,0 +1,24 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "tcpconntrack", + srcs = ["tcp_conntrack.go"], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcpconntrack", + visibility = ["//visibility:public"], + deps = [ + "//pkg/tcpip/header", + "//pkg/tcpip/seqnum", + ], +) + +go_test( + name = "tcpconntrack_test", + size = "small", + srcs = ["tcp_conntrack_test.go"], + deps = [ + ":tcpconntrack", + "//pkg/tcpip/header", + ], +) diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go new file mode 100644 index 000000000..487f2572d --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go @@ -0,0 +1,333 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tcpconntrack implements a TCP connection tracking object. It allows +// users with access to a segment stream to figure out when a connection is +// established, reset, and closed (and in the last case, who closed first). +package tcpconntrack + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" +) + +// Result is returned when the state of a TCB is updated in response to an +// inbound or outbound segment. +type Result int + +const ( + // ResultDrop indicates that the segment should be dropped. + ResultDrop Result = iota + + // ResultConnecting indicates that the connection remains in a + // connecting state. + ResultConnecting + + // ResultAlive indicates that the connection remains alive (connected). + ResultAlive + + // ResultReset indicates that the connection was reset. + ResultReset + + // ResultClosedByPeer indicates that the connection was gracefully + // closed, and the inbound stream was closed first. + ResultClosedByPeer + + // ResultClosedBySelf indicates that the connection was gracefully + // closed, and the outbound stream was closed first. + ResultClosedBySelf +) + +// TCB is a TCP Control Block. It holds state necessary to keep track of a TCP +// connection and inform the caller when the connection has been closed. +type TCB struct { + inbound stream + outbound stream + + // State handlers. + handlerInbound func(*TCB, header.TCP) Result + handlerOutbound func(*TCB, header.TCP) Result + + // firstFin holds a pointer to the first stream to send a FIN. + firstFin *stream + + // state is the current state of the stream. + state Result +} + +// Init initializes the state of the TCB according to the initial SYN. +func (t *TCB) Init(initialSyn header.TCP) { + t.handlerInbound = synSentStateInbound + t.handlerOutbound = synSentStateOutbound + + iss := seqnum.Value(initialSyn.SequenceNumber()) + t.outbound.una = iss + t.outbound.nxt = iss.Add(logicalLen(initialSyn)) + t.outbound.end = t.outbound.nxt + + // Even though "end" is a sequence number, we don't know the initial + // receive sequence number yet, so we store the window size until we get + // a SYN from the peer. + t.inbound.una = 0 + t.inbound.nxt = 0 + t.inbound.end = seqnum.Value(initialSyn.WindowSize()) + t.state = ResultConnecting +} + +// UpdateStateInbound updates the state of the TCB based on the supplied inbound +// segment. +func (t *TCB) UpdateStateInbound(tcp header.TCP) Result { + st := t.handlerInbound(t, tcp) + if st != ResultDrop { + t.state = st + } + return st +} + +// UpdateStateOutbound updates the state of the TCB based on the supplied +// outbound segment. +func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result { + st := t.handlerOutbound(t, tcp) + if st != ResultDrop { + t.state = st + } + return st +} + +// IsAlive returns true as long as the connection is established(Alive) +// or connecting state. +func (t *TCB) IsAlive() bool { + return !t.inbound.rstSeen && !t.outbound.rstSeen && (!t.inbound.closed() || !t.outbound.closed()) +} + +// OutboundSendSequenceNumber returns the snd.NXT for the outbound stream. +func (t *TCB) OutboundSendSequenceNumber() seqnum.Value { + return t.outbound.nxt +} + +// adapResult modifies the supplied "Result" according to the state of the TCB; +// if r is anything other than "Alive", or if one of the streams isn't closed +// yet, it is returned unmodified. Otherwise it's converted to either +// ClosedBySelf or ClosedByPeer depending on which stream was closed first. +func (t *TCB) adaptResult(r Result) Result { + // Check the unmodified case. + if r != ResultAlive || !t.inbound.closed() || !t.outbound.closed() { + return r + } + + // Find out which was closed first. + if t.firstFin == &t.outbound { + return ResultClosedBySelf + } + + return ResultClosedByPeer +} + +// synSentStateInbound is the state handler for inbound segments when the +// connection is in SYN-SENT state. +func synSentStateInbound(t *TCB, tcp header.TCP) Result { + flags := tcp.Flags() + ackPresent := flags&header.TCPFlagAck != 0 + ack := seqnum.Value(tcp.AckNumber()) + + // Ignore segment if ack is present but not acceptable. + if ackPresent && !(ack-1).InRange(t.outbound.una, t.outbound.nxt) { + return ResultConnecting + } + + // If reset is specified, we will let the packet through no matter what + // but we will also destroy the connection if the ACK is present (and + // implicitly acceptable). + if flags&header.TCPFlagRst != 0 { + if ackPresent { + t.inbound.rstSeen = true + return ResultReset + } + return ResultConnecting + } + + // Ignore segment if SYN is not set. + if flags&header.TCPFlagSyn == 0 { + return ResultConnecting + } + + // Update state informed by this SYN. + irs := seqnum.Value(tcp.SequenceNumber()) + t.inbound.una = irs + t.inbound.nxt = irs.Add(logicalLen(tcp)) + t.inbound.end += irs + + t.outbound.end = t.outbound.una.Add(seqnum.Size(tcp.WindowSize())) + + // If the ACK was set (it is acceptable), update our unacknowledgement + // tracking. + if ackPresent { + // Advance the "una" and "end" indices of the outbound stream. + if t.outbound.una.LessThan(ack) { + t.outbound.una = ack + } + + if end := ack.Add(seqnum.Size(tcp.WindowSize())); t.outbound.end.LessThan(end) { + t.outbound.end = end + } + } + + // Update handlers so that new calls will be handled by new state. + t.handlerInbound = allOtherInbound + t.handlerOutbound = allOtherOutbound + + return ResultAlive +} + +// synSentStateOutbound is the state handler for outbound segments when the +// connection is in SYN-SENT state. +func synSentStateOutbound(t *TCB, tcp header.TCP) Result { + // Drop outbound segments that aren't retransmits of the original one. + if tcp.Flags() != header.TCPFlagSyn || + tcp.SequenceNumber() != uint32(t.outbound.una) { + return ResultDrop + } + + // Update the receive window. We only remember the largest value seen. + if wnd := seqnum.Value(tcp.WindowSize()); wnd > t.inbound.end { + t.inbound.end = wnd + } + + return ResultConnecting +} + +// update updates the state of inbound and outbound streams, given the supplied +// inbound segment. For outbound segments, this same function can be called with +// swapped inbound/outbound streams. +func update(tcp header.TCP, inbound, outbound *stream, firstFin **stream) Result { + // Ignore segments out of the window. + s := seqnum.Value(tcp.SequenceNumber()) + if !inbound.acceptable(s, dataLen(tcp)) { + return ResultAlive + } + + flags := tcp.Flags() + if flags&header.TCPFlagRst != 0 { + inbound.rstSeen = true + return ResultReset + } + + // Ignore segments that don't have the ACK flag, and those with the SYN + // flag. + if flags&header.TCPFlagAck == 0 || flags&header.TCPFlagSyn != 0 { + return ResultAlive + } + + // Ignore segments that acknowledge not yet sent data. + ack := seqnum.Value(tcp.AckNumber()) + if outbound.nxt.LessThan(ack) { + return ResultAlive + } + + // Advance the "una" and "end" indices of the outbound stream. + if outbound.una.LessThan(ack) { + outbound.una = ack + } + + if end := ack.Add(seqnum.Size(tcp.WindowSize())); outbound.end.LessThan(end) { + outbound.end = end + } + + // Advance the "nxt" index of the inbound stream. + end := s.Add(logicalLen(tcp)) + if inbound.nxt.LessThan(end) { + inbound.nxt = end + } + + // Note the index of the FIN segment. And stash away a pointer to the + // first stream to see a FIN. + if flags&header.TCPFlagFin != 0 && !inbound.finSeen { + inbound.finSeen = true + inbound.fin = end - 1 + + if *firstFin == nil { + *firstFin = inbound + } + } + + return ResultAlive +} + +// allOtherInbound is the state handler for inbound segments in all states +// except SYN-SENT. +func allOtherInbound(t *TCB, tcp header.TCP) Result { + return t.adaptResult(update(tcp, &t.inbound, &t.outbound, &t.firstFin)) +} + +// allOtherOutbound is the state handler for outbound segments in all states +// except SYN-SENT. +func allOtherOutbound(t *TCB, tcp header.TCP) Result { + return t.adaptResult(update(tcp, &t.outbound, &t.inbound, &t.firstFin)) +} + +// streams holds the state of a TCP unidirectional stream. +type stream struct { + // The interval [una, end) is the allowed interval as defined by the + // receiver, i.e., anything less than una has already been acknowledged + // and anything greater than or equal to end is beyond the receiver + // window. The interval [una, nxt) is the acknowledgable range, whose + // right edge indicates the sequence number of the next byte to be sent + // by the sender, i.e., anything greater than or equal to nxt hasn't + // been sent yet. + una seqnum.Value + nxt seqnum.Value + end seqnum.Value + + // finSeen indicates if a FIN has already been sent on this stream. + finSeen bool + + // fin is the sequence number of the FIN. It is only valid after finSeen + // is set to true. + fin seqnum.Value + + // rstSeen indicates if a RST has already been sent on this stream. + rstSeen bool +} + +// acceptable determines if the segment with the given sequence number and data +// length is acceptable, i.e., if it's within the [una, end) window or, in case +// the window is zero, if it's a packet with no payload and sequence number +// equal to una. +func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool { + wnd := s.una.Size(s.end) + if wnd == 0 { + return segLen == 0 && segSeq == s.una + } + + // Make sure [segSeq, seqSeq+segLen) is non-empty. + if segLen == 0 { + segLen = 1 + } + + return seqnum.Overlap(s.una, wnd, segSeq, segLen) +} + +// closed determines if the stream has already been closed. This happens when +// a FIN has been set by the sender and acknowledged by the receiver. +func (s *stream) closed() bool { + return s.finSeen && s.fin.LessThan(s.una) +} + +// dataLen returns the length of the TCP segment payload. +func dataLen(tcp header.TCP) seqnum.Size { + return seqnum.Size(len(tcp) - int(tcp.DataOffset())) +} + +// logicalLen calculates the logical length of the TCP segment. +func logicalLen(tcp header.TCP) seqnum.Size { + l := dataLen(tcp) + flags := tcp.Flags() + if flags&header.TCPFlagSyn != 0 { + l++ + } + if flags&header.TCPFlagFin != 0 { + l++ + } + return l +} diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go new file mode 100644 index 000000000..50cab3132 --- /dev/null +++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go @@ -0,0 +1,501 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package tcpconntrack_test + +import ( + "testing" + + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcpconntrack" +) + +// connected creates a connection tracker TCB and sets it to a connected state +// by performing a 3-way handshake. +func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: iss, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: irw, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN-ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: irs, + AckNum: iss + 1, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + WindowSize: isw, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: iss + 1, + AckNum: irs + 1, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: irw, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + return &tcb +} + +func TestConnectionRefused(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive RST. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 1235, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst | header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) + } +} + +func TestConnectionRefusedInSynRcvd(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive RST with no ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 790, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) + } +} + +func TestConnectionResetInSynRcvd(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send RST with no ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagRst, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset) + } +} + +func TestRetransmitOnSynSent(t *testing.T) { + // Send initial SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Retransmit the same SYN. + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting) + } +} + +func TestRetransmitOnSynRcvd(t *testing.T) { + // Send initial SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive SYN. This will cause the state to go to SYN-RCVD. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Retransmit the original SYN. + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Transmit a SYN-ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 790, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } +} + +func TestClosedBySelf(t *testing.T) { + tcb := connected(t, 1234, 789, 30000, 50000) + + // Send FIN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 790, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive FIN/ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 790, + AckNum: 1236, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1236, + AckNum: 791, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) + } +} + +func TestClosedByPeer(t *testing.T) { + tcb := connected(t, 1234, 789, 30000, 50000) + + // Receive FIN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 790, + AckNum: 1235, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send FIN/ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 791, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 791, + AckNum: 1236, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer) + } +} + +func TestSendAndReceiveDataClosedBySelf(t *testing.T) { + sseq := uint32(1234) + rseq := uint32(789) + tcb := connected(t, sseq, rseq, 30000, 50000) + sseq++ + rseq++ + + // Send some data. + tcp := make(header.TCP, header.TCPMinimumSize+1024) + + for i := uint32(0); i < 10; i++ { + // Send some data. + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + sseq += uint32(len(tcp)) - header.TCPMinimumSize + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive ack for data. + tcp.Encode(&header.TCPFields{ + SeqNum: rseq, + AckNum: sseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + } + + for i := uint32(0); i < 10; i++ { + // Receive some data. + tcp.Encode(&header.TCPFields{ + SeqNum: rseq, + AckNum: sseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 50000, + }) + rseq += uint32(len(tcp)) - header.TCPMinimumSize + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ack for data. + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + } + + // Send FIN. + tcp = tcp[:header.TCPMinimumSize] + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 30000, + }) + sseq++ + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Receive FIN/ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: rseq, + AckNum: sseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck | header.TCPFlagFin, + WindowSize: 50000, + }) + rseq++ + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: sseq, + AckNum: rseq, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf) + } +} + +func TestIgnoreBadResetOnSynSent(t *testing.T) { + // Send SYN. + tcp := make(header.TCP, header.TCPMinimumSize) + tcp.Encode(&header.TCPFields{ + SeqNum: 1234, + AckNum: 0, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn, + WindowSize: 30000, + }) + + tcb := tcpconntrack.TCB{} + tcb.Init(tcp) + + // Receive a RST with a bad ACK, it should not cause the connection to + // be reset. + acks := []uint32{1234, 1236, 1000, 5000} + flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck} + for _, a := range acks { + for _, f := range flags { + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: a, + DataOffset: header.TCPMinimumSize, + Flags: f, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + } + } + + // Complete the handshake. + // Receive SYN-ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 789, + AckNum: 1235, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagSyn | header.TCPFlagAck, + WindowSize: 50000, + }) + + if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } + + // Send ACK. + tcp.Encode(&header.TCPFields{ + SeqNum: 1235, + AckNum: 790, + DataOffset: header.TCPMinimumSize, + Flags: header.TCPFlagAck, + WindowSize: 30000, + }) + + if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive { + t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive) + } +} diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD new file mode 100644 index 000000000..ac34a932e --- /dev/null +++ b/pkg/tcpip/transport/udp/BUILD @@ -0,0 +1,77 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_stateify") + +go_stateify( + name = "udp_state", + srcs = [ + "endpoint.go", + "endpoint_state.go", + "udp_packet_list.go", + ], + out = "udp_state.go", + imports = ["gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"], + package = "udp", +) + +go_template_instance( + name = "udp_packet_list", + out = "udp_packet_list.go", + package = "udp", + prefix = "udpPacket", + template = "//pkg/ilist:generic_list", + types = { + "Linker": "*udpPacket", + }, +) + +go_library( + name = "udp", + srcs = [ + "endpoint.go", + "endpoint_state.go", + "protocol.go", + "udp_packet_list.go", + "udp_state.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp", + visibility = ["//visibility:public"], + deps = [ + "//pkg/sleep", + "//pkg/state", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) + +go_test( + name = "udp_x_test", + size = "small", + srcs = ["udp_test.go"], + deps = [ + ":udp", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", + "//pkg/tcpip/header", + "//pkg/tcpip/link/channel", + "//pkg/tcpip/link/sniffer", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/waiter", + ], +) + +filegroup( + name = "autogen", + srcs = [ + "udp_packet_list.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go new file mode 100644 index 000000000..80fa88c4c --- /dev/null +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -0,0 +1,746 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package udp + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +type udpPacket struct { + udpPacketEntry + senderAddress tcpip.FullAddress + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + // views is used as buffer for data when its length is large + // enough to store a VectorisedView. + views [8]buffer.View `state:"nosave"` +} + +type endpointState int + +const ( + stateInitial endpointState = iota + stateBound + stateConnected + stateClosed +) + +// endpoint represents a UDP endpoint. This struct serves as the interface +// between users of the endpoint and the protocol implementation; it is legal to +// have concurrent goroutines make calls into the endpoint, they are properly +// synchronized. +type endpoint struct { + // The following fields are initialized at creation time and do not + // change throughout the lifetime of the endpoint. + stack *stack.Stack `state:"manual"` + netProto tcpip.NetworkProtocolNumber + waiterQueue *waiter.Queue + + // The following fields are used to manage the receive queue, and are + // protected by rcvMu. + rcvMu sync.Mutex `state:"nosave"` + rcvReady bool + rcvList udpPacketList + rcvBufSizeMax int `state:".(int)"` + rcvBufSize int + rcvClosed bool + + // The following fields are protected by the mu mutex. + mu sync.RWMutex `state:"nosave"` + sndBufSize int + id stack.TransportEndpointID + state endpointState + bindNICID tcpip.NICID + bindAddr tcpip.Address + regNICID tcpip.NICID + route stack.Route `state:"manual"` + dstPort uint16 + v6only bool + + // effectiveNetProtos contains the network protocols actually in use. In + // most cases it will only contain "netProto", but in cases like IPv6 + // endpoints with v6only set to false, this could include multiple + // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g., + // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped + // address). + effectiveNetProtos []tcpip.NetworkProtocolNumber +} + +func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint { + return &endpoint{ + stack: stack, + netProto: netProto, + waiterQueue: waiterQueue, + rcvBufSizeMax: 32 * 1024, + sndBufSize: 32 * 1024, + } +} + +// NewConnectedEndpoint creates a new endpoint in the connected state using the +// provided route. +func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.TransportEndpointID, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + ep := newEndpoint(stack, r.NetProto, waiterQueue) + + // Register new endpoint so that packets are routed to it. + if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep); err != nil { + ep.Close() + return nil, err + } + + ep.id = id + ep.route = r.Clone() + ep.dstPort = id.RemotePort + ep.regNICID = r.NICID() + + ep.state = stateConnected + + return ep, nil +} + +// Close puts the endpoint in a closed state and frees all resources +// associated with it. +func (e *endpoint) Close() { + e.mu.Lock() + defer e.mu.Unlock() + + switch e.state { + case stateBound, stateConnected: + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + } + + // Close the receive list and drain it. + e.rcvMu.Lock() + e.rcvClosed = true + e.rcvBufSize = 0 + for !e.rcvList.Empty() { + p := e.rcvList.Front() + e.rcvList.Remove(p) + } + e.rcvMu.Unlock() + + e.route.Release() + + // Update the state. + e.state = stateClosed +} + +// Read reads data from the endpoint. This method does not block if +// there is no data pending. +func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, *tcpip.Error) { + e.rcvMu.Lock() + + if e.rcvList.Empty() { + err := tcpip.ErrWouldBlock + if e.rcvClosed { + err = tcpip.ErrClosedForReceive + } + e.rcvMu.Unlock() + return buffer.View{}, err + } + + p := e.rcvList.Front() + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() + + e.rcvMu.Unlock() + + if addr != nil { + *addr = p.senderAddress + } + + return p.data.ToView(), nil +} + +// prepareForWrite prepares the endpoint for sending data. In particular, it +// binds it if it's still in the initial state. To do so, it must first +// reacquire the mutex in exclusive mode. +// +// Returns true for retry if preparation should be retried. +func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) { + switch e.state { + case stateInitial: + case stateConnected: + return false, nil + + case stateBound: + if to == nil { + return false, tcpip.ErrDestinationRequired + } + return false, nil + default: + return false, tcpip.ErrInvalidEndpointState + } + + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // The state changed when we released the shared locked and re-acquired + // it in exclusive mode. Try again. + if e.state != stateInitial { + return true, nil + } + + // The state is still 'initial', so try to bind the endpoint. + if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil { + return false, err + } + + return true, nil +} + +// Write writes data to the endpoint's peer. This method does not block +// if the data cannot be written. +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) { + // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.) + if opts.More { + return 0, tcpip.ErrInvalidOptionValue + } + + to := opts.To + + e.mu.RLock() + defer e.mu.RUnlock() + + // Prepare for write. + for { + retry, err := e.prepareForWrite(to) + if err != nil { + return 0, err + } + + if !retry { + break + } + } + + var route *stack.Route + var dstPort uint16 + if to == nil { + route = &e.route + dstPort = e.dstPort + + if route.IsResolutionRequired() { + // Promote lock to exclusive if using a shared route, given that it may need to + // change in Route.Resolve() call below. + e.mu.RUnlock() + defer e.mu.RLock() + + e.mu.Lock() + defer e.mu.Unlock() + + // Recheck state after lock was re-acquired. + if e.state != stateConnected { + return 0, tcpip.ErrInvalidEndpointState + } + } + } else { + // Reject destination address if it goes through a different + // NIC than the endpoint was bound to. + nicid := to.NIC + if e.bindNICID != 0 { + if nicid != 0 && nicid != e.bindNICID { + return 0, tcpip.ErrNoRoute + } + + nicid = e.bindNICID + } + + toCopy := *to + to = &toCopy + netProto, err := e.checkV4Mapped(to, true) + if err != nil { + return 0, err + } + + // Find the enpoint. + r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto) + if err != nil { + return 0, err + } + defer r.Release() + + route = &r + dstPort = to.Port + } + + if route.IsResolutionRequired() { + waker := &sleep.Waker{} + if err := route.Resolve(waker); err != nil { + if err == tcpip.ErrWouldBlock { + // Link address needs to be resolved. Resolution was triggered the background. + // Better luck next time. + // + // TODO: queue up the request and send after link address + // is resolved. + route.RemoveWaker(waker) + return 0, tcpip.ErrNoLinkAddress + } + return 0, err + } + } + + v, err := p.Get(p.Size()) + if err != nil { + return 0, err + } + sendUDP(route, v, e.id.LocalPort, dstPort) + return uintptr(len(v)), nil +} + +// Peek only returns data from a single datagram, so do nothing here. +func (e *endpoint) Peek([][]byte) (uintptr, *tcpip.Error) { + return 0, nil +} + +// SetSockOpt sets a socket option. Currently not supported. +func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { + // TODO: Actually implement this. + switch v := opt.(type) { + case tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + // We only allow this to be set when we're in the initial state. + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + e.v6only = v != 0 + } + return nil +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.SendBufferSizeOption: + e.mu.Lock() + *o = tcpip.SendBufferSizeOption(e.sndBufSize) + e.mu.Unlock() + return nil + + case *tcpip.ReceiveBufferSizeOption: + e.rcvMu.Lock() + *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax) + e.rcvMu.Unlock() + return nil + + case *tcpip.V6OnlyOption: + // We only recognize this option on v6 endpoints. + if e.netProto != header.IPv6ProtocolNumber { + return tcpip.ErrUnknownProtocolOption + } + + e.mu.Lock() + v := e.v6only + e.mu.Unlock() + + *o = 0 + if v { + *o = 1 + } + return nil + + case *tcpip.ReceiveQueueSizeOption: + e.rcvMu.Lock() + if e.rcvList.Empty() { + *o = 0 + } else { + p := e.rcvList.Front() + *o = tcpip.ReceiveQueueSizeOption(p.data.Size()) + } + e.rcvMu.Unlock() + return nil + } + + return tcpip.ErrUnknownProtocolOption +} + +// sendUDP sends a UDP segment via the provided network endpoint and under the +// provided identity. +func sendUDP(r *stack.Route, data buffer.View, localPort, remotePort uint16) *tcpip.Error { + // Allocate a buffer for the UDP header. + hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength())) + + // Initialize the header. + udp := header.UDP(hdr.Prepend(header.UDPMinimumSize)) + + length := uint16(hdr.UsedLength()) + uint16(len(data)) + udp.Encode(&header.UDPFields{ + SrcPort: localPort, + DstPort: remotePort, + Length: length, + }) + + // Only calculate the checksum if offloading isn't supported. + if r.Capabilities()&stack.CapabilityChecksumOffload == 0 { + xsum := r.PseudoHeaderChecksum(ProtocolNumber) + if data != nil { + xsum = header.Checksum(data, xsum) + } + + udp.SetChecksum(^udp.CalculateChecksum(xsum, length)) + } + + return r.WritePacket(&hdr, data, ProtocolNumber) +} + +func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) { + netProto := e.netProto + if header.IsV4MappedAddress(addr.Addr) { + // Fail if using a v4 mapped address on a v6only endpoint. + if e.v6only { + return 0, tcpip.ErrNoRoute + } + + netProto = header.IPv4ProtocolNumber + addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:] + if addr.Addr == "\x00\x00\x00\x00" { + addr.Addr = "" + } + } + + // Fail if we're bound to an address length different from the one we're + // checking. + if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) { + return 0, tcpip.ErrInvalidEndpointState + } + + return netProto, nil +} + +// Connect connects the endpoint to its peer. Specifying a NIC is optional. +func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + if addr.Port == 0 { + // We don't support connecting to port zero. + return tcpip.ErrInvalidEndpointState + } + + e.mu.Lock() + defer e.mu.Unlock() + + nicid := addr.NIC + localPort := uint16(0) + switch e.state { + case stateInitial: + case stateBound, stateConnected: + localPort = e.id.LocalPort + if e.bindNICID == 0 { + break + } + + if nicid != 0 && nicid != e.bindNICID { + return tcpip.ErrInvalidEndpointState + } + + nicid = e.bindNICID + default: + return tcpip.ErrInvalidEndpointState + } + + netProto, err := e.checkV4Mapped(&addr, false) + if err != nil { + return err + } + + // Find a route to the desired destination. + r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto) + if err != nil { + return err + } + defer r.Release() + + id := stack.TransportEndpointID{ + LocalAddress: r.LocalAddress, + LocalPort: localPort, + RemotePort: addr.Port, + RemoteAddress: r.RemoteAddress, + } + + // Even if we're connected, this endpoint can still be used to send + // packets on a different network protocol, so we register both even if + // v6only is set to false and this is an ipv6 endpoint. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if e.netProto == header.IPv6ProtocolNumber && !e.v6only { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv4ProtocolNumber, + header.IPv6ProtocolNumber, + } + } + + id, err = e.registerWithStack(nicid, netProtos, id) + if err != nil { + return err + } + + // Remove the old registration. + if e.id.LocalPort != 0 { + e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id) + } + + e.id = id + e.route = r.Clone() + e.dstPort = addr.Port + e.regNICID = nicid + e.effectiveNetProtos = netProtos + + e.state = stateConnected + + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// ConnectEndpoint is not supported. +func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + +// Shutdown closes the read and/or write end of the endpoint connection +// to its peer. +func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state != stateConnected { + return tcpip.ErrNotConnected + } + + if flags&tcpip.ShutdownRead != 0 { + e.rcvMu.Lock() + wasClosed := e.rcvClosed + e.rcvClosed = true + e.rcvMu.Unlock() + + if !wasClosed { + e.waiterQueue.Notify(waiter.EventIn) + } + } + + return nil +} + +// Listen is not supported by UDP, it just fails. +func (*endpoint) Listen(int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept is not supported by UDP, it just fails. +func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, tcpip.ErrNotSupported +} + +func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) { + if id.LocalPort != 0 { + // The endpoint already has a local port, just attempt to + // register it. + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) + return id, err + } + + // We need to find a port for the endpoint. + _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) { + id.LocalPort = p + err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e) + switch err { + case nil: + return true, nil + case tcpip.ErrPortInUse: + return false, nil + default: + return false, err + } + }) + + return id, err +} + +func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + // Don't allow binding once endpoint is not in the initial state + // anymore. + if e.state != stateInitial { + return tcpip.ErrInvalidEndpointState + } + + netProto, err := e.checkV4Mapped(&addr, false) + if err != nil { + return err + } + + // Expand netProtos to include v4 and v6 if the caller is binding to a + // wildcard (empty) address, and this is an IPv6 endpoint with v6only + // set to false. + netProtos := []tcpip.NetworkProtocolNumber{netProto} + if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { + netProtos = []tcpip.NetworkProtocolNumber{ + header.IPv6ProtocolNumber, + header.IPv4ProtocolNumber, + } + } + + if len(addr.Addr) != 0 { + // A local address was specified, verify that it's valid. + if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 { + return tcpip.ErrBadLocalAddress + } + } + + id := stack.TransportEndpointID{ + LocalPort: addr.Port, + LocalAddress: addr.Addr, + } + id, err = e.registerWithStack(addr.NIC, netProtos, id) + if err != nil { + return err + } + if commit != nil { + if err := commit(); err != nil { + // Unregister, the commit failed. + e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id) + return err + } + } + + e.id = id + e.regNICID = addr.NIC + e.effectiveNetProtos = netProtos + + // Mark endpoint as bound. + e.state = stateBound + + e.rcvMu.Lock() + e.rcvReady = true + e.rcvMu.Unlock() + + return nil +} + +// Bind binds the endpoint to a specific local address and port. +// Specifying a NIC is optional. +func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + e.mu.Lock() + defer e.mu.Unlock() + + err := e.bindLocked(addr, commit) + if err != nil { + return err + } + + e.bindNICID = addr.NIC + e.bindAddr = addr.Addr + + return nil +} + +// GetLocalAddress returns the address to which the endpoint is bound. +func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + return tcpip.FullAddress{ + NIC: e.regNICID, + Addr: e.id.LocalAddress, + Port: e.id.LocalPort, + }, nil +} + +// GetRemoteAddress returns the address to which the endpoint is connected. +func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() + + if e.state != stateConnected { + return tcpip.FullAddress{}, tcpip.ErrNotConnected + } + + return tcpip.FullAddress{ + NIC: e.regNICID, + Addr: e.id.RemoteAddress, + Port: e.id.RemotePort, + }, nil +} + +// Readiness returns the current readiness of the endpoint. For example, if +// waiter.EventIn is set, the endpoint is immediately readable. +func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + // The endpoint is always writable. + result := waiter.EventOut & mask + + // Determine if the endpoint is readable if requested. + if (mask & waiter.EventIn) != 0 { + e.rcvMu.Lock() + if !e.rcvList.Empty() || e.rcvClosed { + result |= waiter.EventIn + } + e.rcvMu.Unlock() + } + + return result +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) { + // Get the header then trim it from the view. + hdr := header.UDP(vv.First()) + if int(hdr.Length()) > vv.Size() { + // Malformed packet. + return + } + + vv.TrimFront(header.UDPMinimumSize) + + e.rcvMu.Lock() + + // Drop the packet if our buffer is currently full. + if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax { + e.rcvMu.Unlock() + return + } + + wasEmpty := e.rcvBufSize == 0 + + // Push new packet into receive list and increment the buffer size. + pkt := &udpPacket{ + senderAddress: tcpip.FullAddress{ + NIC: r.NICID(), + Addr: id.RemoteAddress, + Port: hdr.SourcePort(), + }, + } + pkt.data = vv.Clone(pkt.views[:]) + e.rcvList.PushBack(pkt) + e.rcvBufSize += vv.Size() + + e.rcvMu.Unlock() + + // Notify any waiters that there's data to be read now. + if wasEmpty { + e.waiterQueue.Notify(waiter.EventIn) + } +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) { +} diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go new file mode 100644 index 000000000..41b98424a --- /dev/null +++ b/pkg/tcpip/transport/udp/endpoint_state.go @@ -0,0 +1,91 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package udp + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +// saveData saves udpPacket.data field. +func (u *udpPacket) saveData() buffer.VectorisedView { + // We canoot save u.data directly as u.data.views may alias to u.views, + // which is not allowed by state framework (in-struct pointer). + return u.data.Clone(nil) +} + +// loadData loads udpPacket.data field. +func (u *udpPacket) loadData(data buffer.VectorisedView) { + // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization + // here because data.views is not guaranteed to be loaded by now. Plus, + // data.views will be allocated anyway so there really is little point + // of utilizing u.views for data.views. + u.data = data +} + +// beforeSave is invoked by stateify. +func (e *endpoint) beforeSave() { + // Stop incoming packets from being handled (and mutate endpoint state). + // The lock will be released after savercvBufSizeMax(), which would have + // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming + // packets. + e.rcvMu.Lock() +} + +// saveRcvBufSizeMax is invoked by stateify. +func (e *endpoint) saveRcvBufSizeMax() int { + max := e.rcvBufSizeMax + // Make sure no new packets will be handled regardless of the lock. + e.rcvBufSizeMax = 0 + // Release the lock acquired in beforeSave() so regular endpoint closing + // logic can proceed after save. + e.rcvMu.Unlock() + return max +} + +// loadRcvBufSizeMax is invoked by stateify. +func (e *endpoint) loadRcvBufSizeMax(max int) { + e.rcvBufSizeMax = max +} + +// afterLoad is invoked by stateify. +func (e *endpoint) afterLoad() { + e.stack = stack.StackFromEnv + + if e.state != stateBound && e.state != stateConnected { + return + } + + netProto := e.effectiveNetProtos[0] + // Connect() and bindLocked() both assert + // + // netProto == header.IPv6ProtocolNumber + // + // before creating a multi-entry effectiveNetProtos. + if len(e.effectiveNetProtos) > 1 { + netProto = header.IPv6ProtocolNumber + } + + var err *tcpip.Error + if e.state == stateConnected { + e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, netProto) + if err != nil { + panic(*err) + } + + e.id.LocalAddress = e.route.LocalAddress + } else if len(e.id.LocalAddress) != 0 { // stateBound + if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 { + panic(tcpip.ErrBadLocalAddress) + } + } + + e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, e.id) + if err != nil { + panic(*err) + } +} diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go new file mode 100644 index 000000000..fa30e7201 --- /dev/null +++ b/pkg/tcpip/transport/udp/protocol.go @@ -0,0 +1,73 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package udp contains the implementation of the UDP transport protocol. To use +// it in the networking stack, this package must be added to the project, and +// activated on the stack by passing udp.ProtocolName (or "udp") as one of the +// transport protocols when calling stack.New(). Then endpoints can be created +// by passing udp.ProtocolNumber as the transport protocol number when calling +// Stack.NewEndpoint(). +package udp + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + // ProtocolName is the string representation of the udp protocol name. + ProtocolName = "udp" + + // ProtocolNumber is the udp protocol number. + ProtocolNumber = header.UDPProtocolNumber +) + +type protocol struct{} + +// Number returns the udp protocol number. +func (*protocol) Number() tcpip.TransportProtocolNumber { + return ProtocolNumber +} + +// NewEndpoint creates a new udp endpoint. +func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newEndpoint(stack, netProto, waiterQueue), nil +} + +// MinimumPacketSize returns the minimum valid udp packet size. +func (*protocol) MinimumPacketSize() int { + return header.UDPMinimumSize +} + +// ParsePorts returns the source and destination ports stored in the given udp +// packet. +func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { + h := header.UDP(v) + return h.SourcePort(), h.DestinationPort(), nil +} + +// HandleUnknownDestinationPacket handles packets targeted at this protocol but +// that don't match any existing endpoint. +func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool { + return true +} + +// SetOption implements TransportProtocol.SetOption. +func (p *protocol) SetOption(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +// Option implements TransportProtocol.Option. +func (p *protocol) Option(option interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption +} + +func init() { + stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol { + return &protocol{} + }) +} diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go new file mode 100644 index 000000000..65c567952 --- /dev/null +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -0,0 +1,625 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package udp_test + +import ( + "bytes" + "math/rand" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/checker" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4" + "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr + testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr + V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + + stackAddr = "\x0a\x00\x00\x01" + stackPort = 1234 + testAddr = "\x0a\x00\x00\x02" + testPort = 4096 + + // defaultMTU is the MTU, in bytes, used throughout the tests, except + // where another value is explicitly used. It is chosen to match the MTU + // of loopback interfaces on linux systems. + defaultMTU = 65536 +) + +type testContext struct { + t *testing.T + linkEP *channel.Endpoint + s *stack.Stack + + ep tcpip.Endpoint + wq waiter.Queue +} + +type headers struct { + srcPort uint16 + dstPort uint16 +} + +func newDualTestContext(t *testing.T, mtu uint32) *testContext { + s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName}) + + id, linkEP := channel.New(256, mtu, "") + if testing.Verbose() { + id = sniffer.New(id) + } + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: "\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }, + { + Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Gateway: "", + NIC: 1, + }, + }) + + return &testContext{ + t: t, + s: s, + linkEP: linkEP, + } +} + +func (c *testContext) cleanup() { + if c.ep != nil { + c.ep.Close() + } +} + +func (c *testContext) createV6Endpoint(v4only bool) { + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + var v tcpip.V6OnlyOption + if v4only { + v = 1 + } + if err := c.ep.SetSockOpt(v); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } +} + +func (c *testContext) getV6Packet() []byte { + select { + case p := <-c.linkEP.C: + if p.Proto != ipv6.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber) + } + b := make([]byte, len(p.Header)+len(p.Payload)) + copy(b, p.Header) + copy(b[len(p.Header):], p.Payload) + + checker.IPv6(c.t, b, checker.SrcAddr(stackV6Addr), checker.DstAddr(testV6Addr)) + return b + + case <-time.After(2 * time.Second): + c.t.Fatalf("Packet wasn't written out") + } + + return nil +} + +func (c *testContext) getPacket() []byte { + select { + case p := <-c.linkEP.C: + if p.Proto != ipv4.ProtocolNumber { + c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber) + } + b := make([]byte, len(p.Header)+len(p.Payload)) + copy(b, p.Header) + copy(b[len(p.Header):], p.Payload) + + checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(testAddr)) + return b + + case <-time.After(2 * time.Second): + c.t.Fatalf("Packet wasn't written out") + } + + return nil +} + +func (c *testContext) sendV6Packet(payload []byte, h *headers) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: testV6Addr, + DstAddr: stackV6Addr, + }) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.Checksum([]byte(testV6Addr), 0) + xsum = header.Checksum([]byte(stackV6Addr), xsum) + xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum) + + // Calculate the UDP checksum and set it. + length := uint16(header.UDPMinimumSize + len(payload)) + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum, length)) + + // Inject packet. + var views [1]buffer.View + vv := buf.ToVectorisedView(views) + c.linkEP.Inject(ipv6.ProtocolNumber, &vv) +} + +func (c *testContext) sendPacket(payload []byte, h *headers) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv4(buf) + ip.Encode(&header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TotalLength: uint16(len(buf)), + TTL: 65, + Protocol: uint8(udp.ProtocolNumber), + SrcAddr: testAddr, + DstAddr: stackAddr, + }) + ip.SetChecksum(^ip.CalculateChecksum()) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv4MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.Checksum([]byte(testAddr), 0) + xsum = header.Checksum([]byte(stackAddr), xsum) + xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum) + + // Calculate the UDP checksum and set it. + length := uint16(header.UDPMinimumSize + len(payload)) + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum, length)) + + // Inject packet. + var views [1]buffer.View + vv := buf.ToVectorisedView(views) + c.linkEP.Inject(ipv4.ProtocolNumber, &vv) +} + +func newPayload() []byte { + b := make([]byte, 30+rand.Intn(100)) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return b +} + +func testV4Read(c *testContext) { + // Send a packet. + payload := newPayload() + c.sendPacket(payload, &headers{ + srcPort: testPort, + dstPort: stackPort, + }) + + // Try to receive the data. + we, ch := waiter.NewChannelEntry(nil) + c.wq.EventRegister(&we, waiter.EventIn) + defer c.wq.EventUnregister(&we) + + var addr tcpip.FullAddress + v, err := c.ep.Read(&addr) + if err == tcpip.ErrWouldBlock { + // Wait for data to become available. + select { + case <-ch: + v, err = c.ep.Read(&addr) + if err != nil { + c.t.Fatalf("Read failed: %v", err) + } + + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for data") + } + } + + // Check the peer address. + if addr.Addr != testAddr { + c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) + } + + // Check the payload. + if !bytes.Equal(payload, v) { + c.t.Fatalf("Bad payload: got %x, want %x", v, payload) + } +} + +func TestV4ReadOnV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Read(c) +} + +func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Bind to v4 mapped wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Read(c) +} + +func TestV4ReadOnBoundToV4Mapped(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Bind to local address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Read(c) +} + +func TestV6ReadOnV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + // Send a packet. + payload := newPayload() + c.sendV6Packet(payload, &headers{ + srcPort: testPort, + dstPort: stackPort, + }) + + // Try to receive the data. + we, ch := waiter.NewChannelEntry(nil) + c.wq.EventRegister(&we, waiter.EventIn) + defer c.wq.EventUnregister(&we) + + var addr tcpip.FullAddress + v, err := c.ep.Read(&addr) + if err == tcpip.ErrWouldBlock { + // Wait for data to become available. + select { + case <-ch: + v, err = c.ep.Read(&addr) + if err != nil { + c.t.Fatalf("Read failed: %v", err) + } + + case <-time.After(1 * time.Second): + c.t.Fatalf("Timed out waiting for data") + } + } + + // Check the peer address. + if addr.Addr != testV6Addr { + c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr) + } + + // Check the payload. + if !bytes.Equal(payload, v) { + c.t.Fatalf("Bad payload: got %x, want %x", v, payload) + } +} + +func TestV4ReadOnV4(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + // Create v4 UDP endpoint. + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + // Test acceptance. + testV4Read(c) +} + +func testDualWrite(c *testContext) uint16 { + // Write to V4 mapped address. + payload := buffer.View(newPayload()) + n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, + }) + if err != nil { + c.t.Fatalf("Write failed: %v", err) + } + if n != uintptr(len(payload)) { + c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) + } + + // Check that we received the packet. + b := c.getPacket() + udp := header.UDP(header.IPv4(b).Payload()) + checker.IPv4(c.t, b, + checker.UDP( + checker.DstPort(testPort), + ), + ) + + port := udp.SourcePort() + + // Check the payload. + if !bytes.Equal(payload, udp.Payload()) { + c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + } + + // Write to v6 address. + payload = buffer.View(newPayload()) + n, err = c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + }) + if err != nil { + c.t.Fatalf("Write failed: %v", err) + } + if n != uintptr(len(payload)) { + c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) + } + + // Check that we received the packet, and that the source port is the + // same as the one used in ipv4. + b = c.getV6Packet() + udp = header.UDP(header.IPv6(b).Payload()) + checker.IPv6(c.t, b, + checker.UDP( + checker.DstPort(testPort), + checker.SrcPort(port), + ), + ) + + // Check the payload. + if !bytes.Equal(payload, udp.Payload()) { + c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + } + + return port +} + +func TestDualWriteUnbound(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + testDualWrite(c) +} + +func TestDualWriteBoundToWildcard(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Bind to wildcard. + if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + p := testDualWrite(c) + if p != stackPort { + c.t.Fatalf("Bad port: got %v, want %v", p, stackPort) + } +} + +func TestDualWriteConnectedToV6(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Connect to v6 address. + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + testDualWrite(c) +} + +func TestDualWriteConnectedToV4Mapped(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Connect to v4 mapped address. + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + testDualWrite(c) +} + +func TestV4WriteOnV6Only(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(true) + + // Write to V4 mapped address. + payload := buffer.View(newPayload()) + _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}, + }) + if err != tcpip.ErrNoRoute { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute) + } +} + +func TestV6WriteOnBoundToV4Mapped(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Bind to v4 mapped address. + if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}, nil); err != nil { + c.t.Fatalf("Bind failed: %v", err) + } + + // Write to v6 address. + payload := buffer.View(newPayload()) + _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{ + To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort}, + }) + if err != tcpip.ErrNoRoute { + c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute) + } +} + +func TestV6WriteOnConnected(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Connect to v6 address. + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + // Write without destination. + payload := buffer.View(newPayload()) + n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) + if err != nil { + c.t.Fatalf("Write failed: %v", err) + } + if n != uintptr(len(payload)) { + c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) + } + + // Check that we received the packet. + b := c.getV6Packet() + udp := header.UDP(header.IPv6(b).Payload()) + checker.IPv6(c.t, b, + checker.UDP( + checker.DstPort(testPort), + ), + ) + + // Check the payload. + if !bytes.Equal(payload, udp.Payload()) { + c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + } +} + +func TestV4WriteOnConnected(t *testing.T) { + c := newDualTestContext(t, defaultMTU) + defer c.cleanup() + + c.createV6Endpoint(false) + + // Connect to v4 mapped address. + if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil { + c.t.Fatalf("Connect failed: %v", err) + } + + // Write without destination. + payload := buffer.View(newPayload()) + n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{}) + if err != nil { + c.t.Fatalf("Write failed: %v", err) + } + if n != uintptr(len(payload)) { + c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload)) + } + + // Check that we received the packet. + b := c.getPacket() + udp := header.UDP(header.IPv4(b).Payload()) + checker.IPv4(c.t, b, + checker.UDP( + checker.DstPort(testPort), + ), + ) + + // Check the payload. + if !bytes.Equal(payload, udp.Payload()) { + c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload) + } +} diff --git a/pkg/tcpip/transport/unix/BUILD b/pkg/tcpip/transport/unix/BUILD new file mode 100644 index 000000000..47bc7a649 --- /dev/null +++ b/pkg/tcpip/transport/unix/BUILD @@ -0,0 +1,37 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//tools/go_stateify:defs.bzl", "go_stateify") + +go_stateify( + name = "unix_state", + srcs = [ + "connectioned.go", + "connectionless.go", + "unix.go", + ], + out = "unix_state.go", + package = "unix", +) + +go_library( + name = "unix", + srcs = [ + "connectioned.go", + "connectioned_state.go", + "connectionless.go", + "unix.go", + "unix_state.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix", + visibility = ["//:sandbox"], + deps = [ + "//pkg/ilist", + "//pkg/log", + "//pkg/state", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/transport/queue", + "//pkg/waiter", + ], +) diff --git a/pkg/tcpip/transport/unix/connectioned.go b/pkg/tcpip/transport/unix/connectioned.go new file mode 100644 index 000000000..def1b2c99 --- /dev/null +++ b/pkg/tcpip/transport/unix/connectioned.go @@ -0,0 +1,431 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package unix + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/queue" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// UniqueIDProvider generates a sequence of unique identifiers useful for, +// among other things, lock ordering. +type UniqueIDProvider interface { + // UniqueID returns a new unique identifier. + UniqueID() uint64 +} + +// A ConnectingEndpoint is a connectioned unix endpoint that is attempting to +// establish a bidirectional connection with a BoundEndpoint. +type ConnectingEndpoint interface { + // ID returns the endpoint's globally unique identifier. This identifier + // must be used to determine locking order if more than one endpoint is + // to be locked in the same codepath. The endpoint with the smaller + // identifier must be locked before endpoints with larger identifiers. + ID() uint64 + + // Passcred implements socket.Credentialer.Passcred. + Passcred() bool + + // Type returns the socket type, typically either SockStream or + // SockSeqpacket. The connection attempt must be aborted if this + // value doesn't match the ConnectableEndpoint's type. + Type() SockType + + // GetLocalAddress returns the bound path. + GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + + // Locker protects the following methods. While locked, only the holder of + // the lock can change the return value of the protected methods. + sync.Locker + + // Connected returns true iff the ConnectingEndpoint is in the connected + // state. ConnectingEndpoints can only be connected to a single endpoint, + // so the connection attempt must be aborted if this returns true. + Connected() bool + + // Listening returns true iff the ConnectingEndpoint is in the listening + // state. ConnectingEndpoints cannot make connections while listening, so + // the connection attempt must be aborted if this returns true. + Listening() bool + + // WaiterQueue returns a pointer to the endpoint's waiter queue. + WaiterQueue() *waiter.Queue +} + +// connectionedEndpoint is a Unix-domain connected or connectable endpoint and implements +// ConnectingEndpoint, ConnectableEndpoint and tcpip.Endpoint. +// +// connectionedEndpoints must be in connected state in order to transfer data. +// +// This implementation includes STREAM and SEQPACKET Unix sockets created with +// socket(2), accept(2) or socketpair(2) and dgram unix sockets created with +// socketpair(2). See unix_connectionless.go for the implementation of DGRAM +// Unix sockets created with socket(2). +// +// The state is much simpler than a TCP endpoint, so it is not encoded +// explicitly. Instead we enforce the following invariants: +// +// receiver != nil, connected != nil => connected. +// path != "" && acceptedChan == nil => bound, not listening. +// path != "" && acceptedChan != nil => bound and listening. +// +// Only one of these will be true at any moment. +type connectionedEndpoint struct { + baseEndpoint + + // id is the unique endpoint identifier. This is used exclusively for + // lock ordering within connect. + id uint64 + + // idGenerator is used to generate new unique endpoint identifiers. + idGenerator UniqueIDProvider + + // stype is used by connecting sockets to ensure that they are the + // same type. The value is typically either tcpip.SockSeqpacket or + // tcpip.SockStream. + stype SockType + + // acceptedChan is per the TCP endpoint implementation. Note that the + // sockets in this channel are _already in the connected state_, and + // have another associated connectionedEndpoint. + // + // If nil, then no listen call has been made. + acceptedChan chan *connectionedEndpoint `state:".([]*connectionedEndpoint)"` +} + +// NewConnectioned creates a new unbound connectionedEndpoint. +func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint { + return &connectionedEndpoint{ + baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, + id: uid.UniqueID(), + idGenerator: uid, + stype: stype, + } +} + +// NewPair allocates a new pair of connected unix-domain connectionedEndpoints. +func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { + a := &connectionedEndpoint{ + baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, + id: uid.UniqueID(), + idGenerator: uid, + stype: stype, + } + b := &connectionedEndpoint{ + baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, + id: uid.UniqueID(), + idGenerator: uid, + stype: stype, + } + + q1 := queue.New(a.Queue, b.Queue, initialLimit) + q2 := queue.New(b.Queue, a.Queue, initialLimit) + + if stype == SockStream { + a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}} + b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}} + } else { + a.receiver = &queueReceiver{q1} + b.receiver = &queueReceiver{q2} + } + + a.connected = &connectedEndpoint{ + endpoint: b, + writeQueue: q2, + } + b.connected = &connectedEndpoint{ + endpoint: a, + writeQueue: q1, + } + + return a, b +} + +// ID implements ConnectingEndpoint.ID. +func (e *connectionedEndpoint) ID() uint64 { + return e.id +} + +// Type implements ConnectingEndpoint.Type and Endpoint.Type. +func (e *connectionedEndpoint) Type() SockType { + return e.stype +} + +// WaiterQueue implements ConnectingEndpoint.WaiterQueue. +func (e *connectionedEndpoint) WaiterQueue() *waiter.Queue { + return e.Queue +} + +// isBound returns true iff the connectionedEndpoint is bound (but not +// listening). +func (e *connectionedEndpoint) isBound() bool { + return e.path != "" && e.acceptedChan == nil +} + +// Listening implements ConnectingEndpoint.Listening. +func (e *connectionedEndpoint) Listening() bool { + return e.acceptedChan != nil +} + +// Close puts the connectionedEndpoint in a closed state and frees all +// resources associated with it. +// +// The socket will be a fresh state after a call to close and may be reused. +// That is, close may be used to "unbind" or "disconnect" the socket in error +// paths. +func (e *connectionedEndpoint) Close() { + e.Lock() + var c ConnectedEndpoint + var r Receiver + switch { + case e.Connected(): + e.connected.CloseSend() + e.receiver.CloseRecv() + c = e.connected + r = e.receiver + e.connected = nil + e.receiver = nil + case e.isBound(): + e.path = "" + case e.Listening(): + close(e.acceptedChan) + for n := range e.acceptedChan { + n.Close() + } + e.acceptedChan = nil + e.path = "" + } + e.Unlock() + if c != nil { + c.CloseNotify() + c.Release() + } + if r != nil { + r.CloseNotify() + r.Release() + } +} + +// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. +func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *tcpip.Error { + if ce.Type() != e.stype { + return tcpip.ErrConnectionRefused + } + + // Check if ce is e to avoid a deadlock. + if ce, ok := ce.(*connectionedEndpoint); ok && ce == e { + return tcpip.ErrInvalidEndpointState + } + + // Do a dance to safely acquire locks on both endpoints. + if e.id < ce.ID() { + e.Lock() + ce.Lock() + } else { + ce.Lock() + e.Lock() + } + + // Check connecting state. + if ce.Connected() { + e.Unlock() + ce.Unlock() + return tcpip.ErrAlreadyConnected + } + if ce.Listening() { + e.Unlock() + ce.Unlock() + return tcpip.ErrInvalidEndpointState + } + + // Check bound state. + if !e.Listening() { + e.Unlock() + ce.Unlock() + return tcpip.ErrConnectionRefused + } + + // Create a newly bound connectionedEndpoint. + ne := &connectionedEndpoint{ + baseEndpoint: baseEndpoint{ + path: e.path, + Queue: &waiter.Queue{}, + }, + id: e.idGenerator.UniqueID(), + idGenerator: e.idGenerator, + stype: e.stype, + } + readQueue := queue.New(ce.WaiterQueue(), ne.Queue, initialLimit) + writeQueue := queue.New(ne.Queue, ce.WaiterQueue(), initialLimit) + ne.connected = &connectedEndpoint{ + endpoint: ce, + writeQueue: readQueue, + } + if e.stype == SockStream { + ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} + } else { + ne.receiver = &queueReceiver{readQueue: writeQueue} + } + + select { + case e.acceptedChan <- ne: + // Commit state. + connected := &connectedEndpoint{ + endpoint: ne, + writeQueue: writeQueue, + } + if e.stype == SockStream { + returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected) + } else { + returnConnect(&queueReceiver{readQueue: readQueue}, connected) + } + + // Notify can deadlock if we are holding these locks. + e.Unlock() + ce.Unlock() + + // Notify on both ends. + e.Notify(waiter.EventIn) + ce.WaiterQueue().Notify(waiter.EventOut) + + return nil + default: + // Busy; return ECONNREFUSED per spec. + ne.Close() + e.Unlock() + ce.Unlock() + return tcpip.ErrConnectionRefused + } +} + +// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect. +func (e *connectionedEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *tcpip.Error) { + return nil, tcpip.ErrConnectionRefused +} + +// Connect attempts to directly connect to another Endpoint. +// Implements Endpoint.Connect. +func (e *connectionedEndpoint) Connect(server BoundEndpoint) *tcpip.Error { + returnConnect := func(r Receiver, ce ConnectedEndpoint) { + e.receiver = r + e.connected = ce + } + + return server.BidirectionalConnect(e, returnConnect) +} + +// Listen starts listening on the connection. +func (e *connectionedEndpoint) Listen(backlog int) *tcpip.Error { + e.Lock() + defer e.Unlock() + if e.Listening() { + // Adjust the size of the channel iff we can fix existing + // pending connections into the new one. + if len(e.acceptedChan) > backlog { + return tcpip.ErrInvalidEndpointState + } + origChan := e.acceptedChan + e.acceptedChan = make(chan *connectionedEndpoint, backlog) + close(origChan) + for ep := range origChan { + e.acceptedChan <- ep + } + return nil + } + if !e.isBound() { + return tcpip.ErrInvalidEndpointState + } + + // Normal case. + e.acceptedChan = make(chan *connectionedEndpoint, backlog) + return nil +} + +// Accept accepts a new connection. +func (e *connectionedEndpoint) Accept() (Endpoint, *tcpip.Error) { + e.Lock() + defer e.Unlock() + + if !e.Listening() { + return nil, tcpip.ErrInvalidEndpointState + } + + select { + case ne := <-e.acceptedChan: + return ne, nil + + default: + // Nothing left. + return nil, tcpip.ErrWouldBlock + } +} + +// Bind binds the connection. +// +// For Unix connectionedEndpoints, this _only sets the address associated with +// the socket_. Work associated with sockets in the filesystem or finding those +// sockets must be done by a higher level. +// +// Bind will fail only if the socket is connected, bound or the passed address +// is invalid (the empty string). +func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + e.Lock() + defer e.Unlock() + if e.isBound() || e.Listening() { + return tcpip.ErrAlreadyBound + } + if addr.Addr == "" { + // The empty string is not permitted. + return tcpip.ErrBadLocalAddress + } + if commit != nil { + if err := commit(); err != nil { + return err + } + } + + // Save the bound address. + e.path = string(addr.Addr) + return nil +} + +// SendMsg writes data and a control message to the endpoint's peer. +// This method does not block if the data cannot be written. +func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *tcpip.Error) { + // Stream sockets do not support specifying the endpoint. Seqpacket + // sockets ignore the passed endpoint. + if e.stype == SockStream && to != nil { + return 0, tcpip.ErrNotSupported + } + return e.baseEndpoint.SendMsg(data, c, to) +} + +// Readiness returns the current readiness of the connectionedEndpoint. For +// example, if waiter.EventIn is set, the connectionedEndpoint is immediately +// readable. +func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + e.Lock() + defer e.Unlock() + + ready := waiter.EventMask(0) + switch { + case e.Connected(): + if mask&waiter.EventIn != 0 && e.receiver.Readable() { + ready |= waiter.EventIn + } + if mask&waiter.EventOut != 0 && e.connected.Writable() { + ready |= waiter.EventOut + } + case e.Listening(): + if mask&waiter.EventIn != 0 && len(e.acceptedChan) > 0 { + ready |= waiter.EventIn + } + } + + return ready +} diff --git a/pkg/tcpip/transport/unix/connectioned_state.go b/pkg/tcpip/transport/unix/connectioned_state.go new file mode 100644 index 000000000..5d835c8b2 --- /dev/null +++ b/pkg/tcpip/transport/unix/connectioned_state.go @@ -0,0 +1,43 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package unix + +// saveAcceptedChan is invoked by stateify. +func (e *connectionedEndpoint) saveAcceptedChan() []*connectionedEndpoint { + // If acceptedChan is nil (i.e. we are not listening) then we will save nil. + // Otherwise we create a (possibly empty) slice of the values in acceptedChan and + // save that. + var acceptedSlice []*connectionedEndpoint + if e.acceptedChan != nil { + // Swap out acceptedChan with a new empty channel of the same capacity. + saveChan := e.acceptedChan + e.acceptedChan = make(chan *connectionedEndpoint, cap(saveChan)) + + // Create a new slice with the same len and capacity as the channel. + acceptedSlice = make([]*connectionedEndpoint, len(saveChan), cap(saveChan)) + // Drain acceptedChan into saveSlice, and fill up the new acceptChan at the + // same time. + for i := range acceptedSlice { + ep := <-saveChan + acceptedSlice[i] = ep + e.acceptedChan <- ep + } + close(saveChan) + } + return acceptedSlice +} + +// loadAcceptedChan is invoked by stateify. +func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEndpoint) { + // If acceptedSlice is nil, then acceptedChan should also be nil. + if acceptedSlice != nil { + // Otherwise, create a new channel with the same capacity as acceptedSlice. + e.acceptedChan = make(chan *connectionedEndpoint, cap(acceptedSlice)) + // Seed the channel with values from acceptedSlice. + for _, ep := range acceptedSlice { + e.acceptedChan <- ep + } + } +} diff --git a/pkg/tcpip/transport/unix/connectionless.go b/pkg/tcpip/transport/unix/connectionless.go new file mode 100644 index 000000000..34d34f99a --- /dev/null +++ b/pkg/tcpip/transport/unix/connectionless.go @@ -0,0 +1,176 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package unix + +import ( + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/queue" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// connectionlessEndpoint is a unix endpoint for unix sockets that support operating in +// a conectionless fashon. +// +// Specifically, this means datagram unix sockets not created with +// socketpair(2). +type connectionlessEndpoint struct { + baseEndpoint +} + +// NewConnectionless creates a new unbound dgram endpoint. +func NewConnectionless() Endpoint { + ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}} + ep.receiver = &queueReceiver{readQueue: queue.New(&waiter.Queue{}, ep.Queue, initialLimit)} + return ep +} + +// isBound returns true iff the endpoint is bound. +func (e *connectionlessEndpoint) isBound() bool { + return e.path != "" +} + +// Close puts the endpoint in a closed state and frees all resources associated +// with it. +// +// The socket will be a fresh state after a call to close and may be reused. +// That is, close may be used to "unbind" or "disconnect" the socket in error +// paths. +func (e *connectionlessEndpoint) Close() { + e.Lock() + var r Receiver + if e.Connected() { + e.receiver.CloseRecv() + r = e.receiver + e.receiver = nil + + e.connected.Release() + e.connected = nil + } + if e.isBound() { + e.path = "" + } + e.Unlock() + if r != nil { + r.CloseNotify() + r.Release() + } +} + +// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. +func (e *connectionlessEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *tcpip.Error { + return tcpip.ErrConnectionRefused +} + +// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect. +func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *tcpip.Error) { + return &connectedEndpoint{ + endpoint: e, + writeQueue: e.receiver.(*queueReceiver).readQueue, + }, nil +} + +// SendMsg writes data and a control message to the specified endpoint. +// This method does not block if the data cannot be written. +func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *tcpip.Error) { + if to == nil { + return e.baseEndpoint.SendMsg(data, c, nil) + } + + connected, err := to.UnidirectionalConnect() + if err != nil { + return 0, tcpip.ErrInvalidEndpointState + } + defer connected.Release() + + e.Lock() + n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + e.Unlock() + if err != nil { + return 0, err + } + if notify { + connected.SendNotify() + } + + return n, nil +} + +// Type implements Endpoint.Type. +func (e *connectionlessEndpoint) Type() SockType { + return SockDgram +} + +// Connect attempts to connect directly to server. +func (e *connectionlessEndpoint) Connect(server BoundEndpoint) *tcpip.Error { + connected, err := server.UnidirectionalConnect() + if err != nil { + return err + } + + e.Lock() + e.connected = connected + e.Unlock() + + return nil +} + +// Listen starts listening on the connection. +func (e *connectionlessEndpoint) Listen(int) *tcpip.Error { + return tcpip.ErrNotSupported +} + +// Accept accepts a new connection. +func (e *connectionlessEndpoint) Accept() (Endpoint, *tcpip.Error) { + return nil, tcpip.ErrNotSupported +} + +// Bind binds the connection. +// +// For Unix endpoints, this _only sets the address associated with the socket_. +// Work associated with sockets in the filesystem or finding those sockets must +// be done by a higher level. +// +// Bind will fail only if the socket is connected, bound or the passed address +// is invalid (the empty string). +func (e *connectionlessEndpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + e.Lock() + defer e.Unlock() + if e.isBound() { + return tcpip.ErrAlreadyBound + } + if addr.Addr == "" { + // The empty string is not permitted. + return tcpip.ErrBadLocalAddress + } + if commit != nil { + if err := commit(); err != nil { + return err + } + } + + // Save the bound address. + e.path = string(addr.Addr) + return nil +} + +// Readiness returns the current readiness of the endpoint. For example, if +// waiter.EventIn is set, the endpoint is immediately readable. +func (e *connectionlessEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + e.Lock() + defer e.Unlock() + + ready := waiter.EventMask(0) + if mask&waiter.EventIn != 0 && e.receiver.Readable() { + ready |= waiter.EventIn + } + + if e.Connected() { + if mask&waiter.EventOut != 0 && e.connected.Writable() { + ready |= waiter.EventOut + } + } + + return ready +} diff --git a/pkg/tcpip/transport/unix/unix.go b/pkg/tcpip/transport/unix/unix.go new file mode 100644 index 000000000..5fe37eb71 --- /dev/null +++ b/pkg/tcpip/transport/unix/unix.go @@ -0,0 +1,902 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package unix contains the implementation of Unix endpoints. +package unix + +import ( + "sync" + "sync/atomic" + + "gvisor.googlesource.com/gvisor/pkg/ilist" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/queue" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// initialLimit is the starting limit for the socket buffers. +const initialLimit = 16 * 1024 + +// A SockType is a type (as opposed to family) of sockets. These are enumerated +// in the syscall package as syscall.SOCK_* constants. +type SockType int + +const ( + // SockStream corresponds to syscall.SOCK_STREAM. + SockStream SockType = 1 + // SockDgram corresponds to syscall.SOCK_DGRAM. + SockDgram SockType = 2 + // SockRaw corresponds to syscall.SOCK_RAW. + SockRaw SockType = 3 + // SockSeqpacket corresponds to syscall.SOCK_SEQPACKET. + SockSeqpacket SockType = 5 +) + +// A RightsControlMessage is a control message containing FDs. +type RightsControlMessage interface { + // Clone returns a copy of the RightsControlMessage. + Clone() RightsControlMessage + + // Release releases any resources owned by the RightsControlMessage. + Release() +} + +// A CredentialsControlMessage is a control message containing Unix credentials. +type CredentialsControlMessage interface { + // Equals returns true iff the two messages are equal. + Equals(CredentialsControlMessage) bool +} + +// A ControlMessages represents a collection of socket control messages. +type ControlMessages struct { + // Rights is a control message containing FDs. + Rights RightsControlMessage + + // Credentials is a control message containing Unix credentials. + Credentials CredentialsControlMessage +} + +// Empty returns true iff the ControlMessages does not contain either +// credentials or rights. +func (c *ControlMessages) Empty() bool { + return c.Rights == nil && c.Credentials == nil +} + +// Clone clones both the credentials and the rights. +func (c *ControlMessages) Clone() ControlMessages { + cm := ControlMessages{} + if c.Rights != nil { + cm.Rights = c.Rights.Clone() + } + cm.Credentials = c.Credentials + return cm +} + +// Release releases both the credentials and the rights. +func (c *ControlMessages) Release() { + if c.Rights != nil { + c.Rights.Release() + } + *c = ControlMessages{} +} + +// Endpoint is the interface implemented by Unix transport protocol +// implementations that expose functionality like sendmsg, recvmsg, connect, +// etc. to Unix socket implementations. +type Endpoint interface { + Credentialer + waiter.Waitable + + // Close puts the endpoint in a closed state and frees all resources + // associated with it. + Close() + + // RecvMsg reads data and a control message from the endpoint. This method + // does not block if there is no data pending. + // + // creds indicates if credential control messages are requested by the + // caller. This is useful for determining if control messages can be + // coalesced. creds is a hint and can be safely ignored by the + // implementation if no coalescing is possible. It is fine to return + // credential control messages when none were requested or to not return + // credential control messages when they were requested. + // + // numRights is the number of SCM_RIGHTS FDs requested by the caller. This + // is useful if one must allocate a buffer to receive a SCM_RIGHTS message + // or determine if control messages can be coalesced. numRights is a hint + // and can be safely ignored by the implementation if the number of + // available SCM_RIGHTS FDs is known and no coalescing is possible. It is + // fine for the returned number of SCM_RIGHTS FDs to be either higher or + // lower than the requested number. + // + // If peek is true, no data should be consumed from the Endpoint. Any and + // all data returned from a peek should be available in the next call to + // RecvMsg. + // + // recvLen is the number of bytes copied into data. + // + // msgLen is the length of the read message consumed for datagram Endpoints. + // msgLen is always the same as recvLen for stream Endpoints. + RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, err *tcpip.Error) + + // SendMsg writes data and a control message to the endpoint's peer. + // This method does not block if the data cannot be written. + // + // SendMsg does not take ownership of any of its arguments on error. + SendMsg([][]byte, ControlMessages, BoundEndpoint) (uintptr, *tcpip.Error) + + // Connect connects this endpoint directly to another. + // + // This should be called on the client endpoint, and the (bound) + // endpoint passed in as a parameter. + // + // The error codes are the same as Connect. + Connect(server BoundEndpoint) *tcpip.Error + + // Shutdown closes the read and/or write end of the endpoint connection + // to its peer. + Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error + + // Listen puts the endpoint in "listen" mode, which allows it to accept + // new connections. + Listen(backlog int) *tcpip.Error + + // Accept returns a new endpoint if a peer has established a connection + // to an endpoint previously set to listen mode. This method does not + // block if no new connections are available. + // + // The returned Queue is the wait queue for the newly created endpoint. + Accept() (Endpoint, *tcpip.Error) + + // Bind binds the endpoint to a specific local address and port. + // Specifying a NIC is optional. + // + // An optional commit function will be executed atomically with respect + // to binding the endpoint. If this returns an error, the bind will not + // occur and the error will be propagated back to the caller. + Bind(address tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error + + // Type return the socket type, typically either SockStream, SockDgram + // or SockSeqpacket. + Type() SockType + + // GetLocalAddress returns the address to which the endpoint is bound. + GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + + // GetRemoteAddress returns the address to which the endpoint is + // connected. + GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) + + // SetSockOpt sets a socket option. opt should be one of the tcpip.*Option + // types. + SetSockOpt(opt interface{}) *tcpip.Error + + // GetSockOpt gets a socket option. opt should be a pointer to one of the + // tcpip.*Option types. + GetSockOpt(opt interface{}) *tcpip.Error +} + +// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket +// option. +type Credentialer interface { + // Passcred returns whether or not the SO_PASSCRED socket option is + // enabled on this end. + Passcred() bool + + // ConnectedPasscred returns whether or not the SO_PASSCRED socket option + // is enabled on the connected end. + ConnectedPasscred() bool +} + +// A BoundEndpoint is a unix endpoint that can be connected to. +type BoundEndpoint interface { + // BidirectionalConnect establishes a bi-directional connection between two + // unix endpoints in an all-or-nothing manner. If an error occurs during + // connecting, the state of neither endpoint should be modified. + // + // In order for an endpoint to establish such a bidirectional connection + // with a BoundEndpoint, the endpoint calls the BidirectionalConnect method + // on the BoundEndpoint and sends a representation of itself (the + // ConnectingEndpoint) and a callback (returnConnect) to receive the + // connection information (Receiver and ConnectedEndpoint) upon a + // successful connect. The callback should only be called on a successful + // connect. + // + // For a connection attempt to be successful, the ConnectingEndpoint must + // be unconnected and not listening and the BoundEndpoint whose + // BidirectionalConnect method is being called must be listening. + // + // This method will return tcpip.ErrConnectionRefused on endpoints with a + // type that isn't SockStream or SockSeqpacket. + BidirectionalConnect(ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *tcpip.Error + + // UnidirectionalConnect establishes a write-only connection to a unix endpoint. + // + // This method will return tcpip.ErrConnectionRefused on a non-SockDgram + // endpoint. + UnidirectionalConnect() (ConnectedEndpoint, *tcpip.Error) + + // Release releases any resources held by the BoundEndpoint. It must be + // called before dropping all references to a BoundEndpoint returned by a + // function. + Release() +} + +// message represents a message passed over a Unix domain socket. +type message struct { + ilist.Entry + + // Data is the Message payload. + Data buffer.View + + // Control is auxiliary control message data that goes along with the + // data. + Control ControlMessages + + // Address is the bound address of the endpoint that sent the message. + // + // If the endpoint that sent the message is not bound, the Address is + // the empty string. + Address tcpip.FullAddress +} + +// Length returns number of bytes stored in the Message. +func (m *message) Length() int64 { + return int64(len(m.Data)) +} + +// Release releases any resources held by the Message. +func (m *message) Release() { + m.Control.Release() +} + +func (m *message) Peek() queue.Entry { + return &message{Data: m.Data, Control: m.Control.Clone(), Address: m.Address} +} + +// A Receiver can be used to receive Messages. +type Receiver interface { + // Recv receives a single message. This method does not block. + // + // See Endpoint.RecvMsg for documentation on shared arguments. + // + // notify indicates if RecvNotify should be called. + Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, source tcpip.FullAddress, notify bool, err *tcpip.Error) + + // RecvNotify notifies the Receiver of a successful Recv. This must not be + // called while holding any endpoint locks. + RecvNotify() + + // CloseRecv prevents the receiving of additional Messages. + // + // After CloseRecv is called, CloseNotify must also be called. + CloseRecv() + + // CloseNotify notifies the Receiver of recv being closed. This must not be + // called while holding any endpoint locks. + CloseNotify() + + // Readable returns if messages should be attempted to be received. This + // includes when read has been shutdown. + Readable() bool + + // RecvQueuedSize returns the total amount of data currently receivable. + // RecvQueuedSize should return -1 if the operation isn't supported. + RecvQueuedSize() int64 + + // RecvMaxQueueSize returns maximum value for RecvQueuedSize. + // RecvMaxQueueSize should return -1 if the operation isn't supported. + RecvMaxQueueSize() int64 + + // Release releases any resources owned by the Receiver. It should be + // called before droping all references to a Receiver. + Release() +} + +// queueReceiver implements Receiver for datagram sockets. +type queueReceiver struct { + readQueue *queue.Queue +} + +// Recv implements Receiver.Recv. +func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) { + var m queue.Entry + var notify bool + var err *tcpip.Error + if peek { + m, err = q.readQueue.Peek() + } else { + m, notify, err = q.readQueue.Dequeue() + } + if err != nil { + return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err + } + msg := m.(*message) + src := []byte(msg.Data) + var copied uintptr + for i := 0; i < len(data) && len(src) > 0; i++ { + n := copy(data[i], src) + copied += uintptr(n) + src = src[n:] + } + return copied, uintptr(len(msg.Data)), msg.Control, msg.Address, notify, nil +} + +// RecvNotify implements Receiver.RecvNotify. +func (q *queueReceiver) RecvNotify() { + q.readQueue.WriterQueue.Notify(waiter.EventOut) +} + +// CloseNotify implements Receiver.CloseNotify. +func (q *queueReceiver) CloseNotify() { + q.readQueue.ReaderQueue.Notify(waiter.EventIn) + q.readQueue.WriterQueue.Notify(waiter.EventOut) +} + +// CloseRecv implements Receiver.CloseRecv. +func (q *queueReceiver) CloseRecv() { + q.readQueue.Close() +} + +// Readable implements Receiver.Readable. +func (q *queueReceiver) Readable() bool { + return q.readQueue.IsReadable() +} + +// RecvQueuedSize implements Receiver.RecvQueuedSize. +func (q *queueReceiver) RecvQueuedSize() int64 { + return q.readQueue.QueuedSize() +} + +// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize. +func (q *queueReceiver) RecvMaxQueueSize() int64 { + return q.readQueue.MaxQueueSize() +} + +// Release implements Receiver.Release. +func (*queueReceiver) Release() {} + +// streamQueueReceiver implements Receiver for stream sockets. +type streamQueueReceiver struct { + queueReceiver + + mu sync.Mutex `state:"nosave"` + buffer []byte + control ControlMessages + addr tcpip.FullAddress +} + +func vecCopy(data [][]byte, buf []byte) (uintptr, [][]byte, []byte) { + var copied uintptr + for len(data) > 0 && len(buf) > 0 { + n := copy(data[0], buf) + copied += uintptr(n) + buf = buf[n:] + data[0] = data[0][n:] + if len(data[0]) == 0 { + data = data[1:] + } + } + return copied, data, buf +} + +// Readable implements Receiver.Readable. +func (q *streamQueueReceiver) Readable() bool { + // We're readable if we have data in our buffer or if the queue receiver is + // readable. + return len(q.buffer) > 0 || q.readQueue.IsReadable() +} + +// RecvQueuedSize implements Receiver.RecvQueuedSize. +func (q *streamQueueReceiver) RecvQueuedSize() int64 { + return int64(len(q.buffer)) + q.readQueue.QueuedSize() +} + +// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize. +func (q *streamQueueReceiver) RecvMaxQueueSize() int64 { + // The RecvMaxQueueSize() is the readQueue's MaxQueueSize() plus the largest + // message we can buffer which is also the largest message we can receive. + return 2 * q.readQueue.MaxQueueSize() +} + +// Recv implements Receiver.Recv. +func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) { + q.mu.Lock() + defer q.mu.Unlock() + + var notify bool + + // If we have no data in the endpoint, we need to get some. + if len(q.buffer) == 0 { + // Load the next message into a buffer, even if we are peeking. Peeking + // won't consume the message, so it will be still available to be read + // the next time Recv() is called. + m, n, err := q.readQueue.Dequeue() + if err != nil { + return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err + } + notify = n + msg := m.(*message) + q.buffer = []byte(msg.Data) + q.control = msg.Control + q.addr = msg.Address + } + + var copied uintptr + if peek { + // Don't consume control message if we are peeking. + c := q.control.Clone() + + // Don't consume data since we are peeking. + copied, data, _ = vecCopy(data, q.buffer) + + return copied, copied, c, q.addr, notify, nil + } + + // Consume data and control message since we are not peeking. + copied, data, q.buffer = vecCopy(data, q.buffer) + + // Save the original state of q.control. + c := q.control + + // Remove rights from q.control and leave behind just the creds. + q.control.Rights = nil + if !wantCreds { + c.Credentials = nil + } + + if c.Rights != nil && numRights == 0 { + c.Rights.Release() + c.Rights = nil + } + + haveRights := c.Rights != nil + + // If we have more capacity for data and haven't received any usable + // rights. + // + // Linux never coalesces rights control messages. + for !haveRights && len(data) > 0 { + // Get a message from the readQueue. + m, n, err := q.readQueue.Dequeue() + if err != nil { + // We already got some data, so ignore this error. This will + // manifest as a short read to the user, which is what Linux + // does. + break + } + notify = notify || n + msg := m.(*message) + q.buffer = []byte(msg.Data) + q.control = msg.Control + q.addr = msg.Address + + if wantCreds { + if (q.control.Credentials == nil) != (c.Credentials == nil) { + // One message has credentials, the other does not. + break + } + + if q.control.Credentials != nil && c.Credentials != nil && !q.control.Credentials.Equals(c.Credentials) { + // Both messages have credentials, but they don't match. + break + } + } + + if numRights != 0 && c.Rights != nil && q.control.Rights != nil { + // Both messages have rights. + break + } + + var cpd uintptr + cpd, data, q.buffer = vecCopy(data, q.buffer) + copied += cpd + + if cpd == 0 { + // data was actually full. + break + } + + if q.control.Rights != nil { + // Consume rights. + if numRights == 0 { + q.control.Rights.Release() + } else { + c.Rights = q.control.Rights + haveRights = true + } + q.control.Rights = nil + } + } + return copied, copied, c, q.addr, notify, nil +} + +// A ConnectedEndpoint is an Endpoint that can be used to send Messages. +type ConnectedEndpoint interface { + // Passcred implements Endpoint.Passcred. + Passcred() bool + + // GetLocalAddress implements Endpoint.GetLocalAddress. + GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + + // Send sends a single message. This method does not block. + // + // notify indicates if SendNotify should be called. + Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n uintptr, notify bool, err *tcpip.Error) + + // SendNotify notifies the ConnectedEndpoint of a successful Send. This + // must not be called while holding any endpoint locks. + SendNotify() + + // CloseSend prevents the sending of additional Messages. + // + // After CloseSend is call, CloseNotify must also be called. + CloseSend() + + // CloseNotify notifies the ConnectedEndpoint of send being closed. This + // must not be called while holding any endpoint locks. + CloseNotify() + + // Writable returns if messages should be attempted to be sent. This + // includes when write has been shutdown. + Writable() bool + + // EventUpdate lets the ConnectedEndpoint know that event registrations + // have changed. + EventUpdate() + + // SendQueuedSize returns the total amount of data currently queued for + // sending. SendQueuedSize should return -1 if the operation isn't + // supported. + SendQueuedSize() int64 + + // SendMaxQueueSize returns maximum value for SendQueuedSize. + // SendMaxQueueSize should return -1 if the operation isn't supported. + SendMaxQueueSize() int64 + + // Release releases any resources owned by the ConnectedEndpoint. It should + // be called before droping all references to a ConnectedEndpoint. + Release() +} + +type connectedEndpoint struct { + // endpoint represents the subset of the Endpoint functionality needed by + // the connectedEndpoint. It is implemented by both connectionedEndpoint + // and connectionlessEndpoint and allows the use of types which don't + // fully implement Endpoint. + endpoint interface { + // Passcred implements Endpoint.Passcred. + Passcred() bool + + // GetLocalAddress implements Endpoint.GetLocalAddress. + GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) + + // Type implements Endpoint.Type. + Type() SockType + } + + writeQueue *queue.Queue +} + +// Passcred implements ConnectedEndpoint.Passcred. +func (e *connectedEndpoint) Passcred() bool { + return e.endpoint.Passcred() +} + +// GetLocalAddress implements ConnectedEndpoint.GetLocalAddress. +func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return e.endpoint.GetLocalAddress() +} + +// Send implements ConnectedEndpoint.Send. +func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) { + var l int + for _, d := range data { + l += len(d) + } + // Discard empty stream packets. Since stream sockets don't preserve + // message boundaries, sending zero bytes is a no-op. In Linux, the + // receiver actually uses a zero-length receive as an indication that the + // stream was closed. + if l == 0 && e.endpoint.Type() == SockStream { + controlMessages.Release() + return 0, false, nil + } + v := make([]byte, 0, l) + for _, d := range data { + v = append(v, d...) + } + notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from}) + return uintptr(l), notify, err +} + +// SendNotify implements ConnectedEndpoint.SendNotify. +func (e *connectedEndpoint) SendNotify() { + e.writeQueue.ReaderQueue.Notify(waiter.EventIn) +} + +// CloseNotify implements ConnectedEndpoint.CloseNotify. +func (e *connectedEndpoint) CloseNotify() { + e.writeQueue.ReaderQueue.Notify(waiter.EventIn) + e.writeQueue.WriterQueue.Notify(waiter.EventOut) +} + +// CloseSend implements ConnectedEndpoint.CloseSend. +func (e *connectedEndpoint) CloseSend() { + e.writeQueue.Close() +} + +// Writable implements ConnectedEndpoint.Writable. +func (e *connectedEndpoint) Writable() bool { + return e.writeQueue.IsWritable() +} + +// EventUpdate implements ConnectedEndpoint.EventUpdate. +func (*connectedEndpoint) EventUpdate() {} + +// SendQueuedSize implements ConnectedEndpoint.SendQueuedSize. +func (e *connectedEndpoint) SendQueuedSize() int64 { + return e.writeQueue.QueuedSize() +} + +// SendMaxQueueSize implements ConnectedEndpoint.SendMaxQueueSize. +func (e *connectedEndpoint) SendMaxQueueSize() int64 { + return e.writeQueue.MaxQueueSize() +} + +// Release implements ConnectedEndpoint.Release. +func (*connectedEndpoint) Release() {} + +// baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless +// unix domain socket Endpoint implementations. +// +// Not to be used on its own. +type baseEndpoint struct { + *waiter.Queue + + // passcred specifies whether SCM_CREDENTIALS socket control messages are + // enabled on this endpoint. Must be accessed atomically. + passcred int32 + + // Mutex protects the below fields. + sync.Mutex `state:"nosave"` + + // receiver allows Messages to be received. + receiver Receiver + + // connected allows messages to be sent and state information about the + // connected endpoint to be read. + connected ConnectedEndpoint + + // path is not empty if the endpoint has been bound, + // or may be used if the endpoint is connected. + path string +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) { + e.Lock() + e.Queue.EventRegister(we, mask) + if e.connected != nil { + e.connected.EventUpdate() + } + e.Unlock() +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (e *baseEndpoint) EventUnregister(we *waiter.Entry) { + e.Lock() + e.Queue.EventUnregister(we) + if e.connected != nil { + e.connected.EventUpdate() + } + e.Unlock() +} + +// Passcred implements Credentialer.Passcred. +func (e *baseEndpoint) Passcred() bool { + return atomic.LoadInt32(&e.passcred) != 0 +} + +// ConnectedPasscred implements Credentialer.ConnectedPasscred. +func (e *baseEndpoint) ConnectedPasscred() bool { + e.Lock() + defer e.Unlock() + return e.connected != nil && e.connected.Passcred() +} + +func (e *baseEndpoint) setPasscred(pc bool) { + if pc { + atomic.StoreInt32(&e.passcred, 1) + } else { + atomic.StoreInt32(&e.passcred, 0) + } +} + +// Connected implements ConnectingEndpoint.Connected. +func (e *baseEndpoint) Connected() bool { + return e.receiver != nil && e.connected != nil +} + +// RecvMsg reads data and a control message from the endpoint. +func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, *tcpip.Error) { + e.Lock() + + if e.receiver == nil { + e.Unlock() + return 0, 0, ControlMessages{}, tcpip.ErrNotConnected + } + + recvLen, msgLen, cms, a, notify, err := e.receiver.Recv(data, creds, numRights, peek) + e.Unlock() + if err != nil { + return 0, 0, ControlMessages{}, err + } + + if notify { + e.receiver.RecvNotify() + } + + if addr != nil { + *addr = a + } + return recvLen, msgLen, cms, nil +} + +// SendMsg writes data and a control message to the endpoint's peer. +// This method does not block if the data cannot be written. +func (e *baseEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *tcpip.Error) { + e.Lock() + if !e.Connected() { + e.Unlock() + return 0, tcpip.ErrNotConnected + } + if to != nil { + e.Unlock() + return 0, tcpip.ErrAlreadyConnected + } + + n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + e.Unlock() + if err != nil { + return 0, err + } + + if notify { + e.connected.SendNotify() + } + + return n, nil +} + +// SetSockOpt sets a socket option. Currently not supported. +func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { + switch v := opt.(type) { + case tcpip.PasscredOption: + e.setPasscred(v != 0) + return nil + } + return nil +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + case *tcpip.SendQueueSizeOption: + e.Lock() + if !e.Connected() { + e.Unlock() + return tcpip.ErrNotConnected + } + qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize()) + e.Unlock() + if qs < 0 { + return tcpip.ErrQueueSizeNotSupported + } + *o = qs + return nil + case *tcpip.ReceiveQueueSizeOption: + e.Lock() + if !e.Connected() { + e.Unlock() + return tcpip.ErrNotConnected + } + qs := tcpip.ReceiveQueueSizeOption(e.receiver.RecvQueuedSize()) + e.Unlock() + if qs < 0 { + return tcpip.ErrQueueSizeNotSupported + } + *o = qs + return nil + case *tcpip.PasscredOption: + if e.Passcred() { + *o = tcpip.PasscredOption(1) + } else { + *o = tcpip.PasscredOption(0) + } + return nil + case *tcpip.SendBufferSizeOption: + e.Lock() + if !e.Connected() { + e.Unlock() + return tcpip.ErrNotConnected + } + qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize()) + e.Unlock() + if qs < 0 { + return tcpip.ErrQueueSizeNotSupported + } + *o = qs + return nil + case *tcpip.ReceiveBufferSizeOption: + e.Lock() + if e.receiver == nil { + e.Unlock() + return tcpip.ErrNotConnected + } + qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize()) + e.Unlock() + if qs < 0 { + return tcpip.ErrQueueSizeNotSupported + } + *o = qs + return nil + } + return tcpip.ErrUnknownProtocolOption +} + +// Shutdown closes the read and/or write end of the endpoint connection to its +// peer. +func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { + e.Lock() + if !e.Connected() { + e.Unlock() + return tcpip.ErrNotConnected + } + + if flags&tcpip.ShutdownRead != 0 { + e.receiver.CloseRecv() + } + + if flags&tcpip.ShutdownWrite != 0 { + e.connected.CloseSend() + } + + e.Unlock() + + if flags&tcpip.ShutdownRead != 0 { + e.receiver.CloseNotify() + } + + if flags&tcpip.ShutdownWrite != 0 { + e.connected.CloseNotify() + } + + return nil +} + +// GetLocalAddress returns the bound path. +func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + e.Lock() + defer e.Unlock() + return tcpip.FullAddress{Addr: tcpip.Address(e.path)}, nil +} + +// GetRemoteAddress returns the local address of the connected endpoint (if +// available). +func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + e.Lock() + c := e.connected + e.Unlock() + if c != nil { + return c.GetLocalAddress() + } + return tcpip.FullAddress{}, tcpip.ErrNotConnected +} + +// Release implements BoundEndpoint.Release. +func (*baseEndpoint) Release() {} |