summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
authorGoogler <noreply@google.com>2018-04-27 10:37:02 -0700
committerAdin Scannell <ascannell@google.com>2018-04-28 01:44:26 -0400
commitd02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch)
tree54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/tcpip
parentf70210e742919f40aa2f0934a22f1c9ba6dada62 (diff)
Check in gVisor.
PiperOrigin-RevId: 194583126 Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/BUILD35
-rw-r--r--pkg/tcpip/adapters/gonet/BUILD36
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go605
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go424
-rw-r--r--pkg/tcpip/buffer/BUILD32
-rw-r--r--pkg/tcpip/buffer/prependable.go53
-rw-r--r--pkg/tcpip/buffer/view.go181
-rw-r--r--pkg/tcpip/buffer/view_test.go212
-rw-r--r--pkg/tcpip/checker/BUILD16
-rw-r--r--pkg/tcpip/checker/checker.go517
-rw-r--r--pkg/tcpip/header/BUILD51
-rw-r--r--pkg/tcpip/header/arp.go90
-rw-r--r--pkg/tcpip/header/checksum.go46
-rw-r--r--pkg/tcpip/header/eth.go64
-rw-r--r--pkg/tcpip/header/gue.go63
-rw-r--r--pkg/tcpip/header/icmpv4.go98
-rw-r--r--pkg/tcpip/header/icmpv6.go111
-rw-r--r--pkg/tcpip/header/interfaces.go82
-rw-r--r--pkg/tcpip/header/ipv4.go251
-rw-r--r--pkg/tcpip/header/ipv6.go191
-rw-r--r--pkg/tcpip/header/ipv6_fragment.go136
-rw-r--r--pkg/tcpip/header/ipversion_test.go57
-rw-r--r--pkg/tcpip/header/tcp.go518
-rw-r--r--pkg/tcpip/header/tcp_test.go134
-rw-r--r--pkg/tcpip/header/udp.go106
-rw-r--r--pkg/tcpip/link/channel/BUILD15
-rw-r--r--pkg/tcpip/link/channel/channel.go110
-rw-r--r--pkg/tcpip/link/fdbased/BUILD32
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go261
-rw-r--r--pkg/tcpip/link/fdbased/endpoint_test.go336
-rw-r--r--pkg/tcpip/link/loopback/BUILD15
-rw-r--r--pkg/tcpip/link/loopback/loopback.go74
-rw-r--r--pkg/tcpip/link/rawfile/BUILD17
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_amd64.s26
-rw-r--r--pkg/tcpip/link/rawfile/errors.go40
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go161
-rw-r--r--pkg/tcpip/link/sharedmem/BUILD42
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/BUILD23
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe.go68
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_test.go507
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go25
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/rx.go83
-rw-r--r--pkg/tcpip/link/sharedmem/pipe/tx.go151
-rw-r--r--pkg/tcpip/link/sharedmem/queue/BUILD28
-rw-r--r--pkg/tcpip/link/sharedmem/queue/queue_test.go507
-rw-r--r--pkg/tcpip/link/sharedmem/queue/rx.go211
-rw-r--r--pkg/tcpip/link/sharedmem/queue/tx.go141
-rw-r--r--pkg/tcpip/link/sharedmem/rx.go147
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem.go240
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_test.go703
-rw-r--r--pkg/tcpip/link/sharedmem/sharedmem_unsafe.go15
-rw-r--r--pkg/tcpip/link/sharedmem/tx.go262
-rw-r--r--pkg/tcpip/link/sniffer/BUILD23
-rw-r--r--pkg/tcpip/link/sniffer/pcap.go52
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go310
-rw-r--r--pkg/tcpip/link/tun/BUILD12
-rw-r--r--pkg/tcpip/link/tun/tun_unsafe.go50
-rw-r--r--pkg/tcpip/link/waitable/BUILD33
-rw-r--r--pkg/tcpip/link/waitable/waitable.go108
-rw-r--r--pkg/tcpip/link/waitable/waitable_test.go144
-rw-r--r--pkg/tcpip/network/BUILD19
-rw-r--r--pkg/tcpip/network/arp/BUILD34
-rw-r--r--pkg/tcpip/network/arp/arp.go170
-rw-r--r--pkg/tcpip/network/arp/arp_test.go138
-rw-r--r--pkg/tcpip/network/fragmentation/BUILD61
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap.go67
-rw-r--r--pkg/tcpip/network/fragmentation/frag_heap_test.go112
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation.go124
-rw-r--r--pkg/tcpip/network/fragmentation/fragmentation_test.go166
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler.go109
-rw-r--r--pkg/tcpip/network/fragmentation/reassembler_test.go95
-rw-r--r--pkg/tcpip/network/hash/BUILD11
-rw-r--r--pkg/tcpip/network/hash/hash.go83
-rw-r--r--pkg/tcpip/network/ip_test.go560
-rw-r--r--pkg/tcpip/network/ipv4/BUILD38
-rw-r--r--pkg/tcpip/network/ipv4/icmp.go282
-rw-r--r--pkg/tcpip/network/ipv4/icmp_test.go124
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go233
-rw-r--r--pkg/tcpip/network/ipv6/BUILD21
-rw-r--r--pkg/tcpip/network/ipv6/icmp.go80
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go172
-rw-r--r--pkg/tcpip/ports/BUILD20
-rw-r--r--pkg/tcpip/ports/ports.go148
-rw-r--r--pkg/tcpip/ports/ports_test.go134
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/BUILD20
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go208
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/BUILD20
-rw-r--r--pkg/tcpip/sample/tun_tcp_echo/main.go182
-rw-r--r--pkg/tcpip/seqnum/BUILD26
-rw-r--r--pkg/tcpip/seqnum/seqnum.go57
-rw-r--r--pkg/tcpip/stack/BUILD70
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go313
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go256
-rw-r--r--pkg/tcpip/stack/nic.go453
-rw-r--r--pkg/tcpip/stack/registration.go322
-rw-r--r--pkg/tcpip/stack/route.go133
-rw-r--r--pkg/tcpip/stack/stack.go811
-rw-r--r--pkg/tcpip/stack/stack_global_state.go9
-rw-r--r--pkg/tcpip/stack/stack_test.go760
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go166
-rw-r--r--pkg/tcpip/stack/transport_test.go420
-rw-r--r--pkg/tcpip/tcpip.go499
-rw-r--r--pkg/tcpip/tcpip_test.go130
-rw-r--r--pkg/tcpip/transport/queue/BUILD29
-rw-r--r--pkg/tcpip/transport/queue/queue.go166
-rw-r--r--pkg/tcpip/transport/tcp/BUILD97
-rw-r--r--pkg/tcpip/transport/tcp/accept.go407
-rw-r--r--pkg/tcpip/transport/tcp/connect.go953
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go550
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go1371
-rw-r--r--pkg/tcpip/transport/tcp/endpoint_state.go128
-rw-r--r--pkg/tcpip/transport/tcp/forwarder.go161
-rw-r--r--pkg/tcpip/transport/tcp/protocol.go192
-rw-r--r--pkg/tcpip/transport/tcp/rcv.go208
-rw-r--r--pkg/tcpip/transport/tcp/sack.go85
-rw-r--r--pkg/tcpip/transport/tcp/segment.go145
-rw-r--r--pkg/tcpip/transport/tcp/segment_heap.go36
-rw-r--r--pkg/tcpip/transport/tcp/segment_queue.go69
-rw-r--r--pkg/tcpip/transport/tcp/snd.go628
-rw-r--r--pkg/tcpip/transport/tcp/tcp_sack_test.go336
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go2759
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go302
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/BUILD27
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go900
-rw-r--r--pkg/tcpip/transport/tcp/timer.go131
-rw-r--r--pkg/tcpip/transport/tcpconntrack/BUILD24
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go333
-rw-r--r--pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go501
-rw-r--r--pkg/tcpip/transport/udp/BUILD77
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go746
-rw-r--r--pkg/tcpip/transport/udp/endpoint_state.go91
-rw-r--r--pkg/tcpip/transport/udp/protocol.go73
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go625
-rw-r--r--pkg/tcpip/transport/unix/BUILD37
-rw-r--r--pkg/tcpip/transport/unix/connectioned.go431
-rw-r--r--pkg/tcpip/transport/unix/connectioned_state.go43
-rw-r--r--pkg/tcpip/transport/unix/connectionless.go176
-rw-r--r--pkg/tcpip/transport/unix/unix.go902
138 files changed, 30676 insertions, 0 deletions
diff --git a/pkg/tcpip/BUILD b/pkg/tcpip/BUILD
new file mode 100644
index 000000000..5c38a4961
--- /dev/null
+++ b/pkg/tcpip/BUILD
@@ -0,0 +1,35 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "tcpip_state",
+ srcs = [
+ "tcpip.go",
+ ],
+ out = "tcpip_state.go",
+ package = "tcpip",
+)
+
+go_library(
+ name = "tcpip",
+ srcs = [
+ "tcpip.go",
+ "tcpip_state.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/state",
+ "//pkg/tcpip/buffer",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "tcpip_test",
+ size = "small",
+ srcs = ["tcpip_test.go"],
+ embed = [":tcpip"],
+)
diff --git a/pkg/tcpip/adapters/gonet/BUILD b/pkg/tcpip/adapters/gonet/BUILD
new file mode 100644
index 000000000..69cfc84ab
--- /dev/null
+++ b/pkg/tcpip/adapters/gonet/BUILD
@@ -0,0 +1,36 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "gonet",
+ srcs = ["gonet.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/adapters/gonet",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "gonet_test",
+ size = "small",
+ srcs = ["gonet_test.go"],
+ embed = [":gonet"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/waiter",
+ "@org_golang_x_net//nettest:go_default_library",
+ ],
+)
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
new file mode 100644
index 000000000..96a2d670d
--- /dev/null
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -0,0 +1,605 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package gonet provides a Go net package compatible wrapper for a tcpip stack.
+package gonet
+
+import (
+ "errors"
+ "io"
+ "net"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+var errCanceled = errors.New("operation canceled")
+
+// timeoutError is how the net package reports timeouts.
+type timeoutError struct{}
+
+func (e *timeoutError) Error() string { return "i/o timeout" }
+func (e *timeoutError) Timeout() bool { return true }
+func (e *timeoutError) Temporary() bool { return true }
+
+// A Listener is a wrapper around a tcpip endpoint that implements
+// net.Listener.
+type Listener struct {
+ stack *stack.Stack
+ ep tcpip.Endpoint
+ wq *waiter.Queue
+ cancel chan struct{}
+}
+
+// NewListener creates a new Listener.
+func NewListener(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Listener, error) {
+ // Create TCP endpoint, bind it, then start listening.
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
+ if err != nil {
+ return nil, errors.New(err.String())
+ }
+
+ if err := ep.Bind(addr, nil); err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "bind",
+ Net: "tcp",
+ Addr: fullToTCPAddr(addr),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ if err := ep.Listen(10); err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "listen",
+ Net: "tcp",
+ Addr: fullToTCPAddr(addr),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ return &Listener{
+ stack: s,
+ ep: ep,
+ wq: &wq,
+ cancel: make(chan struct{}),
+ }, nil
+}
+
+// Close implements net.Listener.Close.
+func (l *Listener) Close() error {
+ l.ep.Close()
+ return nil
+}
+
+// Shutdown stops the HTTP server.
+func (l *Listener) Shutdown() {
+ l.ep.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+ close(l.cancel) // broadcast cancellation
+}
+
+// Addr implements net.Listener.Addr.
+func (l *Listener) Addr() net.Addr {
+ a, err := l.ep.GetLocalAddress()
+ if err != nil {
+ return nil
+ }
+ return fullToTCPAddr(a)
+}
+
+type deadlineTimer struct {
+ // mu protects the fields below.
+ mu sync.Mutex
+
+ readTimer *time.Timer
+ readCancelCh chan struct{}
+ writeTimer *time.Timer
+ writeCancelCh chan struct{}
+}
+
+func (d *deadlineTimer) init() {
+ d.readCancelCh = make(chan struct{})
+ d.writeCancelCh = make(chan struct{})
+}
+
+func (d *deadlineTimer) readCancel() <-chan struct{} {
+ d.mu.Lock()
+ c := d.readCancelCh
+ d.mu.Unlock()
+ return c
+}
+func (d *deadlineTimer) writeCancel() <-chan struct{} {
+ d.mu.Lock()
+ c := d.writeCancelCh
+ d.mu.Unlock()
+ return c
+}
+
+// setDeadline contains the shared logic for setting a deadline.
+//
+// cancelCh and timer must be pointers to deadlineTimer.readCancelCh and
+// deadlineTimer.readTimer or deadlineTimer.writeCancelCh and
+// deadlineTimer.writeTimer.
+//
+// setDeadline must only be called while holding d.mu.
+func (d *deadlineTimer) setDeadline(cancelCh *chan struct{}, timer **time.Timer, t time.Time) {
+ if *timer != nil && !(*timer).Stop() {
+ *cancelCh = make(chan struct{})
+ }
+
+ // Create a new channel if we already closed it due to setting an already
+ // expired time. We won't race with the timer because we already handled
+ // that above.
+ select {
+ case <-*cancelCh:
+ *cancelCh = make(chan struct{})
+ default:
+ }
+
+ // "A zero value for t means I/O operations will not time out."
+ // - net.Conn.SetDeadline
+ if t.IsZero() {
+ return
+ }
+
+ timeout := t.Sub(time.Now())
+ if timeout <= 0 {
+ close(*cancelCh)
+ return
+ }
+
+ // Timer.Stop returns whether or not the AfterFunc has started, but
+ // does not indicate whether or not it has completed. Make a copy of
+ // the cancel channel to prevent this code from racing with the next
+ // call of setDeadline replacing *cancelCh.
+ ch := *cancelCh
+ *timer = time.AfterFunc(timeout, func() {
+ close(ch)
+ })
+}
+
+// SetReadDeadline implements net.Conn.SetReadDeadline and
+// net.PacketConn.SetReadDeadline.
+func (d *deadlineTimer) SetReadDeadline(t time.Time) error {
+ d.mu.Lock()
+ d.setDeadline(&d.readCancelCh, &d.readTimer, t)
+ d.mu.Unlock()
+ return nil
+}
+
+// SetWriteDeadline implements net.Conn.SetWriteDeadline and
+// net.PacketConn.SetWriteDeadline.
+func (d *deadlineTimer) SetWriteDeadline(t time.Time) error {
+ d.mu.Lock()
+ d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
+ d.mu.Unlock()
+ return nil
+}
+
+// SetDeadline implements net.Conn.SetDeadline and net.PacketConn.SetDeadline.
+func (d *deadlineTimer) SetDeadline(t time.Time) error {
+ d.mu.Lock()
+ d.setDeadline(&d.readCancelCh, &d.readTimer, t)
+ d.setDeadline(&d.writeCancelCh, &d.writeTimer, t)
+ d.mu.Unlock()
+ return nil
+}
+
+// A Conn is a wrapper around a tcpip.Endpoint that implements the net.Conn
+// interface.
+type Conn struct {
+ deadlineTimer
+
+ wq *waiter.Queue
+ ep tcpip.Endpoint
+
+ // readMu serializes reads and implicitly protects read.
+ //
+ // Lock ordering:
+ // If both readMu and deadlineTimer.mu are to be used in a single
+ // request, readMu must be acquired before deadlineTimer.mu.
+ readMu sync.Mutex
+
+ // read contains bytes that have been read from the endpoint,
+ // but haven't yet been returned.
+ read buffer.View
+}
+
+// NewConn creates a new Conn.
+func NewConn(wq *waiter.Queue, ep tcpip.Endpoint) *Conn {
+ c := &Conn{
+ wq: wq,
+ ep: ep,
+ }
+ c.deadlineTimer.init()
+ return c
+}
+
+// Accept implements net.Conn.Accept.
+func (l *Listener) Accept() (net.Conn, error) {
+ n, wq, err := l.ep.Accept()
+
+ if err == tcpip.ErrWouldBlock {
+ // Create wait queue entry that notifies a channel.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ l.wq.EventRegister(&waitEntry, waiter.EventIn)
+ defer l.wq.EventUnregister(&waitEntry)
+
+ for {
+ n, wq, err = l.ep.Accept()
+
+ if err != tcpip.ErrWouldBlock {
+ break
+ }
+
+ select {
+ case <-l.cancel:
+ return nil, errCanceled
+ case <-notifyCh:
+ }
+ }
+ }
+
+ if err != nil {
+ return nil, &net.OpError{
+ Op: "accept",
+ Net: "tcp",
+ Addr: l.Addr(),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ return NewConn(wq, n), nil
+}
+
+type opErrorer interface {
+ newOpError(op string, err error) *net.OpError
+}
+
+// commonRead implements the common logic between net.Conn.Read and
+// net.PacketConn.ReadFrom.
+func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) ([]byte, error) {
+ read, err := ep.Read(addr)
+
+ if err == tcpip.ErrWouldBlock {
+ // Create wait queue entry that notifies a channel.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventIn)
+ defer wq.EventUnregister(&waitEntry)
+ for {
+ read, err = ep.Read(addr)
+ if err != tcpip.ErrWouldBlock {
+ break
+ }
+ select {
+ case <-deadline:
+ return nil, errorer.newOpError("read", &timeoutError{})
+ case <-notifyCh:
+ }
+ }
+ }
+
+ if err == tcpip.ErrClosedForReceive {
+ return nil, io.EOF
+ }
+
+ if err != nil {
+ return nil, errorer.newOpError("read", errors.New(err.String()))
+ }
+
+ return read, nil
+}
+
+// Read implements net.Conn.Read.
+func (c *Conn) Read(b []byte) (int, error) {
+ c.readMu.Lock()
+ defer c.readMu.Unlock()
+
+ deadline := c.readCancel()
+
+ // Check if deadline has already expired.
+ select {
+ case <-deadline:
+ return 0, c.newOpError("read", &timeoutError{})
+ default:
+ }
+
+ if len(c.read) == 0 {
+ var err error
+ c.read, err = commonRead(c.ep, c.wq, deadline, nil, c)
+ if err != nil {
+ return 0, err
+ }
+ }
+
+ n := copy(b, c.read)
+ c.read.TrimFront(n)
+ if len(c.read) == 0 {
+ c.read = nil
+ }
+ return n, nil
+}
+
+// Write implements net.Conn.Write.
+func (c *Conn) Write(b []byte) (int, error) {
+ deadline := c.writeCancel()
+
+ // Check if deadlineTimer has already expired.
+ select {
+ case <-deadline:
+ return 0, c.newOpError("write", &timeoutError{})
+ default:
+ }
+
+ v := buffer.NewView(len(b))
+ copy(v, b)
+
+ // We must handle two soft failure conditions simultaneously:
+ // 1. Write may write nothing and return tcpip.ErrWouldBlock.
+ // If this happens, we need to register for notifications if we have
+ // not already and wait to try again.
+ // 2. Write may write fewer than the full number of bytes and return
+ // without error. In this case we need to try writing the remaining
+ // bytes again. I do not need to register for notifications.
+ //
+ // What is more, these two soft failure conditions can be interspersed.
+ // There is no guarantee that all of the condition #1s will occur before
+ // all of the condition #2s or visa-versa.
+ var (
+ err *tcpip.Error
+ nbytes int
+ reg bool
+ notifyCh chan struct{}
+ )
+ for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) {
+ if err == tcpip.ErrWouldBlock {
+ if !reg {
+ // Only register once.
+ reg = true
+
+ // Create wait queue entry that notifies a channel.
+ var waitEntry waiter.Entry
+ waitEntry, notifyCh = waiter.NewChannelEntry(nil)
+ c.wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.wq.EventUnregister(&waitEntry)
+ } else {
+ // Don't wait immediately after registration in case more data
+ // became available between when we last checked and when we setup
+ // the notification.
+ select {
+ case <-deadline:
+ return nbytes, c.newOpError("write", &timeoutError{})
+ case <-notifyCh:
+ }
+ }
+ }
+
+ var n uintptr
+ n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ nbytes += int(n)
+ v.TrimFront(int(n))
+ }
+
+ if err == nil {
+ return nbytes, nil
+ }
+
+ return nbytes, c.newOpError("write", errors.New(err.String()))
+}
+
+// Close implements net.Conn.Close.
+func (c *Conn) Close() error {
+ c.ep.Close()
+ return nil
+}
+
+// LocalAddr implements net.Conn.LocalAddr.
+func (c *Conn) LocalAddr() net.Addr {
+ a, err := c.ep.GetLocalAddress()
+ if err != nil {
+ return nil
+ }
+ return fullToTCPAddr(a)
+}
+
+// RemoteAddr implements net.Conn.RemoteAddr.
+func (c *Conn) RemoteAddr() net.Addr {
+ a, err := c.ep.GetRemoteAddress()
+ if err != nil {
+ return nil
+ }
+ return fullToTCPAddr(a)
+}
+
+func (c *Conn) newOpError(op string, err error) *net.OpError {
+ return &net.OpError{
+ Op: op,
+ Net: "tcp",
+ Source: c.LocalAddr(),
+ Addr: c.RemoteAddr(),
+ Err: err,
+ }
+}
+
+func fullToTCPAddr(addr tcpip.FullAddress) *net.TCPAddr {
+ return &net.TCPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
+}
+
+func fullToUDPAddr(addr tcpip.FullAddress) *net.UDPAddr {
+ return &net.UDPAddr{IP: net.IP(addr.Addr), Port: int(addr.Port)}
+}
+
+// DialTCP creates a new TCP Conn connected to the specified address.
+func DialTCP(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*Conn, error) {
+ // Create TCP endpoint, then connect.
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, network, &wq)
+ if err != nil {
+ return nil, errors.New(err.String())
+ }
+
+ // Create wait queue entry that notifies a channel.
+ //
+ // We do this unconditionally as Connect will always return an error.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer wq.EventUnregister(&waitEntry)
+
+ err = ep.Connect(addr)
+ if err == tcpip.ErrConnectStarted {
+ <-notifyCh
+ err = ep.GetSockOpt(tcpip.ErrorOption{})
+ }
+ if err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "connect",
+ Net: "tcp",
+ Addr: fullToTCPAddr(addr),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ return NewConn(&wq, ep), nil
+}
+
+// A PacketConn is a wrapper around a tcpip endpoint that implements
+// net.PacketConn.
+type PacketConn struct {
+ deadlineTimer
+
+ stack *stack.Stack
+ ep tcpip.Endpoint
+ wq *waiter.Queue
+}
+
+// NewPacketConn creates a new PacketConn.
+func NewPacketConn(s *stack.Stack, addr tcpip.FullAddress, network tcpip.NetworkProtocolNumber) (*PacketConn, error) {
+ // Create UDP endpoint and bind it.
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(udp.ProtocolNumber, network, &wq)
+ if err != nil {
+ return nil, errors.New(err.String())
+ }
+
+ if err := ep.Bind(addr, nil); err != nil {
+ ep.Close()
+ return nil, &net.OpError{
+ Op: "bind",
+ Net: "udp",
+ Addr: fullToUDPAddr(addr),
+ Err: errors.New(err.String()),
+ }
+ }
+
+ c := &PacketConn{
+ stack: s,
+ ep: ep,
+ wq: &wq,
+ }
+ c.deadlineTimer.init()
+ return c, nil
+}
+
+func (c *PacketConn) newOpError(op string, err error) *net.OpError {
+ return c.newRemoteOpError(op, nil, err)
+}
+
+func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *net.OpError {
+ return &net.OpError{
+ Op: op,
+ Net: "udp",
+ Source: c.LocalAddr(),
+ Addr: remote,
+ Err: err,
+ }
+}
+
+// ReadFrom implements net.PacketConn.ReadFrom.
+func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
+ deadline := c.readCancel()
+
+ // Check if deadline has already expired.
+ select {
+ case <-deadline:
+ return 0, nil, c.newOpError("read", &timeoutError{})
+ default:
+ }
+
+ var addr tcpip.FullAddress
+ read, err := commonRead(c.ep, c.wq, deadline, &addr, c)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ return copy(b, read), fullToUDPAddr(addr), nil
+}
+
+// WriteTo implements net.PacketConn.WriteTo.
+func (c *PacketConn) WriteTo(b []byte, addr net.Addr) (int, error) {
+ deadline := c.writeCancel()
+
+ // Check if deadline has already expired.
+ select {
+ case <-deadline:
+ return 0, c.newRemoteOpError("write", addr, &timeoutError{})
+ default:
+ }
+
+ ua := addr.(*net.UDPAddr)
+ fullAddr := tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)}
+
+ v := buffer.NewView(len(b))
+ copy(v, b)
+
+ wopts := tcpip.WriteOptions{To: &fullAddr}
+ n, err := c.ep.Write(tcpip.SlicePayload(v), wopts)
+
+ if err == tcpip.ErrWouldBlock {
+ // Create wait queue entry that notifies a channel.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.wq.EventUnregister(&waitEntry)
+ for {
+ n, err = c.ep.Write(tcpip.SlicePayload(v), wopts)
+ if err != tcpip.ErrWouldBlock {
+ break
+ }
+ select {
+ case <-deadline:
+ return int(n), c.newRemoteOpError("write", addr, &timeoutError{})
+ case <-notifyCh:
+ }
+ }
+ }
+
+ if err == nil {
+ return int(n), nil
+ }
+
+ return int(n), c.newRemoteOpError("write", addr, errors.New(err.String()))
+}
+
+// Close implements net.PacketConn.Close.
+func (c *PacketConn) Close() error {
+ c.ep.Close()
+ return nil
+}
+
+// LocalAddr implements net.PacketConn.LocalAddr.
+func (c *PacketConn) LocalAddr() net.Addr {
+ a, err := c.ep.GetLocalAddress()
+ if err != nil {
+ return nil
+ }
+ return fullToUDPAddr(a)
+}
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
new file mode 100644
index 000000000..2f86469eb
--- /dev/null
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -0,0 +1,424 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package gonet
+
+import (
+ "fmt"
+ "net"
+ "reflect"
+ "strings"
+ "testing"
+ "time"
+
+ "golang.org/x/net/nettest"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ NICID = 1
+)
+
+func TestTimeouts(t *testing.T) {
+ nc := NewConn(nil, nil)
+ dlfs := []struct {
+ name string
+ f func(time.Time) error
+ }{
+ {"SetDeadline", nc.SetDeadline},
+ {"SetReadDeadline", nc.SetReadDeadline},
+ {"SetWriteDeadline", nc.SetWriteDeadline},
+ }
+
+ for _, dlf := range dlfs {
+ if err := dlf.f(time.Time{}); err != nil {
+ t.Errorf("got %s(time.Time{}) = %v, want = %v", dlf.name, err, nil)
+ }
+ }
+}
+
+func newLoopbackStack() (*stack.Stack, *tcpip.Error) {
+ // Create the stack and add a NIC.
+ s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName, udp.ProtocolName})
+
+ if err := s.CreateNIC(NICID, loopback.New()); err != nil {
+ return nil, err
+ }
+
+ // Add default route.
+ s.SetRouteTable([]tcpip.Route{
+ // IPv4
+ {
+ Destination: tcpip.Address(strings.Repeat("\x00", 4)),
+ Mask: tcpip.Address(strings.Repeat("\x00", 4)),
+ Gateway: "",
+ NIC: NICID,
+ },
+
+ // IPv6
+ {
+ Destination: tcpip.Address(strings.Repeat("\x00", 16)),
+ Mask: tcpip.Address(strings.Repeat("\x00", 16)),
+ Gateway: "",
+ NIC: NICID,
+ },
+ })
+
+ return s, nil
+}
+
+type testConnection struct {
+ wq *waiter.Queue
+ e *waiter.Entry
+ ch chan struct{}
+ ep tcpip.Endpoint
+}
+
+func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) {
+ wq := &waiter.Queue{}
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+
+ entry, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&entry, waiter.EventOut)
+
+ err = ep.Connect(addr)
+ if err == tcpip.ErrConnectStarted {
+ <-ch
+ err = ep.GetSockOpt(tcpip.ErrorOption{})
+ }
+ if err != nil {
+ return nil, err
+ }
+
+ wq.EventUnregister(&entry)
+ wq.EventRegister(&entry, waiter.EventIn)
+
+ return &testConnection{wq, &entry, ch, ep}, nil
+}
+
+func (c *testConnection) close() {
+ c.wq.EventUnregister(c.e)
+ c.ep.Close()
+}
+
+// TestCloseReader tests that Conn.Close() causes Conn.Read() to unblock.
+func TestCloseReader(t *testing.T) {
+ s, err := newLoopbackStack()
+ if err != nil {
+ t.Fatalf("newLoopbackStack() = %v", err)
+ }
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ if e != nil {
+ t.Fatalf("NewListener() = %v", e)
+ }
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ c, err := l.Accept()
+ if err != nil {
+ t.Fatalf("l.Accept() = %v", err)
+ }
+
+ // Give c.Read() a chance to block before closing the connection.
+ time.AfterFunc(time.Millisecond*50, func() {
+ c.Close()
+ })
+
+ buf := make([]byte, 256)
+ n, err := c.Read(buf)
+ got, ok := err.(*net.OpError)
+ want := tcpip.ErrConnectionAborted
+ if n != 0 || !ok || got.Err.Error() != want.String() {
+ t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, err, want)
+ }
+ }()
+ sender, err := connect(s, addr)
+ if err != nil {
+ t.Fatalf("connect() = %v", err)
+ }
+
+ select {
+ case <-done:
+ case <-time.After(5 * time.Second):
+ t.Errorf("c.Read() didn't unblock")
+ }
+ sender.close()
+}
+
+// TestCloseReaderWithForwarder tests that Conn.Close() wakes Conn.Read() when
+// using tcp.Forwarder.
+func TestCloseReaderWithForwarder(t *testing.T) {
+ s, err := newLoopbackStack()
+ if err != nil {
+ t.Fatalf("newLoopbackStack() = %v", err)
+ }
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ done := make(chan struct{})
+
+ fwd := tcp.NewForwarder(s, 30000, 10, func(r *tcp.ForwarderRequest) {
+ defer close(done)
+
+ var wq waiter.Queue
+ ep, err := r.CreateEndpoint(&wq)
+ if err != nil {
+ t.Fatalf("r.CreateEndpoint() = %v", err)
+ }
+ defer ep.Close()
+ r.Complete(false)
+
+ c := NewConn(&wq, ep)
+
+ // Give c.Read() a chance to block before closing the connection.
+ time.AfterFunc(time.Millisecond*50, func() {
+ c.Close()
+ })
+
+ buf := make([]byte, 256)
+ n, e := c.Read(buf)
+ got, ok := e.(*net.OpError)
+ want := tcpip.ErrConnectionAborted
+ if n != 0 || !ok || got.Err.Error() != want.String() {
+ t.Errorf("c.Read() = (%d, %v), want (0, OpError(%v))", n, e, want)
+ }
+ })
+ s.SetTransportProtocolHandler(tcp.ProtocolNumber, fwd.HandlePacket)
+
+ sender, err := connect(s, addr)
+ if err != nil {
+ t.Fatalf("connect() = %v", err)
+ }
+
+ select {
+ case <-done:
+ case <-time.After(5 * time.Second):
+ t.Errorf("c.Read() didn't unblock")
+ }
+ sender.close()
+}
+
+// TestDeadlineChange tests that changing the deadline affects currently blocked reads.
+func TestDeadlineChange(t *testing.T) {
+ s, err := newLoopbackStack()
+ if err != nil {
+ t.Fatalf("newLoopbackStack() = %v", err)
+ }
+
+ addr := tcpip.FullAddress{NICID, tcpip.Address(net.IPv4(169, 254, 10, 1).To4()), 11211}
+
+ s.AddAddress(NICID, ipv4.ProtocolNumber, addr.Addr)
+
+ l, e := NewListener(s, addr, ipv4.ProtocolNumber)
+ if e != nil {
+ t.Fatalf("NewListener() = %v", e)
+ }
+ done := make(chan struct{})
+ go func() {
+ defer close(done)
+ c, err := l.Accept()
+ if err != nil {
+ t.Fatalf("l.Accept() = %v", err)
+ }
+
+ c.SetDeadline(time.Now().Add(time.Minute))
+ // Give c.Read() a chance to block before closing the connection.
+ time.AfterFunc(time.Millisecond*50, func() {
+ c.SetDeadline(time.Now().Add(time.Millisecond * 10))
+ })
+
+ buf := make([]byte, 256)
+ n, err := c.Read(buf)
+ got, ok := err.(*net.OpError)
+ want := "i/o timeout"
+ if n != 0 || !ok || got.Err == nil || got.Err.Error() != want {
+ t.Errorf("c.Read() = (%d, %v), want (0, OpError(%s))", n, err, want)
+ }
+ }()
+ sender, err := connect(s, addr)
+ if err != nil {
+ t.Fatalf("connect() = %v", err)
+ }
+
+ select {
+ case <-done:
+ case <-time.After(time.Millisecond * 500):
+ t.Errorf("c.Read() didn't unblock")
+ }
+ sender.close()
+}
+
+func TestPacketConnTransfer(t *testing.T) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ t.Fatalf("newLoopbackStack() = %v", e)
+ }
+
+ ip1 := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr1 := tcpip.FullAddress{NICID, ip1, 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip1)
+ ip2 := tcpip.Address(net.IPv4(169, 254, 10, 2).To4())
+ addr2 := tcpip.FullAddress{NICID, ip2, 11311}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip2)
+
+ c1, err := NewPacketConn(s, addr1, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("NewPacketConn(port 4):", err)
+ }
+ c2, err := NewPacketConn(s, addr2, ipv4.ProtocolNumber)
+ if err != nil {
+ t.Fatal("NewPacketConn(port 5):", err)
+ }
+
+ c1.SetDeadline(time.Now().Add(time.Second))
+ c2.SetDeadline(time.Now().Add(time.Second))
+
+ sent := "abc123"
+ sendAddr := fullToUDPAddr(addr2)
+ if n, err := c1.WriteTo([]byte(sent), sendAddr); err != nil || n != len(sent) {
+ t.Errorf("got c1.WriteTo(%q, %v) = %d, %v, want = %d, %v", sent, sendAddr, n, err, len(sent), nil)
+ }
+ recv := make([]byte, len(sent))
+ n, recvAddr, err := c2.ReadFrom(recv)
+ if err != nil || n != len(recv) {
+ t.Errorf("got c2.ReadFrom() = %d, %v, want = %d, %v", n, err, len(recv), nil)
+ }
+
+ if recv := string(recv); recv != sent {
+ t.Errorf("got recv = %q, want = %q", recv, sent)
+ }
+
+ if want := fullToUDPAddr(addr1); !reflect.DeepEqual(recvAddr, want) {
+ t.Errorf("got recvAddr = %v, want = %v", recvAddr, want)
+ }
+
+ if err := c1.Close(); err != nil {
+ t.Error("c1.Close():", err)
+ }
+ if err := c2.Close(); err != nil {
+ t.Error("c2.Close():", err)
+ }
+}
+
+func makePipe() (c1, c2 net.Conn, stop func(), err error) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ return nil, nil, nil, fmt.Errorf("newLoopbackStack() = %v", e)
+ }
+
+ ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr := tcpip.FullAddress{NICID, ip, 11211}
+ s.AddAddress(NICID, ipv4.ProtocolNumber, ip)
+
+ l, err := NewListener(s, addr, ipv4.ProtocolNumber)
+ if err != nil {
+ return nil, nil, nil, fmt.Errorf("NewListener: %v", err)
+ }
+
+ c1, err = DialTCP(s, addr, ipv4.ProtocolNumber)
+ if err != nil {
+ l.Close()
+ return nil, nil, nil, fmt.Errorf("DialTCP: %v", err)
+ }
+
+ c2, err = l.Accept()
+ if err != nil {
+ l.Close()
+ c1.Close()
+ return nil, nil, nil, fmt.Errorf("l.Accept: %v", err)
+ }
+
+ stop = func() {
+ c1.Close()
+ c2.Close()
+ }
+
+ if err := l.Close(); err != nil {
+ stop()
+ return nil, nil, nil, fmt.Errorf("l.Close(): %v", err)
+ }
+
+ return c1, c2, stop, nil
+}
+
+func TestTCPConnTransfer(t *testing.T) {
+ c1, c2, _, err := makePipe()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer func() {
+ if err := c1.Close(); err != nil {
+ t.Error("c1.Close():", err)
+ }
+ if err := c2.Close(); err != nil {
+ t.Error("c2.Close():", err)
+ }
+ }()
+
+ c1.SetDeadline(time.Now().Add(time.Second))
+ c2.SetDeadline(time.Now().Add(time.Second))
+
+ const sent = "abc123"
+
+ tests := []struct {
+ name string
+ c1 net.Conn
+ c2 net.Conn
+ }{
+ {"connected to accepted", c1, c2},
+ {"accepted to connected", c2, c1},
+ }
+
+ for _, test := range tests {
+ if n, err := test.c1.Write([]byte(sent)); err != nil || n != len(sent) {
+ t.Errorf("%s: got test.c1.Write(%q) = %d, %v, want = %d, %v", test.name, sent, n, err, len(sent), nil)
+ continue
+ }
+
+ recv := make([]byte, len(sent))
+ n, err := test.c2.Read(recv)
+ if err != nil || n != len(recv) {
+ t.Errorf("%s: got test.c2.Read() = %d, %v, want = %d, %v", test.name, n, err, len(recv), nil)
+ continue
+ }
+
+ if recv := string(recv); recv != sent {
+ t.Errorf("%s: got recv = %q, want = %q", test.name, recv, sent)
+ }
+ }
+}
+
+func TestTCPDialError(t *testing.T) {
+ s, e := newLoopbackStack()
+ if e != nil {
+ t.Fatalf("newLoopbackStack() = %v", e)
+ }
+
+ ip := tcpip.Address(net.IPv4(169, 254, 10, 1).To4())
+ addr := tcpip.FullAddress{NICID, ip, 11211}
+
+ _, err := DialTCP(s, addr, ipv4.ProtocolNumber)
+ got, ok := err.(*net.OpError)
+ want := tcpip.ErrNoRoute
+ if !ok || got.Err.Error() != want.String() {
+ t.Errorf("Got DialTCP() = %v, want = %v", err, tcpip.ErrNoRoute)
+ }
+}
+
+func TestNetTest(t *testing.T) {
+ nettest.TestConn(t, makePipe)
+}
diff --git a/pkg/tcpip/buffer/BUILD b/pkg/tcpip/buffer/BUILD
new file mode 100644
index 000000000..055e4b953
--- /dev/null
+++ b/pkg/tcpip/buffer/BUILD
@@ -0,0 +1,32 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "buffer_state",
+ srcs = [
+ "view.go",
+ ],
+ out = "buffer_state.go",
+ package = "buffer",
+)
+
+go_library(
+ name = "buffer",
+ srcs = [
+ "buffer_state.go",
+ "prependable.go",
+ "view.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer",
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/state"],
+)
+
+go_test(
+ name = "buffer_test",
+ size = "small",
+ srcs = ["view_test.go"],
+ embed = [":buffer"],
+)
diff --git a/pkg/tcpip/buffer/prependable.go b/pkg/tcpip/buffer/prependable.go
new file mode 100644
index 000000000..fd84585f9
--- /dev/null
+++ b/pkg/tcpip/buffer/prependable.go
@@ -0,0 +1,53 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package buffer
+
+// Prependable is a buffer that grows backwards, that is, more data can be
+// prepended to it. It is useful when building networking packets, where each
+// protocol adds its own headers to the front of the higher-level protocol
+// header and payload; for example, TCP would prepend its header to the payload,
+// then IP would prepend its own, then ethernet.
+type Prependable struct {
+ // Buf is the buffer backing the prependable buffer.
+ buf View
+
+ // usedIdx is the index where the used part of the buffer begins.
+ usedIdx int
+}
+
+// NewPrependable allocates a new prependable buffer with the given size.
+func NewPrependable(size int) Prependable {
+ return Prependable{buf: NewView(size), usedIdx: size}
+}
+
+// Prepend reserves the requested space in front of the buffer, returning a
+// slice that represents the reserved space.
+func (p *Prependable) Prepend(size int) []byte {
+ if size > p.usedIdx {
+ return nil
+ }
+
+ p.usedIdx -= size
+ return p.buf[p.usedIdx:][:size:size]
+}
+
+// View returns a View of the backing buffer that contains all prepended
+// data so far.
+func (p *Prependable) View() View {
+ v := p.buf
+ v.TrimFront(p.usedIdx)
+ return v
+}
+
+// UsedBytes returns a slice of the backing buffer that contains all prepended
+// data so far.
+func (p *Prependable) UsedBytes() []byte {
+ return p.buf[p.usedIdx:]
+}
+
+// UsedLength returns the number of bytes used so far.
+func (p *Prependable) UsedLength() int {
+ return len(p.buf) - p.usedIdx
+}
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
new file mode 100644
index 000000000..241ccc7a8
--- /dev/null
+++ b/pkg/tcpip/buffer/view.go
@@ -0,0 +1,181 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package buffer provides the implementation of a buffer view.
+package buffer
+
+// View is a slice of a buffer, with convenience methods.
+type View []byte
+
+// NewView allocates a new buffer and returns an initialized view that covers
+// the whole buffer.
+func NewView(size int) View {
+ return make(View, size)
+}
+
+// NewViewFromBytes allocates a new buffer and copies in the given bytes.
+func NewViewFromBytes(b []byte) View {
+ return append(View(nil), b...)
+}
+
+// TrimFront removes the first "count" bytes from the visible section of the
+// buffer.
+func (v *View) TrimFront(count int) {
+ *v = (*v)[count:]
+}
+
+// CapLength irreversibly reduces the length of the visible section of the
+// buffer to the value specified.
+func (v *View) CapLength(length int) {
+ // We also set the slice cap because if we don't, one would be able to
+ // expand the view back to include the region just excluded. We want to
+ // prevent that to avoid potential data leak if we have uninitialized
+ // data in excluded region.
+ *v = (*v)[:length:length]
+}
+
+// ToVectorisedView transforms a View in a VectorisedView from an
+// already-allocated slice of View.
+func (v *View) ToVectorisedView(views [1]View) VectorisedView {
+ views[0] = *v
+ return NewVectorisedView(len(*v), views[:])
+}
+
+// VectorisedView is a vectorised version of View using non contigous memory.
+// It supports all the convenience methods supported by View.
+type VectorisedView struct {
+ views []View
+ size int
+}
+
+// NewVectorisedView creates a new vectorised view from an already-allocated slice
+// of View and sets its size.
+func NewVectorisedView(size int, views []View) VectorisedView {
+ return VectorisedView{views: views, size: size}
+}
+
+// TrimFront removes the first "count" bytes of the vectorised view.
+func (vv *VectorisedView) TrimFront(count int) {
+ for count > 0 && len(vv.views) > 0 {
+ if count < len(vv.views[0]) {
+ vv.size -= count
+ vv.views[0].TrimFront(count)
+ return
+ }
+ count -= len(vv.views[0])
+ vv.RemoveFirst()
+ }
+}
+
+// CapLength irreversibly reduces the length of the vectorised view.
+func (vv *VectorisedView) CapLength(length int) {
+ if length < 0 {
+ length = 0
+ }
+ if vv.size < length {
+ return
+ }
+ vv.size = length
+ for i := range vv.views {
+ v := &vv.views[i]
+ if len(*v) >= length {
+ if length == 0 {
+ vv.views = vv.views[:i]
+ } else {
+ v.CapLength(length)
+ vv.views = vv.views[:i+1]
+ }
+ return
+ }
+ length -= len(*v)
+ }
+}
+
+// Clone returns a clone of this VectorisedView.
+// If the buffer argument is large enough to contain all the Views of this VectorisedView,
+// the method will avoid allocations and use the buffer to store the Views of the clone.
+func (vv *VectorisedView) Clone(buffer []View) VectorisedView {
+ var views []View
+ if len(buffer) >= len(vv.views) {
+ views = buffer[:len(vv.views)]
+ } else {
+ views = make([]View, len(vv.views))
+ }
+ for i, v := range vv.views {
+ views[i] = v
+ }
+ return VectorisedView{views: views, size: vv.size}
+}
+
+// First returns the first view of the vectorised view.
+// It panics if the vectorised view is empty.
+func (vv *VectorisedView) First() View {
+ if len(vv.views) == 0 {
+ return nil
+ }
+ return vv.views[0]
+}
+
+// RemoveFirst removes the first view of the vectorised view.
+func (vv *VectorisedView) RemoveFirst() {
+ if len(vv.views) == 0 {
+ return
+ }
+ vv.size -= len(vv.views[0])
+ vv.views = vv.views[1:]
+}
+
+// SetSize unsafely sets the size of the VectorisedView.
+func (vv *VectorisedView) SetSize(size int) {
+ vv.size = size
+}
+
+// SetViews unsafely sets the views of the VectorisedView.
+func (vv *VectorisedView) SetViews(views []View) {
+ vv.views = views
+}
+
+// Size returns the size in bytes of the entire content stored in the vectorised view.
+func (vv *VectorisedView) Size() int {
+ return vv.size
+}
+
+// ToView returns the a single view containing the content of the vectorised view.
+func (vv *VectorisedView) ToView() View {
+ v := make([]byte, vv.size)
+ u := v
+ for i := range vv.views {
+ n := copy(u, vv.views[i])
+ u = u[n:]
+ }
+ return v
+}
+
+// Views returns the slice containing the all views.
+func (vv *VectorisedView) Views() []View {
+ return vv.views
+}
+
+// ByteSlice returns a slice containing the all views as a []byte.
+func (vv *VectorisedView) ByteSlice() [][]byte {
+ s := make([][]byte, len(vv.views))
+ for i := range vv.views {
+ s[i] = []byte(vv.views[i])
+ }
+ return s
+}
+
+// copy returns a deep-copy of the vectorised view.
+// It is an expensive method that should be used only in tests.
+func (vv *VectorisedView) copy() *VectorisedView {
+ uu := &VectorisedView{
+ views: make([]View, len(vv.views)),
+ size: vv.size,
+ }
+ for i, v := range vv.views {
+ uu.views[i] = make(View, len(v))
+ copy(uu.views[i], v)
+ }
+ return uu
+}
diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go
new file mode 100644
index 000000000..ff8535ba5
--- /dev/null
+++ b/pkg/tcpip/buffer/view_test.go
@@ -0,0 +1,212 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package buffer_test contains tests for the VectorisedView type.
+package buffer
+
+import (
+ "reflect"
+ "testing"
+)
+
+// vv is an helper to build VectorisedView from different strings.
+func vv(size int, pieces ...string) *VectorisedView {
+ views := make([]View, len(pieces))
+ for i, p := range pieces {
+ views[i] = []byte(p)
+ }
+
+ vv := NewVectorisedView(size, views)
+ return &vv
+}
+
+var capLengthTestCases = []struct {
+ comment string
+ in *VectorisedView
+ length int
+ want *VectorisedView
+}{
+ {
+ comment: "Simple case",
+ in: vv(2, "12"),
+ length: 1,
+ want: vv(1, "1"),
+ },
+ {
+ comment: "Case spanning across two Views",
+ in: vv(4, "123", "4"),
+ length: 2,
+ want: vv(2, "12"),
+ },
+ {
+ comment: "Corner case with negative length",
+ in: vv(1, "1"),
+ length: -1,
+ want: vv(0),
+ },
+ {
+ comment: "Corner case with length = 0",
+ in: vv(3, "12", "3"),
+ length: 0,
+ want: vv(0),
+ },
+ {
+ comment: "Corner case with length = size",
+ in: vv(1, "1"),
+ length: 1,
+ want: vv(1, "1"),
+ },
+ {
+ comment: "Corner case with length > size",
+ in: vv(1, "1"),
+ length: 2,
+ want: vv(1, "1"),
+ },
+}
+
+func TestCapLength(t *testing.T) {
+ for _, c := range capLengthTestCases {
+ orig := c.in.copy()
+ c.in.CapLength(c.length)
+ if !reflect.DeepEqual(c.in, c.want) {
+ t.Errorf("Test \"%s\" failed when calling CapLength(%d) on %v. Got %v. Want %v",
+ c.comment, c.length, orig, c.in, c.want)
+ }
+ }
+}
+
+var trimFrontTestCases = []struct {
+ comment string
+ in *VectorisedView
+ count int
+ want *VectorisedView
+}{
+ {
+ comment: "Simple case",
+ in: vv(2, "12"),
+ count: 1,
+ want: vv(1, "2"),
+ },
+ {
+ comment: "Case where we trim an entire View",
+ in: vv(2, "1", "2"),
+ count: 1,
+ want: vv(1, "2"),
+ },
+ {
+ comment: "Case spanning across two Views",
+ in: vv(3, "1", "23"),
+ count: 2,
+ want: vv(1, "3"),
+ },
+ {
+ comment: "Corner case with negative count",
+ in: vv(1, "1"),
+ count: -1,
+ want: vv(1, "1"),
+ },
+ {
+ comment: " Corner case with count = 0",
+ in: vv(1, "1"),
+ count: 0,
+ want: vv(1, "1"),
+ },
+ {
+ comment: "Corner case with count = size",
+ in: vv(1, "1"),
+ count: 1,
+ want: vv(0),
+ },
+ {
+ comment: "Corner case with count > size",
+ in: vv(1, "1"),
+ count: 2,
+ want: vv(0),
+ },
+}
+
+func TestTrimFront(t *testing.T) {
+ for _, c := range trimFrontTestCases {
+ orig := c.in.copy()
+ c.in.TrimFront(c.count)
+ if !reflect.DeepEqual(c.in, c.want) {
+ t.Errorf("Test \"%s\" failed when calling TrimFront(%d) on %v. Got %v. Want %v",
+ c.comment, c.count, orig, c.in, c.want)
+ }
+ }
+}
+
+var toViewCases = []struct {
+ comment string
+ in *VectorisedView
+ want View
+}{
+ {
+ comment: "Simple case",
+ in: vv(2, "12"),
+ want: []byte("12"),
+ },
+ {
+ comment: "Case with multiple views",
+ in: vv(2, "1", "2"),
+ want: []byte("12"),
+ },
+ {
+ comment: "Empty case",
+ in: vv(0),
+ want: []byte(""),
+ },
+}
+
+func TestToView(t *testing.T) {
+ for _, c := range toViewCases {
+ got := c.in.ToView()
+ if !reflect.DeepEqual(got, c.want) {
+ t.Errorf("Test \"%s\" failed when calling ToView() on %v. Got %v. Want %v",
+ c.comment, c.in, got, c.want)
+ }
+ }
+}
+
+var toCloneCases = []struct {
+ comment string
+ inView *VectorisedView
+ inBuffer []View
+}{
+ {
+ comment: "Simple case",
+ inView: vv(1, "1"),
+ inBuffer: make([]View, 1),
+ },
+ {
+ comment: "Case with multiple views",
+ inView: vv(2, "1", "2"),
+ inBuffer: make([]View, 2),
+ },
+ {
+ comment: "Case with buffer too small",
+ inView: vv(2, "1", "2"),
+ inBuffer: make([]View, 1),
+ },
+ {
+ comment: "Case with buffer larger than needed",
+ inView: vv(1, "1"),
+ inBuffer: make([]View, 2),
+ },
+ {
+ comment: "Case with nil buffer",
+ inView: vv(1, "1"),
+ inBuffer: nil,
+ },
+}
+
+func TestToClone(t *testing.T) {
+ for _, c := range toCloneCases {
+ got := c.inView.Clone(c.inBuffer)
+ if !reflect.DeepEqual(&got, c.inView) {
+ t.Errorf("Test \"%s\" failed when calling Clone(%v) on %v. Got %v. Want %v",
+ c.comment, c.inBuffer, c.inView, got, c.inView)
+ }
+ }
+}
diff --git a/pkg/tcpip/checker/BUILD b/pkg/tcpip/checker/BUILD
new file mode 100644
index 000000000..ac5203031
--- /dev/null
+++ b/pkg/tcpip/checker/BUILD
@@ -0,0 +1,16 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "checker",
+ testonly = 1,
+ srcs = ["checker.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/checker",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ ],
+)
diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go
new file mode 100644
index 000000000..209f9d60b
--- /dev/null
+++ b/pkg/tcpip/checker/checker.go
@@ -0,0 +1,517 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package checker provides helper functions to check networking packets for
+// validity.
+package checker
+
+import (
+ "encoding/binary"
+ "reflect"
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+// NetworkChecker is a function to check a property of a network packet.
+type NetworkChecker func(*testing.T, []header.Network)
+
+// TransportChecker is a function to check a property of a transport packet.
+type TransportChecker func(*testing.T, header.Transport)
+
+// IPv4 checks the validity and properties of the given IPv4 packet. It is
+// expected to be used in conjunction with other network checkers for specific
+// properties. For example, to check the source and destination address, one
+// would call:
+//
+// checker.IPv4(t, b, checker.SrcAddr(x), checker.DstAddr(y))
+func IPv4(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ ipv4 := header.IPv4(b)
+
+ if !ipv4.IsValid(len(b)) {
+ t.Fatalf("Not a valid IPv4 packet")
+ }
+
+ xsum := ipv4.CalculateChecksum()
+ if xsum != 0 && xsum != 0xffff {
+ t.Fatalf("Bad checksum: 0x%x, checksum in packet: 0x%x", xsum, ipv4.Checksum())
+ }
+
+ for _, f := range checkers {
+ f(t, []header.Network{ipv4})
+ }
+}
+
+// IPv6 checks the validity and properties of the given IPv6 packet. The usage
+// is similar to IPv4.
+func IPv6(t *testing.T, b []byte, checkers ...NetworkChecker) {
+ ipv6 := header.IPv6(b)
+ if !ipv6.IsValid(len(b)) {
+ t.Fatalf("Not a valid IPv6 packet")
+ }
+
+ for _, f := range checkers {
+ f(t, []header.Network{ipv6})
+ }
+}
+
+// SrcAddr creates a checker that checks the source address.
+func SrcAddr(addr tcpip.Address) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ if a := h[0].SourceAddress(); a != addr {
+ t.Fatalf("Bad source address, got %v, want %v", a, addr)
+ }
+ }
+}
+
+// DstAddr creates a checker that checks the destination address.
+func DstAddr(addr tcpip.Address) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ if a := h[0].DestinationAddress(); a != addr {
+ t.Fatalf("Bad destination address, got %v, want %v", a, addr)
+ }
+ }
+}
+
+// PayloadLen creates a checker that checks the payload length.
+func PayloadLen(plen int) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ if l := len(h[0].Payload()); l != plen {
+ t.Fatalf("Bad payload length, got %v, want %v", l, plen)
+ }
+ }
+}
+
+// FragmentOffset creates a checker that checks the FragmentOffset field.
+func FragmentOffset(offset uint16) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ // We only do this of IPv4 for now.
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ if v := ip.FragmentOffset(); v != offset {
+ t.Fatalf("Bad fragment offset, got %v, want %v", v, offset)
+ }
+ }
+ }
+}
+
+// FragmentFlags creates a checker that checks the fragment flags field.
+func FragmentFlags(flags uint8) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ // We only do this of IPv4 for now.
+ switch ip := h[0].(type) {
+ case header.IPv4:
+ if v := ip.Flags(); v != flags {
+ t.Fatalf("Bad fragment offset, got %v, want %v", v, flags)
+ }
+ }
+ }
+}
+
+// TOS creates a checker that checks the TOS field.
+func TOS(tos uint8, label uint32) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ if v, l := h[0].TOS(); v != tos || l != label {
+ t.Fatalf("Bad TOS, got (%v, %v), want (%v,%v)", v, l, tos, label)
+ }
+ }
+}
+
+// Raw creates a checker that checks the bytes of payload.
+// The checker always checks the payload of the last network header.
+// For instance, in case of IPv6 fragments, the payload that will be checked
+// is the one containing the actual data that the packet is carrying, without
+// the bytes added by the IPv6 fragmentation.
+func Raw(want []byte) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ if got := h[len(h)-1].Payload(); !reflect.DeepEqual(got, want) {
+ t.Fatalf("Wrong payload, got %v, want %v", got, want)
+ }
+ }
+}
+
+// IPv6Fragment creates a checker that validates an IPv6 fragment.
+func IPv6Fragment(checkers ...NetworkChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ if p := h[0].TransportProtocol(); p != header.IPv6FragmentHeader {
+ t.Fatalf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ }
+
+ ipv6Frag := header.IPv6Fragment(h[0].Payload())
+ if !ipv6Frag.IsValid() {
+ t.Fatalf("Not a valid IPv6 fragment")
+ }
+
+ for _, f := range checkers {
+ f(t, []header.Network{h[0], ipv6Frag})
+ }
+ }
+}
+
+// TCP creates a checker that checks that the transport protocol is TCP and
+// potentially additional transport header fields.
+func TCP(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ first := h[0]
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.TCPProtocolNumber {
+ t.Fatalf("Bad protocol, got %v, want %v", p, header.TCPProtocolNumber)
+ }
+
+ // Verify the checksum.
+ tcp := header.TCP(last.Payload())
+ l := uint16(len(tcp))
+
+ xsum := header.Checksum([]byte(first.SourceAddress()), 0)
+ xsum = header.Checksum([]byte(first.DestinationAddress()), xsum)
+ xsum = header.Checksum([]byte{0, byte(last.TransportProtocol())}, xsum)
+ xsum = header.Checksum([]byte{byte(l >> 8), byte(l)}, xsum)
+ xsum = header.Checksum(tcp, xsum)
+
+ if xsum != 0 && xsum != 0xffff {
+ t.Fatalf("Bad checksum: 0x%x, checksum in segment: 0x%x", xsum, tcp.Checksum())
+ }
+
+ // Run the transport checkers.
+ for _, f := range checkers {
+ f(t, tcp)
+ }
+ }
+}
+
+// UDP creates a checker that checks that the transport protocol is UDP and
+// potentially additional transport header fields.
+func UDP(checkers ...TransportChecker) NetworkChecker {
+ return func(t *testing.T, h []header.Network) {
+ last := h[len(h)-1]
+
+ if p := last.TransportProtocol(); p != header.UDPProtocolNumber {
+ t.Fatalf("Bad protocol, got %v, want %v", p, header.UDPProtocolNumber)
+ }
+
+ udp := header.UDP(last.Payload())
+ for _, f := range checkers {
+ f(t, udp)
+ }
+ }
+}
+
+// SrcPort creates a checker that checks the source port.
+func SrcPort(port uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ if p := h.SourcePort(); p != port {
+ t.Fatalf("Bad source port, got %v, want %v", p, port)
+ }
+ }
+}
+
+// DstPort creates a checker that checks the destination port.
+func DstPort(port uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ if p := h.DestinationPort(); p != port {
+ t.Fatalf("Bad destination port, got %v, want %v", p, port)
+ }
+ }
+}
+
+// SeqNum creates a checker that checks the sequence number.
+func SeqNum(seq uint32) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+
+ if s := tcp.SequenceNumber(); s != seq {
+ t.Fatalf("Bad sequence number, got %v, want %v", s, seq)
+ }
+ }
+}
+
+// AckNum creates a checker that checks the ack number.
+func AckNum(seq uint32) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+
+ if s := tcp.AckNumber(); s != seq {
+ t.Fatalf("Bad ack number, got %v, want %v", s, seq)
+ }
+ }
+}
+
+// Window creates a checker that checks the tcp window.
+func Window(window uint16) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+
+ if w := tcp.WindowSize(); w != window {
+ t.Fatalf("Bad window, got 0x%x, want 0x%x", w, window)
+ }
+ }
+}
+
+// TCPFlags creates a checker that checks the tcp flags.
+func TCPFlags(flags uint8) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+
+ if f := tcp.Flags(); f != flags {
+ t.Fatalf("Bad flags, got 0x%x, want 0x%x", f, flags)
+ }
+ }
+}
+
+// TCPFlagsMatch creates a checker that checks that the tcp flags, masked by the
+// given mask, match the supplied flags.
+func TCPFlagsMatch(flags, mask uint8) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+
+ if f := tcp.Flags(); (f & mask) != (flags & mask) {
+ t.Fatalf("Bad masked flags, got 0x%x, want 0x%x, mask 0x%x", f, flags, mask)
+ }
+ }
+}
+
+// TCPSynOptions creates a checker that checks the presence of TCP options in
+// SYN segments.
+//
+// If wndscale is negative, the window scale option must not be present.
+func TCPSynOptions(wantOpts header.TCPSynOptions) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+ opts := tcp.Options()
+ limit := len(opts)
+ foundMSS := false
+ foundWS := false
+ foundTS := false
+ foundSACKPermitted := false
+ tsVal := uint32(0)
+ tsEcr := uint32(0)
+ for i := 0; i < limit; {
+ switch opts[i] {
+ case header.TCPOptionEOL:
+ i = limit
+ case header.TCPOptionNOP:
+ i++
+ case header.TCPOptionMSS:
+ v := uint16(opts[i+2])<<8 | uint16(opts[i+3])
+ if wantOpts.MSS != v {
+ t.Fatalf("Bad MSS: got %v, want %v", v, wantOpts.MSS)
+ }
+ foundMSS = true
+ i += 4
+ case header.TCPOptionWS:
+ if wantOpts.WS < 0 {
+ t.Fatalf("WS present when it shouldn't be")
+ }
+ v := int(opts[i+2])
+ if v != wantOpts.WS {
+ t.Fatalf("Bad WS: got %v, want %v", v, wantOpts.WS)
+ }
+ foundWS = true
+ i += 3
+ case header.TCPOptionTS:
+ if i+9 >= limit {
+ t.Fatalf("TS Option truncated , option is only: %d bytes, want 10", limit-i)
+ }
+ if opts[i+1] != 10 {
+ t.Fatalf("Bad length %d for TS option, limit: %d", opts[i+1], limit)
+ }
+ tsVal = binary.BigEndian.Uint32(opts[i+2:])
+ tsEcr = uint32(0)
+ if tcp.Flags()&header.TCPFlagAck != 0 {
+ // If the syn is an SYN-ACK then read
+ // the tsEcr value as well.
+ tsEcr = binary.BigEndian.Uint32(opts[i+6:])
+ }
+ foundTS = true
+ i += 10
+ case header.TCPOptionSACKPermitted:
+ if i+1 >= limit {
+ t.Fatalf("SACKPermitted option truncated, option is only : %d bytes, want 2", limit-i)
+ }
+ if opts[i+1] != 2 {
+ t.Fatalf("Bad length %d for SACKPermitted option, limit: %d", opts[i+1], limit)
+ }
+ foundSACKPermitted = true
+ i += 2
+
+ default:
+ i += int(opts[i+1])
+ }
+ }
+
+ if !foundMSS {
+ t.Fatalf("MSS option not found. Options: %x", opts)
+ }
+
+ if !foundWS && wantOpts.WS >= 0 {
+ t.Fatalf("WS option not found. Options: %x", opts)
+ }
+ if wantOpts.TS && !foundTS {
+ t.Fatalf("TS option not found. Options: %x", opts)
+ }
+ if foundTS && tsVal == 0 {
+ t.Fatalf("TS option specified but the timestamp value is zero")
+ }
+ if foundTS && tsEcr == 0 && wantOpts.TSEcr != 0 {
+ t.Fatalf("TS option specified but TSEcr is incorrect: got %d, want: %d", tsEcr, wantOpts.TSEcr)
+ }
+ if wantOpts.SACKPermitted && !foundSACKPermitted {
+ t.Fatalf("SACKPermitted option not found. Options: %x", opts)
+ }
+ }
+}
+
+// TCPTimestampChecker creates a checker that validates that a TCP segment has a
+// TCP Timestamp option if wantTS is true, it also compares the wantTSVal and
+// wantTSEcr values with those in the TCP segment (if present).
+//
+// If wantTSVal or wantTSEcr is zero then the corresponding comparison is
+// skipped.
+func TCPTimestampChecker(wantTS bool, wantTSVal uint32, wantTSEcr uint32) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+ opts := []byte(tcp.Options())
+ limit := len(opts)
+ foundTS := false
+ tsVal := uint32(0)
+ tsEcr := uint32(0)
+ for i := 0; i < limit; {
+ switch opts[i] {
+ case header.TCPOptionEOL:
+ i = limit
+ case header.TCPOptionNOP:
+ i++
+ case header.TCPOptionTS:
+ if i+9 >= limit {
+ t.Fatalf("TS option found, but option is truncated, option length: %d, want 10 bytes", limit-i)
+ }
+ if opts[i+1] != 10 {
+ t.Fatalf("TS option found, but bad length specified: %d, want: 10", opts[i+1])
+ }
+ tsVal = binary.BigEndian.Uint32(opts[i+2:])
+ tsEcr = binary.BigEndian.Uint32(opts[i+6:])
+ foundTS = true
+ i += 10
+ default:
+ // We don't recognize this option, just skip over it.
+ if i+2 > limit {
+ return
+ }
+ l := int(opts[i+1])
+ if i < 2 || i+l > limit {
+ return
+ }
+ i += l
+ }
+ }
+
+ if wantTS != foundTS {
+ t.Fatalf("TS Option mismatch: got TS= %v, want TS= %v", foundTS, wantTS)
+ }
+ if wantTS && wantTSVal != 0 && wantTSVal != tsVal {
+ t.Fatalf("Timestamp value is incorrect: got: %d, want: %d", tsVal, wantTSVal)
+ }
+ if wantTS && wantTSEcr != 0 && tsEcr != wantTSEcr {
+ t.Fatalf("Timestamp Echo Reply is incorrect: got: %d, want: %d", tsEcr, wantTSEcr)
+ }
+ }
+}
+
+// TCPNoSACKBlockChecker creates a checker that verifies that the segment does not
+// contain any SACK blocks in the TCP options.
+func TCPNoSACKBlockChecker() TransportChecker {
+ return TCPSACKBlockChecker(nil)
+}
+
+// TCPSACKBlockChecker creates a checker that verifies that the segment does
+// contain the specified SACK blocks in the TCP options.
+func TCPSACKBlockChecker(sackBlocks []header.SACKBlock) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ t.Helper()
+ tcp, ok := h.(header.TCP)
+ if !ok {
+ return
+ }
+ var gotSACKBlocks []header.SACKBlock
+
+ opts := []byte(tcp.Options())
+ limit := len(opts)
+ for i := 0; i < limit; {
+ switch opts[i] {
+ case header.TCPOptionEOL:
+ i = limit
+ case header.TCPOptionNOP:
+ i++
+ case header.TCPOptionSACK:
+ if i+2 > limit {
+ // Malformed SACK block.
+ t.Fatalf("malformed SACK option in options: %v", opts)
+ }
+ sackOptionLen := int(opts[i+1])
+ if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
+ // Malformed SACK block.
+ t.Fatalf("malformed SACK option length in options: %v", opts)
+ }
+ numBlocks := sackOptionLen / 8
+ for j := 0; j < numBlocks; j++ {
+ start := binary.BigEndian.Uint32(opts[i+2+j*8:])
+ end := binary.BigEndian.Uint32(opts[i+2+j*8+4:])
+ gotSACKBlocks = append(gotSACKBlocks, header.SACKBlock{
+ Start: seqnum.Value(start),
+ End: seqnum.Value(end),
+ })
+ }
+ i += sackOptionLen
+ default:
+ // We don't recognize this option, just skip over it.
+ if i+2 > limit {
+ break
+ }
+ l := int(opts[i+1])
+ if l < 2 || i+l > limit {
+ break
+ }
+ i += l
+ }
+ }
+
+ if !reflect.DeepEqual(gotSACKBlocks, sackBlocks) {
+ t.Fatalf("SACKBlocks are not equal, got: %v, want: %v", gotSACKBlocks, sackBlocks)
+ }
+ }
+}
+
+// Payload creates a checker that checks the payload.
+func Payload(want []byte) TransportChecker {
+ return func(t *testing.T, h header.Transport) {
+ if got := h.Payload(); !reflect.DeepEqual(got, want) {
+ t.Fatalf("Wrong payload, got %v, want %v", got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/header/BUILD b/pkg/tcpip/header/BUILD
new file mode 100644
index 000000000..167ea250d
--- /dev/null
+++ b/pkg/tcpip/header/BUILD
@@ -0,0 +1,51 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "tcp_header_state",
+ srcs = [
+ "tcp.go",
+ ],
+ out = "tcp_header_state.go",
+ package = "header",
+)
+
+go_library(
+ name = "header",
+ srcs = [
+ "arp.go",
+ "checksum.go",
+ "eth.go",
+ "gue.go",
+ "icmpv4.go",
+ "icmpv6.go",
+ "interfaces.go",
+ "ipv4.go",
+ "ipv6.go",
+ "ipv6_fragment.go",
+ "tcp.go",
+ "tcp_header_state.go",
+ "udp.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/header",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/state",
+ "//pkg/tcpip",
+ "//pkg/tcpip/seqnum",
+ ],
+)
+
+go_test(
+ name = "header_test",
+ size = "small",
+ srcs = [
+ "ipversion_test.go",
+ "tcp_test.go",
+ ],
+ deps = [
+ ":header",
+ ],
+)
diff --git a/pkg/tcpip/header/arp.go b/pkg/tcpip/header/arp.go
new file mode 100644
index 000000000..af7f988f3
--- /dev/null
+++ b/pkg/tcpip/header/arp.go
@@ -0,0 +1,90 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package header
+
+import "gvisor.googlesource.com/gvisor/pkg/tcpip"
+
+const (
+ // ARPProtocolNumber is the ARP network protocol number.
+ ARPProtocolNumber tcpip.NetworkProtocolNumber = 0x0806
+
+ // ARPSize is the size of an IPv4-over-Ethernet ARP packet.
+ ARPSize = 2 + 2 + 1 + 1 + 2 + 2*6 + 2*4
+)
+
+// ARPOp is an ARP opcode.
+type ARPOp uint16
+
+// Typical ARP opcodes defined in RFC 826.
+const (
+ ARPRequest ARPOp = 1
+ ARPReply ARPOp = 2
+)
+
+// ARP is an ARP packet stored in a byte array as described in RFC 826.
+type ARP []byte
+
+func (a ARP) hardwareAddressSpace() uint16 { return uint16(a[0])<<8 | uint16(a[1]) }
+func (a ARP) protocolAddressSpace() uint16 { return uint16(a[2])<<8 | uint16(a[3]) }
+func (a ARP) hardwareAddressSize() int { return int(a[4]) }
+func (a ARP) protocolAddressSize() int { return int(a[5]) }
+
+// Op is the ARP opcode.
+func (a ARP) Op() ARPOp { return ARPOp(a[6])<<8 | ARPOp(a[7]) }
+
+// SetOp sets the ARP opcode.
+func (a ARP) SetOp(op ARPOp) {
+ a[6] = uint8(op >> 8)
+ a[7] = uint8(op)
+}
+
+// SetIPv4OverEthernet configures the ARP packet for IPv4-over-Ethernet.
+func (a ARP) SetIPv4OverEthernet() {
+ a[0], a[1] = 0, 1 // htypeEthernet
+ a[2], a[3] = 0x08, 0x00 // IPv4ProtocolNumber
+ a[4] = 6 // macSize
+ a[5] = uint8(IPv4AddressSize)
+}
+
+// HardwareAddressSender is the link address of the sender.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) HardwareAddressSender() []byte {
+ const s = 8
+ return a[s : s+6]
+}
+
+// ProtocolAddressSender is the protocol address of the sender.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) ProtocolAddressSender() []byte {
+ const s = 8 + 6
+ return a[s : s+4]
+}
+
+// HardwareAddressTarget is the link address of the target.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) HardwareAddressTarget() []byte {
+ const s = 8 + 6 + 4
+ return a[s : s+6]
+}
+
+// ProtocolAddressTarget is the protocol address of the target.
+// It is a view on to the ARP packet so it can be used to set the value.
+func (a ARP) ProtocolAddressTarget() []byte {
+ const s = 8 + 6 + 4 + 6
+ return a[s : s+4]
+}
+
+// IsValid reports whether this is an ARP packet for IPv4 over Ethernet.
+func (a ARP) IsValid() bool {
+ if len(a) < ARPSize {
+ return false
+ }
+ const htypeEthernet = 1
+ const macSize = 6
+ return a.hardwareAddressSpace() == htypeEthernet &&
+ a.protocolAddressSpace() == uint16(IPv4ProtocolNumber) &&
+ a.hardwareAddressSize() == macSize &&
+ a.protocolAddressSize() == IPv4AddressSize
+}
diff --git a/pkg/tcpip/header/checksum.go b/pkg/tcpip/header/checksum.go
new file mode 100644
index 000000000..6399b1b95
--- /dev/null
+++ b/pkg/tcpip/header/checksum.go
@@ -0,0 +1,46 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package header provides the implementation of the encoding and decoding of
+// network protocol headers.
+package header
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// Checksum calculates the checksum (as defined in RFC 1071) of the bytes in the
+// given byte array.
+func Checksum(buf []byte, initial uint16) uint16 {
+ v := uint32(initial)
+
+ l := len(buf)
+ if l&1 != 0 {
+ l--
+ v += uint32(buf[l]) << 8
+ }
+
+ for i := 0; i < l; i += 2 {
+ v += (uint32(buf[i]) << 8) + uint32(buf[i+1])
+ }
+
+ return ChecksumCombine(uint16(v), uint16(v>>16))
+}
+
+// ChecksumCombine combines the two uint16 to form their checksum. This is done
+// by adding them and the carry.
+func ChecksumCombine(a, b uint16) uint16 {
+ v := uint32(a) + uint32(b)
+ return uint16(v + v>>16)
+}
+
+// PseudoHeaderChecksum calculates the pseudo-header checksum for the
+// given destination protocol and network address, ignoring the length
+// field. Pseudo-headers are needed by transport layers when calculating
+// their own checksum.
+func PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, srcAddr tcpip.Address, dstAddr tcpip.Address) uint16 {
+ xsum := Checksum([]byte(srcAddr), 0)
+ xsum = Checksum([]byte(dstAddr), xsum)
+ return Checksum([]byte{0, uint8(protocol)}, xsum)
+}
diff --git a/pkg/tcpip/header/eth.go b/pkg/tcpip/header/eth.go
new file mode 100644
index 000000000..23b7efdfc
--- /dev/null
+++ b/pkg/tcpip/header/eth.go
@@ -0,0 +1,64 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ dstMAC = 0
+ srcMAC = 6
+ ethType = 12
+)
+
+// EthernetFields contains the fields of an ethernet frame header. It is used to
+// describe the fields of a frame that needs to be encoded.
+type EthernetFields struct {
+ // SrcAddr is the "MAC source" field of an ethernet frame header.
+ SrcAddr tcpip.LinkAddress
+
+ // DstAddr is the "MAC destination" field of an ethernet frame header.
+ DstAddr tcpip.LinkAddress
+
+ // Type is the "ethertype" field of an ethernet frame header.
+ Type tcpip.NetworkProtocolNumber
+}
+
+// Ethernet represents an ethernet frame header stored in a byte array.
+type Ethernet []byte
+
+const (
+ // EthernetMinimumSize is the minimum size of a valid ethernet frame.
+ EthernetMinimumSize = 14
+
+ // EthernetAddressSize is the size, in bytes, of an ethernet address.
+ EthernetAddressSize = 6
+)
+
+// SourceAddress returns the "MAC source" field of the ethernet frame header.
+func (b Ethernet) SourceAddress() tcpip.LinkAddress {
+ return tcpip.LinkAddress(b[srcMAC:][:EthernetAddressSize])
+}
+
+// DestinationAddress returns the "MAC destination" field of the ethernet frame
+// header.
+func (b Ethernet) DestinationAddress() tcpip.LinkAddress {
+ return tcpip.LinkAddress(b[dstMAC:][:EthernetAddressSize])
+}
+
+// Type returns the "ethertype" field of the ethernet frame header.
+func (b Ethernet) Type() tcpip.NetworkProtocolNumber {
+ return tcpip.NetworkProtocolNumber(binary.BigEndian.Uint16(b[ethType:]))
+}
+
+// Encode encodes all the fields of the ethernet frame header.
+func (b Ethernet) Encode(e *EthernetFields) {
+ binary.BigEndian.PutUint16(b[ethType:], uint16(e.Type))
+ copy(b[srcMAC:][:EthernetAddressSize], e.SrcAddr)
+ copy(b[dstMAC:][:EthernetAddressSize], e.DstAddr)
+}
diff --git a/pkg/tcpip/header/gue.go b/pkg/tcpip/header/gue.go
new file mode 100644
index 000000000..a069fb669
--- /dev/null
+++ b/pkg/tcpip/header/gue.go
@@ -0,0 +1,63 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package header
+
+const (
+ typeHLen = 0
+ encapProto = 1
+)
+
+// GUEFields contains the fields of a GUE packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type GUEFields struct {
+ // Type is the "type" field of the GUE header.
+ Type uint8
+
+ // Control is the "control" field of the GUE header.
+ Control bool
+
+ // HeaderLength is the "header length" field of the GUE header. It must
+ // be at least 4 octets, and a multiple of 4 as well.
+ HeaderLength uint8
+
+ // Protocol is the "protocol" field of the GUE header. This is one of
+ // the IPPROTO_* values.
+ Protocol uint8
+}
+
+// GUE represents a Generic UDP Encapsulation header stored in a byte array, the
+// fields are described in https://tools.ietf.org/html/draft-ietf-nvo3-gue-01.
+type GUE []byte
+
+const (
+ // GUEMinimumSize is the minimum size of a valid GUE packet.
+ GUEMinimumSize = 4
+)
+
+// TypeAndControl returns the GUE packet type (top 3 bits of the first byte,
+// which includes the control bit).
+func (b GUE) TypeAndControl() uint8 {
+ return b[typeHLen] >> 5
+}
+
+// HeaderLength returns the total length of the GUE header.
+func (b GUE) HeaderLength() uint8 {
+ return 4 + 4*(b[typeHLen]&0x1f)
+}
+
+// Protocol returns the protocol field of the GUE header.
+func (b GUE) Protocol() uint8 {
+ return b[encapProto]
+}
+
+// Encode encodes all the fields of the GUE header.
+func (b GUE) Encode(i *GUEFields) {
+ ctl := uint8(0)
+ if i.Control {
+ ctl = 1 << 5
+ }
+ b[typeHLen] = ctl | i.Type<<6 | (i.HeaderLength-4)/4
+ b[encapProto] = i.Protocol
+}
diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go
new file mode 100644
index 000000000..9f1ad38fc
--- /dev/null
+++ b/pkg/tcpip/header/icmpv4.go
@@ -0,0 +1,98 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// ICMPv4 represents an ICMPv4 header stored in a byte array.
+type ICMPv4 []byte
+
+const (
+ // ICMPv4MinimumSize is the minimum size of a valid ICMP packet.
+ ICMPv4MinimumSize = 4
+
+ // ICMPv4EchoMinimumSize is the minimum size of a valid ICMP echo packet.
+ ICMPv4EchoMinimumSize = 6
+
+ // ICMPv4DstUnreachableMinimumSize is the minimum size of a valid ICMP
+ // destination unreachable packet.
+ ICMPv4DstUnreachableMinimumSize = ICMPv4MinimumSize + 4
+
+ // ICMPv4ProtocolNumber is the ICMP transport protocol number.
+ ICMPv4ProtocolNumber tcpip.TransportProtocolNumber = 1
+)
+
+// ICMPv4Type is the ICMP type field described in RFC 792.
+type ICMPv4Type byte
+
+// Typical values of ICMPv4Type defined in RFC 792.
+const (
+ ICMPv4EchoReply ICMPv4Type = 0
+ ICMPv4DstUnreachable ICMPv4Type = 3
+ ICMPv4SrcQuench ICMPv4Type = 4
+ ICMPv4Redirect ICMPv4Type = 5
+ ICMPv4Echo ICMPv4Type = 8
+ ICMPv4TimeExceeded ICMPv4Type = 11
+ ICMPv4ParamProblem ICMPv4Type = 12
+ ICMPv4Timestamp ICMPv4Type = 13
+ ICMPv4TimestampReply ICMPv4Type = 14
+ ICMPv4InfoRequest ICMPv4Type = 15
+ ICMPv4InfoReply ICMPv4Type = 16
+)
+
+// Values for ICMP code as defined in RFC 792.
+const (
+ ICMPv4PortUnreachable = 3
+ ICMPv4FragmentationNeeded = 4
+)
+
+// Type is the ICMP type field.
+func (b ICMPv4) Type() ICMPv4Type { return ICMPv4Type(b[0]) }
+
+// SetType sets the ICMP type field.
+func (b ICMPv4) SetType(t ICMPv4Type) { b[0] = byte(t) }
+
+// Code is the ICMP code field. Its meaning depends on the value of Type.
+func (b ICMPv4) Code() byte { return b[1] }
+
+// SetCode sets the ICMP code field.
+func (b ICMPv4) SetCode(c byte) { b[1] = c }
+
+// Checksum is the ICMP checksum field.
+func (b ICMPv4) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[2:])
+}
+
+// SetChecksum sets the ICMP checksum field.
+func (b ICMPv4) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[2:], checksum)
+}
+
+// SourcePort implements Transport.SourcePort.
+func (ICMPv4) SourcePort() uint16 {
+ return 0
+}
+
+// DestinationPort implements Transport.DestinationPort.
+func (ICMPv4) DestinationPort() uint16 {
+ return 0
+}
+
+// SetSourcePort implements Transport.SetSourcePort.
+func (ICMPv4) SetSourcePort(uint16) {
+}
+
+// SetDestinationPort implements Transport.SetDestinationPort.
+func (ICMPv4) SetDestinationPort(uint16) {
+}
+
+// Payload implements Transport.Payload.
+func (b ICMPv4) Payload() []byte {
+ return b[ICMPv4MinimumSize:]
+}
diff --git a/pkg/tcpip/header/icmpv6.go b/pkg/tcpip/header/icmpv6.go
new file mode 100644
index 000000000..a061cd02b
--- /dev/null
+++ b/pkg/tcpip/header/icmpv6.go
@@ -0,0 +1,111 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+// ICMPv6 represents an ICMPv6 header stored in a byte array.
+type ICMPv6 []byte
+
+const (
+ // ICMPv6MinimumSize is the minimum size of a valid ICMP packet.
+ ICMPv6MinimumSize = 4
+
+ // ICMPv6ProtocolNumber is the ICMP transport protocol number.
+ ICMPv6ProtocolNumber tcpip.TransportProtocolNumber = 58
+
+ // ICMPv6NeighborSolicitMinimumSize is the minimum size of a
+ // neighbor solicitation packet.
+ ICMPv6NeighborSolicitMinimumSize = ICMPv6MinimumSize + 4 + 16
+
+ // ICMPv6NeighborAdvertSize is size of a neighbor advertisement.
+ ICMPv6NeighborAdvertSize = 32
+
+ // ICMPv6EchoMinimumSize is the minimum size of a valid ICMP echo packet.
+ ICMPv6EchoMinimumSize = 8
+
+ // ICMPv6DstUnreachableMinimumSize is the minimum size of a valid ICMP
+ // destination unreachable packet.
+ ICMPv6DstUnreachableMinimumSize = ICMPv6MinimumSize + 4
+
+ // ICMPv6PacketTooBigMinimumSize is the minimum size of a valid ICMP
+ // packet-too-big packet.
+ ICMPv6PacketTooBigMinimumSize = ICMPv6MinimumSize + 4
+)
+
+// ICMPv6Type is the ICMP type field described in RFC 4443 and friends.
+type ICMPv6Type byte
+
+// Typical values of ICMPv6Type defined in RFC 4443.
+const (
+ ICMPv6DstUnreachable ICMPv6Type = 1
+ ICMPv6PacketTooBig ICMPv6Type = 2
+ ICMPv6TimeExceeded ICMPv6Type = 3
+ ICMPv6ParamProblem ICMPv6Type = 4
+ ICMPv6EchoRequest ICMPv6Type = 128
+ ICMPv6EchoReply ICMPv6Type = 129
+
+ // Neighbor Discovery Protocol (NDP) messages, see RFC 4861.
+
+ ICMPv6RouterSolicit ICMPv6Type = 133
+ ICMPv6RouterAdvert ICMPv6Type = 134
+ ICMPv6NeighborSolicit ICMPv6Type = 135
+ ICMPv6NeighborAdvert ICMPv6Type = 136
+ ICMPv6RedirectMsg ICMPv6Type = 137
+)
+
+// Values for ICMP code as defined in RFC 4443.
+const (
+ ICMPv6PortUnreachable = 4
+)
+
+// Type is the ICMP type field.
+func (b ICMPv6) Type() ICMPv6Type { return ICMPv6Type(b[0]) }
+
+// SetType sets the ICMP type field.
+func (b ICMPv6) SetType(t ICMPv6Type) { b[0] = byte(t) }
+
+// Code is the ICMP code field. Its meaning depends on the value of Type.
+func (b ICMPv6) Code() byte { return b[1] }
+
+// SetCode sets the ICMP code field.
+func (b ICMPv6) SetCode(c byte) { b[1] = c }
+
+// Checksum is the ICMP checksum field.
+func (b ICMPv6) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[2:])
+}
+
+// SetChecksum calculates and sets the ICMP checksum field.
+func (b ICMPv6) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[2:], checksum)
+}
+
+// SourcePort implements Transport.SourcePort.
+func (ICMPv6) SourcePort() uint16 {
+ return 0
+}
+
+// DestinationPort implements Transport.DestinationPort.
+func (ICMPv6) DestinationPort() uint16 {
+ return 0
+}
+
+// SetSourcePort implements Transport.SetSourcePort.
+func (ICMPv6) SetSourcePort(uint16) {
+}
+
+// SetDestinationPort implements Transport.SetDestinationPort.
+func (ICMPv6) SetDestinationPort(uint16) {
+}
+
+// Payload implements Transport.Payload.
+func (b ICMPv6) Payload() []byte {
+ return b[ICMPv6MinimumSize:]
+}
diff --git a/pkg/tcpip/header/interfaces.go b/pkg/tcpip/header/interfaces.go
new file mode 100644
index 000000000..a92286761
--- /dev/null
+++ b/pkg/tcpip/header/interfaces.go
@@ -0,0 +1,82 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package header
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ // MaxIPPacketSize is the maximum supported IP packet size, excluding
+ // jumbograms. The maximum IPv4 packet size is 64k-1 (total size must fit
+ // in 16 bits). For IPv6, the payload max size (excluding jumbograms) is
+ // 64k-1 (also needs to fit in 16 bits). So we use 64k - 1 + 2 * m, where
+ // m is the minimum IPv6 header size; we leave room for some potential
+ // IP options.
+ MaxIPPacketSize = 0xffff + 2*IPv6MinimumSize
+)
+
+// Transport offers generic methods to query and/or update the fields of the
+// header of a transport protocol buffer.
+type Transport interface {
+ // SourcePort returns the value of the "source port" field.
+ SourcePort() uint16
+
+ // Destination returns the value of the "destination port" field.
+ DestinationPort() uint16
+
+ // Checksum returns the value of the "checksum" field.
+ Checksum() uint16
+
+ // SetSourcePort sets the value of the "source port" field.
+ SetSourcePort(uint16)
+
+ // SetDestinationPort sets the value of the "destination port" field.
+ SetDestinationPort(uint16)
+
+ // SetChecksum sets the value of the "checksum" field.
+ SetChecksum(uint16)
+
+ // Payload returns the data carried in the transport buffer.
+ Payload() []byte
+}
+
+// Network offers generic methods to query and/or update the fields of the
+// header of a network protocol buffer.
+type Network interface {
+ // SourceAddress returns the value of the "source address" field.
+ SourceAddress() tcpip.Address
+
+ // DestinationAddress returns the value of the "destination address"
+ // field.
+ DestinationAddress() tcpip.Address
+
+ // Checksum returns the value of the "checksum" field.
+ Checksum() uint16
+
+ // SetSourceAddress sets the value of the "source address" field.
+ SetSourceAddress(tcpip.Address)
+
+ // SetDestinationAddress sets the value of the "destination address"
+ // field.
+ SetDestinationAddress(tcpip.Address)
+
+ // SetChecksum sets the value of the "checksum" field.
+ SetChecksum(uint16)
+
+ // TransportProtocol returns the number of the transport protocol
+ // stored in the payload.
+ TransportProtocol() tcpip.TransportProtocolNumber
+
+ // Payload returns a byte slice containing the payload of the network
+ // packet.
+ Payload() []byte
+
+ // TOS returns the values of the "type of service" and "flow label" fields.
+ TOS() (uint8, uint32)
+
+ // SetTOS sets the values of the "type of service" and "flow label" fields.
+ SetTOS(t uint8, l uint32)
+}
diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go
new file mode 100644
index 000000000..cb0d42093
--- /dev/null
+++ b/pkg/tcpip/header/ipv4.go
@@ -0,0 +1,251 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ versIHL = 0
+ tos = 1
+ totalLen = 2
+ id = 4
+ flagsFO = 6
+ ttl = 8
+ protocol = 9
+ checksum = 10
+ srcAddr = 12
+ dstAddr = 16
+)
+
+// IPv4Fields contains the fields of an IPv4 packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv4Fields struct {
+ // IHL is the "internet header length" field of an IPv4 packet.
+ IHL uint8
+
+ // TOS is the "type of service" field of an IPv4 packet.
+ TOS uint8
+
+ // TotalLength is the "total length" field of an IPv4 packet.
+ TotalLength uint16
+
+ // ID is the "identification" field of an IPv4 packet.
+ ID uint16
+
+ // Flags is the "flags" field of an IPv4 packet.
+ Flags uint8
+
+ // FragmentOffset is the "fragment offset" field of an IPv4 packet.
+ FragmentOffset uint16
+
+ // TTL is the "time to live" field of an IPv4 packet.
+ TTL uint8
+
+ // Protocol is the "protocol" field of an IPv4 packet.
+ Protocol uint8
+
+ // Checksum is the "checksum" field of an IPv4 packet.
+ Checksum uint16
+
+ // SrcAddr is the "source ip address" of an IPv4 packet.
+ SrcAddr tcpip.Address
+
+ // DstAddr is the "destination ip address" of an IPv4 packet.
+ DstAddr tcpip.Address
+}
+
+// IPv4 represents an ipv4 header stored in a byte array.
+// Most of the methods of IPv4 access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv4 before using other methods.
+type IPv4 []byte
+
+const (
+ // IPv4MinimumSize is the minimum size of a valid IPv4 packet.
+ IPv4MinimumSize = 20
+
+ // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given
+ // that there are only 4 bits to represents the header length in 32-bit
+ // units, the header cannot exceed 15*4 = 60 bytes.
+ IPv4MaximumHeaderSize = 60
+
+ // IPv4AddressSize is the size, in bytes, of an IPv4 address.
+ IPv4AddressSize = 4
+
+ // IPv4ProtocolNumber is IPv4's network protocol number.
+ IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800
+
+ // IPv4Version is the version of the ipv4 procotol.
+ IPv4Version = 4
+)
+
+// Flags that may be set in an IPv4 packet.
+const (
+ IPv4FlagMoreFragments = 1 << iota
+ IPv4FlagDontFragment
+)
+
+// IPVersion returns the version of IP used in the given packet. It returns -1
+// if the packet is not large enough to contain the version field.
+func IPVersion(b []byte) int {
+ // Length must be at least offset+length of version field.
+ if len(b) < versIHL+1 {
+ return -1
+ }
+ return int(b[versIHL] >> 4)
+}
+
+// HeaderLength returns the value of the "header length" field of the ipv4
+// header.
+func (b IPv4) HeaderLength() uint8 {
+ return (b[versIHL] & 0xf) * 4
+}
+
+// ID returns the value of the identifier field of the ipv4 header.
+func (b IPv4) ID() uint16 {
+ return binary.BigEndian.Uint16(b[id:])
+}
+
+// Protocol returns the value of the protocol field of the ipv4 header.
+func (b IPv4) Protocol() uint8 {
+ return b[protocol]
+}
+
+// Flags returns the "flags" field of the ipv4 header.
+func (b IPv4) Flags() uint8 {
+ return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13)
+}
+
+// TTL returns the "TTL" field of the ipv4 header.
+func (b IPv4) TTL() uint8 {
+ return b[ttl]
+}
+
+// FragmentOffset returns the "fragment offset" field of the ipv4 header.
+func (b IPv4) FragmentOffset() uint16 {
+ return binary.BigEndian.Uint16(b[flagsFO:]) << 3
+}
+
+// TotalLength returns the "total length" field of the ipv4 header.
+func (b IPv4) TotalLength() uint16 {
+ return binary.BigEndian.Uint16(b[totalLen:])
+}
+
+// Checksum returns the checksum field of the ipv4 header.
+func (b IPv4) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[checksum:])
+}
+
+// SourceAddress returns the "source address" field of the ipv4 header.
+func (b IPv4) SourceAddress() tcpip.Address {
+ return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize])
+}
+
+// DestinationAddress returns the "destination address" field of the ipv4
+// header.
+func (b IPv4) DestinationAddress() tcpip.Address {
+ return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize])
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv4) TransportProtocol() tcpip.TransportProtocolNumber {
+ return tcpip.TransportProtocolNumber(b.Protocol())
+}
+
+// Payload implements Network.Payload.
+func (b IPv4) Payload() []byte {
+ return b[b.HeaderLength():][:b.PayloadLength()]
+}
+
+// PayloadLength returns the length of the payload portion of the ipv4 packet.
+func (b IPv4) PayloadLength() uint16 {
+ return b.TotalLength() - uint16(b.HeaderLength())
+}
+
+// TOS returns the "type of service" field of the ipv4 header.
+func (b IPv4) TOS() (uint8, uint32) {
+ return b[tos], 0
+}
+
+// SetTOS sets the "type of service" field of the ipv4 header.
+func (b IPv4) SetTOS(v uint8, _ uint32) {
+ b[tos] = v
+}
+
+// SetTotalLength sets the "total length" field of the ipv4 header.
+func (b IPv4) SetTotalLength(totalLength uint16) {
+ binary.BigEndian.PutUint16(b[totalLen:], totalLength)
+}
+
+// SetChecksum sets the checksum field of the ipv4 header.
+func (b IPv4) SetChecksum(v uint16) {
+ binary.BigEndian.PutUint16(b[checksum:], v)
+}
+
+// SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the
+// ipv4 header.
+func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) {
+ v := (uint16(flags) << 13) | (offset >> 3)
+ binary.BigEndian.PutUint16(b[flagsFO:], v)
+}
+
+// SetSourceAddress sets the "source address" field of the ipv4 header.
+func (b IPv4) SetSourceAddress(addr tcpip.Address) {
+ copy(b[srcAddr:srcAddr+IPv4AddressSize], addr)
+}
+
+// SetDestinationAddress sets the "destination address" field of the ipv4
+// header.
+func (b IPv4) SetDestinationAddress(addr tcpip.Address) {
+ copy(b[dstAddr:dstAddr+IPv4AddressSize], addr)
+}
+
+// CalculateChecksum calculates the checksum of the ipv4 header.
+func (b IPv4) CalculateChecksum() uint16 {
+ return Checksum(b[:b.HeaderLength()], 0)
+}
+
+// Encode encodes all the fields of the ipv4 header.
+func (b IPv4) Encode(i *IPv4Fields) {
+ b[versIHL] = (4 << 4) | ((i.IHL / 4) & 0xf)
+ b[tos] = i.TOS
+ b.SetTotalLength(i.TotalLength)
+ binary.BigEndian.PutUint16(b[id:], i.ID)
+ b.SetFlagsFragmentOffset(i.Flags, i.FragmentOffset)
+ b[ttl] = i.TTL
+ b[protocol] = i.Protocol
+ b.SetChecksum(i.Checksum)
+ copy(b[srcAddr:srcAddr+IPv4AddressSize], i.SrcAddr)
+ copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr)
+}
+
+// EncodePartial updates the total length and checksum fields of ipv4 header,
+// taking in the partial checksum, which is the checksum of the header without
+// the total length and checksum fields. It is useful in cases when similar
+// packets are produced.
+func (b IPv4) EncodePartial(partialChecksum, totalLength uint16) {
+ b.SetTotalLength(totalLength)
+ checksum := Checksum(b[totalLen:totalLen+2], partialChecksum)
+ b.SetChecksum(^checksum)
+}
+
+// IsValid performs basic validation on the packet.
+func (b IPv4) IsValid(pktSize int) bool {
+ if len(b) < IPv4MinimumSize {
+ return false
+ }
+
+ hlen := int(b.HeaderLength())
+ tlen := int(b.TotalLength())
+ if hlen > tlen || tlen > pktSize {
+ return false
+ }
+
+ return true
+}
diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go
new file mode 100644
index 000000000..d8dc138b3
--- /dev/null
+++ b/pkg/tcpip/header/ipv6.go
@@ -0,0 +1,191 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ versTCFL = 0
+ payloadLen = 4
+ nextHdr = 6
+ hopLimit = 7
+ v6SrcAddr = 8
+ v6DstAddr = 24
+)
+
+// IPv6Fields contains the fields of an IPv6 packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv6Fields struct {
+ // TrafficClass is the "traffic class" field of an IPv6 packet.
+ TrafficClass uint8
+
+ // FlowLabel is the "flow label" field of an IPv6 packet.
+ FlowLabel uint32
+
+ // PayloadLength is the "payload length" field of an IPv6 packet.
+ PayloadLength uint16
+
+ // NextHeader is the "next header" field of an IPv6 packet.
+ NextHeader uint8
+
+ // HopLimit is the "hop limit" field of an IPv6 packet.
+ HopLimit uint8
+
+ // SrcAddr is the "source ip address" of an IPv6 packet.
+ SrcAddr tcpip.Address
+
+ // DstAddr is the "destination ip address" of an IPv6 packet.
+ DstAddr tcpip.Address
+}
+
+// IPv6 represents an ipv6 header stored in a byte array.
+// Most of the methods of IPv6 access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv6 before using other methods.
+type IPv6 []byte
+
+const (
+ // IPv6MinimumSize is the minimum size of a valid IPv6 packet.
+ IPv6MinimumSize = 40
+
+ // IPv6AddressSize is the size, in bytes, of an IPv6 address.
+ IPv6AddressSize = 16
+
+ // IPv6ProtocolNumber is IPv6's network protocol number.
+ IPv6ProtocolNumber tcpip.NetworkProtocolNumber = 0x86dd
+
+ // IPv6Version is the version of the ipv6 procotol.
+ IPv6Version = 6
+
+ // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460,
+ // section 5.
+ IPv6MinimumMTU = 1280
+)
+
+// PayloadLength returns the value of the "payload length" field of the ipv6
+// header.
+func (b IPv6) PayloadLength() uint16 {
+ return binary.BigEndian.Uint16(b[payloadLen:])
+}
+
+// HopLimit returns the value of the "hop limit" field of the ipv6 header.
+func (b IPv6) HopLimit() uint8 {
+ return b[hopLimit]
+}
+
+// NextHeader returns the value of the "next header" field of the ipv6 header.
+func (b IPv6) NextHeader() uint8 {
+ return b[nextHdr]
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv6) TransportProtocol() tcpip.TransportProtocolNumber {
+ return tcpip.TransportProtocolNumber(b.NextHeader())
+}
+
+// Payload implements Network.Payload.
+func (b IPv6) Payload() []byte {
+ return b[IPv6MinimumSize:][:b.PayloadLength()]
+}
+
+// SourceAddress returns the "source address" field of the ipv6 header.
+func (b IPv6) SourceAddress() tcpip.Address {
+ return tcpip.Address(b[v6SrcAddr : v6SrcAddr+IPv6AddressSize])
+}
+
+// DestinationAddress returns the "destination address" field of the ipv6
+// header.
+func (b IPv6) DestinationAddress() tcpip.Address {
+ return tcpip.Address(b[v6DstAddr : v6DstAddr+IPv6AddressSize])
+}
+
+// Checksum implements Network.Checksum. Given that IPv6 doesn't have a
+// checksum, it just returns 0.
+func (IPv6) Checksum() uint16 {
+ return 0
+}
+
+// TOS returns the "traffic class" and "flow label" fields of the ipv6 header.
+func (b IPv6) TOS() (uint8, uint32) {
+ v := binary.BigEndian.Uint32(b[versTCFL:])
+ return uint8(v >> 20), v & 0xfffff
+}
+
+// SetTOS sets the "traffic class" and "flow label" fields of the ipv6 header.
+func (b IPv6) SetTOS(t uint8, l uint32) {
+ vtf := (6 << 28) | (uint32(t) << 20) | (l & 0xfffff)
+ binary.BigEndian.PutUint32(b[versTCFL:], vtf)
+}
+
+// SetPayloadLength sets the "payload length" field of the ipv6 header.
+func (b IPv6) SetPayloadLength(payloadLength uint16) {
+ binary.BigEndian.PutUint16(b[payloadLen:], payloadLength)
+}
+
+// SetSourceAddress sets the "source address" field of the ipv6 header.
+func (b IPv6) SetSourceAddress(addr tcpip.Address) {
+ copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], addr)
+}
+
+// SetDestinationAddress sets the "destination address" field of the ipv6
+// header.
+func (b IPv6) SetDestinationAddress(addr tcpip.Address) {
+ copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], addr)
+}
+
+// SetNextHeader sets the value of the "next header" field of the ipv6 header.
+func (b IPv6) SetNextHeader(v uint8) {
+ b[nextHdr] = v
+}
+
+// SetChecksum implements Network.SetChecksum. Given that IPv6 doesn't have a
+// checksum, it is empty.
+func (IPv6) SetChecksum(uint16) {
+}
+
+// Encode encodes all the fields of the ipv6 header.
+func (b IPv6) Encode(i *IPv6Fields) {
+ b.SetTOS(i.TrafficClass, i.FlowLabel)
+ b.SetPayloadLength(i.PayloadLength)
+ b[nextHdr] = i.NextHeader
+ b[hopLimit] = i.HopLimit
+ copy(b[v6SrcAddr:v6SrcAddr+IPv6AddressSize], i.SrcAddr)
+ copy(b[v6DstAddr:v6DstAddr+IPv6AddressSize], i.DstAddr)
+}
+
+// IsValid performs basic validation on the packet.
+func (b IPv6) IsValid(pktSize int) bool {
+ if len(b) < IPv6MinimumSize {
+ return false
+ }
+
+ dlen := int(b.PayloadLength())
+ if dlen > pktSize-IPv6MinimumSize {
+ return false
+ }
+
+ return true
+}
+
+// IsV4MappedAddress determines if the provided address is an IPv4 mapped
+// address by checking if its prefix is 0:0:0:0:0:ffff::/96.
+func IsV4MappedAddress(addr tcpip.Address) bool {
+ if len(addr) != IPv6AddressSize {
+ return false
+ }
+
+ const prefix = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff"
+ for i := 0; i < len(prefix); i++ {
+ if prefix[i] != addr[i] {
+ return false
+ }
+ }
+
+ return true
+}
diff --git a/pkg/tcpip/header/ipv6_fragment.go b/pkg/tcpip/header/ipv6_fragment.go
new file mode 100644
index 000000000..04aa5c7b8
--- /dev/null
+++ b/pkg/tcpip/header/ipv6_fragment.go
@@ -0,0 +1,136 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ nextHdrFrag = 0
+ fragOff = 2
+ more = 3
+ idV6 = 4
+)
+
+// IPv6FragmentFields contains the fields of an IPv6 fragment. It is used to describe the
+// fields of a packet that needs to be encoded.
+type IPv6FragmentFields struct {
+ // NextHeader is the "next header" field of an IPv6 fragment.
+ NextHeader uint8
+
+ // FragmentOffset is the "fragment offset" field of an IPv6 fragment.
+ FragmentOffset uint16
+
+ // M is the "more" field of an IPv6 fragment.
+ M bool
+
+ // Identification is the "identification" field of an IPv6 fragment.
+ Identification uint32
+}
+
+// IPv6Fragment represents an ipv6 fragment header stored in a byte array.
+// Most of the methods of IPv6Fragment access to the underlying slice without
+// checking the boundaries and could panic because of 'index out of range'.
+// Always call IsValid() to validate an instance of IPv6Fragment before using other methods.
+type IPv6Fragment []byte
+
+const (
+ // IPv6FragmentHeader header is the number used to specify that the next
+ // header is a fragment header, per RFC 2460.
+ IPv6FragmentHeader = 44
+
+ // IPv6FragmentHeaderSize is the size of the fragment header.
+ IPv6FragmentHeaderSize = 8
+)
+
+// Encode encodes all the fields of the ipv6 fragment.
+func (b IPv6Fragment) Encode(i *IPv6FragmentFields) {
+ b[nextHdrFrag] = i.NextHeader
+ binary.BigEndian.PutUint16(b[fragOff:], i.FragmentOffset<<3)
+ if i.M {
+ b[more] |= 1
+ }
+ binary.BigEndian.PutUint32(b[idV6:], i.Identification)
+}
+
+// IsValid performs basic validation on the fragment header.
+func (b IPv6Fragment) IsValid() bool {
+ return len(b) >= IPv6FragmentHeaderSize
+}
+
+// NextHeader returns the value of the "next header" field of the ipv6 fragment.
+func (b IPv6Fragment) NextHeader() uint8 {
+ return b[nextHdrFrag]
+}
+
+// FragmentOffset returns the "fragment offset" field of the ipv6 fragment.
+func (b IPv6Fragment) FragmentOffset() uint16 {
+ return binary.BigEndian.Uint16(b[fragOff:]) >> 3
+}
+
+// More returns the "more" field of the ipv6 fragment.
+func (b IPv6Fragment) More() bool {
+ return b[more]&1 > 0
+}
+
+// Payload implements Network.Payload.
+func (b IPv6Fragment) Payload() []byte {
+ return b[IPv6FragmentHeaderSize:]
+}
+
+// ID returns the value of the identifier field of the ipv6 fragment.
+func (b IPv6Fragment) ID() uint32 {
+ return binary.BigEndian.Uint32(b[idV6:])
+}
+
+// TransportProtocol implements Network.TransportProtocol.
+func (b IPv6Fragment) TransportProtocol() tcpip.TransportProtocolNumber {
+ return tcpip.TransportProtocolNumber(b.NextHeader())
+}
+
+// The functions below have been added only to satisfy the Network interface.
+
+// Checksum is not supported by IPv6Fragment.
+func (b IPv6Fragment) Checksum() uint16 {
+ panic("not supported")
+}
+
+// SourceAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) SourceAddress() tcpip.Address {
+ panic("not supported")
+}
+
+// DestinationAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) DestinationAddress() tcpip.Address {
+ panic("not supported")
+}
+
+// SetSourceAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetSourceAddress(tcpip.Address) {
+ panic("not supported")
+}
+
+// SetDestinationAddress is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetDestinationAddress(tcpip.Address) {
+ panic("not supported")
+}
+
+// SetChecksum is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetChecksum(uint16) {
+ panic("not supported")
+}
+
+// TOS is not supported by IPv6Fragment.
+func (b IPv6Fragment) TOS() (uint8, uint32) {
+ panic("not supported")
+}
+
+// SetTOS is not supported by IPv6Fragment.
+func (b IPv6Fragment) SetTOS(t uint8, l uint32) {
+ panic("not supported")
+}
diff --git a/pkg/tcpip/header/ipversion_test.go b/pkg/tcpip/header/ipversion_test.go
new file mode 100644
index 000000000..5f3956160
--- /dev/null
+++ b/pkg/tcpip/header/ipversion_test.go
@@ -0,0 +1,57 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package header_test
+
+import (
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+func TestIPv4(t *testing.T) {
+ b := header.IPv4(make([]byte, header.IPv4MinimumSize))
+ b.Encode(&header.IPv4Fields{})
+
+ const want = header.IPv4Version
+ if v := header.IPVersion(b); v != want {
+ t.Fatalf("Bad version, want %v, got %v", want, v)
+ }
+}
+
+func TestIPv6(t *testing.T) {
+ b := header.IPv6(make([]byte, header.IPv6MinimumSize))
+ b.Encode(&header.IPv6Fields{})
+
+ const want = header.IPv6Version
+ if v := header.IPVersion(b); v != want {
+ t.Fatalf("Bad version, want %v, got %v", want, v)
+ }
+}
+
+func TestOtherVersion(t *testing.T) {
+ const want = header.IPv4Version + header.IPv6Version
+ b := make([]byte, 1)
+ b[0] = want << 4
+
+ if v := header.IPVersion(b); v != want {
+ t.Fatalf("Bad version, want %v, got %v", want, v)
+ }
+}
+
+func TestTooShort(t *testing.T) {
+ b := make([]byte, 1)
+ b[0] = (header.IPv4Version + header.IPv6Version) << 4
+
+ // Get the version of a zero-length slice.
+ const want = -1
+ if v := header.IPVersion(b[:0]); v != want {
+ t.Fatalf("Bad version, want %v, got %v", want, v)
+ }
+
+ // Get the version of a nil slice.
+ if v := header.IPVersion(nil); v != want {
+ t.Fatalf("Bad version, want %v, got %v", want, v)
+ }
+}
diff --git a/pkg/tcpip/header/tcp.go b/pkg/tcpip/header/tcp.go
new file mode 100644
index 000000000..995df4076
--- /dev/null
+++ b/pkg/tcpip/header/tcp.go
@@ -0,0 +1,518 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ srcPort = 0
+ dstPort = 2
+ seqNum = 4
+ ackNum = 8
+ dataOffset = 12
+ tcpFlags = 13
+ winSize = 14
+ tcpChecksum = 16
+ urgentPtr = 18
+)
+
+const (
+ // MaxWndScale is maximum allowed window scaling, as described in
+ // RFC 1323, section 2.3, page 11.
+ MaxWndScale = 14
+
+ // TCPMaxSACKBlocks is the maximum number of SACK blocks that can
+ // be encoded in a TCP option field.
+ TCPMaxSACKBlocks = 4
+)
+
+// Flags that may be set in a TCP segment.
+const (
+ TCPFlagFin = 1 << iota
+ TCPFlagSyn
+ TCPFlagRst
+ TCPFlagPsh
+ TCPFlagAck
+ TCPFlagUrg
+)
+
+// Options that may be present in a TCP segment.
+const (
+ TCPOptionEOL = 0
+ TCPOptionNOP = 1
+ TCPOptionMSS = 2
+ TCPOptionWS = 3
+ TCPOptionTS = 8
+ TCPOptionSACKPermitted = 4
+ TCPOptionSACK = 5
+)
+
+// TCPFields contains the fields of a TCP packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type TCPFields struct {
+ // SrcPort is the "source port" field of a TCP packet.
+ SrcPort uint16
+
+ // DstPort is the "destination port" field of a TCP packet.
+ DstPort uint16
+
+ // SeqNum is the "sequence number" field of a TCP packet.
+ SeqNum uint32
+
+ // AckNum is the "acknowledgement number" field of a TCP packet.
+ AckNum uint32
+
+ // DataOffset is the "data offset" field of a TCP packet.
+ DataOffset uint8
+
+ // Flags is the "flags" field of a TCP packet.
+ Flags uint8
+
+ // WindowSize is the "window size" field of a TCP packet.
+ WindowSize uint16
+
+ // Checksum is the "checksum" field of a TCP packet.
+ Checksum uint16
+
+ // UrgentPointer is the "urgent pointer" field of a TCP packet.
+ UrgentPointer uint16
+}
+
+// TCPSynOptions is used to return the parsed TCP Options in a syn
+// segment.
+type TCPSynOptions struct {
+ // MSS is the maximum segment size provided by the peer in the SYN.
+ MSS uint16
+
+ // WS is the window scale option provided by the peer in the SYN.
+ //
+ // Set to -1 if no window scale option was provided.
+ WS int
+
+ // TS is true if the timestamp option was provided in the syn/syn-ack.
+ TS bool
+
+ // TSVal is the value of the TSVal field in the timestamp option.
+ TSVal uint32
+
+ // TSEcr is the value of the TSEcr field in the timestamp option.
+ TSEcr uint32
+
+ // SACKPermitted is true if the SACK option was provided in the SYN/SYN-ACK.
+ SACKPermitted bool
+}
+
+// SACKBlock represents a single contiguous SACK block.
+type SACKBlock struct {
+ // Start indicates the lowest sequence number in the block.
+ Start seqnum.Value
+
+ // End indicates the sequence number immediately following the last
+ // sequence number of this block.
+ End seqnum.Value
+}
+
+// TCPOptions are used to parse and cache the TCP segment options for a non
+// syn/syn-ack segment.
+type TCPOptions struct {
+ // TS is true if the TimeStamp option is enabled.
+ TS bool
+
+ // TSVal is the value in the TSVal field of the segment.
+ TSVal uint32
+
+ // TSEcr is the value in the TSEcr field of the segment.
+ TSEcr uint32
+
+ // SACKBlocks are the SACK blocks specified in the segment.
+ SACKBlocks []SACKBlock
+}
+
+// TCP represents a TCP header stored in a byte array.
+type TCP []byte
+
+const (
+ // TCPMinimumSize is the minimum size of a valid TCP packet.
+ TCPMinimumSize = 20
+
+ // TCPProtocolNumber is TCP's transport protocol number.
+ TCPProtocolNumber tcpip.TransportProtocolNumber = 6
+)
+
+// SourcePort returns the "source port" field of the tcp header.
+func (b TCP) SourcePort() uint16 {
+ return binary.BigEndian.Uint16(b[srcPort:])
+}
+
+// DestinationPort returns the "destination port" field of the tcp header.
+func (b TCP) DestinationPort() uint16 {
+ return binary.BigEndian.Uint16(b[dstPort:])
+}
+
+// SequenceNumber returns the "sequence number" field of the tcp header.
+func (b TCP) SequenceNumber() uint32 {
+ return binary.BigEndian.Uint32(b[seqNum:])
+}
+
+// AckNumber returns the "ack number" field of the tcp header.
+func (b TCP) AckNumber() uint32 {
+ return binary.BigEndian.Uint32(b[ackNum:])
+}
+
+// DataOffset returns the "data offset" field of the tcp header.
+func (b TCP) DataOffset() uint8 {
+ return (b[dataOffset] >> 4) * 4
+}
+
+// Payload returns the data in the tcp packet.
+func (b TCP) Payload() []byte {
+ return b[b.DataOffset():]
+}
+
+// Flags returns the flags field of the tcp header.
+func (b TCP) Flags() uint8 {
+ return b[tcpFlags]
+}
+
+// WindowSize returns the "window size" field of the tcp header.
+func (b TCP) WindowSize() uint16 {
+ return binary.BigEndian.Uint16(b[winSize:])
+}
+
+// Checksum returns the "checksum" field of the tcp header.
+func (b TCP) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[tcpChecksum:])
+}
+
+// SetSourcePort sets the "source port" field of the tcp header.
+func (b TCP) SetSourcePort(port uint16) {
+ binary.BigEndian.PutUint16(b[srcPort:], port)
+}
+
+// SetDestinationPort sets the "destination port" field of the tcp header.
+func (b TCP) SetDestinationPort(port uint16) {
+ binary.BigEndian.PutUint16(b[dstPort:], port)
+}
+
+// SetChecksum sets the checksum field of the tcp header.
+func (b TCP) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[tcpChecksum:], checksum)
+}
+
+// CalculateChecksum calculates the checksum of the tcp segment given
+// the totalLen and partialChecksum(descriptions below)
+// totalLen is the total length of the segment
+// partialChecksum is the checksum of the network-layer pseudo-header
+// (excluding the total length) and the checksum of the segment data.
+func (b TCP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 {
+ // Add the length portion of the checksum to the pseudo-checksum.
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ checksum := Checksum(tmp, partialChecksum)
+
+ // Calculate the rest of the checksum.
+ return Checksum(b[:b.DataOffset()], checksum)
+}
+
+// Options returns a slice that holds the unparsed TCP options in the segment.
+func (b TCP) Options() []byte {
+ return b[TCPMinimumSize:b.DataOffset()]
+}
+
+// ParsedOptions returns a TCPOptions structure which parses and caches the TCP
+// option values in the TCP segment. NOTE: Invoking this function repeatedly is
+// expensive as it reparses the options on each invocation.
+func (b TCP) ParsedOptions() TCPOptions {
+ return ParseTCPOptions(b.Options())
+}
+
+func (b TCP) encodeSubset(seq, ack uint32, flags uint8, rcvwnd uint16) {
+ binary.BigEndian.PutUint32(b[seqNum:], seq)
+ binary.BigEndian.PutUint32(b[ackNum:], ack)
+ b[tcpFlags] = flags
+ binary.BigEndian.PutUint16(b[winSize:], rcvwnd)
+}
+
+// Encode encodes all the fields of the tcp header.
+func (b TCP) Encode(t *TCPFields) {
+ b.encodeSubset(t.SeqNum, t.AckNum, t.Flags, t.WindowSize)
+ binary.BigEndian.PutUint16(b[srcPort:], t.SrcPort)
+ binary.BigEndian.PutUint16(b[dstPort:], t.DstPort)
+ b[dataOffset] = (t.DataOffset / 4) << 4
+ binary.BigEndian.PutUint16(b[tcpChecksum:], t.Checksum)
+ binary.BigEndian.PutUint16(b[urgentPtr:], t.UrgentPointer)
+}
+
+// EncodePartial updates a subset of the fields of the tcp header. It is useful
+// in cases when similar segments are produced.
+func (b TCP) EncodePartial(partialChecksum, length uint16, seqnum, acknum uint32, flags byte, rcvwnd uint16) {
+ // Add the total length and "flags" field contributions to the checksum.
+ // We don't use the flags field directly from the header because it's a
+ // one-byte field with an odd offset, so it would be accounted for
+ // incorrectly by the Checksum routine.
+ tmp := make([]byte, 4)
+ binary.BigEndian.PutUint16(tmp, length)
+ binary.BigEndian.PutUint16(tmp[2:], uint16(flags))
+ checksum := Checksum(tmp, partialChecksum)
+
+ // Encode the passed-in fields.
+ b.encodeSubset(seqnum, acknum, flags, rcvwnd)
+
+ // Add the contributions of the passed-in fields to the checksum.
+ checksum = Checksum(b[seqNum:seqNum+8], checksum)
+ checksum = Checksum(b[winSize:winSize+2], checksum)
+
+ // Encode the checksum.
+ b.SetChecksum(^checksum)
+}
+
+// ParseSynOptions parses the options received in a SYN segment and returns the
+// relevant ones. opts should point to the option part of the TCP Header.
+func ParseSynOptions(opts []byte, isAck bool) TCPSynOptions {
+ limit := len(opts)
+
+ synOpts := TCPSynOptions{
+ // Per RFC 1122, page 85: "If an MSS option is not received at
+ // connection setup, TCP MUST assume a default send MSS of 536."
+ MSS: 536,
+ // If no window scale option is specified, WS in options is
+ // returned as -1; this is because the absence of the option
+ // indicates that the we cannot use window scaling on the
+ // receive end either.
+ WS: -1,
+ }
+
+ for i := 0; i < limit; {
+ switch opts[i] {
+ case TCPOptionEOL:
+ i = limit
+ case TCPOptionNOP:
+ i++
+ case TCPOptionMSS:
+ if i+4 > limit || opts[i+1] != 4 {
+ return synOpts
+ }
+ mss := uint16(opts[i+2])<<8 | uint16(opts[i+3])
+ if mss == 0 {
+ return synOpts
+ }
+ synOpts.MSS = mss
+ i += 4
+
+ case TCPOptionWS:
+ if i+3 > limit || opts[i+1] != 3 {
+ return synOpts
+ }
+ ws := int(opts[i+2])
+ if ws > MaxWndScale {
+ ws = MaxWndScale
+ }
+ synOpts.WS = ws
+ i += 3
+
+ case TCPOptionTS:
+ if i+10 > limit || opts[i+1] != 10 {
+ return synOpts
+ }
+ synOpts.TSVal = binary.BigEndian.Uint32(opts[i+2:])
+ if isAck {
+ // If the segment is a SYN-ACK then store the Timestamp Echo Reply
+ // in the segment.
+ synOpts.TSEcr = binary.BigEndian.Uint32(opts[i+6:])
+ }
+ synOpts.TS = true
+ i += 10
+ case TCPOptionSACKPermitted:
+ if i+2 > limit || opts[i+1] != 2 {
+ return synOpts
+ }
+ synOpts.SACKPermitted = true
+ i += 2
+
+ default:
+ // We don't recognize this option, just skip over it.
+ if i+2 > limit {
+ return synOpts
+ }
+ l := int(opts[i+1])
+ // If the length is incorrect or if l+i overflows the
+ // total options length then return false.
+ if l < 2 || i+l > limit {
+ return synOpts
+ }
+ i += l
+ }
+ }
+
+ return synOpts
+}
+
+// ParseTCPOptions extracts and stores all known options in the provided byte
+// slice in a TCPOptions structure.
+func ParseTCPOptions(b []byte) TCPOptions {
+ opts := TCPOptions{}
+ limit := len(b)
+ for i := 0; i < limit; {
+ switch b[i] {
+ case TCPOptionEOL:
+ i = limit
+ case TCPOptionNOP:
+ i++
+ case TCPOptionTS:
+ if i+10 > limit || (b[i+1] != 10) {
+ return opts
+ }
+ opts.TS = true
+ opts.TSVal = binary.BigEndian.Uint32(b[i+2:])
+ opts.TSEcr = binary.BigEndian.Uint32(b[i+6:])
+ i += 10
+ case TCPOptionSACK:
+ if i+2 > limit {
+ // Malformed SACK block, just return and stop parsing.
+ return opts
+ }
+ sackOptionLen := int(b[i+1])
+ if i+sackOptionLen > limit || (sackOptionLen-2)%8 != 0 {
+ // Malformed SACK block, just return and stop parsing.
+ return opts
+ }
+ numBlocks := (sackOptionLen - 2) / 8
+ opts.SACKBlocks = []SACKBlock{}
+ for j := 0; j < numBlocks; j++ {
+ start := binary.BigEndian.Uint32(b[i+2+j*8:])
+ end := binary.BigEndian.Uint32(b[i+2+j*8+4:])
+ opts.SACKBlocks = append(opts.SACKBlocks, SACKBlock{
+ Start: seqnum.Value(start),
+ End: seqnum.Value(end),
+ })
+ }
+ i += sackOptionLen
+ default:
+ // We don't recognize this option, just skip over it.
+ if i+2 > limit {
+ return opts
+ }
+ l := int(b[i+1])
+ // If the length is incorrect or if l+i overflows the
+ // total options length then return false.
+ if l < 2 || i+l > limit {
+ return opts
+ }
+ i += l
+ }
+ }
+ return opts
+}
+
+// EncodeMSSOption encodes the MSS TCP option with the provided MSS values in
+// the supplied buffer. If the provided buffer is not large enough then it just
+// returns without encoding anything. It returns the number of bytes written to
+// the provided buffer.
+func EncodeMSSOption(mss uint32, b []byte) int {
+ // mssOptionSize is the number of bytes in a valid MSS option.
+ const mssOptionSize = 4
+
+ if len(b) < mssOptionSize {
+ return 0
+ }
+ b[0], b[1], b[2], b[3] = TCPOptionMSS, mssOptionSize, byte(mss>>8), byte(mss)
+ return mssOptionSize
+}
+
+// EncodeWSOption encodes the WS TCP option with the WS value in the
+// provided buffer. If the provided buffer is not large enough then it just
+// returns without encoding anything. It returns the number of bytes written to
+// the provided buffer.
+func EncodeWSOption(ws int, b []byte) int {
+ if len(b) < 3 {
+ return 0
+ }
+ b[0], b[1], b[2] = TCPOptionWS, 3, uint8(ws)
+ return int(b[1])
+}
+
+// EncodeTSOption encodes the provided tsVal and tsEcr values as a TCP timestamp
+// option into the provided buffer. If the buffer is smaller than expected it
+// just returns without encoding anything. It returns the number of bytes
+// written to the provided buffer.
+func EncodeTSOption(tsVal, tsEcr uint32, b []byte) int {
+ if len(b) < 10 {
+ return 0
+ }
+ b[0], b[1] = TCPOptionTS, 10
+ binary.BigEndian.PutUint32(b[2:], tsVal)
+ binary.BigEndian.PutUint32(b[6:], tsEcr)
+ return int(b[1])
+}
+
+// EncodeSACKPermittedOption encodes a SACKPermitted option into the provided
+// buffer. If the buffer is smaller than required it just returns without
+// encoding anything. It returns the number of bytes written to the provided
+// buffer.
+func EncodeSACKPermittedOption(b []byte) int {
+ if len(b) < 2 {
+ return 0
+ }
+
+ b[0], b[1] = TCPOptionSACKPermitted, 2
+ return int(b[1])
+}
+
+// EncodeSACKBlocks encodes the provided SACK blocks as a TCP SACK option block
+// in the provided slice. It tries to fit in as many blocks as possible based on
+// number of bytes available in the provided buffer. It returns the number of
+// bytes written to the provided buffer.
+func EncodeSACKBlocks(sackBlocks []SACKBlock, b []byte) int {
+ if len(sackBlocks) == 0 {
+ return 0
+ }
+ l := len(sackBlocks)
+ if l > TCPMaxSACKBlocks {
+ l = TCPMaxSACKBlocks
+ }
+ if ll := (len(b) - 2) / 8; ll < l {
+ l = ll
+ }
+ if l == 0 {
+ // There is not enough space in the provided buffer to add
+ // any SACK blocks.
+ return 0
+ }
+ b[0] = TCPOptionSACK
+ b[1] = byte(l*8 + 2)
+ for i := 0; i < l; i++ {
+ binary.BigEndian.PutUint32(b[i*8+2:], uint32(sackBlocks[i].Start))
+ binary.BigEndian.PutUint32(b[i*8+6:], uint32(sackBlocks[i].End))
+ }
+ return int(b[1])
+}
+
+// EncodeNOP adds an explicit NOP to the option list.
+func EncodeNOP(b []byte) int {
+ if len(b) == 0 {
+ return 0
+ }
+ b[0] = TCPOptionNOP
+ return 1
+}
+
+// AddTCPOptionPadding adds the required number of TCPOptionNOP to quad align
+// the option buffer. It adds padding bytes after the offset specified and
+// returns the number of padding bytes added. The passed in options slice
+// must have space for the padding bytes.
+func AddTCPOptionPadding(options []byte, offset int) int {
+ paddingToAdd := -offset & 3
+ // Now add any padding bytes that might be required to quad align the
+ // options.
+ for i := offset; i < offset+paddingToAdd; i++ {
+ options[i] = TCPOptionNOP
+ }
+ return paddingToAdd
+}
diff --git a/pkg/tcpip/header/tcp_test.go b/pkg/tcpip/header/tcp_test.go
new file mode 100644
index 000000000..27d43479a
--- /dev/null
+++ b/pkg/tcpip/header/tcp_test.go
@@ -0,0 +1,134 @@
+package header_test
+
+import (
+ "reflect"
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+func TestEncodeSACKBlocks(t *testing.T) {
+ testCases := []struct {
+ sackBlocks []header.SACKBlock
+ want []header.SACKBlock
+ bufSize int
+ }{
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}},
+ 40,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}},
+ 30,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}, {22, 30}},
+ 20,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}},
+ 10,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ nil,
+ 8,
+ },
+ {
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}, {52, 60}, {62, 70}},
+ []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}, {42, 50}},
+ 60,
+ },
+ }
+ for _, tc := range testCases {
+ b := make([]byte, tc.bufSize)
+ t.Logf("testing: %v", tc)
+ header.EncodeSACKBlocks(tc.sackBlocks, b)
+ opts := header.ParseTCPOptions(b)
+ if got, want := opts.SACKBlocks, tc.want; !reflect.DeepEqual(got, want) {
+ t.Errorf("header.EncodeSACKBlocks(%v, %v), encoded blocks got: %v, want: %v", tc.sackBlocks, b, got, want)
+ }
+ }
+}
+
+func TestTCPParseOptions(t *testing.T) {
+ type tsOption struct {
+ tsVal uint32
+ tsEcr uint32
+ }
+
+ generateOptions := func(tsOpt *tsOption, sackBlocks []header.SACKBlock) []byte {
+ l := 0
+ if tsOpt != nil {
+ l += 10
+ }
+ if len(sackBlocks) != 0 {
+ l += len(sackBlocks)*8 + 2
+ }
+ b := make([]byte, l)
+ offset := 0
+ if tsOpt != nil {
+ offset = header.EncodeTSOption(tsOpt.tsVal, tsOpt.tsEcr, b)
+ }
+ header.EncodeSACKBlocks(sackBlocks, b[offset:])
+ return b
+ }
+
+ testCases := []struct {
+ b []byte
+ want header.TCPOptions
+ }{
+ // Trivial cases.
+ {nil, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP, header.TCPOptionNOP}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionEOL}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP, header.TCPOptionEOL, header.TCPOptionTS, 10, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
+
+ // Test timestamp parsing.
+ {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
+
+ // Test malformed timestamp option.
+ {[]byte{header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 1, 1}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionNOP, header.TCPOptionTS, 8, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
+
+ // Test SACKBlock parsing.
+ {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}}}},
+ {[]byte{header.TCPOptionSACK, 18, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{1, 10}, {11, 12}}}},
+
+ // Test malformed SACK option.
+ {[]byte{header.TCPOptionSACK, 0}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 8, 0, 0, 0, 1, 0, 0, 0, 10}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 17, 0, 0, 0, 1, 0, 0, 0, 10, 0, 0, 0, 11, 0, 0, 0, 12}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 10}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, 1, 0, 0, 0}, header.TCPOptions{false, 0, 0, nil}},
+
+ // Test Timestamp + SACK block parsing.
+ {generateOptions(&tsOption{1, 1}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 1, []header.SACKBlock{{1, 10}, {11, 12}}}},
+ {generateOptions(&tsOption{1, 2}, []header.SACKBlock{{1, 10}, {11, 12}}), header.TCPOptions{true, 1, 2, []header.SACKBlock{{1, 10}, {11, 12}}}},
+ {generateOptions(&tsOption{1, 3}, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}, {15, 16}}), header.TCPOptions{true, 1, 3, []header.SACKBlock{{1, 10}, {11, 12}, {13, 14}, {14, 15}}}},
+
+ // Test valid timestamp + malformed SACK block parsing.
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 10, 0, 0, 0}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{true, 1, 1, nil}},
+ {[]byte{header.TCPOptionSACK, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
+ {[]byte{header.TCPOptionSACK, 10, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{134873088, 65536}}}},
+ {[]byte{header.TCPOptionSACK, 10, 0, 0, 0, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, []header.SACKBlock{{8, 167772160}}}},
+ {[]byte{header.TCPOptionSACK, 11, 0, 0, 0, 1, 0, 0, 0, 1, header.TCPOptionTS, 10, 0, 0, 0, 1, 0, 0, 0, 1}, header.TCPOptions{false, 0, 0, nil}},
+ }
+ for _, tc := range testCases {
+ if got, want := header.ParseTCPOptions(tc.b), tc.want; !reflect.DeepEqual(got, want) {
+ t.Errorf("ParseTCPOptions(%v) = %v, want: %v", tc.b, got, tc.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/header/udp.go b/pkg/tcpip/header/udp.go
new file mode 100644
index 000000000..7c2548634
--- /dev/null
+++ b/pkg/tcpip/header/udp.go
@@ -0,0 +1,106 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package header
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ udpSrcPort = 0
+ udpDstPort = 2
+ udpLength = 4
+ udpChecksum = 6
+)
+
+// UDPFields contains the fields of a UDP packet. It is used to describe the
+// fields of a packet that needs to be encoded.
+type UDPFields struct {
+ // SrcPort is the "source port" field of a UDP packet.
+ SrcPort uint16
+
+ // DstPort is the "destination port" field of a UDP packet.
+ DstPort uint16
+
+ // Length is the "length" field of a UDP packet.
+ Length uint16
+
+ // Checksum is the "checksum" field of a UDP packet.
+ Checksum uint16
+}
+
+// UDP represents a UDP header stored in a byte array.
+type UDP []byte
+
+const (
+ // UDPMinimumSize is the minimum size of a valid UDP packet.
+ UDPMinimumSize = 8
+
+ // UDPProtocolNumber is UDP's transport protocol number.
+ UDPProtocolNumber tcpip.TransportProtocolNumber = 17
+)
+
+// SourcePort returns the "source port" field of the udp header.
+func (b UDP) SourcePort() uint16 {
+ return binary.BigEndian.Uint16(b[udpSrcPort:])
+}
+
+// DestinationPort returns the "destination port" field of the udp header.
+func (b UDP) DestinationPort() uint16 {
+ return binary.BigEndian.Uint16(b[udpDstPort:])
+}
+
+// Length returns the "length" field of the udp header.
+func (b UDP) Length() uint16 {
+ return binary.BigEndian.Uint16(b[udpLength:])
+}
+
+// Payload returns the data contained in the UDP datagram.
+func (b UDP) Payload() []byte {
+ return b[UDPMinimumSize:]
+}
+
+// Checksum returns the "checksum" field of the udp header.
+func (b UDP) Checksum() uint16 {
+ return binary.BigEndian.Uint16(b[udpChecksum:])
+}
+
+// SetSourcePort sets the "source port" field of the udp header.
+func (b UDP) SetSourcePort(port uint16) {
+ binary.BigEndian.PutUint16(b[udpSrcPort:], port)
+}
+
+// SetDestinationPort sets the "destination port" field of the udp header.
+func (b UDP) SetDestinationPort(port uint16) {
+ binary.BigEndian.PutUint16(b[udpDstPort:], port)
+}
+
+// SetChecksum sets the "checksum" field of the udp header.
+func (b UDP) SetChecksum(checksum uint16) {
+ binary.BigEndian.PutUint16(b[udpChecksum:], checksum)
+}
+
+// CalculateChecksum calculates the checksum of the udp packet, given the total
+// length of the packet and the checksum of the network-layer pseudo-header
+// (excluding the total length) and the checksum of the payload.
+func (b UDP) CalculateChecksum(partialChecksum uint16, totalLen uint16) uint16 {
+ // Add the length portion of the checksum to the pseudo-checksum.
+ tmp := make([]byte, 2)
+ binary.BigEndian.PutUint16(tmp, totalLen)
+ checksum := Checksum(tmp, partialChecksum)
+
+ // Calculate the rest of the checksum.
+ return Checksum(b[:UDPMinimumSize], checksum)
+}
+
+// Encode encodes all the fields of the udp header.
+func (b UDP) Encode(u *UDPFields) {
+ binary.BigEndian.PutUint16(b[udpSrcPort:], u.SrcPort)
+ binary.BigEndian.PutUint16(b[udpDstPort:], u.DstPort)
+ binary.BigEndian.PutUint16(b[udpLength:], u.Length)
+ binary.BigEndian.PutUint16(b[udpChecksum:], u.Checksum)
+}
diff --git a/pkg/tcpip/link/channel/BUILD b/pkg/tcpip/link/channel/BUILD
new file mode 100644
index 000000000..b58a76699
--- /dev/null
+++ b/pkg/tcpip/link/channel/BUILD
@@ -0,0 +1,15 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "channel",
+ srcs = ["channel.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/channel/channel.go b/pkg/tcpip/link/channel/channel.go
new file mode 100644
index 000000000..cebc34553
--- /dev/null
+++ b/pkg/tcpip/link/channel/channel.go
@@ -0,0 +1,110 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package channel provides the implemention of channel-based data-link layer
+// endpoints. Such endpoints allow injection of inbound packets and store
+// outbound packets in a channel.
+package channel
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// PacketInfo holds all the information about an outbound packet.
+type PacketInfo struct {
+ Header buffer.View
+ Payload buffer.View
+ Proto tcpip.NetworkProtocolNumber
+}
+
+// Endpoint is link layer endpoint that stores outbound packets in a channel
+// and allows injection of inbound packets.
+type Endpoint struct {
+ dispatcher stack.NetworkDispatcher
+ mtu uint32
+ linkAddr tcpip.LinkAddress
+
+ // C is where outbound packets are queued.
+ C chan PacketInfo
+}
+
+// New creates a new channel endpoint.
+func New(size int, mtu uint32, linkAddr tcpip.LinkAddress) (tcpip.LinkEndpointID, *Endpoint) {
+ e := &Endpoint{
+ C: make(chan PacketInfo, size),
+ mtu: mtu,
+ linkAddr: linkAddr,
+ }
+
+ return stack.RegisterLinkEndpoint(e), e
+}
+
+// Drain removes all outbound packets from the channel and counts them.
+func (e *Endpoint) Drain() int {
+ c := 0
+ for {
+ select {
+ case <-e.C:
+ c++
+ default:
+ return c
+ }
+ }
+}
+
+// Inject injects an inbound packet.
+func (e *Endpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
+ uu := vv.Clone(nil)
+ e.dispatcher.DeliverNetworkPacket(e, "", protocol, &uu)
+}
+
+// Attach saves the stack network-layer dispatcher for use later when packets
+// are injected.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *Endpoint) MTU() uint32 {
+ return e.mtu
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (*Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
+// MaxHeaderLength returns the maximum size of the link layer header. Given it
+// doesn't have a header, it just returns 0.
+func (*Endpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+// WritePacket stores outbound packets into the channel.
+func (e *Endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ p := PacketInfo{
+ Header: hdr.View(),
+ Proto: protocol,
+ }
+
+ if payload != nil {
+ p.Payload = make(buffer.View, len(payload))
+ copy(p.Payload, payload)
+ }
+
+ select {
+ case e.C <- p:
+ default:
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/link/fdbased/BUILD b/pkg/tcpip/link/fdbased/BUILD
new file mode 100644
index 000000000..b5ab1ea6a
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/BUILD
@@ -0,0 +1,32 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "fdbased",
+ srcs = ["endpoint.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/fdbased",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "fdbased_test",
+ size = "small",
+ srcs = ["endpoint_test.go"],
+ embed = [":fdbased"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
new file mode 100644
index 000000000..da74cd644
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -0,0 +1,261 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package fdbased provides the implemention of data-link layer endpoints
+// backed by boundary-preserving file descriptors (e.g., TUN devices,
+// seqpacket/datagram sockets).
+//
+// FD based endpoints can be used in the networking stack by calling New() to
+// create a new endpoint, and then passing it as an argument to
+// Stack.CreateNIC().
+package fdbased
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// BufConfig defines the shape of the vectorised view used to read packets from the NIC.
+var BufConfig = []int{128, 256, 256, 512, 1024, 2048, 4096, 8192, 16384, 32768}
+
+type endpoint struct {
+ // fd is the file descriptor used to send and receive packets.
+ fd int
+
+ // mtu (maximum transmission unit) is the maximum size of a packet.
+ mtu uint32
+
+ // hdrSize specifies the link-layer header size. If set to 0, no header
+ // is added/removed; otherwise an ethernet header is used.
+ hdrSize int
+
+ // addr is the address of the endpoint.
+ addr tcpip.LinkAddress
+
+ // caps holds the endpoint capabilities.
+ caps stack.LinkEndpointCapabilities
+
+ // closed is a function to be called when the FD's peer (if any) closes
+ // its end of the communication pipe.
+ closed func(*tcpip.Error)
+
+ vv *buffer.VectorisedView
+ iovecs []syscall.Iovec
+ views []buffer.View
+}
+
+// Options specify the details about the fd-based endpoint to be created.
+type Options struct {
+ FD int
+ MTU uint32
+ EthernetHeader bool
+ ChecksumOffload bool
+ ClosedFunc func(*tcpip.Error)
+ Address tcpip.LinkAddress
+}
+
+// New creates a new fd-based endpoint.
+//
+// Makes fd non-blocking, but does not take ownership of fd, which must remain
+// open for the lifetime of the returned endpoint.
+func New(opts *Options) tcpip.LinkEndpointID {
+ syscall.SetNonblock(opts.FD, true)
+
+ caps := stack.LinkEndpointCapabilities(0)
+ if opts.ChecksumOffload {
+ caps |= stack.CapabilityChecksumOffload
+ }
+
+ hdrSize := 0
+ if opts.EthernetHeader {
+ hdrSize = header.EthernetMinimumSize
+ caps |= stack.CapabilityResolutionRequired
+ }
+
+ e := &endpoint{
+ fd: opts.FD,
+ mtu: opts.MTU,
+ caps: caps,
+ closed: opts.ClosedFunc,
+ addr: opts.Address,
+ hdrSize: hdrSize,
+ views: make([]buffer.View, len(BufConfig)),
+ iovecs: make([]syscall.Iovec, len(BufConfig)),
+ }
+ vv := buffer.NewVectorisedView(0, e.views)
+ e.vv = &vv
+ return stack.RegisterLinkEndpoint(e)
+}
+
+// Attach launches the goroutine that reads packets from the file descriptor and
+// dispatches them via the provided dispatcher.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ go e.dispatchLoop(dispatcher) // S/R-FIXME
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *endpoint) MTU() uint32 {
+ return e.mtu
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.caps
+}
+
+// MaxHeaderLength returns the maximum size of the link-layer header.
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return uint16(e.hdrSize)
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (e *endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.addr
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ if e.hdrSize > 0 {
+ // Add ethernet header if needed.
+ eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize))
+ eth.Encode(&header.EthernetFields{
+ DstAddr: r.RemoteLinkAddress,
+ SrcAddr: e.addr,
+ Type: protocol,
+ })
+ }
+
+ if len(payload) == 0 {
+ return rawfile.NonBlockingWrite(e.fd, hdr.UsedBytes())
+
+ }
+
+ return rawfile.NonBlockingWrite2(e.fd, hdr.UsedBytes(), payload)
+}
+
+func (e *endpoint) capViews(n int, buffers []int) int {
+ c := 0
+ for i, s := range buffers {
+ c += s
+ if c >= n {
+ e.views[i].CapLength(s - (c - n))
+ return i + 1
+ }
+ }
+ return len(buffers)
+}
+
+func (e *endpoint) allocateViews(bufConfig []int) {
+ for i, v := range e.views {
+ if v != nil {
+ break
+ }
+ b := buffer.NewView(bufConfig[i])
+ e.views[i] = b
+ e.iovecs[i] = syscall.Iovec{
+ Base: &b[0],
+ Len: uint64(len(b)),
+ }
+ }
+}
+
+// dispatch reads one packet from the file descriptor and dispatches it.
+func (e *endpoint) dispatch(d stack.NetworkDispatcher, largeV buffer.View) (bool, *tcpip.Error) {
+ e.allocateViews(BufConfig)
+
+ n, err := rawfile.BlockingReadv(e.fd, e.iovecs)
+ if err != nil {
+ return false, err
+ }
+
+ if n <= e.hdrSize {
+ return false, nil
+ }
+
+ var p tcpip.NetworkProtocolNumber
+ var addr tcpip.LinkAddress
+ if e.hdrSize > 0 {
+ eth := header.Ethernet(e.views[0])
+ p = eth.Type()
+ addr = eth.SourceAddress()
+ } else {
+ // We don't get any indication of what the packet is, so try to guess
+ // if it's an IPv4 or IPv6 packet.
+ switch header.IPVersion(e.views[0]) {
+ case header.IPv4Version:
+ p = header.IPv4ProtocolNumber
+ case header.IPv6Version:
+ p = header.IPv6ProtocolNumber
+ default:
+ return true, nil
+ }
+ }
+
+ used := e.capViews(n, BufConfig)
+ e.vv.SetViews(e.views[:used])
+ e.vv.SetSize(n)
+ e.vv.TrimFront(e.hdrSize)
+
+ d.DeliverNetworkPacket(e, addr, p, e.vv)
+
+ // Prepare e.views for another packet: release used views.
+ for i := 0; i < used; i++ {
+ e.views[i] = nil
+ }
+
+ return true, nil
+}
+
+// dispatchLoop reads packets from the file descriptor in a loop and dispatches
+// them to the network stack.
+func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) *tcpip.Error {
+ v := buffer.NewView(header.MaxIPPacketSize)
+ for {
+ cont, err := e.dispatch(d, v)
+ if err != nil || !cont {
+ if e.closed != nil {
+ e.closed(err)
+ }
+ return err
+ }
+ }
+}
+
+// InjectableEndpoint is an injectable fd-based endpoint. The endpoint writes
+// to the FD, but does not read from it. All reads come from injected packets.
+type InjectableEndpoint struct {
+ endpoint
+
+ dispatcher stack.NetworkDispatcher
+}
+
+// Attach saves the stack network-layer dispatcher for use later when packets
+// are injected.
+func (e *InjectableEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// Inject injects an inbound packet.
+func (e *InjectableEndpoint) Inject(protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
+ e.dispatcher.DeliverNetworkPacket(e, "", protocol, vv)
+}
+
+// NewInjectable creates a new fd-based InjectableEndpoint.
+func NewInjectable(fd int, mtu uint32) (tcpip.LinkEndpointID, *InjectableEndpoint) {
+ syscall.SetNonblock(fd, true)
+
+ e := &InjectableEndpoint{endpoint: endpoint{
+ fd: fd,
+ mtu: mtu,
+ }}
+
+ return stack.RegisterLinkEndpoint(e), e
+}
diff --git a/pkg/tcpip/link/fdbased/endpoint_test.go b/pkg/tcpip/link/fdbased/endpoint_test.go
new file mode 100644
index 000000000..f7bbb28e1
--- /dev/null
+++ b/pkg/tcpip/link/fdbased/endpoint_test.go
@@ -0,0 +1,336 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fdbased
+
+import (
+ "fmt"
+ "math/rand"
+ "reflect"
+ "syscall"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+type packetInfo struct {
+ raddr tcpip.LinkAddress
+ proto tcpip.NetworkProtocolNumber
+ contents buffer.View
+}
+
+type context struct {
+ t *testing.T
+ fds [2]int
+ ep stack.LinkEndpoint
+ ch chan packetInfo
+ done chan struct{}
+}
+
+func newContext(t *testing.T, opt *Options) *context {
+ fds, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_SEQPACKET, 0)
+ if err != nil {
+ t.Fatalf("Socketpair failed: %v", err)
+ }
+
+ done := make(chan struct{}, 1)
+ opt.ClosedFunc = func(*tcpip.Error) {
+ done <- struct{}{}
+ }
+
+ opt.FD = fds[1]
+ ep := stack.FindLinkEndpoint(New(opt)).(*endpoint)
+
+ c := &context{
+ t: t,
+ fds: fds,
+ ep: ep,
+ ch: make(chan packetInfo, 100),
+ done: done,
+ }
+
+ ep.Attach(c)
+
+ return c
+}
+
+func (c *context) cleanup() {
+ syscall.Close(c.fds[0])
+ <-c.done
+ syscall.Close(c.fds[1])
+}
+
+func (c *context) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
+ c.ch <- packetInfo{remoteLinkAddr, protocol, vv.ToView()}
+}
+
+func TestNoEthernetProperties(t *testing.T) {
+ const mtu = 1500
+ c := newContext(t, &Options{MTU: mtu})
+ defer c.cleanup()
+
+ if want, v := uint16(0), c.ep.MaxHeaderLength(); want != v {
+ t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
+ }
+
+ if want, v := uint32(mtu), c.ep.MTU(); want != v {
+ t.Fatalf("MTU() = %v, want %v", v, want)
+ }
+}
+
+func TestEthernetProperties(t *testing.T) {
+ const mtu = 1500
+ c := newContext(t, &Options{EthernetHeader: true, MTU: mtu})
+ defer c.cleanup()
+
+ if want, v := uint16(header.EthernetMinimumSize), c.ep.MaxHeaderLength(); want != v {
+ t.Fatalf("MaxHeaderLength() = %v, want %v", v, want)
+ }
+
+ if want, v := uint32(mtu), c.ep.MTU(); want != v {
+ t.Fatalf("MTU() = %v, want %v", v, want)
+ }
+}
+
+func TestAddress(t *testing.T) {
+ const mtu = 1500
+ addrs := []tcpip.LinkAddress{"", "abc", "def"}
+ for _, a := range addrs {
+ t.Run(fmt.Sprintf("Address: %q", a), func(t *testing.T) {
+ c := newContext(t, &Options{Address: a, MTU: mtu})
+ defer c.cleanup()
+
+ if want, v := a, c.ep.LinkAddress(); want != v {
+ t.Fatalf("LinkAddress() = %v, want %v", v, want)
+ }
+ })
+ }
+}
+
+func TestWritePacket(t *testing.T) {
+ const (
+ mtu = 1500
+ laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66")
+ raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc")
+ proto = 10
+ )
+
+ lengths := []int{0, 100, 1000}
+ eths := []bool{true, false}
+
+ for _, eth := range eths {
+ for _, plen := range lengths {
+ t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) {
+ c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth})
+ defer c.cleanup()
+
+ r := &stack.Route{
+ RemoteLinkAddress: raddr,
+ }
+
+ // Build header.
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()) + 100)
+ b := hdr.Prepend(100)
+ for i := range b {
+ b[i] = uint8(rand.Intn(256))
+ }
+
+ // Buiild payload and write.
+ payload := make([]byte, plen)
+ for i := range payload {
+ payload[i] = uint8(rand.Intn(256))
+ }
+ want := append(hdr.UsedBytes(), payload...)
+ if err := c.ep.WritePacket(r, &hdr, payload, proto); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+
+ // Read from fd, then compare with what we wrote.
+ b = make([]byte, mtu)
+ n, err := syscall.Read(c.fds[0], b)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+ b = b[:n]
+ if eth {
+ h := header.Ethernet(b)
+ b = b[header.EthernetMinimumSize:]
+
+ if a := h.SourceAddress(); a != laddr {
+ t.Fatalf("SourceAddress() = %v, want %v", a, laddr)
+ }
+
+ if a := h.DestinationAddress(); a != raddr {
+ t.Fatalf("DestinationAddress() = %v, want %v", a, raddr)
+ }
+
+ if et := h.Type(); et != proto {
+ t.Fatalf("Type() = %v, want %v", et, proto)
+ }
+ }
+ if len(b) != len(want) {
+ t.Fatalf("Read returned %v bytes, want %v", len(b), len(want))
+ }
+ if !reflect.DeepEqual(b, want) {
+ t.Fatalf("Read returned %x, want %x", b, want)
+ }
+ })
+ }
+ }
+}
+
+func TestDeliverPacket(t *testing.T) {
+ const (
+ mtu = 1500
+ laddr = tcpip.LinkAddress("\x11\x22\x33\x44\x55\x66")
+ raddr = tcpip.LinkAddress("\x77\x88\x99\xaa\xbb\xcc")
+ proto = 10
+ )
+
+ lengths := []int{100, 1000}
+ eths := []bool{true, false}
+
+ for _, eth := range eths {
+ for _, plen := range lengths {
+ t.Run(fmt.Sprintf("Eth=%v,PayloadLen=%v", eth, plen), func(t *testing.T) {
+ c := newContext(t, &Options{Address: laddr, MTU: mtu, EthernetHeader: eth})
+ defer c.cleanup()
+
+ // Build packet.
+ b := make([]byte, plen)
+ all := b
+ for i := range b {
+ b[i] = uint8(rand.Intn(256))
+ }
+
+ if !eth {
+ // So that it looks like an IPv4 packet.
+ b[0] = 0x40
+ } else {
+ hdr := make(header.Ethernet, header.EthernetMinimumSize)
+ hdr.Encode(&header.EthernetFields{
+ SrcAddr: raddr,
+ DstAddr: laddr,
+ Type: proto,
+ })
+ all = append(hdr, b...)
+ }
+
+ // Write packet via the file descriptor.
+ if _, err := syscall.Write(c.fds[0], all); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ // Receive packet through the endpoint.
+ select {
+ case pi := <-c.ch:
+ want := packetInfo{
+ raddr: raddr,
+ proto: proto,
+ contents: b,
+ }
+ if !eth {
+ want.proto = header.IPv4ProtocolNumber
+ want.raddr = ""
+ }
+ if !reflect.DeepEqual(want, pi) {
+ t.Fatalf("Unexpected received packet: %+v, want %+v", pi, want)
+ }
+ case <-time.After(10 * time.Second):
+ t.Fatalf("Timed out waiting for packet")
+ }
+ })
+ }
+ }
+}
+
+func TestBufConfigMaxLength(t *testing.T) {
+ got := 0
+ for _, i := range BufConfig {
+ got += i
+ }
+ want := header.MaxIPPacketSize // maximum TCP packet size
+ if got < want {
+ t.Errorf("total buffer size is invalid: got %d, want >= %d", got, want)
+ }
+}
+
+func TestBufConfigFirst(t *testing.T) {
+ // The stack assumes that the TCP/IP header is enterily contained in the first view.
+ // Therefore, the first view needs to be large enough to contain the maximum TCP/IP
+ // header, which is 120 bytes (60 bytes for IP + 60 bytes for TCP).
+ want := 120
+ got := BufConfig[0]
+ if got < want {
+ t.Errorf("first view has an invalid size: got %d, want >= %d", got, want)
+ }
+}
+
+func build(bufConfig []int) *endpoint {
+ e := &endpoint{
+ views: make([]buffer.View, len(bufConfig)),
+ iovecs: make([]syscall.Iovec, len(bufConfig)),
+ }
+ e.allocateViews(bufConfig)
+ return e
+}
+
+var capLengthTestCases = []struct {
+ comment string
+ config []int
+ n int
+ wantUsed int
+ wantLengths []int
+}{
+ {
+ comment: "Single slice",
+ config: []int{2},
+ n: 1,
+ wantUsed: 1,
+ wantLengths: []int{1},
+ },
+ {
+ comment: "Multiple slices",
+ config: []int{1, 2},
+ n: 2,
+ wantUsed: 2,
+ wantLengths: []int{1, 1},
+ },
+ {
+ comment: "Entire buffer",
+ config: []int{1, 2},
+ n: 3,
+ wantUsed: 2,
+ wantLengths: []int{1, 2},
+ },
+ {
+ comment: "Entire buffer but not on the last slice",
+ config: []int{1, 2, 3},
+ n: 3,
+ wantUsed: 2,
+ wantLengths: []int{1, 2, 3},
+ },
+}
+
+func TestCapLength(t *testing.T) {
+ for _, c := range capLengthTestCases {
+ e := build(c.config)
+ used := e.capViews(c.n, c.config)
+ if used != c.wantUsed {
+ t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %d. Want %d", c.comment, c.n, c.config, used, c.wantUsed)
+ }
+ lengths := make([]int, len(e.views))
+ for i, v := range e.views {
+ lengths[i] = len(v)
+ }
+ if !reflect.DeepEqual(lengths, c.wantLengths) {
+ t.Errorf("Test \"%s\" failed when calling capViews(%d, %v). Got %v. Want %v", c.comment, c.n, c.config, lengths, c.wantLengths)
+ }
+
+ }
+}
diff --git a/pkg/tcpip/link/loopback/BUILD b/pkg/tcpip/link/loopback/BUILD
new file mode 100644
index 000000000..b454d0839
--- /dev/null
+++ b/pkg/tcpip/link/loopback/BUILD
@@ -0,0 +1,15 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "loopback",
+ srcs = ["loopback.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/loopback/loopback.go b/pkg/tcpip/link/loopback/loopback.go
new file mode 100644
index 000000000..1a9cd09d7
--- /dev/null
+++ b/pkg/tcpip/link/loopback/loopback.go
@@ -0,0 +1,74 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package loopback provides the implemention of loopback data-link layer
+// endpoints. Such endpoints just turn outbound packets into inbound ones.
+//
+// Loopback endpoints can be used in the networking stack by calling New() to
+// create a new endpoint, and then passing it as an argument to
+// Stack.CreateNIC().
+package loopback
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+type endpoint struct {
+ dispatcher stack.NetworkDispatcher
+}
+
+// New creates a new loopback endpoint. This link-layer endpoint just turns
+// outbound packets into inbound packets.
+func New() tcpip.LinkEndpointID {
+ return stack.RegisterLinkEndpoint(&endpoint{})
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It just saves the stack network-
+// layer dispatcher for later use when packets need to be dispatched.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns a constant that matches the
+// linux loopback interface.
+func (*endpoint) MTU() uint32 {
+ return 65536
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities. Loopback advertises
+// itself as supporting checksum offload, but in reality it's just omitted.
+func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return stack.CapabilityChecksumOffload
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. Given that the
+// loopback interface doesn't have a header, it just returns 0.
+func (*endpoint) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*endpoint) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket. It delivers outbound
+// packets to the network-layer dispatcher.
+func (e *endpoint) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ if len(payload) == 0 {
+ // We don't have a payload, so just use the buffer from the
+ // header as the full packet.
+ v := hdr.View()
+ vv := v.ToVectorisedView([1]buffer.View{})
+ e.dispatcher.DeliverNetworkPacket(e, "", protocol, &vv)
+ } else {
+ views := []buffer.View{hdr.View(), payload}
+ vv := buffer.NewVectorisedView(len(views[0])+len(views[1]), views)
+ e.dispatcher.DeliverNetworkPacket(e, "", protocol, &vv)
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/link/rawfile/BUILD b/pkg/tcpip/link/rawfile/BUILD
new file mode 100644
index 000000000..4c63af0ea
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/BUILD
@@ -0,0 +1,17 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "rawfile",
+ srcs = [
+ "blockingpoll_amd64.s",
+ "errors.go",
+ "rawfile_unsafe.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = ["//pkg/tcpip"],
+)
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_amd64.s b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
new file mode 100644
index 000000000..88206fc87
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/blockingpoll_amd64.s
@@ -0,0 +1,26 @@
+#include "textflag.h"
+
+// blockingPoll makes the poll() syscall while calling the version of
+// entersyscall that relinquishes the P so that other Gs can run. This is meant
+// to be called in cases when the syscall is expected to block.
+//
+// func blockingPoll(fds unsafe.Pointer, nfds int, timeout int64) (n int, err syscall.Errno)
+TEXT ·blockingPoll(SB),NOSPLIT,$0-40
+ CALL runtime·entersyscallblock(SB)
+ MOVQ fds+0(FP), DI
+ MOVQ nfds+8(FP), SI
+ MOVQ timeout+16(FP), DX
+ MOVQ $0x7, AX // SYS_POLL
+ SYSCALL
+ CMPQ AX, $0xfffffffffffff001
+ JLS ok
+ MOVQ $-1, n+24(FP)
+ NEGQ AX
+ MOVQ AX, err+32(FP)
+ CALL runtime·exitsyscall(SB)
+ RET
+ok:
+ MOVQ AX, n+24(FP)
+ MOVQ $0, err+32(FP)
+ CALL runtime·exitsyscall(SB)
+ RET
diff --git a/pkg/tcpip/link/rawfile/errors.go b/pkg/tcpip/link/rawfile/errors.go
new file mode 100644
index 000000000..b6e7b3d71
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/errors.go
@@ -0,0 +1,40 @@
+package rawfile
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+var translations = map[syscall.Errno]*tcpip.Error{
+ syscall.EEXIST: tcpip.ErrDuplicateAddress,
+ syscall.ENETUNREACH: tcpip.ErrNoRoute,
+ syscall.EINVAL: tcpip.ErrInvalidEndpointState,
+ syscall.EALREADY: tcpip.ErrAlreadyConnecting,
+ syscall.EISCONN: tcpip.ErrAlreadyConnected,
+ syscall.EADDRINUSE: tcpip.ErrPortInUse,
+ syscall.EADDRNOTAVAIL: tcpip.ErrBadLocalAddress,
+ syscall.EPIPE: tcpip.ErrClosedForSend,
+ syscall.EWOULDBLOCK: tcpip.ErrWouldBlock,
+ syscall.ECONNREFUSED: tcpip.ErrConnectionRefused,
+ syscall.ETIMEDOUT: tcpip.ErrTimeout,
+ syscall.EINPROGRESS: tcpip.ErrConnectStarted,
+ syscall.EDESTADDRREQ: tcpip.ErrDestinationRequired,
+ syscall.ENOTSUP: tcpip.ErrNotSupported,
+ syscall.ENOTTY: tcpip.ErrQueueSizeNotSupported,
+ syscall.ENOTCONN: tcpip.ErrNotConnected,
+ syscall.ECONNRESET: tcpip.ErrConnectionReset,
+ syscall.ECONNABORTED: tcpip.ErrConnectionAborted,
+}
+
+// TranslateErrno translate an errno from the syscall package into a
+// *tcpip.Error.
+//
+// Not all errnos are supported and this function will panic on unreconized
+// errnos.
+func TranslateErrno(e syscall.Errno) *tcpip.Error {
+ if err, ok := translations[e]; ok {
+ return err
+ }
+ return tcpip.ErrInvalidEndpointState
+}
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
new file mode 100644
index 000000000..d3660e1b4
--- /dev/null
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -0,0 +1,161 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package rawfile contains utilities for using the netstack with raw host
+// files on Linux hosts.
+package rawfile
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+//go:noescape
+func blockingPoll(fds unsafe.Pointer, nfds int, timeout int64) (n int, err syscall.Errno)
+
+// GetMTU determines the MTU of a network interface device.
+func GetMTU(name string) (uint32, error) {
+ fd, err := syscall.Socket(syscall.AF_UNIX, syscall.SOCK_DGRAM, 0)
+ if err != nil {
+ return 0, err
+ }
+
+ defer syscall.Close(fd)
+
+ var ifreq struct {
+ name [16]byte
+ mtu int32
+ _ [20]byte
+ }
+
+ copy(ifreq.name[:], name)
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.SIOCGIFMTU, uintptr(unsafe.Pointer(&ifreq)))
+ if errno != 0 {
+ return 0, errno
+ }
+
+ return uint32(ifreq.mtu), nil
+}
+
+// NonBlockingWrite writes the given buffer to a file descriptor. It fails if
+// partial data is written.
+func NonBlockingWrite(fd int, buf []byte) *tcpip.Error {
+ var ptr unsafe.Pointer
+ if len(buf) > 0 {
+ ptr = unsafe.Pointer(&buf[0])
+ }
+
+ _, _, e := syscall.RawSyscall(syscall.SYS_WRITE, uintptr(fd), uintptr(ptr), uintptr(len(buf)))
+ if e != 0 {
+ return TranslateErrno(e)
+ }
+
+ return nil
+}
+
+// NonBlockingWrite2 writes up to two byte slices to a file descriptor in a
+// single syscall. It fails if partial data is written.
+func NonBlockingWrite2(fd int, b1, b2 []byte) *tcpip.Error {
+ // If the is no second buffer, issue a regular write.
+ if len(b2) == 0 {
+ return NonBlockingWrite(fd, b1)
+ }
+
+ // We have two buffers. Build the iovec that represents them and issue
+ // a writev syscall.
+ iovec := [...]syscall.Iovec{
+ {
+ Base: &b1[0],
+ Len: uint64(len(b1)),
+ },
+ {
+ Base: &b2[0],
+ Len: uint64(len(b2)),
+ },
+ }
+
+ _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec)))
+ if e != 0 {
+ return TranslateErrno(e)
+ }
+
+ return nil
+}
+
+// NonBlockingWriteN writes up to N byte slices to a file descriptor in a
+// single syscall. It fails if partial data is written.
+func NonBlockingWriteN(fd int, bs ...[]byte) *tcpip.Error {
+ iovec := make([]syscall.Iovec, 0, len(bs))
+
+ for _, b := range bs {
+ if len(b) == 0 {
+ continue
+ }
+ iovec = append(iovec, syscall.Iovec{
+ Base: &b[0],
+ Len: uint64(len(b)),
+ })
+ }
+
+ _, _, e := syscall.RawSyscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&iovec[0])), uintptr(len(iovec)))
+ if e != 0 {
+ return TranslateErrno(e)
+ }
+
+ return nil
+}
+
+// BlockingRead reads from a file descriptor that is set up as non-blocking. If
+// no data is available, it will block in a poll() syscall until the file
+// descirptor becomes readable.
+func BlockingRead(fd int, b []byte) (int, *tcpip.Error) {
+ for {
+ n, _, e := syscall.RawSyscall(syscall.SYS_READ, uintptr(fd), uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)))
+ if e == 0 {
+ return int(n), nil
+ }
+
+ event := struct {
+ fd int32
+ events int16
+ revents int16
+ }{
+ fd: int32(fd),
+ events: 1, // POLLIN
+ }
+
+ _, e = blockingPoll(unsafe.Pointer(&event), 1, -1)
+ if e != 0 && e != syscall.EINTR {
+ return 0, TranslateErrno(e)
+ }
+ }
+}
+
+// BlockingReadv reads from a file descriptor that is set up as non-blocking and
+// stores the data in a list of iovecs buffers. If no data is available, it will
+// block in a poll() syscall until the file descirptor becomes readable.
+func BlockingReadv(fd int, iovecs []syscall.Iovec) (int, *tcpip.Error) {
+ for {
+ n, _, e := syscall.RawSyscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
+ if e == 0 {
+ return int(n), nil
+ }
+
+ event := struct {
+ fd int32
+ events int16
+ revents int16
+ }{
+ fd: int32(fd),
+ events: 1, // POLLIN
+ }
+
+ _, e = blockingPoll(unsafe.Pointer(&event), 1, -1)
+ if e != 0 && e != syscall.EINTR {
+ return 0, TranslateErrno(e)
+ }
+ }
+}
diff --git a/pkg/tcpip/link/sharedmem/BUILD b/pkg/tcpip/link/sharedmem/BUILD
new file mode 100644
index 000000000..a4a965924
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/BUILD
@@ -0,0 +1,42 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "sharedmem",
+ srcs = [
+ "rx.go",
+ "sharedmem.go",
+ "sharedmem_unsafe.go",
+ "tx.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem",
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//pkg/log",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/link/sharedmem/queue",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "sharedmem_test",
+ srcs = [
+ "sharedmem_test.go",
+ ],
+ embed = [":sharedmem"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/sharedmem/pipe",
+ "//pkg/tcpip/link/sharedmem/queue",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/sharedmem/pipe/BUILD b/pkg/tcpip/link/sharedmem/pipe/BUILD
new file mode 100644
index 000000000..e8d795500
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/BUILD
@@ -0,0 +1,23 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "pipe",
+ srcs = [
+ "pipe.go",
+ "pipe_unsafe.go",
+ "rx.go",
+ "tx.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/pipe",
+ visibility = ["//:sandbox"],
+)
+
+go_test(
+ name = "pipe_test",
+ srcs = [
+ "pipe_test.go",
+ ],
+ embed = [":pipe"],
+)
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe.go b/pkg/tcpip/link/sharedmem/pipe/pipe.go
new file mode 100644
index 000000000..1173a60da
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe.go
@@ -0,0 +1,68 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package pipe implements a shared memory ring buffer on which a single reader
+// and a single writer can operate (read/write) concurrently. The ring buffer
+// allows for data of different sizes to be written, and preserves the boundary
+// of the written data.
+//
+// Example usage is as follows:
+//
+// wb := t.Push(20)
+// // Write data to wb.
+// t.Flush()
+//
+// rb := r.Pull()
+// // Do something with data in rb.
+// t.Flush()
+package pipe
+
+import (
+ "math"
+)
+
+const (
+ jump uint64 = math.MaxUint32 + 1
+ offsetMask uint64 = math.MaxUint32
+ revolutionMask uint64 = ^offsetMask
+
+ sizeOfSlotHeader = 8 // sizeof(uint64)
+ slotFree uint64 = 1 << 63
+ slotSizeMask uint64 = math.MaxUint32
+)
+
+// payloadToSlotSize calculates the total size of a slot based on its payload
+// size. The total size is the header size, plus the payload size, plus padding
+// if necessary to make the total size a multiple of sizeOfSlotHeader.
+func payloadToSlotSize(payloadSize uint64) uint64 {
+ s := sizeOfSlotHeader + payloadSize
+ return (s + sizeOfSlotHeader - 1) &^ (sizeOfSlotHeader - 1)
+}
+
+// slotToPayloadSize calculates the payload size of a slot based on the total
+// size of the slot. This is only meant to be used when creating slots that
+// don't carry information (e.g., free slots or wrap slots).
+func slotToPayloadSize(offset uint64) uint64 {
+ return offset - sizeOfSlotHeader
+}
+
+// pipe is a basic data structure used by both (transmit & receive) ends of a
+// pipe. Indices into this pipe are split into two fields: offset, which counts
+// the number of bytes from the beginning of the buffer, and revolution, which
+// counts the number of times the index has wrapped around.
+type pipe struct {
+ buffer []byte
+}
+
+// init initializes the pipe buffer such that its size is a multiple of the size
+// of the slot header.
+func (p *pipe) init(b []byte) {
+ p.buffer = b[:len(b)&^(sizeOfSlotHeader-1)]
+}
+
+// data returns a section of the buffer starting at the given index (which may
+// include revolution information) and with the given size.
+func (p *pipe) data(idx uint64, size uint64) []byte {
+ return p.buffer[(idx&offsetMask)+sizeOfSlotHeader:][:size]
+}
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
new file mode 100644
index 000000000..441ff5b25
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go
@@ -0,0 +1,507 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package pipe
+
+import (
+ "math/rand"
+ "reflect"
+ "runtime"
+ "sync"
+ "testing"
+)
+
+func TestSimpleReadWrite(t *testing.T) {
+ // Check that a simple write can be properly read from the rx side.
+ tr := rand.New(rand.NewSource(99))
+ rr := rand.New(rand.NewSource(99))
+
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ wb := tx.Push(10)
+ if wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ for i := range wb {
+ wb[i] = byte(tr.Intn(256))
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ rb := rx.Pull()
+ if len(rb) != 10 {
+ t.Fatalf("Bad buffer size returned: got %v, want %v", len(rb), 10)
+ }
+
+ for i := range rb {
+ if v := byte(rr.Intn(256)); v != rb[i] {
+ t.Fatalf("Bad read buffer at index %v: got %v, want %v", i, rb[i], v)
+ }
+ }
+ rx.Flush()
+}
+
+func TestEmptyRead(t *testing.T) {
+ // Check that pulling from an empty pipe fails.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+}
+
+func TestTooLargeWrite(t *testing.T) {
+ // Check that writes that are too large are properly rejected.
+ b := make([]byte, 96)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(96); wb != nil {
+ t.Fatalf("Write of 96 bytes succeeded on 96-byte pipe")
+ }
+
+ if wb := tx.Push(88); wb != nil {
+ t.Fatalf("Write of 88 bytes succeeded on 96-byte pipe")
+ }
+
+ if wb := tx.Push(80); wb == nil {
+ t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
+ }
+}
+
+func TestFullWrite(t *testing.T) {
+ // Check that writes fail when the pipe is full.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(80); wb == nil {
+ t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
+ }
+
+ if wb := tx.Push(1); wb != nil {
+ t.Fatalf("Write succeeded on full pipe")
+ }
+}
+
+func TestFullAndFlushedWrite(t *testing.T) {
+ // Check that writes fail when the pipe is full and has already been
+ // flushed.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(80); wb == nil {
+ t.Fatalf("Write of 80 bytes failed on 96-byte pipe")
+ }
+
+ tx.Flush()
+
+ if wb := tx.Push(1); wb != nil {
+ t.Fatalf("Write succeeded on full pipe")
+ }
+}
+
+func TestTxFlushTwice(t *testing.T) {
+ // Checks that a second consecutive tx flush is a no-op.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ // Make copy of original tx queue, flush it, then check that it didn't
+ // change.
+ orig := tx
+ tx.Flush()
+
+ if !reflect.DeepEqual(orig, tx) {
+ t.Fatalf("Flush mutated tx pipe: got %v, want %v", tx, orig)
+ }
+}
+
+func TestRxFlushTwice(t *testing.T) {
+ // Checks that a second consecutive rx flush is a no-op.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // Make copy of original rx queue, flush it, then check that it didn't
+ // change.
+ orig := rx
+ rx.Flush()
+
+ if !reflect.DeepEqual(orig, rx) {
+ t.Fatalf("Flush mutated rx pipe: got %v, want %v", rx, orig)
+ }
+}
+
+func TestWrapInMiddleOfTransaction(t *testing.T) {
+ // Check that writes are not flushed when we need to wrap the buffer
+ // around.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // At this point the ring buffer is empty, but the write is at offset
+ // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment).
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ // We haven't flushed yet, so pull must return nil.
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on non-flushed pipe")
+ }
+
+ tx.Flush()
+
+ // The two buffers must be available now.
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+}
+
+func TestWriteAbort(t *testing.T) {
+ // Check that a read fails on a pipe that has had data pushed to it but
+ // has aborted the push.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Write failed on empty pipe")
+ }
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+
+ tx.Abort()
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+}
+
+func TestWrappedWriteAbort(t *testing.T) {
+ // Check that writes are properly aborted even if the writes wrap
+ // around.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // At this point the ring buffer is empty, but the write is at offset
+ // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment).
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ // We haven't flushed yet, so pull must return nil.
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on non-flushed pipe")
+ }
+
+ tx.Abort()
+
+ // The pushes were aborted, so no data should be readable.
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on non-flushed pipe")
+ }
+
+ // Try the same transactions again, but flush this time.
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ tx.Flush()
+
+ // The two buffers must be available now.
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+}
+
+func TestEmptyReadOnNonFlushedWrite(t *testing.T) {
+ // Check that a read fails on a pipe that has had data pushed to it
+ // but not yet flushed.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Write failed on empty pipe")
+ }
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+
+ tx.Flush()
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull on failed on non-empty pipe")
+ }
+}
+
+func TestPullAfterPullingEntirePipe(t *testing.T) {
+ // Check that Pull fails when the pipe is full, but all of it has
+ // already been pulled but not yet flushed.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // At this point the ring buffer is empty, but the write is at offset
+ // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 3
+ // buffers that will fill the pipe.
+ if wb := tx.Push(10); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ if wb := tx.Push(20); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ if wb := tx.Push(24); wb == nil {
+ t.Fatalf("Push failed on non-full pipe")
+ }
+
+ tx.Flush()
+
+ // The three buffers must be available now.
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+
+ // Fourth pull must fail.
+ if rb := rx.Pull(); rb != nil {
+ t.Fatalf("Pull succeeded on empty pipe")
+ }
+}
+
+func TestNoRoomToWrapOnPush(t *testing.T) {
+ // Check that Push fails when it tries to allocate room to add a wrap
+ // message.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ var rx Rx
+ rx.Init(b)
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // At this point the ring buffer is empty, but the write is at offset
+ // 64 (50 + sizeOfSlotHeader + padding-for-8-byte-alignment). Write 20,
+ // which won't fit (64+20+8+padding = 96, which wouldn't leave room for
+ // the padding), so it wraps around.
+ if wb := tx.Push(20); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+
+ tx.Flush()
+
+ // Buffer offset is at 28. Try to write 70, which would require a wrap
+ // slot which cannot be created now.
+ if wb := tx.Push(70); wb != nil {
+ t.Fatalf("Push succeeded on pipe with no room for wrap message")
+ }
+}
+
+func TestRxImplicitFlushOfWrapMessage(t *testing.T) {
+ // Check if the first read is that of a wrapping message, that it gets
+ // immediately flushed.
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ if wb := tx.Push(50); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+ tx.Flush()
+
+ // This will cause a wrapping message to written.
+ if wb := tx.Push(60); wb != nil {
+ t.Fatalf("Push succeeded when there is no room in pipe")
+ }
+
+ var rx Rx
+ rx.Init(b)
+
+ // Read the first message.
+ if rb := rx.Pull(); rb == nil {
+ t.Fatalf("Pull failed on non-empty pipe")
+ }
+ rx.Flush()
+
+ // This should fail because of the wrapping message is taking up space.
+ if wb := tx.Push(60); wb != nil {
+ t.Fatalf("Push succeeded when there is no room in pipe")
+ }
+
+ // Try to read the next one. This should consume the wrapping message.
+ rx.Pull()
+
+ // This must now succeed.
+ if wb := tx.Push(60); wb == nil {
+ t.Fatalf("Push failed on empty pipe")
+ }
+}
+
+func TestConcurrentReaderWriter(t *testing.T) {
+ // Push a million buffers of random sizes and random contents. Check
+ // that buffers read match what was written.
+ tr := rand.New(rand.NewSource(99))
+ rr := rand.New(rand.NewSource(99))
+
+ b := make([]byte, 100)
+ var tx Tx
+ tx.Init(b)
+
+ var rx Rx
+ rx.Init(b)
+
+ const count = 1000000
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ runtime.Gosched()
+ for i := 0; i < count; i++ {
+ n := 1 + tr.Intn(80)
+ wb := tx.Push(uint64(n))
+ for wb == nil {
+ wb = tx.Push(uint64(n))
+ }
+
+ for j := range wb {
+ wb[j] = byte(tr.Intn(256))
+ }
+
+ tx.Flush()
+ }
+ }()
+
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ runtime.Gosched()
+ for i := 0; i < count; i++ {
+ n := 1 + rr.Intn(80)
+ rb := rx.Pull()
+ for rb == nil {
+ rb = rx.Pull()
+ }
+
+ if n != len(rb) {
+ t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n)
+ }
+
+ for j := range rb {
+ if v := byte(rr.Intn(256)); v != rb[j] {
+ t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v)
+ }
+ }
+
+ rx.Flush()
+ }
+ }()
+
+ wg.Wait()
+}
diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go
new file mode 100644
index 000000000..d536abedf
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/pipe_unsafe.go
@@ -0,0 +1,25 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package pipe
+
+import (
+ "sync/atomic"
+ "unsafe"
+)
+
+func (p *pipe) write(idx uint64, v uint64) {
+ ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0]))
+ *ptr = v
+}
+
+func (p *pipe) writeAtomic(idx uint64, v uint64) {
+ ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0]))
+ atomic.StoreUint64(ptr, v)
+}
+
+func (p *pipe) readAtomic(idx uint64) uint64 {
+ ptr := (*uint64)(unsafe.Pointer(&p.buffer[idx&offsetMask:][:8][0]))
+ return atomic.LoadUint64(ptr)
+}
diff --git a/pkg/tcpip/link/sharedmem/pipe/rx.go b/pkg/tcpip/link/sharedmem/pipe/rx.go
new file mode 100644
index 000000000..261e21f9e
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/rx.go
@@ -0,0 +1,83 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package pipe
+
+// Rx is the receive side of the shared memory ring buffer.
+type Rx struct {
+ p pipe
+
+ tail uint64
+ head uint64
+}
+
+// Init initializes the receive end of the pipe. In the initial state, the next
+// slot to be inspected is the very first one.
+func (r *Rx) Init(b []byte) {
+ r.p.init(b)
+ r.tail = 0xfffffffe * jump
+ r.head = r.tail
+}
+
+// Pull reads the next buffer from the pipe, returning nil if there isn't one
+// currently available.
+//
+// The returned slice is available until Flush() is next called. After that, it
+// must not be touched.
+func (r *Rx) Pull() []byte {
+ if r.head == r.tail+jump {
+ // We've already pulled the whole pipe.
+ return nil
+ }
+
+ header := r.p.readAtomic(r.head)
+ if header&slotFree != 0 {
+ // The next slot is free, we can't pull it yet.
+ return nil
+ }
+
+ payloadSize := header & slotSizeMask
+ newHead := r.head + payloadToSlotSize(payloadSize)
+ headWrap := (r.head & revolutionMask) | uint64(len(r.p.buffer))
+
+ // Check if this is a wrapping slot. If that's the case, it carries no
+ // data, so we just skip it and try again from the first slot.
+ if int64(newHead-headWrap) >= 0 {
+ if int64(newHead-headWrap) > int64(jump) || newHead&offsetMask != 0 {
+ return nil
+ }
+
+ if r.tail == r.head {
+ // If this is the first pull since the last Flush()
+ // call, we flush the state so that the sender can use
+ // this space if it needs to.
+ r.p.writeAtomic(r.head, slotFree|slotToPayloadSize(newHead-r.head))
+ r.tail = newHead
+ }
+
+ r.head = newHead
+ return r.Pull()
+ }
+
+ // Grab the buffer before updating r.head.
+ b := r.p.data(r.head, payloadSize)
+ r.head = newHead
+ return b
+}
+
+// Flush tells the transmitter that all buffers pulled since the last Flush()
+// have been used, so the transmitter is free to used their slots for further
+// transmission.
+func (r *Rx) Flush() {
+ if r.head == r.tail {
+ return
+ }
+ r.p.writeAtomic(r.tail, slotFree|slotToPayloadSize(r.head-r.tail))
+ r.tail = r.head
+}
+
+// Bytes returns the byte slice on which the pipe operates.
+func (r *Rx) Bytes() []byte {
+ return r.p.buffer
+}
diff --git a/pkg/tcpip/link/sharedmem/pipe/tx.go b/pkg/tcpip/link/sharedmem/pipe/tx.go
new file mode 100644
index 000000000..374f515ab
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/pipe/tx.go
@@ -0,0 +1,151 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package pipe
+
+// Tx is the transmit side of the shared memory ring buffer.
+type Tx struct {
+ p pipe
+ maxPayloadSize uint64
+
+ head uint64
+ tail uint64
+ next uint64
+
+ tailHeader uint64
+}
+
+// Init initializes the transmit end of the pipe. In the initial state, the next
+// slot to be written is the very first one, and the transmitter has the whole
+// ring buffer available to it.
+func (t *Tx) Init(b []byte) {
+ t.p.init(b)
+ // maxPayloadSize excludes the header of the payload, and the header
+ // of the wrapping message.
+ t.maxPayloadSize = uint64(len(t.p.buffer)) - 2*sizeOfSlotHeader
+ t.tail = 0xfffffffe * jump
+ t.next = t.tail
+ t.head = t.tail + jump
+ t.p.write(t.tail, slotFree)
+}
+
+// Capacity determines how many records of the given size can be written to the
+// pipe before it fills up.
+func (t *Tx) Capacity(recordSize uint64) uint64 {
+ available := uint64(len(t.p.buffer)) - sizeOfSlotHeader
+ entryLen := payloadToSlotSize(recordSize)
+ return available / entryLen
+}
+
+// Push reserves "payloadSize" bytes for transmission in the pipe. The caller
+// populates the returned slice with the data to be transferred and enventually
+// calls Flush() to make the data visible to the reader, or Abort() to make the
+// pipe forget all Push() calls since the last Flush().
+//
+// The returned slice is available until Flush() or Abort() is next called.
+// After that, it must not be touched.
+func (t *Tx) Push(payloadSize uint64) []byte {
+ // Fail request if we know we will never have enough room.
+ if payloadSize > t.maxPayloadSize {
+ return nil
+ }
+
+ totalLen := payloadToSlotSize(payloadSize)
+ newNext := t.next + totalLen
+ nextWrap := (t.next & revolutionMask) | uint64(len(t.p.buffer))
+ if int64(newNext-nextWrap) >= 0 {
+ // The new buffer would overflow the pipe, so we push a wrapping
+ // slot, then try to add the actual slot to the front of the
+ // pipe.
+ newNext = (newNext & revolutionMask) + jump
+ wrappingPayloadSize := slotToPayloadSize(newNext - t.next)
+ if !t.reclaim(newNext) {
+ return nil
+ }
+
+ oldNext := t.next
+ t.next = newNext
+ if oldNext != t.tail {
+ t.p.write(oldNext, wrappingPayloadSize)
+ } else {
+ t.tailHeader = wrappingPayloadSize
+ t.Flush()
+ }
+
+ newNext += totalLen
+ }
+
+ // Check that we have enough room for the buffer.
+ if !t.reclaim(newNext) {
+ return nil
+ }
+
+ if t.next != t.tail {
+ t.p.write(t.next, payloadSize)
+ } else {
+ t.tailHeader = payloadSize
+ }
+
+ // Grab the buffer before updating t.next.
+ b := t.p.data(t.next, payloadSize)
+ t.next = newNext
+
+ return b
+}
+
+// reclaim attempts to advance the head until at least newNext. If the head is
+// already at or beyond newNext, nothing happens and true is returned; otherwise
+// it tries to reclaim slots that have already been consumed by the receive end
+// of the pipe (they will be marked as free) and returns a boolean indicating
+// whether it was successful in reclaiming enough slots.
+func (t *Tx) reclaim(newNext uint64) bool {
+ for int64(newNext-t.head) > 0 {
+ // Can't reclaim if slot is not free.
+ header := t.p.readAtomic(t.head)
+ if header&slotFree == 0 {
+ return false
+ }
+
+ payloadSize := header & slotSizeMask
+ newHead := t.head + payloadToSlotSize(payloadSize)
+
+ // Check newHead is within bounds and valid.
+ if int64(newHead-t.tail) > int64(jump) || newHead&offsetMask >= uint64(len(t.p.buffer)) {
+ return false
+ }
+
+ t.head = newHead
+ }
+
+ return true
+}
+
+// Abort causes all Push() calls since the last Flush() to be forgotten and
+// therefore they will not be made visible to the receiver.
+func (t *Tx) Abort() {
+ t.next = t.tail
+}
+
+// Flush causes all buffers pushed since the last Flush() [or Abort(), whichever
+// is the most recent] to be made visible to the receiver.
+func (t *Tx) Flush() {
+ if t.next == t.tail {
+ // Nothing to do if there are no pushed buffers.
+ return
+ }
+
+ if t.next != t.head {
+ // The receiver will spin in t.next, so we must make sure that
+ // the slotFree bit is set.
+ t.p.write(t.next, slotFree)
+ }
+
+ t.p.writeAtomic(t.tail, t.tailHeader)
+ t.tail = t.next
+}
+
+// Bytes returns the byte slice on which the pipe operates.
+func (t *Tx) Bytes() []byte {
+ return t.p.buffer
+}
diff --git a/pkg/tcpip/link/sharedmem/queue/BUILD b/pkg/tcpip/link/sharedmem/queue/BUILD
new file mode 100644
index 000000000..56ea4641d
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queue/BUILD
@@ -0,0 +1,28 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "queue",
+ srcs = [
+ "rx.go",
+ "tx.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/queue",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "//pkg/tcpip/link/sharedmem/pipe",
+ ],
+)
+
+go_test(
+ name = "queue_test",
+ srcs = [
+ "queue_test.go",
+ ],
+ embed = [":queue"],
+ deps = [
+ "//pkg/tcpip/link/sharedmem/pipe",
+ ],
+)
diff --git a/pkg/tcpip/link/sharedmem/queue/queue_test.go b/pkg/tcpip/link/sharedmem/queue/queue_test.go
new file mode 100644
index 000000000..b022c389c
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queue/queue_test.go
@@ -0,0 +1,507 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package queue
+
+import (
+ "encoding/binary"
+ "reflect"
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/pipe"
+)
+
+func TestBasicTxQueue(t *testing.T) {
+ // Tests that a basic transmit on a queue works, and that completion
+ // gets properly reported as well.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Tx
+ q.Init(pb1, pb2)
+
+ // Enqueue two buffers.
+ b := []TxBuffer{
+ {nil, 100, 60},
+ {nil, 200, 40},
+ }
+
+ b[0].Next = &b[1]
+
+ const usedID = 1002
+ const usedTotalSize = 100
+ if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) {
+ t.Fatalf("Enqueue failed on empty queue")
+ }
+
+ // Check the contents of the pipe.
+ d := rxp.Pull()
+ if d == nil {
+ t.Fatalf("Tx pipe is empty after Enqueue")
+ }
+
+ want := []byte{
+ 234, 3, 0, 0, 0, 0, 0, 0, // id
+ 100, 0, 0, 0, // total size
+ 0, 0, 0, 0, // reserved
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 60, 0, 0, 0, // size 1
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 40, 0, 0, 0, // size 2
+ }
+
+ if !reflect.DeepEqual(want, d) {
+ t.Fatalf("Bad posted packet: got %v, want %v", d, want)
+ }
+
+ rxp.Flush()
+
+ // Check that there are no completions yet.
+ if _, ok := q.CompletedPacket(); ok {
+ t.Fatalf("Packet reported as completed too soon")
+ }
+
+ // Post a completion.
+ d = txp.Push(8)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+ binary.LittleEndian.PutUint64(d, usedID)
+ txp.Flush()
+
+ // Check that completion is properly reported.
+ id, ok := q.CompletedPacket()
+ if !ok {
+ t.Fatalf("Completion not reported")
+ }
+
+ if id != usedID {
+ t.Fatalf("Bad completion id: got %v, want %v", id, usedID)
+ }
+}
+
+func TestBasicRxQueue(t *testing.T) {
+ // Tests that a basic receive on a queue works.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Rx
+ q.Init(pb1, pb2, nil)
+
+ // Post two buffers.
+ b := []RxBuffer{
+ {100, 60, 1077, 0},
+ {200, 40, 2123, 0},
+ }
+
+ if !q.PostBuffers(b) {
+ t.Fatalf("PostBuffers failed on empty queue")
+ }
+
+ // Check the contents of the pipe.
+ want := [][]byte{
+ {
+ 100, 0, 0, 0, 0, 0, 0, 0, // Offset1
+ 60, 0, 0, 0, // Size1
+ 0, 0, 0, 0, // Remaining in group 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // User data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+ },
+ {
+ 200, 0, 0, 0, 0, 0, 0, 0, // Offset2
+ 40, 0, 0, 0, // Size2
+ 0, 0, 0, 0, // Remaining in group 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // User data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ },
+ }
+
+ for i := range b {
+ d := rxp.Pull()
+ if d == nil {
+ t.Fatalf("Tx pipe is empty after PostBuffers")
+ }
+
+ if !reflect.DeepEqual(want[i], d) {
+ t.Fatalf("Bad posted packet: got %v, want %v", d, want[i])
+ }
+
+ rxp.Flush()
+ }
+
+ // Check that there are no completions.
+ if _, n := q.Dequeue(nil); n != 0 {
+ t.Fatalf("Packet reported as received too soon")
+ }
+
+ // Post a completion.
+ d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+
+ copy(d, []byte{
+ 100, 0, 0, 0, // packet size
+ 0, 0, 0, 0, // reserved
+
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 60, 0, 0, 0, // size 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 40, 0, 0, 0, // size 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ })
+
+ txp.Flush()
+
+ // Check that completion is properly reported.
+ bufs, n := q.Dequeue(nil)
+ if n != 100 {
+ t.Fatalf("Bad packet size: got %v, want %v", n, 100)
+ }
+
+ if !reflect.DeepEqual(bufs, b) {
+ t.Fatalf("Bad returned buffers: got %v, want %v", bufs, b)
+ }
+}
+
+func TestBadTxCompletion(t *testing.T) {
+ // Check that tx completions with bad sizes are properly ignored.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Tx
+ q.Init(pb1, pb2)
+
+ // Post a completion that is too short, and check that it is ignored.
+ if d := txp.Push(7); d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+ txp.Flush()
+
+ if _, ok := q.CompletedPacket(); ok {
+ t.Fatalf("Bad completion not ignored")
+ }
+
+ // Post a completion that is too long, and check that it is ignored.
+ if d := txp.Push(10); d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+ txp.Flush()
+
+ if _, ok := q.CompletedPacket(); ok {
+ t.Fatalf("Bad completion not ignored")
+ }
+}
+
+func TestBadRxCompletion(t *testing.T) {
+ // Check that bad rx completions are properly ignored.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Rx
+ q.Init(pb1, pb2, nil)
+
+ // Post a completion that is too short, and check that it is ignored.
+ if d := txp.Push(7); d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+ txp.Flush()
+
+ if b, _ := q.Dequeue(nil); b != nil {
+ t.Fatalf("Bad completion not ignored")
+ }
+
+ // Post a completion whose buffer sizes add up to less than the total
+ // size.
+ d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+
+ copy(d, []byte{
+ 100, 0, 0, 0, // packet size
+ 0, 0, 0, 0, // reserved
+
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 10, 0, 0, 0, // size 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 10, 0, 0, 0, // size 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ })
+
+ txp.Flush()
+ if b, _ := q.Dequeue(nil); b != nil {
+ t.Fatalf("Bad completion not ignored")
+ }
+
+ // Post a completion whose buffer sizes will cause a 32-bit overflow,
+ // but adds up to the right number.
+ d = txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+
+ copy(d, []byte{
+ 100, 0, 0, 0, // packet size
+ 0, 0, 0, 0, // reserved
+
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 255, 255, 255, 255, // size 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 101, 0, 0, 0, // size 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ })
+
+ txp.Flush()
+ if b, _ := q.Dequeue(nil); b != nil {
+ t.Fatalf("Bad completion not ignored")
+ }
+}
+
+func TestFillTxPipe(t *testing.T) {
+ // Check that transmitting a new buffer when the buffer pipe is full
+ // fails gracefully.
+ pb1 := make([]byte, 104)
+ pb2 := make([]byte, 104)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Tx
+ q.Init(pb1, pb2)
+
+ // Transmit twice, which should fill the tx pipe.
+ b := []TxBuffer{
+ {nil, 100, 60},
+ {nil, 200, 40},
+ }
+
+ b[0].Next = &b[1]
+
+ const usedID = 1002
+ const usedTotalSize = 100
+ for i := uint64(0); i < 2; i++ {
+ if !q.Enqueue(usedID+i, usedTotalSize, 2, &b[0]) {
+ t.Fatalf("Failed to transmit buffer")
+ }
+ }
+
+ // Transmit another packet now that the tx pipe is full.
+ if q.Enqueue(usedID+2, usedTotalSize, 2, &b[0]) {
+ t.Fatalf("Enqueue succeeded when tx pipe is full")
+ }
+}
+
+func TestFillRxPipe(t *testing.T) {
+ // Check that posting a new buffer when the buffer pipe is full fails
+ // gracefully.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Rx
+ q.Init(pb1, pb2, nil)
+
+ // Post a buffer twice, it should fill the tx pipe.
+ b := []RxBuffer{
+ {100, 60, 1077, 0},
+ }
+
+ for i := 0; i < 2; i++ {
+ if !q.PostBuffers(b) {
+ t.Fatalf("PostBuffers failed on non-full queue")
+ }
+ }
+
+ // Post another buffer now that the tx pipe is full.
+ if q.PostBuffers(b) {
+ t.Fatalf("PostBuffers succeeded on full queue")
+ }
+}
+
+func TestLotsOfTransmissions(t *testing.T) {
+ // Make sure pipes are being properly flushed when transmitting packets.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Tx
+ q.Init(pb1, pb2)
+
+ // Prepare packet with two buffers.
+ b := []TxBuffer{
+ {nil, 100, 60},
+ {nil, 200, 40},
+ }
+
+ b[0].Next = &b[1]
+
+ const usedID = 1002
+ const usedTotalSize = 100
+
+ // Post 100000 packets and completions.
+ for i := 100000; i > 0; i-- {
+ if !q.Enqueue(usedID, usedTotalSize, 2, &b[0]) {
+ t.Fatalf("Enqueue failed on non-full queue")
+ }
+
+ if d := rxp.Pull(); d == nil {
+ t.Fatalf("Tx pipe is empty after Enqueue")
+ }
+ rxp.Flush()
+
+ d := txp.Push(8)
+ if d == nil {
+ t.Fatalf("Unable to write to rx pipe")
+ }
+ binary.LittleEndian.PutUint64(d, usedID)
+ txp.Flush()
+ if _, ok := q.CompletedPacket(); !ok {
+ t.Fatalf("Completion not returned")
+ }
+ }
+}
+
+func TestLotsOfReceptions(t *testing.T) {
+ // Make sure pipes are being properly flushed when receiving packets.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var rxp pipe.Rx
+ rxp.Init(pb1)
+
+ var txp pipe.Tx
+ txp.Init(pb2)
+
+ var q Rx
+ q.Init(pb1, pb2, nil)
+
+ // Prepare for posting two buffers.
+ b := []RxBuffer{
+ {100, 60, 1077, 0},
+ {200, 40, 2123, 0},
+ }
+
+ // Post 100000 buffers and completions.
+ for i := 100000; i > 0; i-- {
+ if !q.PostBuffers(b) {
+ t.Fatalf("PostBuffers failed on non-full queue")
+ }
+
+ if d := rxp.Pull(); d == nil {
+ t.Fatalf("Tx pipe is empty after PostBuffers")
+ }
+ rxp.Flush()
+
+ if d := rxp.Pull(); d == nil {
+ t.Fatalf("Tx pipe is empty after PostBuffers")
+ }
+ rxp.Flush()
+
+ d := txp.Push(sizeOfConsumedPacketHeader + 2*sizeOfConsumedBuffer)
+ if d == nil {
+ t.Fatalf("Unable to push to rx pipe")
+ }
+
+ copy(d, []byte{
+ 100, 0, 0, 0, // packet size
+ 0, 0, 0, 0, // reserved
+
+ 100, 0, 0, 0, 0, 0, 0, 0, // offset 1
+ 60, 0, 0, 0, // size 1
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 1
+ 53, 4, 0, 0, 0, 0, 0, 0, // ID 1
+
+ 200, 0, 0, 0, 0, 0, 0, 0, // offset 2
+ 40, 0, 0, 0, // size 2
+ 0, 0, 0, 0, 0, 0, 0, 0, // user data 2
+ 75, 8, 0, 0, 0, 0, 0, 0, // ID 2
+ })
+
+ txp.Flush()
+
+ if _, n := q.Dequeue(nil); n == 0 {
+ t.Fatalf("Dequeue failed when there is a completion")
+ }
+ }
+}
+
+func TestRxEnableNotification(t *testing.T) {
+ // Check that enabling nofifications results in properly updated state.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var state uint32
+ var q Rx
+ q.Init(pb1, pb2, &state)
+
+ q.EnableNotification()
+ if state != eventFDEnabled {
+ t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDEnabled)
+ }
+}
+
+func TestRxDisableNotification(t *testing.T) {
+ // Check that disabling nofifications results in properly updated state.
+ pb1 := make([]byte, 100)
+ pb2 := make([]byte, 100)
+
+ var state uint32
+ var q Rx
+ q.Init(pb1, pb2, &state)
+
+ q.DisableNotification()
+ if state != eventFDDisabled {
+ t.Fatalf("Bad value in shared state: got %v, want %v", state, eventFDDisabled)
+ }
+}
diff --git a/pkg/tcpip/link/sharedmem/queue/rx.go b/pkg/tcpip/link/sharedmem/queue/rx.go
new file mode 100644
index 000000000..91bb57190
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queue/rx.go
@@ -0,0 +1,211 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package queue provides the implementation of transmit and receive queues
+// based on shared memory ring buffers.
+package queue
+
+import (
+ "encoding/binary"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/pipe"
+)
+
+const (
+ // Offsets within a posted buffer.
+ postedOffset = 0
+ postedSize = 8
+ postedRemainingInGroup = 12
+ postedUserData = 16
+ postedID = 24
+
+ sizeOfPostedBuffer = 32
+
+ // Offsets within a received packet header.
+ consumedPacketSize = 0
+ consumedPacketReserved = 4
+
+ sizeOfConsumedPacketHeader = 8
+
+ // Offsets within a consumed buffer.
+ consumedOffset = 0
+ consumedSize = 8
+ consumedUserData = 12
+ consumedID = 20
+
+ sizeOfConsumedBuffer = 28
+
+ // The following are the allowed states of the shared data area.
+ eventFDUninitialized = 0
+ eventFDDisabled = 1
+ eventFDEnabled = 2
+)
+
+// RxBuffer is the descriptor of a receive buffer.
+type RxBuffer struct {
+ Offset uint64
+ Size uint32
+ ID uint64
+ UserData uint64
+}
+
+// Rx is a receive queue. It is implemented with one tx and one rx pipe: the tx
+// pipe is used to "post" buffers, while the rx pipe is used to receive packets
+// whose contents have been written to previously posted buffers.
+//
+// This struct is thread-compatible.
+type Rx struct {
+ tx pipe.Tx
+ rx pipe.Rx
+ sharedEventFDState *uint32
+}
+
+// Init initializes the receive queue with the given pipes, and shared state
+// pointer -- the latter is used to enable/disable eventfd notifications.
+func (r *Rx) Init(tx, rx []byte, sharedEventFDState *uint32) {
+ r.sharedEventFDState = sharedEventFDState
+ r.tx.Init(tx)
+ r.rx.Init(rx)
+}
+
+// EnableNotification updates the shared state such that the peer will notify
+// the eventfd when there are packets to be dequeued.
+func (r *Rx) EnableNotification() {
+ atomic.StoreUint32(r.sharedEventFDState, eventFDEnabled)
+}
+
+// DisableNotification updates the shared state such that the peer will not
+// notify the eventfd.
+func (r *Rx) DisableNotification() {
+ atomic.StoreUint32(r.sharedEventFDState, eventFDDisabled)
+}
+
+// PostedBuffersLimit returns the maximum number of buffers that can be posted
+// before the tx queue fills up.
+func (r *Rx) PostedBuffersLimit() uint64 {
+ return r.tx.Capacity(sizeOfPostedBuffer)
+}
+
+// PostBuffers makes the given buffers available for receiving data from the
+// peer. Once they are posted, the peer is free to write to them and will
+// eventually post them back for consumption.
+func (r *Rx) PostBuffers(buffers []RxBuffer) bool {
+ for i := range buffers {
+ b := r.tx.Push(sizeOfPostedBuffer)
+ if b == nil {
+ r.tx.Abort()
+ return false
+ }
+
+ pb := &buffers[i]
+ binary.LittleEndian.PutUint64(b[postedOffset:], pb.Offset)
+ binary.LittleEndian.PutUint32(b[postedSize:], pb.Size)
+ binary.LittleEndian.PutUint32(b[postedRemainingInGroup:], 0)
+ binary.LittleEndian.PutUint64(b[postedUserData:], pb.UserData)
+ binary.LittleEndian.PutUint64(b[postedID:], pb.ID)
+ }
+
+ r.tx.Flush()
+
+ return true
+}
+
+// Dequeue receives buffers that have been previously posted by PostBuffers()
+// and that have been filled by the peer and posted back.
+//
+// This is similar to append() in that new buffers are appended to "bufs", with
+// reallocation only if "bufs" doesn't have enough capacity.
+func (r *Rx) Dequeue(bufs []RxBuffer) ([]RxBuffer, uint32) {
+ for {
+ outBufs := bufs
+
+ // Pull the next descriptor from the rx pipe.
+ b := r.rx.Pull()
+ if b == nil {
+ return bufs, 0
+ }
+
+ if len(b) < sizeOfConsumedPacketHeader {
+ log.Warningf("Ignoring packet header: size (%v) is less than header size (%v)", len(b), sizeOfConsumedPacketHeader)
+ r.rx.Flush()
+ continue
+ }
+
+ totalDataSize := binary.LittleEndian.Uint32(b[consumedPacketSize:])
+
+ // Calculate the number of buffer descriptors and copy them
+ // over to the output.
+ count := (len(b) - sizeOfConsumedPacketHeader) / sizeOfConsumedBuffer
+ offset := sizeOfConsumedPacketHeader
+ buffersSize := uint32(0)
+ for i := count; i > 0; i-- {
+ s := binary.LittleEndian.Uint32(b[offset+consumedSize:])
+ buffersSize += s
+ if buffersSize < s {
+ // The buffer size overflows an unsigned 32-bit
+ // integer, so break out and force it to be
+ // ignored.
+ totalDataSize = 1
+ buffersSize = 0
+ break
+ }
+
+ outBufs = append(outBufs, RxBuffer{
+ Offset: binary.LittleEndian.Uint64(b[offset+consumedOffset:]),
+ Size: s,
+ ID: binary.LittleEndian.Uint64(b[offset+consumedID:]),
+ })
+
+ offset += sizeOfConsumedBuffer
+ }
+
+ r.rx.Flush()
+
+ if buffersSize < totalDataSize {
+ // The descriptor is corrupted, ignore it.
+ log.Warningf("Ignoring packet: actual data size (%v) less than expected size (%v)", buffersSize, totalDataSize)
+ continue
+ }
+
+ return outBufs, totalDataSize
+ }
+}
+
+// Bytes returns the byte slices on which the queue operates.
+func (r *Rx) Bytes() (tx, rx []byte) {
+ return r.tx.Bytes(), r.rx.Bytes()
+}
+
+// DecodeRxBufferHeader decodes the header of a buffer posted on an rx queue.
+func DecodeRxBufferHeader(b []byte) RxBuffer {
+ return RxBuffer{
+ Offset: binary.LittleEndian.Uint64(b[postedOffset:]),
+ Size: binary.LittleEndian.Uint32(b[postedSize:]),
+ ID: binary.LittleEndian.Uint64(b[postedID:]),
+ UserData: binary.LittleEndian.Uint64(b[postedUserData:]),
+ }
+}
+
+// RxCompletionSize returns the number of bytes needed to encode an rx
+// completion containing "count" buffers.
+func RxCompletionSize(count int) uint64 {
+ return sizeOfConsumedPacketHeader + uint64(count)*sizeOfConsumedBuffer
+}
+
+// EncodeRxCompletion encodes an rx completion header.
+func EncodeRxCompletion(b []byte, size, reserved uint32) {
+ binary.LittleEndian.PutUint32(b[consumedPacketSize:], size)
+ binary.LittleEndian.PutUint32(b[consumedPacketReserved:], reserved)
+}
+
+// EncodeRxCompletionBuffer encodes the i-th rx completion buffer header.
+func EncodeRxCompletionBuffer(b []byte, i int, rxb RxBuffer) {
+ b = b[RxCompletionSize(i):]
+ binary.LittleEndian.PutUint64(b[consumedOffset:], rxb.Offset)
+ binary.LittleEndian.PutUint32(b[consumedSize:], rxb.Size)
+ binary.LittleEndian.PutUint64(b[consumedUserData:], rxb.UserData)
+ binary.LittleEndian.PutUint64(b[consumedID:], rxb.ID)
+}
diff --git a/pkg/tcpip/link/sharedmem/queue/tx.go b/pkg/tcpip/link/sharedmem/queue/tx.go
new file mode 100644
index 000000000..b04fb163b
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/queue/tx.go
@@ -0,0 +1,141 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package queue
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/pipe"
+)
+
+const (
+ // Offsets within a packet header.
+ packetID = 0
+ packetSize = 8
+ packetReserved = 12
+
+ sizeOfPacketHeader = 16
+
+ // Offsets with a buffer descriptor
+ bufferOffset = 0
+ bufferSize = 8
+
+ sizeOfBufferDescriptor = 12
+)
+
+// TxBuffer is the descriptor of a transmit buffer.
+type TxBuffer struct {
+ Next *TxBuffer
+ Offset uint64
+ Size uint32
+}
+
+// Tx is a transmit queue. It is implemented with one tx and one rx pipe: the
+// tx pipe is used to request the transmission of packets, while the rx pipe
+// is used to receive which transmissions have completed.
+//
+// This struct is thread-compatible.
+type Tx struct {
+ tx pipe.Tx
+ rx pipe.Rx
+}
+
+// Init initializes the transmit queue with the given pipes.
+func (t *Tx) Init(tx, rx []byte) {
+ t.tx.Init(tx)
+ t.rx.Init(rx)
+}
+
+// Enqueue queues the given linked list of buffers for transmission as one
+// packet. While it is queued, the caller must not modify them.
+func (t *Tx) Enqueue(id uint64, totalDataLen, bufferCount uint32, buffer *TxBuffer) bool {
+ // Reserve room in the tx pipe.
+ totalLen := sizeOfPacketHeader + uint64(bufferCount)*sizeOfBufferDescriptor
+
+ b := t.tx.Push(totalLen)
+ if b == nil {
+ return false
+ }
+
+ // Initialize the packet and buffer descriptors.
+ binary.LittleEndian.PutUint64(b[packetID:], id)
+ binary.LittleEndian.PutUint32(b[packetSize:], totalDataLen)
+ binary.LittleEndian.PutUint32(b[packetReserved:], 0)
+
+ offset := sizeOfPacketHeader
+ for i := bufferCount; i != 0; i-- {
+ binary.LittleEndian.PutUint64(b[offset+bufferOffset:], buffer.Offset)
+ binary.LittleEndian.PutUint32(b[offset+bufferSize:], buffer.Size)
+ offset += sizeOfBufferDescriptor
+ buffer = buffer.Next
+ }
+
+ t.tx.Flush()
+
+ return true
+}
+
+// CompletedPacket returns the id of the last completed transmission. The
+// returned id, if any, refers to a value passed on a previous call to
+// Enqueue().
+func (t *Tx) CompletedPacket() (id uint64, ok bool) {
+ for {
+ b := t.rx.Pull()
+ if b == nil {
+ return 0, false
+ }
+
+ if len(b) != 8 {
+ t.rx.Flush()
+ log.Warningf("Ignoring completed packet: size (%v) is less than expected (%v)", len(b), 8)
+ continue
+ }
+
+ v := binary.LittleEndian.Uint64(b)
+
+ t.rx.Flush()
+
+ return v, true
+ }
+}
+
+// Bytes returns the byte slices on which the queue operates.
+func (t *Tx) Bytes() (tx, rx []byte) {
+ return t.tx.Bytes(), t.rx.Bytes()
+}
+
+// TxPacketInfo holds information about a packet sent on a tx queue.
+type TxPacketInfo struct {
+ ID uint64
+ Size uint32
+ Reserved uint32
+ BufferCount int
+}
+
+// DecodeTxPacketHeader decodes the header of a packet sent over a tx queue.
+func DecodeTxPacketHeader(b []byte) TxPacketInfo {
+ return TxPacketInfo{
+ ID: binary.LittleEndian.Uint64(b[packetID:]),
+ Size: binary.LittleEndian.Uint32(b[packetSize:]),
+ Reserved: binary.LittleEndian.Uint32(b[packetReserved:]),
+ BufferCount: (len(b) - sizeOfPacketHeader) / sizeOfBufferDescriptor,
+ }
+}
+
+// DecodeTxBufferHeader decodes the header of the i-th buffer of a packet sent
+// over a tx queue.
+func DecodeTxBufferHeader(b []byte, i int) TxBuffer {
+ b = b[sizeOfPacketHeader+i*sizeOfBufferDescriptor:]
+ return TxBuffer{
+ Offset: binary.LittleEndian.Uint64(b[bufferOffset:]),
+ Size: binary.LittleEndian.Uint32(b[bufferSize:]),
+ }
+}
+
+// EncodeTxCompletion encodes a tx completion header.
+func EncodeTxCompletion(b []byte, id uint64) {
+ binary.LittleEndian.PutUint64(b, id)
+}
diff --git a/pkg/tcpip/link/sharedmem/rx.go b/pkg/tcpip/link/sharedmem/rx.go
new file mode 100644
index 000000000..951ed966b
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/rx.go
@@ -0,0 +1,147 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sharedmem
+
+import (
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+// rx holds all state associated with an rx queue.
+type rx struct {
+ data []byte
+ sharedData []byte
+ q queue.Rx
+ eventFD int
+}
+
+// init initializes all state needed by the rx queue based on the information
+// provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (r *rx) init(mtu uint32, c *QueueConfig) error {
+ // Map in all buffers.
+ txPipe, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+
+ rxPipe, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ syscall.Munmap(txPipe)
+ return err
+ }
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ syscall.Munmap(txPipe)
+ syscall.Munmap(rxPipe)
+ return err
+ }
+
+ sharedData, err := getBuffer(c.SharedDataFD)
+ if err != nil {
+ syscall.Munmap(txPipe)
+ syscall.Munmap(rxPipe)
+ syscall.Munmap(data)
+ return err
+ }
+
+ // Duplicate the eventFD so that caller can close it but we can still
+ // use it.
+ efd, err := syscall.Dup(c.EventFD)
+ if err != nil {
+ syscall.Munmap(txPipe)
+ syscall.Munmap(rxPipe)
+ syscall.Munmap(data)
+ syscall.Munmap(sharedData)
+ return err
+ }
+
+ // Set the eventfd as non-blocking.
+ if err := syscall.SetNonblock(efd, true); err != nil {
+ syscall.Munmap(txPipe)
+ syscall.Munmap(rxPipe)
+ syscall.Munmap(data)
+ syscall.Munmap(sharedData)
+ syscall.Close(efd)
+ return err
+ }
+
+ // Initialize state based on buffers.
+ r.q.Init(txPipe, rxPipe, sharedDataPointer(sharedData))
+ r.data = data
+ r.eventFD = efd
+ r.sharedData = sharedData
+
+ return nil
+}
+
+// cleanup releases all resources allocated during init(). It must only be
+// called if init() has previously succeeded.
+func (r *rx) cleanup() {
+ a, b := r.q.Bytes()
+ syscall.Munmap(a)
+ syscall.Munmap(b)
+
+ syscall.Munmap(r.data)
+ syscall.Munmap(r.sharedData)
+ syscall.Close(r.eventFD)
+}
+
+// postAndReceive posts the provided buffers (if any), and then tries to read
+// from the receive queue.
+//
+// Capacity permitting, it reuses the posted buffer slice to store the buffers
+// that were read as well.
+//
+// This function will block if there aren't any available packets.
+func (r *rx) postAndReceive(b []queue.RxBuffer, stopRequested *uint32) ([]queue.RxBuffer, uint32) {
+ // Post the buffers first. If we cannot post, sleep until we can. We
+ // never post more than will fit concurrently, so it's safe to wait
+ // until enough room is available.
+ if len(b) != 0 && !r.q.PostBuffers(b) {
+ r.q.EnableNotification()
+ for !r.q.PostBuffers(b) {
+ var tmp [8]byte
+ rawfile.BlockingRead(r.eventFD, tmp[:])
+ if atomic.LoadUint32(stopRequested) != 0 {
+ r.q.DisableNotification()
+ return nil, 0
+ }
+ }
+ r.q.DisableNotification()
+ }
+
+ // Read the next set of descriptors.
+ b, n := r.q.Dequeue(b[:0])
+ if len(b) != 0 {
+ return b, n
+ }
+
+ // Data isn't immediately available. Enable eventfd notifications.
+ r.q.EnableNotification()
+ for {
+ b, n = r.q.Dequeue(b)
+ if len(b) != 0 {
+ break
+ }
+
+ // Wait for notification.
+ var tmp [8]byte
+ rawfile.BlockingRead(r.eventFD, tmp[:])
+ if atomic.LoadUint32(stopRequested) != 0 {
+ r.q.DisableNotification()
+ return nil, 0
+ }
+ }
+ r.q.DisableNotification()
+
+ return b, n
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem.go b/pkg/tcpip/link/sharedmem/sharedmem.go
new file mode 100644
index 000000000..2c0f1b294
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem.go
@@ -0,0 +1,240 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package sharedmem provides the implemention of data-link layer endpoints
+// backed by shared memory.
+//
+// Shared memory endpoints can be used in the networking stack by calling New()
+// to create a new endpoint, and then passing it as an argument to
+// Stack.CreateNIC().
+package sharedmem
+
+import (
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/queue"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// QueueConfig holds all the file descriptors needed to describe a tx or rx
+// queue over shared memory. It is used when creating new shared memory
+// endpoints to describe tx and rx queues.
+type QueueConfig struct {
+ // DataFD is a file descriptor for the file that contains the data to
+ // be transmitted via this queue. Descriptors contain offsets within
+ // this file.
+ DataFD int
+
+ // EventFD is a file descriptor for the event that is signaled when
+ // data is becomes available in this queue.
+ EventFD int
+
+ // TxPipeFD is a file descriptor for the tx pipe associated with the
+ // queue.
+ TxPipeFD int
+
+ // RxPipeFD is a file descriptor for the rx pipe associated with the
+ // queue.
+ RxPipeFD int
+
+ // SharedDataFD is a file descriptor for the file that contains shared
+ // state between the two ends of the queue. This data specifies, for
+ // example, whether EventFD signaling is enabled or disabled.
+ SharedDataFD int
+}
+
+type endpoint struct {
+ // mtu (maximum transmission unit) is the maximum size of a packet.
+ mtu uint32
+
+ // bufferSize is the size of each individual buffer.
+ bufferSize uint32
+
+ // addr is the local address of this endpoint.
+ addr tcpip.LinkAddress
+
+ // rx is the receive queue.
+ rx rx
+
+ // stopRequested is to be accessed atomically only, and determines if
+ // the worker goroutines should stop.
+ stopRequested uint32
+
+ // Wait group used to indicate that all workers have stopped.
+ completed sync.WaitGroup
+
+ // mu protects the following fields.
+ mu sync.Mutex
+
+ // tx is the transmit queue.
+ tx tx
+
+ // workerStarted specifies whether the worker goroutine was started.
+ workerStarted bool
+}
+
+// New creates a new shared-memory-based endpoint. Buffers will be broken up
+// into buffers of "bufferSize" bytes.
+func New(mtu, bufferSize uint32, addr tcpip.LinkAddress, tx, rx QueueConfig) (tcpip.LinkEndpointID, error) {
+ e := &endpoint{
+ mtu: mtu,
+ bufferSize: bufferSize,
+ addr: addr,
+ }
+
+ if err := e.tx.init(bufferSize, &tx); err != nil {
+ return 0, err
+ }
+
+ if err := e.rx.init(bufferSize, &rx); err != nil {
+ e.tx.cleanup()
+ return 0, err
+ }
+
+ return stack.RegisterLinkEndpoint(e), nil
+}
+
+// Close frees all resources associated with the endpoint.
+func (e *endpoint) Close() {
+ // Tell dispatch goroutine to stop, then write to the eventfd so that
+ // it wakes up in case it's sleeping.
+ atomic.StoreUint32(&e.stopRequested, 1)
+ syscall.Write(e.rx.eventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Cleanup the queues inline if the worker hasn't started yet; we also
+ // know it won't start from now on because stopRequested is set to 1.
+ e.mu.Lock()
+ workerPresent := e.workerStarted
+ e.mu.Unlock()
+
+ if !workerPresent {
+ e.tx.cleanup()
+ e.rx.cleanup()
+ }
+}
+
+// Wait waits until all workers have stopped after a Close() call.
+func (e *endpoint) Wait() {
+ e.completed.Wait()
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It launches the goroutine that
+// reads packets from the rx queue.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.mu.Lock()
+ if !e.workerStarted && atomic.LoadUint32(&e.stopRequested) == 0 {
+ e.workerStarted = true
+ e.completed.Add(1)
+ go e.dispatchLoop(dispatcher) // S/R-FIXME
+ }
+ e.mu.Unlock()
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It returns the value initialized
+// during construction.
+func (e *endpoint) MTU() uint32 {
+ return e.mtu - header.EthernetMinimumSize
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (*endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It returns the
+// ethernet frame header size.
+func (*endpoint) MaxHeaderLength() uint16 {
+ return header.EthernetMinimumSize
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress. It returns the local
+// link address.
+func (e *endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.addr
+}
+
+// WritePacket writes outbound packets to the file descriptor. If it is not
+// currently writable, the packet is dropped.
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ // Add the ethernet header here.
+ eth := header.Ethernet(hdr.Prepend(header.EthernetMinimumSize))
+ eth.Encode(&header.EthernetFields{
+ DstAddr: r.RemoteLinkAddress,
+ SrcAddr: e.addr,
+ Type: protocol,
+ })
+
+ // Transmit the packet.
+ e.mu.Lock()
+ ok := e.tx.transmit(hdr.UsedBytes(), payload)
+ e.mu.Unlock()
+
+ if !ok {
+ return tcpip.ErrWouldBlock
+ }
+
+ return nil
+}
+
+// dispatchLoop reads packets from the rx queue in a loop and dispatches them
+// to the network stack.
+func (e *endpoint) dispatchLoop(d stack.NetworkDispatcher) {
+ // Post initial set of buffers.
+ limit := e.rx.q.PostedBuffersLimit()
+ if l := uint64(len(e.rx.data)) / uint64(e.bufferSize); limit > l {
+ limit = l
+ }
+ for i := uint64(0); i < limit; i++ {
+ b := queue.RxBuffer{
+ Offset: i * uint64(e.bufferSize),
+ Size: e.bufferSize,
+ ID: i,
+ }
+ if !e.rx.q.PostBuffers([]queue.RxBuffer{b}) {
+ log.Warningf("Unable to post %v-th buffer", i)
+ }
+ }
+
+ // Read in a loop until a stop is requested.
+ var rxb []queue.RxBuffer
+ views := []buffer.View{nil}
+ vv := buffer.NewVectorisedView(0, views)
+ for atomic.LoadUint32(&e.stopRequested) == 0 {
+ var n uint32
+ rxb, n = e.rx.postAndReceive(rxb, &e.stopRequested)
+
+ // Copy data from the shared area to its own buffer, then
+ // prepare to repost the buffer.
+ b := make([]byte, n)
+ offset := uint32(0)
+ for i := range rxb {
+ copy(b[offset:], e.rx.data[rxb[i].Offset:][:rxb[i].Size])
+ offset += rxb[i].Size
+
+ rxb[i].Size = e.bufferSize
+ }
+
+ if n < header.EthernetMinimumSize {
+ continue
+ }
+
+ // Send packet up the stack.
+ eth := header.Ethernet(b)
+ views[0] = b[header.EthernetMinimumSize:]
+ vv.SetSize(int(n) - header.EthernetMinimumSize)
+ d.DeliverNetworkPacket(e, eth.SourceAddress(), eth.Type(), &vv)
+ }
+
+ // Clean state.
+ e.tx.cleanup()
+ e.rx.cleanup()
+
+ e.completed.Done()
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_test.go b/pkg/tcpip/link/sharedmem/sharedmem_test.go
new file mode 100644
index 000000000..f71e4751f
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_test.go
@@ -0,0 +1,703 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sharedmem
+
+import (
+ "io/ioutil"
+ "math/rand"
+ "os"
+ "reflect"
+ "sync"
+ "syscall"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/pipe"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/queue"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ localLinkAddr = "\xde\xad\xbe\xef\x56\x78"
+ remoteLinkAddr = "\xde\xad\xbe\xef\x12\x34"
+
+ queueDataSize = 1024 * 1024
+ queuePipeSize = 4096
+)
+
+type queueBuffers struct {
+ data []byte
+ rx pipe.Tx
+ tx pipe.Rx
+}
+
+func initQueue(t *testing.T, q *queueBuffers, c *QueueConfig) {
+ // Prepare tx pipe.
+ b, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ t.Fatalf("getBuffer failed: %v", err)
+ }
+ q.tx.Init(b)
+
+ // Prepare rx pipe.
+ b, err = getBuffer(c.RxPipeFD)
+ if err != nil {
+ t.Fatalf("getBuffer failed: %v", err)
+ }
+ q.rx.Init(b)
+
+ // Get data slice.
+ q.data, err = getBuffer(c.DataFD)
+ if err != nil {
+ t.Fatalf("getBuffer failed: %v", err)
+ }
+}
+
+func (q *queueBuffers) cleanup() {
+ syscall.Munmap(q.tx.Bytes())
+ syscall.Munmap(q.rx.Bytes())
+ syscall.Munmap(q.data)
+}
+
+type packetInfo struct {
+ addr tcpip.LinkAddress
+ proto tcpip.NetworkProtocolNumber
+ vv buffer.VectorisedView
+}
+
+type testContext struct {
+ t *testing.T
+ ep *endpoint
+ txCfg QueueConfig
+ rxCfg QueueConfig
+ txq queueBuffers
+ rxq queueBuffers
+
+ packetCh chan struct{}
+ mu sync.Mutex
+ packets []packetInfo
+}
+
+func newTestContext(t *testing.T, mtu, bufferSize uint32, addr tcpip.LinkAddress) *testContext {
+ var err error
+ c := &testContext{
+ t: t,
+ packetCh: make(chan struct{}, 1000000),
+ }
+ c.txCfg = createQueueFDs(t, queueSizes{
+ dataSize: queueDataSize,
+ txPipeSize: queuePipeSize,
+ rxPipeSize: queuePipeSize,
+ sharedDataSize: 4096,
+ })
+
+ c.rxCfg = createQueueFDs(t, queueSizes{
+ dataSize: queueDataSize,
+ txPipeSize: queuePipeSize,
+ rxPipeSize: queuePipeSize,
+ sharedDataSize: 4096,
+ })
+
+ initQueue(t, &c.txq, &c.txCfg)
+ initQueue(t, &c.rxq, &c.rxCfg)
+
+ id, err := New(mtu, bufferSize, addr, c.txCfg, c.rxCfg)
+ if err != nil {
+ t.Fatalf("New failed: %v", err)
+ }
+
+ c.ep = stack.FindLinkEndpoint(id).(*endpoint)
+ c.ep.Attach(c)
+
+ return c
+}
+
+func (c *testContext) DeliverNetworkPacket(_ stack.LinkEndpoint, remoteAddr tcpip.LinkAddress, proto tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
+ c.mu.Lock()
+ c.packets = append(c.packets, packetInfo{
+ addr: remoteAddr,
+ proto: proto,
+ vv: vv.Clone(nil),
+ })
+ c.mu.Unlock()
+
+ c.packetCh <- struct{}{}
+}
+
+func (c *testContext) cleanup() {
+ c.ep.Close()
+ closeFDs(&c.txCfg)
+ closeFDs(&c.rxCfg)
+ c.txq.cleanup()
+ c.rxq.cleanup()
+}
+
+func (c *testContext) waitForPackets(n int, to <-chan time.Time, errorStr string) {
+ for i := 0; i < n; i++ {
+ select {
+ case <-c.packetCh:
+ case <-to:
+ c.t.Fatalf(errorStr)
+ }
+ }
+}
+
+func (c *testContext) pushRxCompletion(size uint32, bs []queue.RxBuffer) {
+ b := c.rxq.rx.Push(queue.RxCompletionSize(len(bs)))
+ queue.EncodeRxCompletion(b, size, 0)
+ for i := range bs {
+ queue.EncodeRxCompletionBuffer(b, i, queue.RxBuffer{
+ Offset: bs[i].Offset,
+ Size: bs[i].Size,
+ ID: bs[i].ID,
+ })
+ }
+}
+
+func randomFill(b []byte) {
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+}
+
+func shuffle(b []int) {
+ for i := len(b) - 1; i >= 0; i-- {
+ j := rand.Intn(i + 1)
+ b[i], b[j] = b[j], b[i]
+ }
+}
+
+func createFile(t *testing.T, size int64, initQueue bool) int {
+ tmpDir := os.Getenv("TEST_TMPDIR")
+ if tmpDir == "" {
+ tmpDir = os.Getenv("TMPDIR")
+ }
+ f, err := ioutil.TempFile(tmpDir, "sharedmem_test")
+ if err != nil {
+ t.Fatalf("TempFile failed: %v", err)
+ }
+ defer f.Close()
+ syscall.Unlink(f.Name())
+
+ if initQueue {
+ // Write the "slot-free" flag in the initial queue.
+ _, err := f.WriteAt([]byte{0, 0, 0, 0, 0, 0, 0, 0x80}, 0)
+ if err != nil {
+ t.Fatalf("WriteAt failed: %v", err)
+ }
+ }
+
+ fd, err := syscall.Dup(int(f.Fd()))
+ if err != nil {
+ t.Fatalf("Dup failed: %v", err)
+ }
+
+ if err := syscall.Ftruncate(fd, size); err != nil {
+ syscall.Close(fd)
+ t.Fatalf("Ftruncate failed: %v", err)
+ }
+
+ return fd
+}
+
+func closeFDs(c *QueueConfig) {
+ syscall.Close(c.DataFD)
+ syscall.Close(c.EventFD)
+ syscall.Close(c.TxPipeFD)
+ syscall.Close(c.RxPipeFD)
+ syscall.Close(c.SharedDataFD)
+}
+
+type queueSizes struct {
+ dataSize int64
+ txPipeSize int64
+ rxPipeSize int64
+ sharedDataSize int64
+}
+
+func createQueueFDs(t *testing.T, s queueSizes) QueueConfig {
+ fd, _, err := syscall.RawSyscall(syscall.SYS_EVENTFD2, 0, 0, 0)
+ if err != 0 {
+ t.Fatalf("eventfd failed: %v", error(err))
+ }
+
+ return QueueConfig{
+ EventFD: int(fd),
+ DataFD: createFile(t, s.dataSize, false),
+ TxPipeFD: createFile(t, s.txPipeSize, true),
+ RxPipeFD: createFile(t, s.rxPipeSize, true),
+ SharedDataFD: createFile(t, s.sharedDataSize, false),
+ }
+}
+
+// TestSimpleSend sends 1000 packets with random header and payload sizes,
+// then checks that the right payload is received on the shared memory queues.
+func TestSimpleSend(t *testing.T) {
+ c := newTestContext(t, 20000, 1500, localLinkAddr)
+ defer c.cleanup()
+
+ // Prepare route.
+ r := stack.Route{
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+
+ for iters := 1000; iters > 0; iters-- {
+ // Prepare and send packet.
+ n := rand.Intn(10000)
+ hdr := buffer.NewPrependable(n + int(c.ep.MaxHeaderLength()))
+ hdrBuf := hdr.Prepend(n)
+ randomFill(hdrBuf)
+
+ n = rand.Intn(10000)
+ buf := buffer.NewView(n)
+ randomFill(buf)
+
+ proto := tcpip.NetworkProtocolNumber(rand.Intn(0x10000))
+ err := c.ep.WritePacket(&r, &hdr, buf, proto)
+ if err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+
+ // Receive packet.
+ desc := c.txq.tx.Pull()
+ pi := queue.DecodeTxPacketHeader(desc)
+ contents := make([]byte, 0, pi.Size)
+ for i := 0; i < pi.BufferCount; i++ {
+ bi := queue.DecodeTxBufferHeader(desc, i)
+ contents = append(contents, c.txq.data[bi.Offset:][:bi.Size]...)
+ }
+ c.txq.tx.Flush()
+
+ if pi.Reserved != 0 {
+ t.Fatalf("Reserved value is non-zero: 0x%x", pi.Reserved)
+ }
+
+ // Check the thernet header.
+ ethTemplate := make(header.Ethernet, header.EthernetMinimumSize)
+ ethTemplate.Encode(&header.EthernetFields{
+ SrcAddr: localLinkAddr,
+ DstAddr: remoteLinkAddr,
+ Type: proto,
+ })
+ if got := contents[:header.EthernetMinimumSize]; !reflect.DeepEqual(got, []byte(ethTemplate)) {
+ t.Fatalf("Bad ethernet header in packet: got %x, want %x", got, ethTemplate)
+ }
+
+ // Compare contents skipping the ethernet header added by the
+ // endpoint.
+ merged := append(hdrBuf, buf...)
+ if uint32(len(contents)) < pi.Size {
+ t.Fatalf("Sum of buffers is less than packet size: %v < %v", len(contents), pi.Size)
+ }
+ contents = contents[:pi.Size][header.EthernetMinimumSize:]
+
+ if !reflect.DeepEqual(contents, merged) {
+ t.Fatalf("Buffers are different: got %x (%v bytes), want %x (%v bytes)", contents, len(contents), merged, len(merged))
+ }
+
+ // Tell the endpoint about the completion of the write.
+ b := c.txq.rx.Push(8)
+ queue.EncodeTxCompletion(b, pi.ID)
+ c.txq.rx.Flush()
+ }
+}
+
+// TestFillTxQueue sends packets until the queue is full.
+func TestFillTxQueue(t *testing.T) {
+ c := newTestContext(t, 20000, 1500, localLinkAddr)
+ defer c.cleanup()
+
+ // Prepare to send a packet.
+ r := stack.Route{
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+
+ buf := buffer.NewView(100)
+
+ // Each packet is uses no more than 40 bytes, so write that many packets
+ // until the tx queue if full.
+ ids := make(map[uint64]struct{})
+ for i := queuePipeSize / 40; i > 0; i-- {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil {
+ t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ }
+
+ // Check that they have different IDs.
+ desc := c.txq.tx.Pull()
+ pi := queue.DecodeTxPacketHeader(desc)
+ if _, ok := ids[pi.ID]; ok {
+ t.Fatalf("ID (%v) reused", pi.ID)
+ }
+ ids[pi.ID] = struct{}{}
+ }
+
+ // Next attempt to write must fail.
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber)
+ if want := tcpip.ErrWouldBlock; err != want {
+ t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ }
+}
+
+// TestFillTxQueueAfterBadCompletion sends a bad completion, then sends packets
+// until the queue is full.
+func TestFillTxQueueAfterBadCompletion(t *testing.T) {
+ c := newTestContext(t, 20000, 1500, localLinkAddr)
+ defer c.cleanup()
+
+ // Send a bad completion.
+ queue.EncodeTxCompletion(c.txq.rx.Push(8), 1)
+ c.txq.rx.Flush()
+
+ // Prepare to send a packet.
+ r := stack.Route{
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+
+ buf := buffer.NewView(100)
+
+ // Send two packets so that the id slice has at least two slots.
+ for i := 2; i > 0; i-- {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil {
+ t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ }
+ }
+
+ // Complete the two writes twice.
+ for i := 2; i > 0; i-- {
+ pi := queue.DecodeTxPacketHeader(c.txq.tx.Pull())
+
+ queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID)
+ queue.EncodeTxCompletion(c.txq.rx.Push(8), pi.ID)
+ c.txq.rx.Flush()
+ }
+ c.txq.tx.Flush()
+
+ // Each packet is uses no more than 40 bytes, so write that many packets
+ // until the tx queue if full.
+ ids := make(map[uint64]struct{})
+ for i := queuePipeSize / 40; i > 0; i-- {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil {
+ t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ }
+
+ // Check that they have different IDs.
+ desc := c.txq.tx.Pull()
+ pi := queue.DecodeTxPacketHeader(desc)
+ if _, ok := ids[pi.ID]; ok {
+ t.Fatalf("ID (%v) reused", pi.ID)
+ }
+ ids[pi.ID] = struct{}{}
+ }
+
+ // Next attempt to write must fail.
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber)
+ if want := tcpip.ErrWouldBlock; err != want {
+ t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ }
+}
+
+// TestFillTxMemory sends packets until the we run out of shared memory.
+func TestFillTxMemory(t *testing.T) {
+ const bufferSize = 1500
+ c := newTestContext(t, 20000, bufferSize, localLinkAddr)
+ defer c.cleanup()
+
+ // Prepare to send a packet.
+ r := stack.Route{
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+
+ buf := buffer.NewView(100)
+
+ // Each packet is uses up one buffer, so write as many as possible until
+ // we fill the memory.
+ ids := make(map[uint64]struct{})
+ for i := queueDataSize / bufferSize; i > 0; i-- {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil {
+ t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ }
+
+ // Check that they have different IDs.
+ desc := c.txq.tx.Pull()
+ pi := queue.DecodeTxPacketHeader(desc)
+ if _, ok := ids[pi.ID]; ok {
+ t.Fatalf("ID (%v) reused", pi.ID)
+ }
+ ids[pi.ID] = struct{}{}
+ c.txq.tx.Flush()
+ }
+
+ // Next attempt to write must fail.
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber)
+ if want := tcpip.ErrWouldBlock; err != want {
+ t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ }
+}
+
+// TestFillTxMemoryWithMultiBuffer sends packets until the we run out of
+// shared memory for a 2-buffer packet, but still with room for a 1-buffer
+// packet.
+func TestFillTxMemoryWithMultiBuffer(t *testing.T) {
+ const bufferSize = 1500
+ c := newTestContext(t, 20000, bufferSize, localLinkAddr)
+ defer c.cleanup()
+
+ // Prepare to send a packet.
+ r := stack.Route{
+ RemoteLinkAddress: remoteLinkAddr,
+ }
+
+ buf := buffer.NewView(100)
+
+ // Each packet is uses up one buffer, so write as many as possible
+ // until there is only one buffer left.
+ for i := queueDataSize/bufferSize - 1; i > 0; i-- {
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil {
+ t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ }
+
+ // Pull the posted buffer.
+ c.txq.tx.Pull()
+ c.txq.tx.Flush()
+ }
+
+ // Attempt to write a two-buffer packet. It must fail.
+ hdr := buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ err := c.ep.WritePacket(&r, &hdr, buffer.NewView(bufferSize), header.IPv4ProtocolNumber)
+ if want := tcpip.ErrWouldBlock; err != want {
+ t.Fatalf("WritePacket return unexpected result: got %v, want %v", err, want)
+ }
+
+ // Attempt to write a one-buffer packet. It must succeed.
+ hdr = buffer.NewPrependable(int(c.ep.MaxHeaderLength()))
+ if err := c.ep.WritePacket(&r, &hdr, buf, header.IPv4ProtocolNumber); err != nil {
+ t.Fatalf("WritePacket failed unexpectedly: %v", err)
+ }
+}
+
+func pollPull(t *testing.T, p *pipe.Rx, to <-chan time.Time, errStr string) []byte {
+ for {
+ b := p.Pull()
+ if b != nil {
+ return b
+ }
+
+ select {
+ case <-time.After(10 * time.Millisecond):
+ case <-to:
+ t.Fatalf(errStr)
+ }
+ }
+}
+
+// TestSimpleReceive completes 1000 different receives with random payload and
+// random number of buffers. It checks that the contents match the expected
+// values.
+func TestSimpleReceive(t *testing.T) {
+ const bufferSize = 1500
+ c := newTestContext(t, 20000, bufferSize, localLinkAddr)
+ defer c.cleanup()
+
+ // Check that buffers have been posted.
+ limit := c.ep.rx.q.PostedBuffersLimit()
+ timeout := time.After(2 * time.Second)
+ for i := uint64(0); i < limit; i++ {
+ bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers to be posted"))
+
+ if want := i * bufferSize; want != bi.Offset {
+ t.Fatalf("Bad posted offset: got %v, want %v", bi.Offset, want)
+ }
+
+ if want := i; want != bi.ID {
+ t.Fatalf("Bad posted ID: got %v, want %v", bi.ID, want)
+ }
+
+ if bufferSize != bi.Size {
+ t.Fatalf("Bad posted bufferSize: got %v, want %v", bi.Size, bufferSize)
+ }
+ }
+ c.rxq.tx.Flush()
+
+ // Create a slice with the indices 0..limit-1.
+ idx := make([]int, limit)
+ for i := range idx {
+ idx[i] = i
+ }
+
+ // Complete random packets 1000 times.
+ for iters := 1000; iters > 0; iters-- {
+ // Prepare a random packet.
+ shuffle(idx)
+ n := 1 + rand.Intn(10)
+ bufs := make([]queue.RxBuffer, n)
+ contents := make([]byte, bufferSize*n-rand.Intn(500))
+ randomFill(contents)
+ for i := range bufs {
+ j := idx[i]
+ bufs[i].Size = bufferSize
+ bufs[i].Offset = uint64(bufferSize * j)
+ bufs[i].ID = uint64(j)
+
+ copy(c.rxq.data[bufs[i].Offset:][:bufferSize], contents[i*bufferSize:])
+ }
+
+ // Push completion.
+ c.pushRxCompletion(uint32(len(contents)), bufs)
+ c.rxq.rx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Wait for packet to be received, then check it.
+ c.waitForPackets(1, time.After(time.Second), "Error waiting for packet")
+ c.mu.Lock()
+ rcvd := []byte(c.packets[0].vv.First())
+ c.packets = c.packets[:0]
+ c.mu.Unlock()
+
+ contents = contents[header.EthernetMinimumSize:]
+ if !reflect.DeepEqual(contents, rcvd) {
+ t.Fatalf("Unexpected buffer contents: got %x, want %x", rcvd, contents)
+ }
+
+ // Check that buffers have been reposted.
+ for i := range bufs {
+ bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffers to be reposted"))
+ if !reflect.DeepEqual(bi, bufs[i]) {
+ t.Fatalf("Unexpected buffer reposted: got %x, want %x", bi, bufs[i])
+ }
+ }
+ c.rxq.tx.Flush()
+ }
+}
+
+// TestRxBuffersReposted tests that rx buffers get reposted after they have been
+// completed.
+func TestRxBuffersReposted(t *testing.T) {
+ const bufferSize = 1500
+ c := newTestContext(t, 20000, bufferSize, localLinkAddr)
+ defer c.cleanup()
+
+ // Receive all posted buffers.
+ limit := c.ep.rx.q.PostedBuffersLimit()
+ buffers := make([]queue.RxBuffer, 0, limit)
+ timeout := time.After(2 * time.Second)
+ for i := limit; i > 0; i-- {
+ buffers = append(buffers, queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for all buffers")))
+ }
+ c.rxq.tx.Flush()
+
+ // Check that all buffers are reposted when individually completed.
+ timeout = time.After(2 * time.Second)
+ for i := range buffers {
+ // Complete the buffer.
+ c.pushRxCompletion(buffers[i].Size, buffers[i:][:1])
+ c.rxq.rx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Wait for it to be reposted.
+ bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
+ if !reflect.DeepEqual(bi, buffers[i]) {
+ t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[i])
+ }
+ }
+ c.rxq.tx.Flush()
+
+ // Check that all buffers are reposted when completed in pairs.
+ timeout = time.After(2 * time.Second)
+ for i := 0; i < len(buffers)/2; i++ {
+ // Complete with two buffers.
+ c.pushRxCompletion(2*bufferSize, buffers[2*i:][:2])
+ c.rxq.rx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Wait for them to be reposted.
+ bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
+ if !reflect.DeepEqual(bi, buffers[2*i]) {
+ t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i])
+ }
+ bi = queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, timeout, "Timeout waiting for buffer to be reposted"))
+ if !reflect.DeepEqual(bi, buffers[2*i+1]) {
+ t.Fatalf("Different buffer posted: got %v, want %v", bi, buffers[2*i+1])
+ }
+ }
+ c.rxq.tx.Flush()
+}
+
+// TestReceivePostingIsFull checks that the endpoint will properly handle the
+// case when a received buffer cannot be immediately reposted because it hasn't
+// been pulled from the tx pipe yet.
+func TestReceivePostingIsFull(t *testing.T) {
+ const bufferSize = 1500
+ c := newTestContext(t, 20000, bufferSize, localLinkAddr)
+ defer c.cleanup()
+
+ // Complete first posted buffer before flushing it from the tx pipe.
+ first := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for first buffer to be posted"))
+ c.pushRxCompletion(first.Size, []queue.RxBuffer{first})
+ c.rxq.rx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Check that packet is received.
+ c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
+
+ // Complete another buffer.
+ second := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for second buffer to be posted"))
+ c.pushRxCompletion(second.Size, []queue.RxBuffer{second})
+ c.rxq.rx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Check that no packet is received yet, as the worker is blocked trying
+ // to repost.
+ select {
+ case <-time.After(500 * time.Millisecond):
+ case <-c.packetCh:
+ t.Fatalf("Unexpected packet received")
+ }
+
+ // Flush tx queue, which will allow the first buffer to be reposted,
+ // and the second completion to be pulled.
+ c.rxq.tx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Check that second packet completes.
+ c.waitForPackets(1, time.After(time.Second), "Timeout waiting for second completed packet")
+}
+
+// TestCloseWhileWaitingToPost closes the endpoint while it is waiting to
+// repost a buffer. Make sure it backs out.
+func TestCloseWhileWaitingToPost(t *testing.T) {
+ const bufferSize = 1500
+ c := newTestContext(t, 20000, bufferSize, localLinkAddr)
+ cleaned := false
+ defer func() {
+ if !cleaned {
+ c.cleanup()
+ }
+ }()
+
+ // Complete first posted buffer before flushing it from the tx pipe.
+ bi := queue.DecodeRxBufferHeader(pollPull(t, &c.rxq.tx, time.After(time.Second), "Timeout waiting for initial buffer to be posted"))
+ c.pushRxCompletion(bi.Size, []queue.RxBuffer{bi})
+ c.rxq.rx.Flush()
+ syscall.Write(c.rxCfg.EventFD, []byte{1, 0, 0, 0, 0, 0, 0, 0})
+
+ // Wait for packet to be indicated.
+ c.waitForPackets(1, time.After(time.Second), "Timeout waiting for completed packet")
+
+ // Cleanup and wait for worker to complete.
+ c.cleanup()
+ cleaned = true
+ c.ep.Wait()
+}
diff --git a/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
new file mode 100644
index 000000000..52f93f480
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/sharedmem_unsafe.go
@@ -0,0 +1,15 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sharedmem
+
+import (
+ "unsafe"
+)
+
+// sharedDataPointer converts the shared data slice into a pointer so that it
+// can be used in atomic operations.
+func sharedDataPointer(sharedData []byte) *uint32 {
+ return (*uint32)(unsafe.Pointer(&sharedData[0:4][0]))
+}
diff --git a/pkg/tcpip/link/sharedmem/tx.go b/pkg/tcpip/link/sharedmem/tx.go
new file mode 100644
index 000000000..bca1d79b4
--- /dev/null
+++ b/pkg/tcpip/link/sharedmem/tx.go
@@ -0,0 +1,262 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package sharedmem
+
+import (
+ "math"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sharedmem/queue"
+)
+
+const (
+ nilID = math.MaxUint64
+)
+
+// tx holds all state associated with a tx queue.
+type tx struct {
+ data []byte
+ q queue.Tx
+ ids idManager
+ bufs bufferManager
+}
+
+// init initializes all state needed by the tx queue based on the information
+// provided.
+//
+// The caller always retains ownership of all file descriptors passed in. The
+// queue implementation will duplicate any that it may need in the future.
+func (t *tx) init(mtu uint32, c *QueueConfig) error {
+ // Map in all buffers.
+ txPipe, err := getBuffer(c.TxPipeFD)
+ if err != nil {
+ return err
+ }
+
+ rxPipe, err := getBuffer(c.RxPipeFD)
+ if err != nil {
+ syscall.Munmap(txPipe)
+ return err
+ }
+
+ data, err := getBuffer(c.DataFD)
+ if err != nil {
+ syscall.Munmap(txPipe)
+ syscall.Munmap(rxPipe)
+ return err
+ }
+
+ // Initialize state based on buffers.
+ t.q.Init(txPipe, rxPipe)
+ t.ids.init()
+ t.bufs.init(0, len(data), int(mtu))
+ t.data = data
+
+ return nil
+}
+
+// cleanup releases all resources allocated during init(). It must only be
+// called if init() has previously succeeded.
+func (t *tx) cleanup() {
+ a, b := t.q.Bytes()
+ syscall.Munmap(a)
+ syscall.Munmap(b)
+ syscall.Munmap(t.data)
+}
+
+// transmit sends a packet made up of up to two buffers. Returns a boolean that
+// specifies whether the packet was successfully transmitted.
+func (t *tx) transmit(a, b []byte) bool {
+ // Pull completions from the tx queue and add their buffers back to the
+ // pool so that we can reuse them.
+ for {
+ id, ok := t.q.CompletedPacket()
+ if !ok {
+ break
+ }
+
+ if buf := t.ids.remove(id); buf != nil {
+ t.bufs.free(buf)
+ }
+ }
+
+ bSize := t.bufs.entrySize
+ total := uint32(len(a) + len(b))
+ bufCount := (total + bSize - 1) / bSize
+
+ // Allocate enough buffers to hold all the data.
+ var buf *queue.TxBuffer
+ for i := bufCount; i != 0; i-- {
+ b := t.bufs.alloc()
+ if b == nil {
+ // Failed to get all buffers. Return to the pool
+ // whatever we had managed to get.
+ if buf != nil {
+ t.bufs.free(buf)
+ }
+ return false
+ }
+ b.Next = buf
+ buf = b
+ }
+
+ // Copy data into allocated buffers.
+ nBuf := buf
+ var dBuf []byte
+ for _, data := range [][]byte{a, b} {
+ for len(data) > 0 {
+ if len(dBuf) == 0 {
+ dBuf = t.data[nBuf.Offset:][:nBuf.Size]
+ nBuf = nBuf.Next
+ }
+ n := copy(dBuf, data)
+ data = data[n:]
+ dBuf = dBuf[n:]
+ }
+ }
+
+ // Get an id for this packet and send it out.
+ id := t.ids.add(buf)
+ if !t.q.Enqueue(id, total, bufCount, buf) {
+ t.ids.remove(id)
+ t.bufs.free(buf)
+ return false
+ }
+
+ return true
+}
+
+// getBuffer returns a memory region mapped to the full contents of the given
+// file descriptor.
+func getBuffer(fd int) ([]byte, error) {
+ var s syscall.Stat_t
+ if err := syscall.Fstat(fd, &s); err != nil {
+ return nil, err
+ }
+
+ // Check that size doesn't overflow an int.
+ if s.Size > int64(^uint(0)>>1) {
+ return nil, syscall.EDOM
+ }
+
+ return syscall.Mmap(fd, 0, int(s.Size), syscall.PROT_READ|syscall.PROT_WRITE, syscall.MAP_SHARED|syscall.MAP_FILE)
+}
+
+// idDescriptor is used by idManager to either point to a tx buffer (in case
+// the ID is assigned) or to the next free element (if the id is not assigned).
+type idDescriptor struct {
+ buf *queue.TxBuffer
+ nextFree uint64
+}
+
+// idManager is a manager of tx buffer identifiers. It assigns unique IDs to
+// tx buffers that are added to it; the IDs can only be reused after they have
+// been removed.
+//
+// The ID assignments are stored so that the tx buffers can be retrieved from
+// the IDs previously assigned to them.
+type idManager struct {
+ // ids is a slice containing all tx buffers. The ID is the index into
+ // this slice.
+ ids []idDescriptor
+
+ // freeList a list of free IDs.
+ freeList uint64
+}
+
+// init initializes the id manager.
+func (m *idManager) init() {
+ m.freeList = nilID
+}
+
+// add assigns an ID to the given tx buffer.
+func (m *idManager) add(b *queue.TxBuffer) uint64 {
+ if i := m.freeList; i != nilID {
+ // There is an id available in the free list, just use it.
+ m.ids[i].buf = b
+ m.freeList = m.ids[i].nextFree
+ return i
+ }
+
+ // We need to expand the id descriptor.
+ m.ids = append(m.ids, idDescriptor{buf: b})
+ return uint64(len(m.ids) - 1)
+}
+
+// remove retrieves the tx buffer associated with the given ID, and removes the
+// ID from the assigned table so that it can be reused in the future.
+func (m *idManager) remove(i uint64) *queue.TxBuffer {
+ if i >= uint64(len(m.ids)) {
+ return nil
+ }
+
+ desc := &m.ids[i]
+ b := desc.buf
+ if b == nil {
+ // The provided id is not currently assigned.
+ return nil
+ }
+
+ desc.buf = nil
+ desc.nextFree = m.freeList
+ m.freeList = i
+
+ return b
+}
+
+// bufferManager manages a buffer region broken up into smaller, equally sized
+// buffers. Smaller buffers can be allocated and freed.
+type bufferManager struct {
+ freeList *queue.TxBuffer
+ curOffset uint64
+ limit uint64
+ entrySize uint32
+}
+
+// init initializes the buffer manager.
+func (b *bufferManager) init(initialOffset, size, entrySize int) {
+ b.freeList = nil
+ b.curOffset = uint64(initialOffset)
+ b.limit = uint64(initialOffset + size/entrySize*entrySize)
+ b.entrySize = uint32(entrySize)
+}
+
+// alloc allocates a buffer from the manager, if one is available.
+func (b *bufferManager) alloc() *queue.TxBuffer {
+ if b.freeList != nil {
+ // There is a descriptor ready for reuse in the free list.
+ d := b.freeList
+ b.freeList = d.Next
+ d.Next = nil
+ return d
+ }
+
+ if b.curOffset < b.limit {
+ // There is room available in the never-used range, so create
+ // a new descriptor for it.
+ d := &queue.TxBuffer{
+ Offset: b.curOffset,
+ Size: b.entrySize,
+ }
+ b.curOffset += uint64(b.entrySize)
+ return d
+ }
+
+ return nil
+}
+
+// free returns all buffers in the list to the buffer manager so that they can
+// be reused.
+func (b *bufferManager) free(d *queue.TxBuffer) {
+ // Find the last buffer in the list.
+ last := d
+ for last.Next != nil {
+ last = last.Next
+ }
+
+ // Push list onto free list.
+ last.Next = b.freeList
+ b.freeList = d
+}
diff --git a/pkg/tcpip/link/sniffer/BUILD b/pkg/tcpip/link/sniffer/BUILD
new file mode 100644
index 000000000..a912707c2
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/BUILD
@@ -0,0 +1,23 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "sniffer",
+ srcs = [
+ "pcap.go",
+ "sniffer.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/log",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/sniffer/pcap.go b/pkg/tcpip/link/sniffer/pcap.go
new file mode 100644
index 000000000..6d0dd565a
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/pcap.go
@@ -0,0 +1,52 @@
+package sniffer
+
+import "time"
+
+type pcapHeader struct {
+ // MagicNumber is the file magic number.
+ MagicNumber uint32
+
+ // VersionMajor is the major version number.
+ VersionMajor uint16
+
+ // VersionMinor is the minor version number.
+ VersionMinor uint16
+
+ // Thiszone is the GMT to local correction.
+ Thiszone int32
+
+ // Sigfigs is the accuracy of timestamps.
+ Sigfigs uint32
+
+ // Snaplen is the max length of captured packets, in octets.
+ Snaplen uint32
+
+ // Network is the data link type.
+ Network uint32
+}
+
+const pcapPacketHeaderLen = 16
+
+type pcapPacketHeader struct {
+ // Seconds is the timestamp seconds.
+ Seconds uint32
+
+ // Microseconds is the timestamp microseconds.
+ Microseconds uint32
+
+ // IncludedLength is the number of octets of packet saved in file.
+ IncludedLength uint32
+
+ // OriginalLength is the actual length of packet.
+ OriginalLength uint32
+}
+
+func newPCAPPacketHeader(incLen, orgLen uint32) pcapPacketHeader {
+ now := time.Now()
+ return pcapPacketHeader{
+ Seconds: uint32(now.Unix()),
+ Microseconds: uint32(now.Nanosecond() / 1000),
+ IncludedLength: incLen,
+ OriginalLength: orgLen,
+ }
+}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
new file mode 100644
index 000000000..da6969e94
--- /dev/null
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -0,0 +1,310 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package sniffer provides the implementation of data-link layer endpoints that
+// wrap another endpoint and logs inbound and outbound packets.
+//
+// Sniffer endpoints can be used in the networking stack by calling New(eID) to
+// create a new endpoint, where eID is the ID of the endpoint being wrapped,
+// and then passing it as an argument to Stack.CreateNIC().
+package sniffer
+
+import (
+ "bytes"
+ "encoding/binary"
+ "fmt"
+ "io"
+ "os"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// LogPackets is a flag used to enable or disable packet logging via the log
+// package. Valid values are 0 or 1.
+//
+// LogPackets must be accessed atomically.
+var LogPackets uint32 = 1
+
+// LogPacketsToFile is a flag used to enable or disable logging packets to a
+// pcap file. Valid values are 0 or 1. A file must have been specified when the
+// sniffer was created for this flag to have effect.
+//
+// LogPacketsToFile must be accessed atomically.
+var LogPacketsToFile uint32 = 1
+
+type endpoint struct {
+ dispatcher stack.NetworkDispatcher
+ lower stack.LinkEndpoint
+ file *os.File
+ maxPCAPLen uint32
+}
+
+// New creates a new sniffer link-layer endpoint. It wraps around another
+// endpoint and logs packets and they traverse the endpoint.
+func New(lower tcpip.LinkEndpointID) tcpip.LinkEndpointID {
+ return stack.RegisterLinkEndpoint(&endpoint{
+ lower: stack.FindLinkEndpoint(lower),
+ })
+}
+
+func zoneOffset() (int32, error) {
+ loc, err := time.LoadLocation("Local")
+ if err != nil {
+ return 0, err
+ }
+ date := time.Date(0, 0, 0, 0, 0, 0, 0, loc)
+ _, offset := date.Zone()
+ return int32(offset), nil
+}
+
+func writePCAPHeader(w io.Writer, maxLen uint32) error {
+ offset, err := zoneOffset()
+ if err != nil {
+ return err
+ }
+ return binary.Write(w, binary.BigEndian, pcapHeader{
+ // From https://wiki.wireshark.org/Development/LibpcapFileFormat
+ MagicNumber: 0xa1b2c3d4,
+
+ VersionMajor: 2,
+ VersionMinor: 4,
+ Thiszone: offset,
+ Sigfigs: 0,
+ Snaplen: maxLen,
+ Network: 101, // LINKTYPE_RAW
+ })
+}
+
+// NewWithFile creates a new sniffer link-layer endpoint. It wraps around
+// another endpoint and logs packets and they traverse the endpoint.
+//
+// Packets can be logged to file in the pcap format in addition to the standard
+// human-readable logs.
+//
+// snapLen is the maximum amount of a packet to be saved. Packets with a length
+// less than or equal too snapLen will be saved in their entirety. Longer
+// packets will be truncated to snapLen.
+func NewWithFile(lower tcpip.LinkEndpointID, file *os.File, snapLen uint32) (tcpip.LinkEndpointID, error) {
+ if err := writePCAPHeader(file, snapLen); err != nil {
+ return 0, err
+ }
+ return stack.RegisterLinkEndpoint(&endpoint{
+ lower: stack.FindLinkEndpoint(lower),
+ file: file,
+ maxPCAPLen: snapLen,
+ }), nil
+}
+
+// DeliverNetworkPacket implements the stack.NetworkDispatcher interface. It is
+// called by the link-layer endpoint being wrapped when a packet arrives, and
+// logs the packet before forwarding to the actual dispatcher.
+func (e *endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
+ if atomic.LoadUint32(&LogPackets) == 1 {
+ LogPacket("recv", protocol, vv.First(), nil)
+ }
+ if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
+ vs := vv.Views()
+ bs := make([][]byte, 1, 1+len(vs))
+ var length int
+ for _, v := range vs {
+ if length+len(v) > int(e.maxPCAPLen) {
+ l := int(e.maxPCAPLen) - length
+ bs = append(bs, []byte(v)[:l])
+ length += l
+ break
+ }
+ bs = append(bs, []byte(v))
+ length += len(v)
+ }
+ buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen))
+ binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(vv.Size())))
+ bs[0] = buf.Bytes()
+ if err := rawfile.NonBlockingWriteN(int(e.file.Fd()), bs...); err != nil {
+ panic(err)
+ }
+ }
+ e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, protocol, vv)
+}
+
+// Attach implements the stack.LinkEndpoint interface. It saves the dispatcher
+// and registers with the lower endpoint as its dispatcher so that "e" is called
+// for inbound packets.
+func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+ e.lower.Attach(e)
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the
+// lower endpoint.
+func (e *endpoint) MTU() uint32 {
+ return e.lower.MTU()
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the
+// request to the lower endpoint.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.lower.Capabilities()
+}
+
+// MaxHeaderLength implements the stack.LinkEndpoint interface. It just forwards
+// the request to the lower endpoint.
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.lower.MaxHeaderLength()
+}
+
+func (e *endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.lower.LinkAddress()
+}
+
+// WritePacket implements the stack.LinkEndpoint interface. It is called by
+// higher-level protocols to write packets; it just logs the packet and forwards
+// the request to the lower endpoint.
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ if atomic.LoadUint32(&LogPackets) == 1 {
+ LogPacket("send", protocol, hdr.UsedBytes(), payload)
+ }
+ if e.file != nil && atomic.LoadUint32(&LogPacketsToFile) == 1 {
+ bs := [][]byte{nil, hdr.UsedBytes(), payload}
+ var length int
+
+ for i, b := range bs[1:] {
+ if rem := int(e.maxPCAPLen) - length; len(b) > rem {
+ b = b[:rem]
+ }
+ bs[i+1] = b
+ length += len(b)
+ }
+
+ buf := bytes.NewBuffer(make([]byte, 0, pcapPacketHeaderLen))
+ binary.Write(buf, binary.BigEndian, newPCAPPacketHeader(uint32(length), uint32(hdr.UsedLength()+len(payload))))
+ bs[0] = buf.Bytes()
+ if err := rawfile.NonBlockingWriteN(int(e.file.Fd()), bs...); err != nil {
+ panic(err)
+ }
+ }
+ return e.lower.WritePacket(r, hdr, payload, protocol)
+}
+
+// LogPacket logs the given packet.
+func LogPacket(prefix string, protocol tcpip.NetworkProtocolNumber, b, plb []byte) {
+ // Figure out the network layer info.
+ var transProto uint8
+ src := tcpip.Address("unknown")
+ dst := tcpip.Address("unknown")
+ id := 0
+ size := uint16(0)
+ switch protocol {
+ case header.IPv4ProtocolNumber:
+ ipv4 := header.IPv4(b)
+ src = ipv4.SourceAddress()
+ dst = ipv4.DestinationAddress()
+ transProto = ipv4.Protocol()
+ size = ipv4.TotalLength() - uint16(ipv4.HeaderLength())
+ b = b[ipv4.HeaderLength():]
+ id = int(ipv4.ID())
+
+ case header.IPv6ProtocolNumber:
+ ipv6 := header.IPv6(b)
+ src = ipv6.SourceAddress()
+ dst = ipv6.DestinationAddress()
+ transProto = ipv6.NextHeader()
+ size = ipv6.PayloadLength()
+ b = b[header.IPv6MinimumSize:]
+
+ case header.ARPProtocolNumber:
+ arp := header.ARP(b)
+ log.Infof(
+ "%s arp %v (%v) -> %v (%v) valid:%v",
+ prefix,
+ tcpip.Address(arp.ProtocolAddressSender()), tcpip.LinkAddress(arp.HardwareAddressSender()),
+ tcpip.Address(arp.ProtocolAddressTarget()), tcpip.LinkAddress(arp.HardwareAddressTarget()),
+ arp.IsValid(),
+ )
+ return
+ default:
+ log.Infof("%s unknown network protocol", prefix)
+ return
+ }
+
+ // Figure out the transport layer info.
+ transName := "unknown"
+ srcPort := uint16(0)
+ dstPort := uint16(0)
+ details := ""
+ switch tcpip.TransportProtocolNumber(transProto) {
+ case header.ICMPv4ProtocolNumber:
+ transName = "icmp"
+ icmp := header.ICMPv4(b)
+ icmpType := "unknown"
+ switch icmp.Type() {
+ case header.ICMPv4EchoReply:
+ icmpType = "echo reply"
+ case header.ICMPv4DstUnreachable:
+ icmpType = "destination unreachable"
+ case header.ICMPv4SrcQuench:
+ icmpType = "source quench"
+ case header.ICMPv4Redirect:
+ icmpType = "redirect"
+ case header.ICMPv4Echo:
+ icmpType = "echo"
+ case header.ICMPv4TimeExceeded:
+ icmpType = "time exceeded"
+ case header.ICMPv4ParamProblem:
+ icmpType = "param problem"
+ case header.ICMPv4Timestamp:
+ icmpType = "timestamp"
+ case header.ICMPv4TimestampReply:
+ icmpType = "timestamp reply"
+ case header.ICMPv4InfoRequest:
+ icmpType = "info request"
+ case header.ICMPv4InfoReply:
+ icmpType = "info reply"
+ }
+ log.Infof("%s %s %v -> %v %s len:%d id:%04x code:%d", prefix, transName, src, dst, icmpType, size, id, icmp.Code())
+ return
+
+ case header.UDPProtocolNumber:
+ transName = "udp"
+ udp := header.UDP(b)
+ srcPort = udp.SourcePort()
+ dstPort = udp.DestinationPort()
+ size -= header.UDPMinimumSize
+
+ details = fmt.Sprintf("xsum: 0x%x", udp.Checksum())
+
+ case header.TCPProtocolNumber:
+ transName = "tcp"
+ tcp := header.TCP(b)
+ srcPort = tcp.SourcePort()
+ dstPort = tcp.DestinationPort()
+ size -= uint16(tcp.DataOffset())
+
+ // Initialize the TCP flags.
+ flags := tcp.Flags()
+ flagsStr := []byte("FSRPAU")
+ for i := range flagsStr {
+ if flags&(1<<uint(i)) == 0 {
+ flagsStr[i] = ' '
+ }
+ }
+ details = fmt.Sprintf("flags:0x%02x (%v) seqnum: %v ack: %v win: %v xsum:0x%x", flags, string(flagsStr), tcp.SequenceNumber(), tcp.AckNumber(), tcp.WindowSize(), tcp.Checksum())
+ if flags&header.TCPFlagSyn != 0 {
+ details += fmt.Sprintf(" options: %+v", header.ParseSynOptions(tcp.Options(), flags&header.TCPFlagAck != 0))
+ } else {
+ details += fmt.Sprintf(" options: %+v", tcp.ParsedOptions())
+ }
+ default:
+ log.Infof("%s %v -> %v unknown transport protocol: %d", prefix, src, dst, transProto)
+ return
+ }
+
+ log.Infof("%s %s %v:%v -> %v:%v len:%d id:%04x %s", prefix, transName, src, srcPort, dst, dstPort, size, id, details)
+}
diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD
new file mode 100644
index 000000000..d627f00f1
--- /dev/null
+++ b/pkg/tcpip/link/tun/BUILD
@@ -0,0 +1,12 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "tun",
+ srcs = ["tun_unsafe.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/tun",
+ visibility = [
+ "//visibility:public",
+ ],
+)
diff --git a/pkg/tcpip/link/tun/tun_unsafe.go b/pkg/tcpip/link/tun/tun_unsafe.go
new file mode 100644
index 000000000..5b6c9b4ab
--- /dev/null
+++ b/pkg/tcpip/link/tun/tun_unsafe.go
@@ -0,0 +1,50 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tun
+
+import (
+ "syscall"
+ "unsafe"
+)
+
+// Open opens the specified TUN device, sets it to non-blocking mode, and
+// returns its file descriptor.
+func Open(name string) (int, error) {
+ return open(name, syscall.IFF_TUN|syscall.IFF_NO_PI)
+}
+
+// OpenTAP opens the specified TAP device, sets it to non-blocking mode, and
+// returns its file descriptor.
+func OpenTAP(name string) (int, error) {
+ return open(name, syscall.IFF_TAP|syscall.IFF_NO_PI)
+}
+
+func open(name string, flags uint16) (int, error) {
+ fd, err := syscall.Open("/dev/net/tun", syscall.O_RDWR, 0)
+ if err != nil {
+ return -1, err
+ }
+
+ var ifr struct {
+ name [16]byte
+ flags uint16
+ _ [22]byte
+ }
+
+ copy(ifr.name[:], name)
+ ifr.flags = flags
+ _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), syscall.TUNSETIFF, uintptr(unsafe.Pointer(&ifr)))
+ if errno != 0 {
+ syscall.Close(fd)
+ return -1, errno
+ }
+
+ if err = syscall.SetNonblock(fd, true); err != nil {
+ syscall.Close(fd)
+ return -1, err
+ }
+
+ return fd, nil
+}
diff --git a/pkg/tcpip/link/waitable/BUILD b/pkg/tcpip/link/waitable/BUILD
new file mode 100644
index 000000000..63b648be7
--- /dev/null
+++ b/pkg/tcpip/link/waitable/BUILD
@@ -0,0 +1,33 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "waitable",
+ srcs = [
+ "waitable.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/link/waitable",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/gate",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "waitable_test",
+ srcs = [
+ "waitable_test.go",
+ ],
+ embed = [":waitable"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/link/waitable/waitable.go b/pkg/tcpip/link/waitable/waitable.go
new file mode 100644
index 000000000..2c6e73f22
--- /dev/null
+++ b/pkg/tcpip/link/waitable/waitable.go
@@ -0,0 +1,108 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package waitable provides the implementation of data-link layer endpoints
+// that wrap other endpoints, and can wait for inflight calls to WritePacket or
+// DeliverNetworkPacket to finish (and new ones to be prevented).
+//
+// Waitable endpoints can be used in the networking stack by calling New(eID) to
+// create a new endpoint, where eID is the ID of the endpoint being wrapped,
+// and then passing it as an argument to Stack.CreateNIC().
+package waitable
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/gate"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// Endpoint is a waitable link-layer endpoint.
+type Endpoint struct {
+ dispatchGate gate.Gate
+ dispatcher stack.NetworkDispatcher
+
+ writeGate gate.Gate
+ lower stack.LinkEndpoint
+}
+
+// New creates a new waitable link-layer endpoint. It wraps around another
+// endpoint and allows the caller to block new write/dispatch calls and wait for
+// the inflight ones to finish before returning.
+func New(lower tcpip.LinkEndpointID) (tcpip.LinkEndpointID, *Endpoint) {
+ e := &Endpoint{
+ lower: stack.FindLinkEndpoint(lower),
+ }
+ return stack.RegisterLinkEndpoint(e), e
+}
+
+// DeliverNetworkPacket implements stack.NetworkDispatcher.DeliverNetworkPacket.
+// It is called by the link-layer endpoint being wrapped when a packet arrives,
+// and only forwards to the actual dispatcher if Wait or WaitDispatch haven't
+// been called.
+func (e *Endpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
+ if !e.dispatchGate.Enter() {
+ return
+ }
+
+ e.dispatcher.DeliverNetworkPacket(e, remoteLinkAddr, protocol, vv)
+ e.dispatchGate.Leave()
+}
+
+// Attach implements stack.LinkEndpoint.Attach. It saves the dispatcher and
+// registers with the lower endpoint as its dispatcher so that "e" is called
+// for inbound packets.
+func (e *Endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.dispatcher = dispatcher
+ e.lower.Attach(e)
+}
+
+// MTU implements stack.LinkEndpoint.MTU. It just forwards the request to the
+// lower endpoint.
+func (e *Endpoint) MTU() uint32 {
+ return e.lower.MTU()
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities. It just forwards the
+// request to the lower endpoint.
+func (e *Endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.lower.Capabilities()
+}
+
+// MaxHeaderLength implements stack.LinkEndpoint.MaxHeaderLength. It just
+// forwards the request to the lower endpoint.
+func (e *Endpoint) MaxHeaderLength() uint16 {
+ return e.lower.MaxHeaderLength()
+}
+
+// LinkAddress implements stack.LinkEndpoint.LinkAddress. It just forwards the
+// request to the lower endpoint.
+func (e *Endpoint) LinkAddress() tcpip.LinkAddress {
+ return e.lower.LinkAddress()
+}
+
+// WritePacket implements stack.LinkEndpoint.WritePacket. It is called by
+// higher-level protocols to write packets. It only forwards packets to the
+// lower endpoint if Wait or WaitWrite haven't been called.
+func (e *Endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ if !e.writeGate.Enter() {
+ return nil
+ }
+
+ err := e.lower.WritePacket(r, hdr, payload, protocol)
+ e.writeGate.Leave()
+ return err
+}
+
+// WaitWrite prevents new calls to WritePacket from reaching the lower endpoint,
+// and waits for inflight ones to finish before returning.
+func (e *Endpoint) WaitWrite() {
+ e.writeGate.Close()
+}
+
+// WaitDispatch prevents new calls to DeliverNetworkPacket from reaching the
+// actual dispatcher, and waits for inflight ones to finish before returning.
+func (e *Endpoint) WaitDispatch() {
+ e.dispatchGate.Close()
+}
diff --git a/pkg/tcpip/link/waitable/waitable_test.go b/pkg/tcpip/link/waitable/waitable_test.go
new file mode 100644
index 000000000..cb433dc19
--- /dev/null
+++ b/pkg/tcpip/link/waitable/waitable_test.go
@@ -0,0 +1,144 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package waitable
+
+import (
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+type countedEndpoint struct {
+ dispatchCount int
+ writeCount int
+ attachCount int
+
+ mtu uint32
+ capabilities stack.LinkEndpointCapabilities
+ hdrLen uint16
+ linkAddr tcpip.LinkAddress
+
+ dispatcher stack.NetworkDispatcher
+}
+
+func (e *countedEndpoint) DeliverNetworkPacket(linkEP stack.LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
+ e.dispatchCount++
+}
+
+func (e *countedEndpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ e.attachCount++
+ e.dispatcher = dispatcher
+}
+
+func (e *countedEndpoint) MTU() uint32 {
+ return e.mtu
+}
+
+func (e *countedEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.capabilities
+}
+
+func (e *countedEndpoint) MaxHeaderLength() uint16 {
+ return e.hdrLen
+}
+
+func (e *countedEndpoint) LinkAddress() tcpip.LinkAddress {
+ return e.linkAddr
+}
+
+func (e *countedEndpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ e.writeCount++
+ return nil
+}
+
+func TestWaitWrite(t *testing.T) {
+ ep := &countedEndpoint{}
+ _, wep := New(stack.RegisterLinkEndpoint(ep))
+
+ // Write and check that it goes through.
+ wep.WritePacket(nil, nil, nil, 0)
+ if want := 1; ep.writeCount != want {
+ t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
+ }
+
+ // Wait on dispatches, then try to write. It must go through.
+ wep.WaitDispatch()
+ wep.WritePacket(nil, nil, nil, 0)
+ if want := 2; ep.writeCount != want {
+ t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
+ }
+
+ // Wait on writes, then try to write. It must not go through.
+ wep.WaitWrite()
+ wep.WritePacket(nil, nil, nil, 0)
+ if want := 2; ep.writeCount != want {
+ t.Fatalf("Unexpected writeCount: got=%v, want=%v", ep.writeCount, want)
+ }
+}
+
+func TestWaitDispatch(t *testing.T) {
+ ep := &countedEndpoint{}
+ _, wep := New(stack.RegisterLinkEndpoint(ep))
+
+ // Check that attach happens.
+ wep.Attach(ep)
+ if want := 1; ep.attachCount != want {
+ t.Fatalf("Unexpected attachCount: got=%v, want=%v", ep.attachCount, want)
+ }
+
+ // Dispatch and check that it goes through.
+ ep.dispatcher.DeliverNetworkPacket(ep, "", 0, nil)
+ if want := 1; ep.dispatchCount != want {
+ t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
+ }
+
+ // Wait on writes, then try to dispatch. It must go through.
+ wep.WaitWrite()
+ ep.dispatcher.DeliverNetworkPacket(ep, "", 0, nil)
+ if want := 2; ep.dispatchCount != want {
+ t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
+ }
+
+ // Wait on dispatches, then try to dispatch. It must not go through.
+ wep.WaitDispatch()
+ ep.dispatcher.DeliverNetworkPacket(ep, "", 0, nil)
+ if want := 2; ep.dispatchCount != want {
+ t.Fatalf("Unexpected dispatchCount: got=%v, want=%v", ep.dispatchCount, want)
+ }
+}
+
+func TestOtherMethods(t *testing.T) {
+ const (
+ mtu = 0xdead
+ capabilities = 0xbeef
+ hdrLen = 0x1234
+ linkAddr = "test address"
+ )
+ ep := &countedEndpoint{
+ mtu: mtu,
+ capabilities: capabilities,
+ hdrLen: hdrLen,
+ linkAddr: linkAddr,
+ }
+ _, wep := New(stack.RegisterLinkEndpoint(ep))
+
+ if v := wep.MTU(); v != mtu {
+ t.Fatalf("Unexpected mtu: got=%v, want=%v", v, mtu)
+ }
+
+ if v := wep.Capabilities(); v != capabilities {
+ t.Fatalf("Unexpected capabilities: got=%v, want=%v", v, capabilities)
+ }
+
+ if v := wep.MaxHeaderLength(); v != hdrLen {
+ t.Fatalf("Unexpected MaxHeaderLength: got=%v, want=%v", v, hdrLen)
+ }
+
+ if v := wep.LinkAddress(); v != linkAddr {
+ t.Fatalf("Unexpected LinkAddress: got=%q, want=%q", v, linkAddr)
+ }
+}
diff --git a/pkg/tcpip/network/BUILD b/pkg/tcpip/network/BUILD
new file mode 100644
index 000000000..36ddaa692
--- /dev/null
+++ b/pkg/tcpip/network/BUILD
@@ -0,0 +1,19 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+
+go_test(
+ name = "ip_test",
+ size = "small",
+ srcs = [
+ "ip_test.go",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD
new file mode 100644
index 000000000..e6d0899a9
--- /dev/null
+++ b/pkg/tcpip/network/arp/BUILD
@@ -0,0 +1,34 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "arp",
+ srcs = ["arp.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/arp",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
+
+go_test(
+ name = "arp_test",
+ size = "small",
+ srcs = ["arp_test.go"],
+ deps = [
+ ":arp",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go
new file mode 100644
index 000000000..4e3d7f597
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp.go
@@ -0,0 +1,170 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package arp implements the ARP network protocol. It is used to resolve
+// IPv4 addresses into link-local MAC addresses, and advertises IPv4
+// addresses of its stack with the local network.
+//
+// To use it in the networking stack, pass arp.ProtocolName as one of the
+// network protocols when calling stack.New. Then add an "arp" address to
+// every NIC on the stack that should respond to ARP requests. That is:
+//
+// if err := s.AddAddress(1, arp.ProtocolNumber, "arp"); err != nil {
+// // handle err
+// }
+package arp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolName is the string representation of the ARP protocol name.
+ ProtocolName = "arp"
+
+ // ProtocolNumber is the ARP protocol number.
+ ProtocolNumber = header.ARPProtocolNumber
+
+ // ProtocolAddress is the address expected by the ARP endpoint.
+ ProtocolAddress = tcpip.Address("arp")
+)
+
+// endpoint implements stack.NetworkEndpoint.
+type endpoint struct {
+ nicid tcpip.NICID
+ addr tcpip.Address
+ linkEP stack.LinkEndpoint
+ linkAddrCache stack.LinkAddressCache
+}
+
+func (e *endpoint) MTU() uint32 {
+ lmtu := e.linkEP.MTU()
+ return lmtu - uint32(e.MaxHeaderLength())
+}
+
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicid
+}
+
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &stack.NetworkEndpointID{ProtocolAddress}
+}
+
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.ARPSize
+}
+
+func (e *endpoint) Close() {}
+
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) {
+ v := vv.First()
+ h := header.ARP(v)
+ if !h.IsValid() {
+ return
+ }
+
+ switch h.Op() {
+ case header.ARPRequest:
+ localAddr := tcpip.Address(h.ProtocolAddressTarget())
+ if e.linkAddrCache.CheckLocalAddress(e.nicid, header.IPv4ProtocolNumber, localAddr) == 0 {
+ return // we have no useful answer, ignore the request
+ }
+ hdr := buffer.NewPrependable(int(e.linkEP.MaxHeaderLength()) + header.ARPSize)
+ pkt := header.ARP(hdr.Prepend(header.ARPSize))
+ pkt.SetIPv4OverEthernet()
+ pkt.SetOp(header.ARPReply)
+ copy(pkt.HardwareAddressSender(), r.LocalLinkAddress[:])
+ copy(pkt.ProtocolAddressSender(), h.ProtocolAddressTarget())
+ copy(pkt.ProtocolAddressTarget(), h.ProtocolAddressSender())
+ e.linkEP.WritePacket(r, &hdr, nil, ProtocolNumber)
+ fallthrough // also fill the cache from requests
+ case header.ARPReply:
+ addr := tcpip.Address(h.ProtocolAddressSender())
+ linkAddr := tcpip.LinkAddress(h.HardwareAddressSender())
+ e.linkAddrCache.AddLinkAddress(e.nicid, addr, linkAddr)
+ }
+}
+
+// protocol implements stack.NetworkProtocol and stack.LinkAddressResolver.
+type protocol struct {
+}
+
+func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber }
+func (p *protocol) MinimumPacketSize() int { return header.ARPSize }
+
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.ARP(v)
+ return tcpip.Address(h.ProtocolAddressSender()), ProtocolAddress
+}
+
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, sender stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ if addr != ProtocolAddress {
+ return nil, tcpip.ErrBadLocalAddress
+ }
+ return &endpoint{
+ nicid: nicid,
+ addr: addr,
+ linkEP: sender,
+ linkAddrCache: linkAddrCache,
+ }, nil
+}
+
+// LinkAddressProtocol implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return header.IPv4ProtocolNumber
+}
+
+// LinkAddressRequest implements stack.LinkAddressResolver.
+func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, linkEP stack.LinkEndpoint) *tcpip.Error {
+ r := &stack.Route{
+ RemoteLinkAddress: broadcastMAC,
+ }
+
+ hdr := buffer.NewPrependable(int(linkEP.MaxHeaderLength()) + header.ARPSize)
+ h := header.ARP(hdr.Prepend(header.ARPSize))
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), linkEP.LinkAddress())
+ copy(h.ProtocolAddressSender(), localAddr)
+ copy(h.ProtocolAddressTarget(), addr)
+
+ return linkEP.WritePacket(r, &hdr, nil, ProtocolNumber)
+}
+
+// ResolveStaticAddress implements stack.LinkAddressResolver.
+func (*protocol) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "\xff\xff\xff\xff" {
+ return broadcastMAC, true
+ }
+ return "", false
+}
+
+// SetOption implements NetworkProtocol.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements NetworkProtocol.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+var broadcastMAC = tcpip.LinkAddress([]byte{0xff, 0xff, 0xff, 0xff, 0xff, 0xff})
+
+func init() {
+ stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go
new file mode 100644
index 000000000..91ffdce4b
--- /dev/null
+++ b/pkg/tcpip/network/arp/arp_test.go
@@ -0,0 +1,138 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package arp_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/arp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ stackLinkAddr = tcpip.LinkAddress("\x0a\x0a\x0b\x0b\x0c\x0c")
+ stackAddr1 = tcpip.Address("\x0a\x00\x00\x01")
+ stackAddr2 = tcpip.Address("\x0a\x00\x00\x02")
+ stackAddrBad = tcpip.Address("\x0a\x00\x00\x03")
+)
+
+type testContext struct {
+ t *testing.T
+ linkEP *channel.Endpoint
+ s *stack.Stack
+}
+
+func newTestContext(t *testing.T) *testContext {
+ s := stack.New([]string{ipv4.ProtocolName, arp.ProtocolName}, []string{ipv4.PingProtocolName})
+
+ const defaultMTU = 65536
+ id, linkEP := channel.New(256, defaultMTU, stackLinkAddr)
+ if testing.Verbose() {
+ id = sniffer.New(id)
+ }
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr1); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ }
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr2); err != nil {
+ t.Fatalf("AddAddress for ipv4 failed: %v", err)
+ }
+ if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ t.Fatalf("AddAddress for arp failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: "\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ }})
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEP: linkEP,
+ }
+}
+
+func (c *testContext) cleanup() {
+ close(c.linkEP.C)
+}
+
+func TestDirectRequest(t *testing.T) {
+ c := newTestContext(t)
+ defer c.cleanup()
+
+ const senderMAC = "\x01\x02\x03\x04\x05\x06"
+ const senderIPv4 = "\x0a\x00\x00\x02"
+
+ v := make(buffer.View, header.ARPSize)
+ h := header.ARP(v)
+ h.SetIPv4OverEthernet()
+ h.SetOp(header.ARPRequest)
+ copy(h.HardwareAddressSender(), senderMAC)
+ copy(h.ProtocolAddressSender(), senderIPv4)
+
+ // stackAddr1
+ copy(h.ProtocolAddressTarget(), stackAddr1)
+ vv := v.ToVectorisedView([1]buffer.View{})
+ c.linkEP.Inject(arp.ProtocolNumber, &vv)
+ pkt := <-c.linkEP.C
+ if pkt.Proto != arp.ProtocolNumber {
+ t.Fatalf("stackAddr1: expected ARP response, got network protocol number %v", pkt.Proto)
+ }
+ rep := header.ARP(pkt.Header)
+ if !rep.IsValid() {
+ t.Fatalf("stackAddr1: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
+ }
+ if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr1 {
+ t.Errorf("stackAddr1: expected sender to be set")
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
+ t.Errorf("stackAddr1: expected sender to be stackLinkAddr, got %q", got)
+ }
+
+ // stackAddr2
+ copy(h.ProtocolAddressTarget(), stackAddr2)
+ vv = v.ToVectorisedView([1]buffer.View{})
+ c.linkEP.Inject(arp.ProtocolNumber, &vv)
+ pkt = <-c.linkEP.C
+ if pkt.Proto != arp.ProtocolNumber {
+ t.Fatalf("stackAddr2: expected ARP response, got network protocol number %v", pkt.Proto)
+ }
+ rep = header.ARP(pkt.Header)
+ if !rep.IsValid() {
+ t.Fatalf("stackAddr2: invalid ARP response len(pkt.Header)=%d", len(pkt.Header))
+ }
+ if tcpip.Address(rep.ProtocolAddressSender()) != stackAddr2 {
+ t.Errorf("stackAddr2: expected sender to be set")
+ }
+ if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr {
+ t.Errorf("stackAddr2: expected sender to be stackLinkAddr, got %q", got)
+ }
+
+ // stackAddrBad
+ copy(h.ProtocolAddressTarget(), stackAddrBad)
+ vv = v.ToVectorisedView([1]buffer.View{})
+ c.linkEP.Inject(arp.ProtocolNumber, &vv)
+ select {
+ case pkt := <-c.linkEP.C:
+ t.Errorf("stackAddrBad: unexpected packet sent, Proto=%v", pkt.Proto)
+ case <-time.After(100 * time.Millisecond):
+ // Sleep tests are gross, but this will only
+ // potentially fail flakily if there's a bugj
+ // If there is no bug this will reliably succeed.
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/BUILD b/pkg/tcpip/network/fragmentation/BUILD
new file mode 100644
index 000000000..78fe878ec
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/BUILD
@@ -0,0 +1,61 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "fragmentation_state",
+ srcs = ["reassembler_list.go"],
+ out = "fragmentation_state.go",
+ package = "fragmentation",
+)
+
+go_template_instance(
+ name = "reassembler_list",
+ out = "reassembler_list.go",
+ package = "fragmentation",
+ prefix = "reassembler",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Linker": "*reassembler",
+ },
+)
+
+go_library(
+ name = "fragmentation",
+ srcs = [
+ "frag_heap.go",
+ "fragmentation.go",
+ "fragmentation_state.go",
+ "reassembler.go",
+ "reassembler_list.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/fragmentation",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/log",
+ "//pkg/state",
+ "//pkg/tcpip/buffer",
+ ],
+)
+
+go_test(
+ name = "fragmentation_test",
+ size = "small",
+ srcs = [
+ "frag_heap_test.go",
+ "fragmentation_test.go",
+ "reassembler_test.go",
+ ],
+ embed = [":fragmentation"],
+ deps = ["//pkg/tcpip/buffer"],
+)
+
+filegroup(
+ name = "autogen",
+ srcs = [
+ "reassembler_list.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/tcpip/network/fragmentation/frag_heap.go b/pkg/tcpip/network/fragmentation/frag_heap.go
new file mode 100644
index 000000000..2e8512909
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/frag_heap.go
@@ -0,0 +1,67 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+type fragment struct {
+ offset uint16
+ vv *buffer.VectorisedView
+}
+
+type fragHeap []fragment
+
+func (h *fragHeap) Len() int {
+ return len(*h)
+}
+
+func (h *fragHeap) Less(i, j int) bool {
+ return (*h)[i].offset < (*h)[j].offset
+}
+
+func (h *fragHeap) Swap(i, j int) {
+ (*h)[i], (*h)[j] = (*h)[j], (*h)[i]
+}
+
+func (h *fragHeap) Push(x interface{}) {
+ *h = append(*h, x.(fragment))
+}
+
+func (h *fragHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[:n-1]
+ return x
+}
+
+// reassamble empties the heap and returns a VectorisedView
+// containing a reassambled version of the fragments inside the heap.
+func (h *fragHeap) reassemble() (buffer.VectorisedView, error) {
+ curr := heap.Pop(h).(fragment)
+ views := curr.vv.Views()
+ size := curr.vv.Size()
+
+ if curr.offset != 0 {
+ return buffer.NewVectorisedView(0, nil), fmt.Errorf("offset of the first packet is != 0 (%d)", curr.offset)
+ }
+
+ for h.Len() > 0 {
+ curr := heap.Pop(h).(fragment)
+ if int(curr.offset) < size {
+ curr.vv.TrimFront(size - int(curr.offset))
+ } else if int(curr.offset) > size {
+ return buffer.NewVectorisedView(0, nil), fmt.Errorf("packet has a hole, expected offset %d, got %d", size, curr.offset)
+ }
+ size += curr.vv.Size()
+ views = append(views, curr.vv.Views()...)
+ }
+ return buffer.NewVectorisedView(size, views), nil
+}
diff --git a/pkg/tcpip/network/fragmentation/frag_heap_test.go b/pkg/tcpip/network/fragmentation/frag_heap_test.go
new file mode 100644
index 000000000..218a24d7b
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/frag_heap_test.go
@@ -0,0 +1,112 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "reflect"
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+var reassambleTestCases = []struct {
+ comment string
+ in []fragment
+ want *buffer.VectorisedView
+}{
+ {
+ comment: "Non-overlapping in-order",
+ in: []fragment{
+ {offset: 0, vv: vv(1, "0")},
+ {offset: 1, vv: vv(1, "1")},
+ },
+ want: vv(2, "0", "1"),
+ },
+ {
+ comment: "Non-overlapping out-of-order",
+ in: []fragment{
+ {offset: 1, vv: vv(1, "1")},
+ {offset: 0, vv: vv(1, "0")},
+ },
+ want: vv(2, "0", "1"),
+ },
+ {
+ comment: "Duplicated packets",
+ in: []fragment{
+ {offset: 0, vv: vv(1, "0")},
+ {offset: 0, vv: vv(1, "0")},
+ },
+ want: vv(1, "0"),
+ },
+ {
+ comment: "Overlapping in-order",
+ in: []fragment{
+ {offset: 0, vv: vv(2, "01")},
+ {offset: 1, vv: vv(2, "12")},
+ },
+ want: vv(3, "01", "2"),
+ },
+ {
+ comment: "Overlapping out-of-order",
+ in: []fragment{
+ {offset: 1, vv: vv(2, "12")},
+ {offset: 0, vv: vv(2, "01")},
+ },
+ want: vv(3, "01", "2"),
+ },
+ {
+ comment: "Overlapping subset in-order",
+ in: []fragment{
+ {offset: 0, vv: vv(3, "012")},
+ {offset: 1, vv: vv(1, "1")},
+ },
+ want: vv(3, "012"),
+ },
+ {
+ comment: "Overlapping subset out-of-order",
+ in: []fragment{
+ {offset: 1, vv: vv(1, "1")},
+ {offset: 0, vv: vv(3, "012")},
+ },
+ want: vv(3, "012"),
+ },
+}
+
+func TestReassamble(t *testing.T) {
+ for _, c := range reassambleTestCases {
+ h := (fragHeap)(make([]fragment, 0, 8))
+ heap.Init(&h)
+ for _, f := range c.in {
+ heap.Push(&h, f)
+ }
+ got, _ := h.reassemble()
+
+ if !reflect.DeepEqual(got, *c.want) {
+ t.Errorf("Test \"%s\" reassembling failed. Got %v. Want %v", c.comment, got, *c.want)
+ }
+ }
+}
+
+func TestReassambleFailsForNonZeroOffset(t *testing.T) {
+ h := (fragHeap)(make([]fragment, 0, 8))
+ heap.Init(&h)
+ heap.Push(&h, fragment{offset: 1, vv: vv(1, "0")})
+ _, err := h.reassemble()
+ if err == nil {
+ t.Errorf("reassemble() did not fail when the first packet had offset != 0")
+ }
+}
+
+func TestReassambleFailsForHoles(t *testing.T) {
+ h := (fragHeap)(make([]fragment, 0, 8))
+ heap.Init(&h)
+ heap.Push(&h, fragment{offset: 0, vv: vv(1, "0")})
+ heap.Push(&h, fragment{offset: 2, vv: vv(1, "1")})
+ _, err := h.reassemble()
+ if err == nil {
+ t.Errorf("reassemble() did not fail when there was a hole in the packet")
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go
new file mode 100644
index 000000000..a309a24c5
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/fragmentation.go
@@ -0,0 +1,124 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package fragmentation contains the implementation of IP fragmentation.
+// It is based on RFC 791 and RFC 815.
+package fragmentation
+
+import (
+ "log"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+// DefaultReassembleTimeout is based on the linux stack: net.ipv4.ipfrag_time.
+const DefaultReassembleTimeout = 30 * time.Second
+
+// HighFragThreshold is the threshold at which we start trimming old
+// fragmented packets. Linux uses a default value of 4 MB. See
+// net.ipv4.ipfrag_high_thresh for more information.
+const HighFragThreshold = 4 << 20 // 4MB
+
+// LowFragThreshold is the threshold we reach to when we start dropping
+// older fragmented packets. It's important that we keep enough room for newer
+// packets to be re-assembled. Hence, this needs to be lower than
+// HighFragThreshold enough. Linux uses a default value of 3 MB. See
+// net.ipv4.ipfrag_low_thresh for more information.
+const LowFragThreshold = 3 << 20 // 3MB
+
+// Fragmentation is the main structure that other modules
+// of the stack should use to implement IP Fragmentation.
+type Fragmentation struct {
+ mu sync.Mutex
+ highLimit int
+ lowLimit int
+ reassemblers map[uint32]*reassembler
+ rList reassemblerList
+ size int
+ timeout time.Duration
+}
+
+// NewFragmentation creates a new Fragmentation.
+//
+// highMemoryLimit specifies the limit on the memory consumed
+// by the fragments stored by Fragmentation (overhead of internal data-structures
+// is not accounted). Fragments are dropped when the limit is reached.
+//
+// lowMemoryLimit specifies the limit on which we will reach by dropping
+// fragments after reaching highMemoryLimit.
+//
+// reassemblingTimeout specifes the maximum time allowed to reassemble a packet.
+// Fragments are lazily evicted only when a new a packet with an
+// already existing fragmentation-id arrives after the timeout.
+func NewFragmentation(highMemoryLimit, lowMemoryLimit int, reassemblingTimeout time.Duration) *Fragmentation {
+ if lowMemoryLimit >= highMemoryLimit {
+ lowMemoryLimit = highMemoryLimit
+ }
+
+ if lowMemoryLimit < 0 {
+ lowMemoryLimit = 0
+ }
+
+ return &Fragmentation{
+ reassemblers: make(map[uint32]*reassembler),
+ highLimit: highMemoryLimit,
+ lowLimit: lowMemoryLimit,
+ timeout: reassemblingTimeout,
+ }
+}
+
+// Process processes an incoming fragment beloning to an ID
+// and returns a complete packet when all the packets belonging to that ID have been received.
+func (f *Fragmentation) Process(id uint32, first, last uint16, more bool, vv *buffer.VectorisedView) (buffer.VectorisedView, bool) {
+ f.mu.Lock()
+ r, ok := f.reassemblers[id]
+ if ok && r.tooOld(f.timeout) {
+ // This is very likely to be an id-collision or someone performing a slow-rate attack.
+ f.release(r)
+ ok = false
+ }
+ if !ok {
+ r = newReassembler(id)
+ f.reassemblers[id] = r
+ f.rList.PushFront(r)
+ }
+ f.mu.Unlock()
+
+ res, done, consumed := r.process(first, last, more, vv)
+
+ f.mu.Lock()
+ f.size += consumed
+ if done {
+ f.release(r)
+ }
+ // Evict reassemblers if we are consuming more memory than highLimit until
+ // we reach lowLimit.
+ if f.size > f.highLimit {
+ tail := f.rList.Back()
+ for f.size > f.lowLimit && tail != nil {
+ f.release(tail)
+ tail = tail.Prev()
+ }
+ }
+ f.mu.Unlock()
+ return res, done
+}
+
+func (f *Fragmentation) release(r *reassembler) {
+ // Before releasing a fragment we need to check if r is already marked as done.
+ // Otherwise, we would delete it twice.
+ if r.checkDoneOrMark() {
+ return
+ }
+
+ delete(f.reassemblers, r.id)
+ f.rList.Remove(r)
+ f.size -= r.size
+ if f.size < 0 {
+ log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size)
+ f.size = 0
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go
new file mode 100644
index 000000000..2f0200d26
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go
@@ -0,0 +1,166 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fragmentation
+
+import (
+ "reflect"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+// vv is a helper to build VectorisedView from different strings.
+func vv(size int, pieces ...string) *buffer.VectorisedView {
+ views := make([]buffer.View, len(pieces))
+ for i, p := range pieces {
+ views[i] = []byte(p)
+ }
+
+ vv := buffer.NewVectorisedView(size, views)
+ return &vv
+}
+
+func emptyVv() *buffer.VectorisedView {
+ vv := buffer.NewVectorisedView(0, nil)
+ return &vv
+}
+
+type processInput struct {
+ id uint32
+ first uint16
+ last uint16
+ more bool
+ vv *buffer.VectorisedView
+}
+
+type processOutput struct {
+ vv *buffer.VectorisedView
+ done bool
+}
+
+var processTestCases = []struct {
+ comment string
+ in []processInput
+ out []processOutput
+}{
+ {
+ comment: "One ID",
+ in: []processInput{
+ {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ },
+ out: []processOutput{
+ {vv: emptyVv(), done: false},
+ {vv: vv(4, "01", "23"), done: true},
+ },
+ },
+ {
+ comment: "Two IDs",
+ in: []processInput{
+ {id: 0, first: 0, last: 1, more: true, vv: vv(2, "01")},
+ {id: 1, first: 0, last: 1, more: true, vv: vv(2, "ab")},
+ {id: 1, first: 2, last: 3, more: false, vv: vv(2, "cd")},
+ {id: 0, first: 2, last: 3, more: false, vv: vv(2, "23")},
+ },
+ out: []processOutput{
+ {vv: emptyVv(), done: false},
+ {vv: emptyVv(), done: false},
+ {vv: vv(4, "ab", "cd"), done: true},
+ {vv: vv(4, "01", "23"), done: true},
+ },
+ },
+}
+
+func TestFragmentationProcess(t *testing.T) {
+ for _, c := range processTestCases {
+ f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
+ for i, in := range c.in {
+ vv, done := f.Process(in.id, in.first, in.last, in.more, in.vv)
+ if !reflect.DeepEqual(vv, *(c.out[i].vv)) {
+ t.Errorf("Test \"%s\" Process() returned a wrong vv. Got %v. Want %v", c.comment, vv, *(c.out[i].vv))
+ }
+ if done != c.out[i].done {
+ t.Errorf("Test \"%s\" Process() returned a wrong done. Got %t. Want %t", c.comment, done, c.out[i].done)
+ }
+ if c.out[i].done {
+ if _, ok := f.reassemblers[in.id]; ok {
+ t.Errorf("Test \"%s\" Process() didn't remove buffer from reassemblers.", c.comment)
+ }
+ for n := f.rList.Front(); n != nil; n = n.Next() {
+ if n.id == in.id {
+ t.Errorf("Test \"%s\" Process() didn't remove buffer from rList.", c.comment)
+ }
+ }
+ }
+ }
+ }
+}
+
+func TestReassemblingTimeout(t *testing.T) {
+ timeout := time.Millisecond
+ f := NewFragmentation(1024, 512, timeout)
+ // Send first fragment with id = 0, first = 0, last = 0, and more = true.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+ // Sleep more than the timeout.
+ time.Sleep(2 * timeout)
+ // Send another fragment that completes a packet.
+ // However, no packet should be reassembled because the fragment arrived after the timeout.
+ _, done := f.Process(0, 1, 1, false, vv(1, "1"))
+ if done {
+ t.Errorf("Fragmentation does not respect the reassembling timeout.")
+ }
+}
+
+func TestMemoryLimits(t *testing.T) {
+ f := NewFragmentation(3, 1, DefaultReassembleTimeout)
+ // Send first fragment with id = 0.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+ // Send first fragment with id = 1.
+ f.Process(1, 0, 0, true, vv(1, "1"))
+ // Send first fragment with id = 2.
+ f.Process(2, 0, 0, true, vv(1, "2"))
+
+ // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be
+ // evicted.
+ f.Process(3, 0, 0, true, vv(1, "3"))
+
+ if _, ok := f.reassemblers[0]; ok {
+ t.Errorf("Memory limits are not respected: id=0 has not been evicted.")
+ }
+ if _, ok := f.reassemblers[1]; ok {
+ t.Errorf("Memory limits are not respected: id=1 has not been evicted.")
+ }
+ if _, ok := f.reassemblers[3]; !ok {
+ t.Errorf("Implementation of memory limits is wrong: id=3 is not present.")
+ }
+}
+
+func TestMemoryLimitsIgnoresDuplicates(t *testing.T) {
+ f := NewFragmentation(1, 0, DefaultReassembleTimeout)
+ // Send first fragment with id = 0.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+ // Send the same packet again.
+ f.Process(0, 0, 0, true, vv(1, "0"))
+
+ got := f.size
+ want := 1
+ if got != want {
+ t.Errorf("Wrong size, duplicates are not handled correctly: got=%d, want=%d.", got, want)
+ }
+}
+
+func TestFragmentationViewsDoNotEscape(t *testing.T) {
+ f := NewFragmentation(1024, 512, DefaultReassembleTimeout)
+ in := vv(2, "0", "1")
+ f.Process(0, 0, 1, true, in)
+ // Modify input view.
+ in.RemoveFirst()
+ got, _ := f.Process(0, 2, 2, false, vv(1, "2"))
+ want := vv(3, "0", "1", "2")
+ if !reflect.DeepEqual(got, *want) {
+ t.Errorf("Process() returned a wrong vv. Got %v. Want %v", got, *want)
+ }
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go
new file mode 100644
index 000000000..0267a575d
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/reassembler.go
@@ -0,0 +1,109 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fragmentation
+
+import (
+ "container/heap"
+ "fmt"
+ "math"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+type hole struct {
+ first uint16
+ last uint16
+ deleted bool
+}
+
+type reassembler struct {
+ reassemblerEntry
+ id uint32
+ size int
+ mu sync.Mutex
+ holes []hole
+ deleted int
+ heap fragHeap
+ done bool
+ creationTime time.Time
+}
+
+func newReassembler(id uint32) *reassembler {
+ r := &reassembler{
+ id: id,
+ holes: make([]hole, 0, 16),
+ deleted: 0,
+ heap: make(fragHeap, 0, 8),
+ creationTime: time.Now(),
+ }
+ r.holes = append(r.holes, hole{
+ first: 0,
+ last: math.MaxUint16,
+ deleted: false})
+ return r
+}
+
+// updateHoles updates the list of holes for an incoming fragment and
+// returns true iff the fragment filled at least part of an existing hole.
+func (r *reassembler) updateHoles(first, last uint16, more bool) bool {
+ used := false
+ for i := range r.holes {
+ if r.holes[i].deleted || first > r.holes[i].last || last < r.holes[i].first {
+ continue
+ }
+ used = true
+ r.deleted++
+ r.holes[i].deleted = true
+ if first > r.holes[i].first {
+ r.holes = append(r.holes, hole{r.holes[i].first, first - 1, false})
+ }
+ if last < r.holes[i].last && more {
+ r.holes = append(r.holes, hole{last + 1, r.holes[i].last, false})
+ }
+ }
+ return used
+}
+
+func (r *reassembler) process(first, last uint16, more bool, vv *buffer.VectorisedView) (buffer.VectorisedView, bool, int) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+ consumed := 0
+ if r.done {
+ // A concurrent goroutine might have already reassembled
+ // the packet and emptied the heap while this goroutine
+ // was waiting on the mutex. We don't have to do anything in this case.
+ return buffer.NewVectorisedView(0, nil), false, consumed
+ }
+ if r.updateHoles(first, last, more) {
+ // We store the incoming packet only if it filled some holes.
+ uu := vv.Clone(nil)
+ heap.Push(&r.heap, fragment{offset: first, vv: &uu})
+ consumed = vv.Size()
+ r.size += consumed
+ }
+ // Check if all the holes have been deleted and we are ready to reassamble.
+ if r.deleted < len(r.holes) {
+ return buffer.NewVectorisedView(0, nil), false, consumed
+ }
+ res, err := r.heap.reassemble()
+ if err != nil {
+ panic(fmt.Sprintf("reassemble failed with: %v. There is probably a bug in the code handling the holes.", err))
+ }
+ return res, true, consumed
+}
+
+func (r *reassembler) tooOld(timeout time.Duration) bool {
+ return time.Now().Sub(r.creationTime) > timeout
+}
+
+func (r *reassembler) checkDoneOrMark() bool {
+ r.mu.Lock()
+ prev := r.done
+ r.done = true
+ r.mu.Unlock()
+ return prev
+}
diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go
new file mode 100644
index 000000000..b64604383
--- /dev/null
+++ b/pkg/tcpip/network/fragmentation/reassembler_test.go
@@ -0,0 +1,95 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package fragmentation
+
+import (
+ "math"
+ "reflect"
+ "testing"
+)
+
+type updateHolesInput struct {
+ first uint16
+ last uint16
+ more bool
+}
+
+var holesTestCases = []struct {
+ comment string
+ in []updateHolesInput
+ want []hole
+}{
+ {
+ comment: "No fragments. Expected holes: {[0 -> inf]}.",
+ in: []updateHolesInput{},
+ want: []hole{{first: 0, last: math.MaxUint16, deleted: false}},
+ },
+ {
+ comment: "One fragment at beginning. Expected holes: {[2, inf]}.",
+ in: []updateHolesInput{{first: 0, last: 1, more: true}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 2, last: math.MaxUint16, deleted: false},
+ },
+ },
+ {
+ comment: "One fragment in the middle. Expected holes: {[0, 0], [3, inf]}.",
+ in: []updateHolesInput{{first: 1, last: 2, more: true}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 0, last: 0, deleted: false},
+ {first: 3, last: math.MaxUint16, deleted: false},
+ },
+ },
+ {
+ comment: "One fragment at the end. Expected holes: {[0, 0]}.",
+ in: []updateHolesInput{{first: 1, last: 2, more: false}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 0, last: 0, deleted: false},
+ },
+ },
+ {
+ comment: "One fragment completing a packet. Expected holes: {}.",
+ in: []updateHolesInput{{first: 0, last: 1, more: false}},
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ },
+ },
+ {
+ comment: "Two non-overlapping fragments completing a packet. Expected holes: {}.",
+ in: []updateHolesInput{
+ {first: 0, last: 1, more: true},
+ {first: 2, last: 3, more: false},
+ },
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 2, last: math.MaxUint16, deleted: true},
+ },
+ },
+ {
+ comment: "Two overlapping fragments completing a packet. Expected holes: {}.",
+ in: []updateHolesInput{
+ {first: 0, last: 2, more: true},
+ {first: 2, last: 3, more: false},
+ },
+ want: []hole{
+ {first: 0, last: math.MaxUint16, deleted: true},
+ {first: 3, last: math.MaxUint16, deleted: true},
+ },
+ },
+}
+
+func TestUpdateHoles(t *testing.T) {
+ for _, c := range holesTestCases {
+ r := newReassembler(0)
+ for _, i := range c.in {
+ r.updateHoles(i.first, i.last, i.more)
+ }
+ if !reflect.DeepEqual(r.holes, c.want) {
+ t.Errorf("Test \"%s\" produced unexepetced holes. Got %v. Want %v", c.comment, r.holes, c.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/hash/BUILD b/pkg/tcpip/network/hash/BUILD
new file mode 100644
index 000000000..96805c690
--- /dev/null
+++ b/pkg/tcpip/network/hash/BUILD
@@ -0,0 +1,11 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "hash",
+ srcs = ["hash.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/hash",
+ visibility = ["//visibility:public"],
+ deps = ["//pkg/tcpip/header"],
+)
diff --git a/pkg/tcpip/network/hash/hash.go b/pkg/tcpip/network/hash/hash.go
new file mode 100644
index 000000000..e5a696158
--- /dev/null
+++ b/pkg/tcpip/network/hash/hash.go
@@ -0,0 +1,83 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package hash contains utility functions for hashing.
+package hash
+
+import (
+ "crypto/rand"
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+var hashIV = RandN32(1)[0]
+
+// RandN32 generates a slice of n cryptographic random 32-bit numbers.
+func RandN32(n int) []uint32 {
+ b := make([]byte, 4*n)
+ if _, err := rand.Read(b); err != nil {
+ panic("unable to get random numbers: " + err.Error())
+ }
+ r := make([]uint32, n)
+ for i := range r {
+ r[i] = binary.LittleEndian.Uint32(b[4*i : (4*i + 4)])
+ }
+ return r
+}
+
+// Hash3Words calculates the Jenkins hash of 3 32-bit words. This is adapted
+// from linux.
+func Hash3Words(a, b, c, initval uint32) uint32 {
+ const iv = 0xdeadbeef + (3 << 2)
+ initval += iv
+
+ a += initval
+ b += initval
+ c += initval
+
+ c ^= b
+ c -= rol32(b, 14)
+ a ^= c
+ a -= rol32(c, 11)
+ b ^= a
+ b -= rol32(a, 25)
+ c ^= b
+ c -= rol32(b, 16)
+ a ^= c
+ a -= rol32(c, 4)
+ b ^= a
+ b -= rol32(a, 14)
+ c ^= b
+ c -= rol32(b, 24)
+
+ return c
+}
+
+// IPv4FragmentHash computes the hash of the IPv4 fragment as suggested in RFC 791.
+func IPv4FragmentHash(h header.IPv4) uint32 {
+ x := uint32(h.ID())<<16 | uint32(h.Protocol())
+ t := h.SourceAddress()
+ y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = h.DestinationAddress()
+ z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return Hash3Words(x, y, z, hashIV)
+}
+
+// IPv6FragmentHash computes the hash of the ipv6 fragment.
+// Unlike IPv4, the protocol is not used to compute the hash.
+// RFC 2640 (sec 4.5) is not very sharp on this aspect.
+// As a reference, also Linux ignores the protocol to compute
+// the hash (inet6_hash_frag).
+func IPv6FragmentHash(h header.IPv6, f header.IPv6Fragment) uint32 {
+ t := h.SourceAddress()
+ y := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = h.DestinationAddress()
+ z := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return Hash3Words(f.ID(), y, z, hashIV)
+}
+
+func rol32(v, shift uint32) uint32 {
+ return (v << shift) | (v >> ((-shift) & 31))
+}
diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go
new file mode 100644
index 000000000..797501858
--- /dev/null
+++ b/pkg/tcpip/network/ip_test.go
@@ -0,0 +1,560 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ip_test
+
+import (
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// testObject implements two interfaces: LinkEndpoint and TransportDispatcher.
+// The former is used to pretend that it's a link endpoint so that we can
+// inspect packets written by the network endpoints. The latter is used to
+// pretend that it's the network stack so that it can inspect incoming packets
+// that have been handled by the network endpoints.
+//
+// Packets are checked by comparing their fields/values against the expected
+// values stored in the test object itself.
+type testObject struct {
+ t *testing.T
+ protocol tcpip.TransportProtocolNumber
+ contents []byte
+ srcAddr tcpip.Address
+ dstAddr tcpip.Address
+ v4 bool
+ typ stack.ControlType
+ extra uint32
+
+ dataCalls int
+ controlCalls int
+}
+
+// checkValues verifies that the transport protocol, data contents, src & dst
+// addresses of a packet match what's expected. If any field doesn't match, the
+// test fails.
+func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView, srcAddr, dstAddr tcpip.Address) {
+ v := vv.ToView()
+ if protocol != t.protocol {
+ t.t.Errorf("protocol = %v, want %v", protocol, t.protocol)
+ }
+
+ if srcAddr != t.srcAddr {
+ t.t.Errorf("srcAddr = %v, want %v", srcAddr, t.srcAddr)
+ }
+
+ if dstAddr != t.dstAddr {
+ t.t.Errorf("dstAddr = %v, want %v", dstAddr, t.dstAddr)
+ }
+
+ if len(v) != len(t.contents) {
+ t.t.Fatalf("len(payload) = %v, want %v", len(v), len(t.contents))
+ }
+
+ for i := range t.contents {
+ if t.contents[i] != v[i] {
+ t.t.Fatalf("payload[%v] = %v, want %v", i, v[i], t.contents[i])
+ }
+ }
+}
+
+// DeliverTransportPacket is called by network endpoints after parsing incoming
+// packets. This is used by the test object to verify that the results of the
+// parsing are expected.
+func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView) {
+ t.checkValues(protocol, vv, r.RemoteAddress, r.LocalAddress)
+ t.dataCalls++
+}
+
+// DeliverTransportControlPacket is called by network endpoints after parsing
+// incoming control (ICMP) packets. This is used by the test object to verify
+// that the results of the parsing are expected.
+func (t *testObject) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) {
+ t.checkValues(trans, vv, remote, local)
+ if typ != t.typ {
+ t.t.Errorf("typ = %v, want %v", typ, t.typ)
+ }
+ if extra != t.extra {
+ t.t.Errorf("extra = %v, want %v", extra, t.extra)
+ }
+ t.controlCalls++
+}
+
+// Attach is only implemented to satisfy the LinkEndpoint interface.
+func (*testObject) Attach(stack.NetworkDispatcher) {}
+
+// MTU implements stack.LinkEndpoint.MTU. It just returns a constant that
+// matches the linux loopback MTU.
+func (*testObject) MTU() uint32 {
+ return 65536
+}
+
+// Capabilities implements stack.LinkEndpoint.Capabilities.
+func (*testObject) Capabilities() stack.LinkEndpointCapabilities {
+ return 0
+}
+
+// MaxHeaderLength is only implemented to satisfy the LinkEndpoint interface.
+func (*testObject) MaxHeaderLength() uint16 {
+ return 0
+}
+
+// LinkAddress returns the link address of this endpoint.
+func (*testObject) LinkAddress() tcpip.LinkAddress {
+ return ""
+}
+
+// WritePacket is called by network endpoints after producing a packet and
+// writing it to the link endpoint. This is used by the test object to verify
+// that the produced packet is as expected.
+func (t *testObject) WritePacket(_ *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error {
+ var prot tcpip.TransportProtocolNumber
+ var srcAddr tcpip.Address
+ var dstAddr tcpip.Address
+
+ if t.v4 {
+ h := header.IPv4(hdr.UsedBytes())
+ prot = tcpip.TransportProtocolNumber(h.Protocol())
+ srcAddr = h.SourceAddress()
+ dstAddr = h.DestinationAddress()
+
+ } else {
+ h := header.IPv6(hdr.UsedBytes())
+ prot = tcpip.TransportProtocolNumber(h.NextHeader())
+ srcAddr = h.SourceAddress()
+ dstAddr = h.DestinationAddress()
+ }
+ var views [1]buffer.View
+ vv := payload.ToVectorisedView(views)
+ t.checkValues(prot, &vv, srcAddr, dstAddr)
+ return nil
+}
+
+func TestIPv4Send(t *testing.T) {
+ o := testObject{t: t, v4: true}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, nil, &o)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Allocate and initialize the payload view.
+ payload := buffer.NewView(100)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = uint8(i)
+ }
+
+ // Allocate the header buffer.
+ hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+
+ // Issue the write.
+ o.protocol = 123
+ o.srcAddr = "\x0a\x00\x00\x01"
+ o.dstAddr = "\x0a\x00\x00\x02"
+ o.contents = payload
+
+ r := stack.Route{
+ RemoteAddress: o.dstAddr,
+ LocalAddress: o.srcAddr,
+ }
+ if err := ep.WritePacket(&r, &hdr, payload, 123); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+}
+
+func TestIPv4Receive(t *testing.T) {
+ o := testObject{t: t, v4: true}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ totalLen := header.IPv4MinimumSize + 30
+ view := buffer.NewView(totalLen)
+ ip := header.IPv4(view)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ TTL: 20,
+ Protocol: 10,
+ SrcAddr: "\x0a\x00\x00\x02",
+ DstAddr: "\x0a\x00\x00\x01",
+ })
+
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < totalLen; i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
+ o.protocol = 10
+ o.srcAddr = "\x0a\x00\x00\x02"
+ o.dstAddr = "\x0a\x00\x00\x01"
+ o.contents = view[header.IPv4MinimumSize:totalLen]
+
+ r := stack.Route{
+ LocalAddress: o.dstAddr,
+ RemoteAddress: o.srcAddr,
+ }
+ var views [1]buffer.View
+ vv := view.ToVectorisedView(views)
+ ep.HandlePacket(&r, &vv)
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv4ReceiveControl(t *testing.T) {
+ const mtu = 0xbeef - header.IPv4MinimumSize
+ cases := []struct {
+ name string
+ expectedCount int
+ fragmentOffset uint16
+ code uint8
+ expectedTyp stack.ControlType
+ expectedExtra uint32
+ trunc int
+ }{
+ {"FragmentationNeeded", 1, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 0},
+ {"Truncated (10 bytes missing)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 10},
+ {"Truncated (missing IPv4 header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.IPv4MinimumSize + 8},
+ {"Truncated (missing 'extra info')", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, 4 + header.IPv4MinimumSize + 8},
+ {"Truncated (missing ICMP header)", 0, 0, header.ICMPv4FragmentationNeeded, stack.ControlPacketTooBig, mtu, header.ICMPv4DstUnreachableMinimumSize + header.IPv4MinimumSize + 8},
+ {"Port unreachable", 1, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Non-zero fragment offset", 0, 100, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Zero-length packet", 0, 0, header.ICMPv4PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv4MinimumSize + header.ICMPv4DstUnreachableMinimumSize + 8},
+ }
+ r := stack.Route{
+ LocalAddress: "\x0a\x00\x00\x01",
+ RemoteAddress: "\x0a\x00\x00\xbb",
+ }
+ for _, c := range cases {
+ t.Run(c.name, func(t *testing.T) {
+ var views [1]buffer.View
+ o := testObject{t: t}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ const dataOffset = header.IPv4MinimumSize*2 + header.ICMPv4MinimumSize + 4
+ view := buffer.NewView(dataOffset + 8)
+
+ // Create the outer IPv4 header.
+ ip := header.IPv4(view)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(len(view) - c.trunc),
+ TTL: 20,
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ SrcAddr: "\x0a\x00\x00\xbb",
+ DstAddr: "\x0a\x00\x00\x01",
+ })
+
+ // Create the ICMP header.
+ icmp := header.ICMPv4(view[header.IPv4MinimumSize:])
+ icmp.SetType(header.ICMPv4DstUnreachable)
+ icmp.SetCode(c.code)
+ copy(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef})
+
+ // Create the inner IPv4 header.
+ ip = header.IPv4(view[header.IPv4MinimumSize+header.ICMPv4MinimumSize+4:])
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: 100,
+ TTL: 20,
+ Protocol: 10,
+ FragmentOffset: c.fragmentOffset,
+ SrcAddr: "\x0a\x00\x00\x01",
+ DstAddr: "\x0a\x00\x00\x02",
+ })
+
+ // Make payload be non-zero.
+ for i := dataOffset; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to IPv4 endpoint, dispatcher will validate that
+ // it's ok.
+ o.protocol = 10
+ o.srcAddr = "\x0a\x00\x00\x02"
+ o.dstAddr = "\x0a\x00\x00\x01"
+ o.contents = view[dataOffset:]
+ o.typ = c.expectedTyp
+ o.extra = c.expectedExtra
+
+ vv := view.ToVectorisedView(views)
+ vv.CapLength(len(view) - c.trunc)
+ ep.HandlePacket(&r, &vv)
+ if want := c.expectedCount; o.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ }
+ })
+ }
+}
+
+func TestIPv4FragmentationReceive(t *testing.T) {
+ o := testObject{t: t, v4: true}
+ proto := ipv4.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x01", nil, &o, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ totalLen := header.IPv4MinimumSize + 24
+
+ frag1 := buffer.NewView(totalLen)
+ ip1 := header.IPv4(frag1)
+ ip1.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ TTL: 20,
+ Protocol: 10,
+ FragmentOffset: 0,
+ Flags: header.IPv4FlagMoreFragments,
+ SrcAddr: "\x0a\x00\x00\x02",
+ DstAddr: "\x0a\x00\x00\x01",
+ })
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < totalLen; i++ {
+ frag1[i] = uint8(i)
+ }
+
+ frag2 := buffer.NewView(totalLen)
+ ip2 := header.IPv4(frag2)
+ ip2.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(totalLen),
+ TTL: 20,
+ Protocol: 10,
+ FragmentOffset: 24,
+ SrcAddr: "\x0a\x00\x00\x02",
+ DstAddr: "\x0a\x00\x00\x01",
+ })
+ // Make payload be non-zero.
+ for i := header.IPv4MinimumSize; i < totalLen; i++ {
+ frag2[i] = uint8(i)
+ }
+
+ // Give packet to ipv4 endpoint, dispatcher will validate that it's ok.
+ o.protocol = 10
+ o.srcAddr = "\x0a\x00\x00\x02"
+ o.dstAddr = "\x0a\x00\x00\x01"
+ o.contents = append(frag1[header.IPv4MinimumSize:totalLen], frag2[header.IPv4MinimumSize:totalLen]...)
+
+ r := stack.Route{
+ LocalAddress: o.dstAddr,
+ RemoteAddress: o.srcAddr,
+ }
+
+ // Send first segment.
+ var views1 [1]buffer.View
+ vv1 := frag1.ToVectorisedView(views1)
+ ep.HandlePacket(&r, &vv1)
+ if o.dataCalls != 0 {
+ t.Fatalf("Bad number of data calls: got %x, want 0", o.dataCalls)
+ }
+
+ // Send second segment.
+ var views2 [1]buffer.View
+ vv2 := frag2.ToVectorisedView(views2)
+ ep.HandlePacket(&r, &vv2)
+
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv6Send(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv6.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, nil, &o)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Allocate and initialize the payload view.
+ payload := buffer.NewView(100)
+ for i := 0; i < len(payload); i++ {
+ payload[i] = uint8(i)
+ }
+
+ // Allocate the header buffer.
+ hdr := buffer.NewPrependable(int(ep.MaxHeaderLength()))
+
+ // Issue the write.
+ o.protocol = 123
+ o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ o.contents = payload
+
+ r := stack.Route{
+ RemoteAddress: o.dstAddr,
+ LocalAddress: o.srcAddr,
+ }
+ if err := ep.WritePacket(&r, &hdr, payload, 123); err != nil {
+ t.Fatalf("WritePacket failed: %v", err)
+ }
+}
+
+func TestIPv6Receive(t *testing.T) {
+ o := testObject{t: t}
+ proto := ipv6.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, &o, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ totalLen := header.IPv6MinimumSize + 30
+ view := buffer.NewView(totalLen)
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(totalLen - header.IPv6MinimumSize),
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02",
+ DstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ })
+
+ // Make payload be non-zero.
+ for i := header.IPv6MinimumSize; i < totalLen; i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to ipv6 endpoint, dispatcher will validate that it's ok.
+ o.protocol = 10
+ o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ o.contents = view[header.IPv6MinimumSize:totalLen]
+
+ r := stack.Route{
+ LocalAddress: o.dstAddr,
+ RemoteAddress: o.srcAddr,
+ }
+ var views [1]buffer.View
+ vv := view.ToVectorisedView(views)
+ ep.HandlePacket(&r, &vv)
+
+ if o.dataCalls != 1 {
+ t.Fatalf("Bad number of data calls: got %x, want 1", o.dataCalls)
+ }
+}
+
+func TestIPv6ReceiveControl(t *testing.T) {
+ newUint16 := func(v uint16) *uint16 { return &v }
+
+ const mtu = 0xffff
+ cases := []struct {
+ name string
+ expectedCount int
+ fragmentOffset *uint16
+ typ header.ICMPv6Type
+ code uint8
+ expectedTyp stack.ControlType
+ expectedExtra uint32
+ trunc int
+ }{
+ {"PacketTooBig", 1, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 0},
+ {"Truncated (10 bytes missing)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 10},
+ {"Truncated (missing IPv6 header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.IPv6MinimumSize + 8},
+ {"Truncated PacketTooBig (missing 'extra info')", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, 4 + header.IPv6MinimumSize + 8},
+ {"Truncated (missing ICMP header)", 0, nil, header.ICMPv6PacketTooBig, 0, stack.ControlPacketTooBig, mtu, header.ICMPv6PacketTooBigMinimumSize + header.IPv6MinimumSize + 8},
+ {"Port unreachable", 1, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Truncated DstUnreachable (missing 'extra info')", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 4 + header.IPv6MinimumSize + 8},
+ {"Fragmented, zero offset", 1, newUint16(0), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Non-zero fragment offset", 0, newUint16(100), header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 0},
+ {"Zero-length packet", 0, nil, header.ICMPv6DstUnreachable, header.ICMPv6PortUnreachable, stack.ControlPortUnreachable, 0, 2*header.IPv6MinimumSize + header.ICMPv6DstUnreachableMinimumSize + 8},
+ }
+ r := stack.Route{
+ LocalAddress: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ RemoteAddress: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
+ }
+ for _, c := range cases {
+ t.Run(c.name, func(t *testing.T) {
+ var views [1]buffer.View
+ o := testObject{t: t}
+ proto := ipv6.NewProtocol()
+ ep, err := proto.NewEndpoint(1, "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", nil, &o, nil)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ defer ep.Close()
+
+ dataOffset := header.IPv6MinimumSize*2 + header.ICMPv6MinimumSize + 4
+ if c.fragmentOffset != nil {
+ dataOffset += header.IPv6FragmentHeaderSize
+ }
+ view := buffer.NewView(dataOffset + 8)
+
+ // Create the outer IPv6 header.
+ ip := header.IPv6(view)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(len(view) - header.IPv6MinimumSize - c.trunc),
+ NextHeader: uint8(header.ICMPv6ProtocolNumber),
+ HopLimit: 20,
+ SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xaa",
+ DstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ })
+
+ // Create the ICMP header.
+ icmp := header.ICMPv6(view[header.IPv6MinimumSize:])
+ icmp.SetType(c.typ)
+ icmp.SetCode(c.code)
+ copy(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize:], []byte{0xde, 0xad, 0xbe, 0xef})
+
+ // Create the inner IPv6 header.
+ ip = header.IPv6(view[header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: 100,
+ NextHeader: 10,
+ HopLimit: 20,
+ SrcAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01",
+ DstAddr: "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02",
+ })
+
+ // Build the fragmentation header if needed.
+ if c.fragmentOffset != nil {
+ ip.SetNextHeader(header.IPv6FragmentHeader)
+ frag := header.IPv6Fragment(view[2*header.IPv6MinimumSize+header.ICMPv6MinimumSize+4:])
+ frag.Encode(&header.IPv6FragmentFields{
+ NextHeader: 10,
+ FragmentOffset: *c.fragmentOffset,
+ M: true,
+ Identification: 0x12345678,
+ })
+ }
+
+ // Make payload be non-zero.
+ for i := dataOffset; i < len(view); i++ {
+ view[i] = uint8(i)
+ }
+
+ // Give packet to IPv6 endpoint, dispatcher will validate that
+ // it's ok.
+ o.protocol = 10
+ o.srcAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ o.dstAddr = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ o.contents = view[dataOffset:]
+ o.typ = c.expectedTyp
+ o.extra = c.expectedExtra
+
+ vv := view.ToVectorisedView(views)
+ vv.CapLength(len(view) - c.trunc)
+ ep.HandlePacket(&r, &vv)
+ if want := c.expectedCount; o.controlCalls != want {
+ t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, o.controlCalls, want)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
new file mode 100644
index 000000000..9df113df1
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -0,0 +1,38 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "ipv4",
+ srcs = [
+ "icmp.go",
+ "ipv4.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/network/fragmentation",
+ "//pkg/tcpip/network/hash",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "ipv4_test",
+ size = "small",
+ srcs = ["icmp_test.go"],
+ deps = [
+ ":ipv4",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go
new file mode 100644
index 000000000..ffd761350
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/icmp.go
@@ -0,0 +1,282 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ipv4
+
+import (
+ "context"
+ "encoding/binary"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// PingProtocolName is a pseudo transport protocol used to handle ping replies.
+// Use it when constructing a stack that intends to use ipv4.Ping.
+const PingProtocolName = "icmpv4ping"
+
+// pingProtocolNumber is a fake transport protocol used to
+// deliver incoming ICMP echo replies. The ICMP identifier
+// number is used as a port number for multiplexing.
+const pingProtocolNumber tcpip.TransportProtocolNumber = 256 + 11
+
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) {
+ h := header.IPv4(vv.First())
+
+ // We don't use IsValid() here because ICMP only requires that the IP
+ // header plus 8 bytes of the transport header be included. So it's
+ // likely that it is truncated, which would cause IsValid to return
+ // false.
+ //
+ // Drop packet if it doesn't have the basic IPv4 header or if the
+ // original source address doesn't match the endpoint's address.
+ if len(h) < header.IPv4MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ return
+ }
+
+ hlen := int(h.HeaderLength())
+ if vv.Size() < hlen || h.FragmentOffset() != 0 {
+ // We won't be able to handle this if it doesn't contain the
+ // full IPv4 header, or if it's a fragment not at offset 0
+ // (because it won't have the transport header).
+ return
+ }
+
+ // Skip the ip header, then deliver control message.
+ vv.TrimFront(hlen)
+ p := h.TransportProtocol()
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) {
+ v := vv.First()
+ if len(v) < header.ICMPv4MinimumSize {
+ return
+ }
+ h := header.ICMPv4(v)
+
+ switch h.Type() {
+ case header.ICMPv4Echo:
+ if len(v) < header.ICMPv4EchoMinimumSize {
+ return
+ }
+ vv.TrimFront(header.ICMPv4MinimumSize)
+ req := echoRequest{r: r.Clone(), v: vv.ToView()}
+ select {
+ case e.echoRequests <- req:
+ default:
+ req.r.Release()
+ }
+
+ case header.ICMPv4EchoReply:
+ e.dispatcher.DeliverTransportPacket(r, pingProtocolNumber, vv)
+
+ case header.ICMPv4DstUnreachable:
+ if len(v) < header.ICMPv4DstUnreachableMinimumSize {
+ return
+ }
+ vv.TrimFront(header.ICMPv4DstUnreachableMinimumSize)
+ switch h.Code() {
+ case header.ICMPv4PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, vv)
+
+ case header.ICMPv4FragmentationNeeded:
+ mtu := uint32(binary.BigEndian.Uint16(v[header.ICMPv4DstUnreachableMinimumSize-2:]))
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+ }
+ }
+ // TODO: Handle other ICMP types.
+}
+
+type echoRequest struct {
+ r stack.Route
+ v buffer.View
+}
+
+func (e *endpoint) echoReplier() {
+ for req := range e.echoRequests {
+ sendICMPv4(&req.r, header.ICMPv4EchoReply, 0, req.v)
+ req.r.Release()
+ }
+}
+
+func sendICMPv4(r *stack.Route, typ header.ICMPv4Type, code byte, data buffer.View) *tcpip.Error {
+ hdr := buffer.NewPrependable(header.ICMPv4MinimumSize + int(r.MaxHeaderLength()))
+
+ icmpv4 := header.ICMPv4(hdr.Prepend(header.ICMPv4MinimumSize))
+ icmpv4.SetType(typ)
+ icmpv4.SetCode(code)
+ icmpv4.SetChecksum(^header.Checksum(icmpv4, header.Checksum(data, 0)))
+
+ return r.WritePacket(&hdr, data, header.ICMPv4ProtocolNumber)
+}
+
+// A Pinger can send echo requests to an address.
+type Pinger struct {
+ Stack *stack.Stack
+ NICID tcpip.NICID
+ Addr tcpip.Address
+ LocalAddr tcpip.Address // optional
+ Wait time.Duration // if zero, defaults to 1 second
+ Count uint16 // if zero, defaults to MaxUint16
+}
+
+// Ping sends echo requests to an ICMPv4 endpoint.
+// Responses are streamed to the channel ch.
+func (p *Pinger) Ping(ctx context.Context, ch chan<- PingReply) *tcpip.Error {
+ count := p.Count
+ if count == 0 {
+ count = 1<<16 - 1
+ }
+ wait := p.Wait
+ if wait == 0 {
+ wait = 1 * time.Second
+ }
+
+ r, err := p.Stack.FindRoute(p.NICID, p.LocalAddr, p.Addr, ProtocolNumber)
+ if err != nil {
+ return err
+ }
+
+ netProtos := []tcpip.NetworkProtocolNumber{ProtocolNumber}
+ ep := &pingEndpoint{
+ stack: p.Stack,
+ pktCh: make(chan buffer.View, 1),
+ }
+ id := stack.TransportEndpointID{
+ LocalAddress: r.LocalAddress,
+ RemoteAddress: p.Addr,
+ }
+
+ _, err = p.Stack.PickEphemeralPort(func(port uint16) (bool, *tcpip.Error) {
+ id.LocalPort = port
+ err := p.Stack.RegisterTransportEndpoint(p.NICID, netProtos, pingProtocolNumber, id, ep)
+ switch err {
+ case nil:
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ })
+ if err != nil {
+ return err
+ }
+ defer p.Stack.UnregisterTransportEndpoint(p.NICID, netProtos, pingProtocolNumber, id)
+
+ v := buffer.NewView(4)
+ binary.BigEndian.PutUint16(v[0:], id.LocalPort)
+
+ start := time.Now()
+
+ done := make(chan struct{})
+ go func(count int) {
+ loop:
+ for ; count > 0; count-- {
+ select {
+ case v := <-ep.pktCh:
+ seq := binary.BigEndian.Uint16(v[header.ICMPv4MinimumSize+2:])
+ ch <- PingReply{
+ Duration: time.Since(start) - time.Duration(seq)*wait,
+ SeqNumber: seq,
+ }
+ case <-ctx.Done():
+ break loop
+ }
+ }
+ close(done)
+ }(int(count))
+ defer func() { <-done }()
+
+ t := time.NewTicker(wait)
+ defer t.Stop()
+ for seq := uint16(0); seq < count; seq++ {
+ select {
+ case <-t.C:
+ case <-ctx.Done():
+ return nil
+ }
+ binary.BigEndian.PutUint16(v[2:], seq)
+ sent := time.Now()
+ if err := sendICMPv4(&r, header.ICMPv4Echo, 0, v); err != nil {
+ ch <- PingReply{
+ Error: err,
+ Duration: time.Since(sent),
+ SeqNumber: seq,
+ }
+ }
+ }
+ return nil
+}
+
+// PingReply summarizes an ICMP echo reply.
+type PingReply struct {
+ Error *tcpip.Error // reports any errors sending a ping request
+ Duration time.Duration
+ SeqNumber uint16
+}
+
+type pingProtocol struct{}
+
+func (*pingProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrNotSupported // endpoints are created directly
+}
+
+func (*pingProtocol) Number() tcpip.TransportProtocolNumber { return pingProtocolNumber }
+
+func (*pingProtocol) MinimumPacketSize() int { return header.ICMPv4EchoMinimumSize }
+
+func (*pingProtocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ ident := binary.BigEndian.Uint16(v[4:])
+ return 0, ident, nil
+}
+
+func (*pingProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
+ return true
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *pingProtocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements TransportProtocol.Option.
+func (p *pingProtocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(PingProtocolName, func() stack.TransportProtocol {
+ return &pingProtocol{}
+ })
+}
+
+type pingEndpoint struct {
+ stack *stack.Stack
+ pktCh chan buffer.View
+}
+
+func (e *pingEndpoint) Close() {
+ close(e.pktCh)
+}
+
+func (e *pingEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
+ select {
+ case e.pktCh <- vv.ToView():
+ default:
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *pingEndpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) {
+}
diff --git a/pkg/tcpip/network/ipv4/icmp_test.go b/pkg/tcpip/network/ipv4/icmp_test.go
new file mode 100644
index 000000000..378fba74b
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/icmp_test.go
@@ -0,0 +1,124 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ipv4_test
+
+import (
+ "context"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const stackAddr = "\x0a\x00\x00\x01"
+
+type testContext struct {
+ t *testing.T
+ linkEP *channel.Endpoint
+ s *stack.Stack
+}
+
+func newTestContext(t *testing.T) *testContext {
+ s := stack.New([]string{ipv4.ProtocolName}, []string{ipv4.PingProtocolName})
+
+ const defaultMTU = 65536
+ id, linkEP := channel.New(256, defaultMTU, "")
+ if testing.Verbose() {
+ id = sniffer.New(id)
+ }
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{
+ Destination: "\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ }})
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEP: linkEP,
+ }
+}
+
+func (c *testContext) cleanup() {
+ close(c.linkEP.C)
+}
+
+func (c *testContext) loopback() {
+ go func() {
+ for pkt := range c.linkEP.C {
+ v := make(buffer.View, len(pkt.Header)+len(pkt.Payload))
+ copy(v, pkt.Header)
+ copy(v[len(pkt.Header):], pkt.Payload)
+ vv := v.ToVectorisedView([1]buffer.View{})
+ c.linkEP.Inject(pkt.Proto, &vv)
+ }
+ }()
+}
+
+func TestEcho(t *testing.T) {
+ c := newTestContext(t)
+ defer c.cleanup()
+ c.loopback()
+
+ ch := make(chan ipv4.PingReply, 1)
+ p := ipv4.Pinger{
+ Stack: c.s,
+ NICID: 1,
+ Addr: stackAddr,
+ Wait: 10 * time.Millisecond,
+ Count: 1, // one ping only
+ }
+ if err := p.Ping(context.Background(), ch); err != nil {
+ t.Fatalf("icmp.Ping failed: %v", err)
+ }
+
+ ping := <-ch
+ if ping.Error != nil {
+ t.Errorf("bad ping response: %v", ping.Error)
+ }
+}
+
+func TestEchoSequence(t *testing.T) {
+ c := newTestContext(t)
+ defer c.cleanup()
+ c.loopback()
+
+ const numPings = 3
+ ch := make(chan ipv4.PingReply, numPings)
+ p := ipv4.Pinger{
+ Stack: c.s,
+ NICID: 1,
+ Addr: stackAddr,
+ Wait: 10 * time.Millisecond,
+ Count: numPings,
+ }
+ if err := p.Ping(context.Background(), ch); err != nil {
+ t.Fatalf("icmp.Ping failed: %v", err)
+ }
+
+ for i := uint16(0); i < numPings; i++ {
+ ping := <-ch
+ if ping.Error != nil {
+ t.Errorf("i=%d bad ping response: %v", i, ping.Error)
+ }
+ if ping.SeqNumber != i {
+ t.Errorf("SeqNumber=%d, want %d", ping.SeqNumber, i)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
new file mode 100644
index 000000000..4cc2a2fd4
--- /dev/null
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -0,0 +1,233 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package ipv4 contains the implementation of the ipv4 network protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing ipv4.ProtocolName (or "ipv4") as one of the
+// network protocols when calling stack.New(). Then endpoints can be created
+// by passing ipv4.ProtocolNumber as the network protocol number when calling
+// Stack.NewEndpoint().
+package ipv4
+
+import (
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/fragmentation"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/hash"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolName is the string representation of the ipv4 protocol name.
+ ProtocolName = "ipv4"
+
+ // ProtocolNumber is the ipv4 protocol number.
+ ProtocolNumber = header.IPv4ProtocolNumber
+
+ // maxTotalSize is maximum size that can be encoded in the 16-bit
+ // TotalLength field of the ipv4 header.
+ maxTotalSize = 0xffff
+
+ // buckets is the number of identifier buckets.
+ buckets = 2048
+)
+
+type address [header.IPv4AddressSize]byte
+
+type endpoint struct {
+ nicid tcpip.NICID
+ id stack.NetworkEndpointID
+ address address
+ linkEP stack.LinkEndpoint
+ dispatcher stack.TransportDispatcher
+ echoRequests chan echoRequest
+ fragmentation *fragmentation.Fragmentation
+}
+
+func newEndpoint(nicid tcpip.NICID, addr tcpip.Address, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) *endpoint {
+ e := &endpoint{
+ nicid: nicid,
+ linkEP: linkEP,
+ dispatcher: dispatcher,
+ echoRequests: make(chan echoRequest, 10),
+ fragmentation: fragmentation.NewFragmentation(fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, fragmentation.DefaultReassembleTimeout),
+ }
+ copy(e.address[:], addr)
+ e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])}
+
+ go e.echoReplier()
+
+ return e
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
+}
+
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// NICID returns the ID of the NIC this endpoint belongs to.
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicid
+}
+
+// ID returns the ipv4 endpoint ID.
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &e.id
+}
+
+// MaxHeaderLength returns the maximum length needed by ipv4 headers (and
+// underlying protocols).
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv4MinimumSize
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+ ip := header.IPv4(hdr.Prepend(header.IPv4MinimumSize))
+ length := uint16(hdr.UsedLength() + len(payload))
+ id := uint32(0)
+ if length > header.IPv4MaximumHeaderSize+8 {
+ // Packets of 68 bytes or less are required by RFC 791 to not be
+ // fragmented, so we only assign ids to larger packets.
+ id = atomic.AddUint32(&ids[hashRoute(r, protocol)%buckets], 1)
+ }
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: length,
+ ID: uint16(id),
+ TTL: 65,
+ Protocol: uint8(protocol),
+ SrcAddr: tcpip.Address(e.address[:]),
+ DstAddr: r.RemoteAddress,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber)
+}
+
+// HandlePacket is called by the link layer when new ipv4 packets arrive for
+// this endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) {
+ h := header.IPv4(vv.First())
+ if !h.IsValid(vv.Size()) {
+ return
+ }
+
+ hlen := int(h.HeaderLength())
+ tlen := int(h.TotalLength())
+ vv.TrimFront(hlen)
+ vv.CapLength(tlen - hlen)
+
+ more := (h.Flags() & header.IPv4FlagMoreFragments) != 0
+ if more || h.FragmentOffset() != 0 {
+ // The packet is a fragment, let's try to reassemble it.
+ last := h.FragmentOffset() + uint16(vv.Size()) - 1
+ tt, ready := e.fragmentation.Process(hash.IPv4FragmentHash(h), h.FragmentOffset(), last, more, vv)
+ if !ready {
+ return
+ }
+ vv = &tt
+ }
+ p := h.TransportProtocol()
+ if p == header.ICMPv4ProtocolNumber {
+ e.handleICMP(r, vv)
+ return
+ }
+ e.dispatcher.DeliverTransportPacket(r, p, vv)
+}
+
+// Close cleans up resources associated with the endpoint.
+func (e *endpoint) Close() {
+ close(e.echoRequests)
+}
+
+type protocol struct{}
+
+// NewProtocol creates a new protocol ipv4 protocol descriptor. This is exported
+// only for tests that short-circuit the stack. Regular use of the protocol is
+// done via the stack, which gets a protocol descriptor from the init() function
+// below.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
+}
+
+// Number returns the ipv4 protocol number.
+func (p *protocol) Number() tcpip.NetworkProtocolNumber {
+ return ProtocolNumber
+}
+
+// MinimumPacketSize returns the minimum valid ipv4 packet size.
+func (p *protocol) MinimumPacketSize() int {
+ return header.IPv4MinimumSize
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv4(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint creates a new ipv4 endpoint.
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ return newEndpoint(nicid, addr, dispatcher, linkEP), nil
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ if mtu > maxTotalSize {
+ mtu = maxTotalSize
+ }
+ return mtu - header.IPv4MinimumSize
+}
+
+// hashRoute calculates a hash value for the given route. It uses the source &
+// destination address, the transport protocol number, and a random initial
+// value (generated once on initialization) to generate the hash.
+func hashRoute(r *stack.Route, protocol tcpip.TransportProtocolNumber) uint32 {
+ t := r.LocalAddress
+ a := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ t = r.RemoteAddress
+ b := uint32(t[0]) | uint32(t[1])<<8 | uint32(t[2])<<16 | uint32(t[3])<<24
+ return hash.Hash3Words(a, b, uint32(protocol), hashIV)
+}
+
+var (
+ ids []uint32
+ hashIV uint32
+)
+
+func init() {
+ ids = make([]uint32, buckets)
+
+ // Randomly initialize hashIV and the ids.
+ r := hash.RandN32(1 + buckets)
+ for i := range ids {
+ ids[i] = r[i]
+ }
+ hashIV = r[buckets]
+
+ stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD
new file mode 100644
index 000000000..db7da0af3
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/BUILD
@@ -0,0 +1,21 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "ipv6",
+ srcs = [
+ "icmp.go",
+ "ipv6.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ ],
+)
diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go
new file mode 100644
index 000000000..0fc6dcce2
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/icmp.go
@@ -0,0 +1,80 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package ipv6
+
+import (
+ "encoding/binary"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// handleControl handles the case when an ICMP packet contains the headers of
+// the original packet that caused the ICMP one to be sent. This information is
+// used to find out which transport endpoint must be notified about the ICMP
+// packet.
+func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) {
+ h := header.IPv6(vv.First())
+
+ // We don't use IsValid() here because ICMP only requires that up to
+ // 1280 bytes of the original packet be included. So it's likely that it
+ // is truncated, which would cause IsValid to return false.
+ //
+ // Drop packet if it doesn't have the basic IPv6 header or if the
+ // original source address doesn't match the endpoint's address.
+ if len(h) < header.IPv6MinimumSize || h.SourceAddress() != e.id.LocalAddress {
+ return
+ }
+
+ // Skip the IP header, then handle the fragmentation header if there
+ // is one.
+ vv.TrimFront(header.IPv6MinimumSize)
+ p := h.TransportProtocol()
+ if p == header.IPv6FragmentHeader {
+ f := header.IPv6Fragment(vv.First())
+ if !f.IsValid() || f.FragmentOffset() != 0 {
+ // We can't handle fragments that aren't at offset 0
+ // because they don't have the transport headers.
+ return
+ }
+
+ // Skip fragmentation header and find out the actual protocol
+ // number.
+ vv.TrimFront(header.IPv6FragmentHeaderSize)
+ p = f.TransportProtocol()
+ }
+
+ // Deliver the control packet to the transport endpoint.
+ e.dispatcher.DeliverTransportControlPacket(e.id.LocalAddress, h.DestinationAddress(), ProtocolNumber, p, typ, extra, vv)
+}
+
+func (e *endpoint) handleICMP(r *stack.Route, vv *buffer.VectorisedView) {
+ v := vv.First()
+ if len(v) < header.ICMPv6MinimumSize {
+ return
+ }
+ h := header.ICMPv6(v)
+
+ switch h.Type() {
+ case header.ICMPv6PacketTooBig:
+ if len(v) < header.ICMPv6PacketTooBigMinimumSize {
+ return
+ }
+ vv.TrimFront(header.ICMPv6PacketTooBigMinimumSize)
+ mtu := binary.BigEndian.Uint32(v[header.ICMPv6MinimumSize:])
+ e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), vv)
+
+ case header.ICMPv6DstUnreachable:
+ if len(v) < header.ICMPv6DstUnreachableMinimumSize {
+ return
+ }
+ vv.TrimFront(header.ICMPv6DstUnreachableMinimumSize)
+ switch h.Code() {
+ case header.ICMPv6PortUnreachable:
+ e.handleControl(stack.ControlPortUnreachable, 0, vv)
+ }
+ }
+}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
new file mode 100644
index 000000000..15654cbbd
--- /dev/null
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -0,0 +1,172 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package ipv6 contains the implementation of the ipv6 network protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing ipv6.ProtocolName (or "ipv6") as one of the
+// network protocols when calling stack.New(). Then endpoints can be created
+// by passing ipv6.ProtocolNumber as the network protocol number when calling
+// Stack.NewEndpoint().
+package ipv6
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ // ProtocolName is the string representation of the ipv6 protocol name.
+ ProtocolName = "ipv6"
+
+ // ProtocolNumber is the ipv6 protocol number.
+ ProtocolNumber = header.IPv6ProtocolNumber
+
+ // maxTotalSize is maximum size that can be encoded in the 16-bit
+ // PayloadLength field of the ipv6 header.
+ maxPayloadSize = 0xffff
+)
+
+type address [header.IPv6AddressSize]byte
+
+type endpoint struct {
+ nicid tcpip.NICID
+ id stack.NetworkEndpointID
+ address address
+ linkEP stack.LinkEndpoint
+ dispatcher stack.TransportDispatcher
+}
+
+func newEndpoint(nicid tcpip.NICID, addr tcpip.Address, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) *endpoint {
+ e := &endpoint{nicid: nicid, linkEP: linkEP, dispatcher: dispatcher}
+ copy(e.address[:], addr)
+ e.id = stack.NetworkEndpointID{tcpip.Address(e.address[:])}
+ return e
+}
+
+// MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus
+// the network layer max header length.
+func (e *endpoint) MTU() uint32 {
+ return calculateMTU(e.linkEP.MTU())
+}
+
+// NICID returns the ID of the NIC this endpoint belongs to.
+func (e *endpoint) NICID() tcpip.NICID {
+ return e.nicid
+}
+
+// ID returns the ipv6 endpoint ID.
+func (e *endpoint) ID() *stack.NetworkEndpointID {
+ return &e.id
+}
+
+// Capabilities implements stack.NetworkEndpoint.Capabilities.
+func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return e.linkEP.Capabilities()
+}
+
+// MaxHeaderLength returns the maximum length needed by ipv6 headers (and
+// underlying protocols).
+func (e *endpoint) MaxHeaderLength() uint16 {
+ return e.linkEP.MaxHeaderLength() + header.IPv6MinimumSize
+}
+
+// WritePacket writes a packet to the given destination address and protocol.
+func (e *endpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+ length := uint16(hdr.UsedLength())
+ if payload != nil {
+ length += uint16(len(payload))
+ }
+ ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize))
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: length,
+ NextHeader: uint8(protocol),
+ HopLimit: 65,
+ SrcAddr: tcpip.Address(e.address[:]),
+ DstAddr: r.RemoteAddress,
+ })
+
+ return e.linkEP.WritePacket(r, hdr, payload, ProtocolNumber)
+}
+
+// HandlePacket is called by the link layer when new ipv6 packets arrive for
+// this endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) {
+ h := header.IPv6(vv.First())
+ if !h.IsValid(vv.Size()) {
+ return
+ }
+
+ vv.TrimFront(header.IPv6MinimumSize)
+ vv.CapLength(int(h.PayloadLength()))
+
+ p := h.TransportProtocol()
+ if p == header.ICMPv6ProtocolNumber {
+ e.handleICMP(r, vv)
+ return
+ }
+
+ e.dispatcher.DeliverTransportPacket(r, p, vv)
+}
+
+// Close cleans up resources associated with the endpoint.
+func (*endpoint) Close() {}
+
+type protocol struct{}
+
+// NewProtocol creates a new protocol ipv6 protocol descriptor. This is exported
+// only for tests that short-circuit the stack. Regular use of the protocol is
+// done via the stack, which gets a protocol descriptor from the init() function
+// below.
+func NewProtocol() stack.NetworkProtocol {
+ return &protocol{}
+}
+
+// Number returns the ipv6 protocol number.
+func (p *protocol) Number() tcpip.NetworkProtocolNumber {
+ return ProtocolNumber
+}
+
+// MinimumPacketSize returns the minimum valid ipv6 packet size.
+func (p *protocol) MinimumPacketSize() int {
+ return header.IPv6MinimumSize
+}
+
+// ParseAddresses implements NetworkProtocol.ParseAddresses.
+func (*protocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ h := header.IPv6(v)
+ return h.SourceAddress(), h.DestinationAddress()
+}
+
+// NewEndpoint creates a new ipv6 endpoint.
+func (p *protocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ return newEndpoint(nicid, addr, dispatcher, linkEP), nil
+}
+
+// SetOption implements NetworkProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements NetworkProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// calculateMTU calculates the network-layer payload MTU based on the link-layer
+// payload mtu.
+func calculateMTU(mtu uint32) uint32 {
+ mtu -= header.IPv6MinimumSize
+ if mtu <= maxPayloadSize {
+ return mtu
+ }
+ return maxPayloadSize
+}
+
+func init() {
+ stack.RegisterNetworkProtocolFactory(ProtocolName, func() stack.NetworkProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/ports/BUILD b/pkg/tcpip/ports/BUILD
new file mode 100644
index 000000000..e0140cea6
--- /dev/null
+++ b/pkg/tcpip/ports/BUILD
@@ -0,0 +1,20 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "ports",
+ srcs = ["ports.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/ports",
+ visibility = ["//:sandbox"],
+ deps = ["//pkg/tcpip"],
+)
+
+go_test(
+ name = "ports_test",
+ srcs = ["ports_test.go"],
+ embed = [":ports"],
+ deps = [
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/tcpip/ports/ports.go b/pkg/tcpip/ports/ports.go
new file mode 100644
index 000000000..24f3095d6
--- /dev/null
+++ b/pkg/tcpip/ports/ports.go
@@ -0,0 +1,148 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package ports provides PortManager that manages allocating, reserving and releasing ports.
+package ports
+
+import (
+ "math"
+ "math/rand"
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ // firstEphemeral is the first ephemeral port.
+ firstEphemeral uint16 = 16000
+
+ anyIPAddress = tcpip.Address("")
+)
+
+type portDescriptor struct {
+ network tcpip.NetworkProtocolNumber
+ transport tcpip.TransportProtocolNumber
+ port uint16
+}
+
+// PortManager manages allocating, reserving and releasing ports.
+type PortManager struct {
+ mu sync.RWMutex
+ allocatedPorts map[portDescriptor]bindAddresses
+}
+
+// bindAddresses is a set of IP addresses.
+type bindAddresses map[tcpip.Address]struct{}
+
+// isAvailable checks whether an IP address is available to bind to.
+func (b bindAddresses) isAvailable(addr tcpip.Address) bool {
+ if addr == anyIPAddress {
+ return len(b) == 0
+ }
+
+ // If all addresses for this portDescriptor are already bound, no
+ // address is available.
+ if _, ok := b[anyIPAddress]; ok {
+ return false
+ }
+
+ if _, ok := b[addr]; ok {
+ return false
+ }
+ return true
+}
+
+// NewPortManager creates new PortManager.
+func NewPortManager() *PortManager {
+ return &PortManager{allocatedPorts: make(map[portDescriptor]bindAddresses)}
+}
+
+// PickEphemeralPort randomly chooses a starting point and iterates over all
+// possible ephemeral ports, allowing the caller to decide whether a given port
+// is suitable for its needs, and stopping when a port is found or an error
+// occurs.
+func (s *PortManager) PickEphemeralPort(testPort func(p uint16) (bool, *tcpip.Error)) (port uint16, err *tcpip.Error) {
+ count := uint16(math.MaxUint16 - firstEphemeral + 1)
+ offset := uint16(rand.Int31n(int32(count)))
+
+ for i := uint16(0); i < count; i++ {
+ port = firstEphemeral + (offset+i)%count
+ ok, err := testPort(port)
+ if err != nil {
+ return 0, err
+ }
+
+ if ok {
+ return port, nil
+ }
+ }
+
+ return 0, tcpip.ErrNoPortAvailable
+}
+
+// ReservePort marks a port/IP combination as reserved so that it cannot be
+// reserved by another endpoint. If port is zero, ReservePort will search for
+// an unreserved ephemeral port and reserve it, returning its value in the
+// "port" return value.
+func (s *PortManager) ReservePort(network []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) (reservedPort uint16, err *tcpip.Error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // If a port is specified, just try to reserve it for all network
+ // protocols.
+ if port != 0 {
+ if !s.reserveSpecificPort(network, transport, addr, port) {
+ return 0, tcpip.ErrPortInUse
+ }
+ return port, nil
+ }
+
+ // A port wasn't specified, so try to find one.
+ return s.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ return s.reserveSpecificPort(network, transport, addr, p), nil
+ })
+}
+
+// reserveSpecificPort tries to reserve the given port on all given protocols.
+func (s *PortManager) reserveSpecificPort(network []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) bool {
+ // Check that the port is available on all network protocols.
+ desc := portDescriptor{0, transport, port}
+ for _, n := range network {
+ desc.network = n
+ if addrs, ok := s.allocatedPorts[desc]; ok {
+ if !addrs.isAvailable(addr) {
+ return false
+ }
+ }
+ }
+
+ // Reserve port on all network protocols.
+ for _, n := range network {
+ desc.network = n
+ m, ok := s.allocatedPorts[desc]
+ if !ok {
+ m = make(bindAddresses)
+ s.allocatedPorts[desc] = m
+ }
+ m[addr] = struct{}{}
+ }
+
+ return true
+}
+
+// ReleasePort releases the reservation on a port/IP combination so that it can
+// be reserved by other endpoints.
+func (s *PortManager) ReleasePort(network []tcpip.NetworkProtocolNumber, transport tcpip.TransportProtocolNumber, addr tcpip.Address, port uint16) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for _, n := range network {
+ desc := portDescriptor{n, transport, port}
+ m := s.allocatedPorts[desc]
+ delete(m, addr)
+ if len(m) == 0 {
+ delete(s.allocatedPorts, desc)
+ }
+ }
+}
diff --git a/pkg/tcpip/ports/ports_test.go b/pkg/tcpip/ports/ports_test.go
new file mode 100644
index 000000000..9a4c702b2
--- /dev/null
+++ b/pkg/tcpip/ports/ports_test.go
@@ -0,0 +1,134 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+package ports
+
+import (
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const (
+ fakeTransNumber tcpip.TransportProtocolNumber = 1
+ fakeNetworkNumber tcpip.NetworkProtocolNumber = 2
+
+ fakeIPAddress = tcpip.Address("\x08\x08\x08\x08")
+ fakeIPAddress1 = tcpip.Address("\x08\x08\x08\x09")
+)
+
+func TestPortReservation(t *testing.T) {
+ pm := NewPortManager()
+ net := []tcpip.NetworkProtocolNumber{fakeNetworkNumber}
+
+ for _, test := range []struct {
+ port uint16
+ ip tcpip.Address
+ want *tcpip.Error
+ }{
+ {
+ port: 80,
+ ip: fakeIPAddress,
+ want: nil,
+ },
+ {
+ port: 80,
+ ip: fakeIPAddress1,
+ want: nil,
+ },
+ {
+ /* N.B. Order of tests matters! */
+ port: 80,
+ ip: anyIPAddress,
+ want: tcpip.ErrPortInUse,
+ },
+ {
+ port: 22,
+ ip: anyIPAddress,
+ want: nil,
+ },
+ {
+ port: 22,
+ ip: fakeIPAddress,
+ want: tcpip.ErrPortInUse,
+ },
+ {
+ port: 0,
+ ip: fakeIPAddress,
+ want: nil,
+ },
+ {
+ port: 0,
+ ip: fakeIPAddress,
+ want: nil,
+ },
+ } {
+ gotPort, err := pm.ReservePort(net, fakeTransNumber, test.ip, test.port)
+ if err != test.want {
+ t.Fatalf("ReservePort(.., .., %s, %d) = %v, want %v", test.ip, test.port, err, test.want)
+ }
+ if test.port == 0 && (gotPort == 0 || gotPort < firstEphemeral) {
+ t.Fatalf("ReservePort(.., .., .., 0) = %d, want port number >= %d to be picked", gotPort, firstEphemeral)
+ }
+ }
+
+ // Release port 22 from any IP address, then try to reserve fake IP
+ // address on 22.
+ pm.ReleasePort(net, fakeTransNumber, anyIPAddress, 22)
+
+ if port, err := pm.ReservePort(net, fakeTransNumber, fakeIPAddress, 22); port != 22 || err != nil {
+ t.Fatalf("ReservePort(.., .., .., %d) = (port %d, err %v), want (22, nil); failed to reserve port after it should have been released", 22, port, err)
+ }
+}
+
+func TestPickEphemeralPort(t *testing.T) {
+ pm := NewPortManager()
+ customErr := &tcpip.Error{}
+ for _, test := range []struct {
+ name string
+ f func(port uint16) (bool, *tcpip.Error)
+ wantErr *tcpip.Error
+ wantPort uint16
+ }{
+ {
+ name: "no-port-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ {
+ name: "port-tester-error",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ return false, customErr
+ },
+ wantErr: customErr,
+ },
+ {
+ name: "only-port-16042-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port == firstEphemeral+42 {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantPort: firstEphemeral + 42,
+ },
+ {
+ name: "only-port-under-16000-available",
+ f: func(port uint16) (bool, *tcpip.Error) {
+ if port < firstEphemeral {
+ return true, nil
+ }
+ return false, nil
+ },
+ wantErr: tcpip.ErrNoPortAvailable,
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ if port, err := pm.PickEphemeralPort(test.f); port != test.wantPort || err != test.wantErr {
+ t.Errorf("PickEphemeralPort(..) = (port %d, err %v); want (port %d, err %v)", port, err, test.wantPort, test.wantErr)
+ }
+ })
+ }
+}
diff --git a/pkg/tcpip/sample/tun_tcp_connect/BUILD b/pkg/tcpip/sample/tun_tcp_connect/BUILD
new file mode 100644
index 000000000..870ee0433
--- /dev/null
+++ b/pkg/tcpip/sample/tun_tcp_connect/BUILD
@@ -0,0 +1,20 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+
+go_binary(
+ name = "tun_tcp_connect",
+ srcs = ["main.go"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/link/fdbased",
+ "//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/link/tun",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
new file mode 100644
index 000000000..332929c85
--- /dev/null
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -0,0 +1,208 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
+// device, and connects to a peer. Similar to "nc <address> <port>". While the
+// sample is running, attempts to connect to its IPv4 address will result in
+// a RST segment.
+//
+// As an example of how to run it, a TUN device can be created and enabled on
+// a linux host as follows (this only needs to be done once per boot):
+//
+// [sudo] ip tuntap add user <username> mode tun <device-name>
+// [sudo] ip link set <device-name> up
+// [sudo] ip addr add <ipv4-address>/<mask-length> dev <device-name>
+//
+// A concrete example:
+//
+// $ sudo ip tuntap add user wedsonaf mode tun tun0
+// $ sudo ip link set tun0 up
+// $ sudo ip addr add 192.168.1.1/24 dev tun0
+//
+// Then one can run tun_tcp_connect as such:
+//
+// $ ./tun/tun_tcp_connect tun0 192.168.1.2 0 192.168.1.1 1234
+//
+// This will attempt to connect to the linux host's stack. One can run nc in
+// listen mode to accept a connect from tun_tcp_connect and exchange data.
+package main
+
+import (
+ "bufio"
+ "fmt"
+ "log"
+ "math/rand"
+ "net"
+ "os"
+ "strconv"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/fdbased"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/tun"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// writer reads from standard input and writes to the endpoint until standard
+// input is closed. It signals that it's done by closing the provided channel.
+func writer(ch chan struct{}, ep tcpip.Endpoint) {
+ defer func() {
+ ep.Shutdown(tcpip.ShutdownWrite)
+ close(ch)
+ }()
+
+ r := bufio.NewReader(os.Stdin)
+ for {
+ v := buffer.NewView(1024)
+ n, err := r.Read(v)
+ if err != nil {
+ return
+ }
+
+ v.CapLength(n)
+ for len(v) > 0 {
+ n, err := ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ if err != nil {
+ fmt.Println("Write failed:", err)
+ return
+ }
+
+ v.TrimFront(int(n))
+ }
+ }
+}
+
+func main() {
+ if len(os.Args) != 6 {
+ log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-ipv4-address> <local-port> <remote-ipv4-address> <remote-port>")
+ }
+
+ tunName := os.Args[1]
+ addrName := os.Args[2]
+ portName := os.Args[3]
+ remoteAddrName := os.Args[4]
+ remotePortName := os.Args[5]
+
+ rand.Seed(time.Now().UnixNano())
+
+ addr := tcpip.Address(net.ParseIP(addrName).To4())
+ remote := tcpip.FullAddress{
+ NIC: 1,
+ Addr: tcpip.Address(net.ParseIP(remoteAddrName).To4()),
+ }
+
+ var localPort uint16
+ if v, err := strconv.Atoi(portName); err != nil {
+ log.Fatalf("Unable to convert port %v: %v", portName, err)
+ } else {
+ localPort = uint16(v)
+ }
+
+ if v, err := strconv.Atoi(remotePortName); err != nil {
+ log.Fatalf("Unable to convert port %v: %v", remotePortName, err)
+ } else {
+ remote.Port = uint16(v)
+ }
+
+ // Create the stack with ipv4 and tcp protocols, then add a tun-based
+ // NIC and ipv4 address.
+ s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName})
+
+ mtu, err := rawfile.GetMTU(tunName)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fd, err := tun.Open(tunName)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ linkID := fdbased.New(&fdbased.Options{FD: fd, MTU: mtu})
+ if err := s.CreateNIC(1, sniffer.New(linkID)); err != nil {
+ log.Fatal(err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, addr); err != nil {
+ log.Fatal(err)
+ }
+
+ // Add default route.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: "\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ },
+ })
+
+ // Create TCP endpoint.
+ var wq waiter.Queue
+ ep, e := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ log.Fatal(e)
+ }
+
+ // Bind if a port is specified.
+ if localPort != 0 {
+ if err := ep.Bind(tcpip.FullAddress{0, "", localPort}, nil); err != nil {
+ log.Fatal("Bind failed: ", err)
+ }
+ }
+
+ // Issue connect request and wait for it to complete.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventOut)
+ terr := ep.Connect(remote)
+ if terr == tcpip.ErrConnectStarted {
+ fmt.Println("Connect is pending...")
+ <-notifyCh
+ terr = ep.GetSockOpt(tcpip.ErrorOption{})
+ }
+ wq.EventUnregister(&waitEntry)
+
+ if terr != nil {
+ log.Fatal("Unable to connect: ", terr)
+ }
+
+ fmt.Println("Connected")
+
+ // Start the writer in its own goroutine.
+ writerCompletedCh := make(chan struct{})
+ go writer(writerCompletedCh, ep) // S/R-FIXME
+
+ // Read data and write to standard output until the peer closes the
+ // connection from its side.
+ wq.EventRegister(&waitEntry, waiter.EventIn)
+ for {
+ v, err := ep.Read(nil)
+ if err != nil {
+ if err == tcpip.ErrClosedForReceive {
+ break
+ }
+
+ if err == tcpip.ErrWouldBlock {
+ <-notifyCh
+ continue
+ }
+
+ log.Fatal("Read() failed:", err)
+ }
+
+ os.Stdout.Write(v)
+ }
+ wq.EventUnregister(&waitEntry)
+
+ // The reader has completed. Now wait for the writer as well.
+ <-writerCompletedCh
+
+ ep.Close()
+}
diff --git a/pkg/tcpip/sample/tun_tcp_echo/BUILD b/pkg/tcpip/sample/tun_tcp_echo/BUILD
new file mode 100644
index 000000000..c51528a12
--- /dev/null
+++ b/pkg/tcpip/sample/tun_tcp_echo/BUILD
@@ -0,0 +1,20 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_binary")
+
+go_binary(
+ name = "tun_tcp_echo",
+ srcs = ["main.go"],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/link/fdbased",
+ "//pkg/tcpip/link/rawfile",
+ "//pkg/tcpip/link/tun",
+ "//pkg/tcpip/network/arp",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go
new file mode 100644
index 000000000..10cd701af
--- /dev/null
+++ b/pkg/tcpip/sample/tun_tcp_echo/main.go
@@ -0,0 +1,182 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// This sample creates a stack with TCP and IPv4 protocols on top of a TUN
+// device, and listens on a port. Data received by the server in the accepted
+// connections is echoed back to the clients.
+package main
+
+import (
+ "flag"
+ "log"
+ "math/rand"
+ "net"
+ "os"
+ "strconv"
+ "strings"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/fdbased"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/rawfile"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/tun"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/arp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+var tap = flag.Bool("tap", false, "use tap istead of tun")
+var mac = flag.String("mac", "aa:00:01:01:01:01", "mac address to use in tap device")
+
+func echo(wq *waiter.Queue, ep tcpip.Endpoint) {
+ defer ep.Close()
+
+ // Create wait queue entry that notifies a channel.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+
+ wq.EventRegister(&waitEntry, waiter.EventIn)
+ defer wq.EventUnregister(&waitEntry)
+
+ for {
+ v, err := ep.Read(nil)
+ if err != nil {
+ if err == tcpip.ErrWouldBlock {
+ <-notifyCh
+ continue
+ }
+
+ return
+ }
+
+ ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{})
+ }
+}
+
+func main() {
+ flag.Parse()
+ if len(flag.Args()) != 3 {
+ log.Fatal("Usage: ", os.Args[0], " <tun-device> <local-address> <local-port>")
+ }
+
+ tunName := flag.Arg(0)
+ addrName := flag.Arg(1)
+ portName := flag.Arg(2)
+
+ rand.Seed(time.Now().UnixNano())
+
+ // Parse the mac address.
+ maddr, err := net.ParseMAC(*mac)
+ if err != nil {
+ log.Fatalf("Bad MAC address: %v", *mac)
+ }
+
+ // Parse the IP address. Support both ipv4 and ipv6.
+ parsedAddr := net.ParseIP(addrName)
+ if parsedAddr == nil {
+ log.Fatalf("Bad IP address: %v", addrName)
+ }
+
+ var addr tcpip.Address
+ var proto tcpip.NetworkProtocolNumber
+ if parsedAddr.To4() != nil {
+ addr = tcpip.Address(parsedAddr.To4())
+ proto = ipv4.ProtocolNumber
+ } else if parsedAddr.To16() != nil {
+ addr = tcpip.Address(parsedAddr.To16())
+ proto = ipv6.ProtocolNumber
+ } else {
+ log.Fatalf("Unknown IP type: %v", addrName)
+ }
+
+ localPort, err := strconv.Atoi(portName)
+ if err != nil {
+ log.Fatalf("Unable to convert port %v: %v", portName, err)
+ }
+
+ // Create the stack with ip and tcp protocols, then add a tun-based
+ // NIC and address.
+ s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName, arp.ProtocolName}, []string{tcp.ProtocolName})
+
+ mtu, err := rawfile.GetMTU(tunName)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ var fd int
+ if *tap {
+ fd, err = tun.OpenTAP(tunName)
+ } else {
+ fd, err = tun.Open(tunName)
+ }
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ linkID := fdbased.New(&fdbased.Options{
+ FD: fd,
+ MTU: mtu,
+ EthernetHeader: *tap,
+ Address: tcpip.LinkAddress(maddr),
+ })
+ if err := s.CreateNIC(1, linkID); err != nil {
+ log.Fatal(err)
+ }
+
+ if err := s.AddAddress(1, proto, addr); err != nil {
+ log.Fatal(err)
+ }
+
+ if err := s.AddAddress(1, arp.ProtocolNumber, arp.ProtocolAddress); err != nil {
+ log.Fatal(err)
+ }
+
+ // Add default route.
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: tcpip.Address(strings.Repeat("\x00", len(addr))),
+ Mask: tcpip.Address(strings.Repeat("\x00", len(addr))),
+ Gateway: "",
+ NIC: 1,
+ },
+ })
+
+ // Create TCP endpoint, bind it, then start listening.
+ var wq waiter.Queue
+ ep, e := s.NewEndpoint(tcp.ProtocolNumber, proto, &wq)
+ if err != nil {
+ log.Fatal(e)
+ }
+
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{0, "", uint16(localPort)}, nil); err != nil {
+ log.Fatal("Bind failed: ", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ log.Fatal("Listen failed: ", err)
+ }
+
+ // Wait for connections to appear.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventIn)
+ defer wq.EventUnregister(&waitEntry)
+
+ for {
+ n, wq, err := ep.Accept()
+ if err != nil {
+ if err == tcpip.ErrWouldBlock {
+ <-notifyCh
+ continue
+ }
+
+ log.Fatal("Accept() failed:", err)
+ }
+
+ go echo(wq, n) // S/R-FIXME
+ }
+}
diff --git a/pkg/tcpip/seqnum/BUILD b/pkg/tcpip/seqnum/BUILD
new file mode 100644
index 000000000..0c717ec8d
--- /dev/null
+++ b/pkg/tcpip/seqnum/BUILD
@@ -0,0 +1,26 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "seqnum_state",
+ srcs = [
+ "seqnum.go",
+ ],
+ out = "seqnum_state.go",
+ package = "seqnum",
+)
+
+go_library(
+ name = "seqnum",
+ srcs = [
+ "seqnum.go",
+ "seqnum_state.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = ["//pkg/state"],
+)
diff --git a/pkg/tcpip/seqnum/seqnum.go b/pkg/tcpip/seqnum/seqnum.go
new file mode 100644
index 000000000..f689be984
--- /dev/null
+++ b/pkg/tcpip/seqnum/seqnum.go
@@ -0,0 +1,57 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package seqnum defines the types and methods for TCP sequence numbers such
+// that they fit in 32-bit words and work properly when overflows occur.
+package seqnum
+
+// Value represents the value of a sequence number.
+type Value uint32
+
+// Size represents the size (length) of a sequence number window.
+type Size uint32
+
+// LessThan checks if v is before w, i.e., v < w.
+func (v Value) LessThan(w Value) bool {
+ return int32(v-w) < 0
+}
+
+// LessThanEq returns true if v==w or v is before i.e., v < w.
+func (v Value) LessThanEq(w Value) bool {
+ if v == w {
+ return true
+ }
+ return v.LessThan(w)
+}
+
+// InRange checks if v is in the range [a,b), i.e., a <= v < b.
+func (v Value) InRange(a, b Value) bool {
+ return v-a < b-a
+}
+
+// InWindow checks if v is in the window that starts at 'first' and spans 'size'
+// sequence numbers.
+func (v Value) InWindow(first Value, size Size) bool {
+ return v.InRange(first, first.Add(size))
+}
+
+// Overlap checks if the window [a,a+b) overlaps with the window [x, x+y).
+func Overlap(a Value, b Size, x Value, y Size) bool {
+ return a.LessThan(x.Add(y)) && x.LessThan(a.Add(b))
+}
+
+// Add calculates the sequence number following the [v, v+s) window.
+func (v Value) Add(s Size) Value {
+ return v + Value(s)
+}
+
+// Size calculates the size of the window defined by [v, w).
+func (v Value) Size(w Value) Size {
+ return Size(w - v)
+}
+
+// UpdateForward updates v such that it becomes v + s.
+func (v *Value) UpdateForward(s Size) {
+ *v += Value(s)
+}
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
new file mode 100644
index 000000000..079ade2c8
--- /dev/null
+++ b/pkg/tcpip/stack/BUILD
@@ -0,0 +1,70 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "stack_state",
+ srcs = [
+ "registration.go",
+ "stack.go",
+ ],
+ out = "stack_state.go",
+ package = "stack",
+)
+
+go_library(
+ name = "stack",
+ srcs = [
+ "linkaddrcache.go",
+ "nic.go",
+ "registration.go",
+ "route.go",
+ "stack.go",
+ "stack_global_state.go",
+ "stack_state.go",
+ "transport_demuxer.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/stack",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/ilist",
+ "//pkg/sleep",
+ "//pkg/state",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/ports",
+ "//pkg/tcpip/seqnum",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "stack_x_test",
+ size = "small",
+ srcs = [
+ "stack_test.go",
+ "transport_test.go",
+ ],
+ deps = [
+ ":stack",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/link/channel",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "stack_test",
+ size = "small",
+ srcs = ["linkaddrcache_test.go"],
+ embed = [":stack"],
+ deps = [
+ "//pkg/sleep",
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
new file mode 100644
index 000000000..789f97882
--- /dev/null
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -0,0 +1,313 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const linkAddrCacheSize = 512 // max cache entries
+
+// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses.
+//
+// The entries are stored in a ring buffer, oldest entry replaced first.
+//
+// This struct is safe for concurrent use.
+type linkAddrCache struct {
+ // ageLimit is how long a cache entry is valid for.
+ ageLimit time.Duration
+
+ // resolutionTimeout is the amount of time to wait for a link request to
+ // resolve an address.
+ resolutionTimeout time.Duration
+
+ // resolutionAttempts is the number of times an address is attempted to be
+ // resolved before failing.
+ resolutionAttempts int
+
+ mu sync.Mutex
+ cache map[tcpip.FullAddress]*linkAddrEntry
+ next int // array index of next available entry
+ entries [linkAddrCacheSize]linkAddrEntry
+}
+
+// entryState controls the state of a single entry in the cache.
+type entryState int
+
+const (
+ // incomplete means that there is an outstanding request to resolve the
+ // address. This is the initial state.
+ incomplete entryState = iota
+ // ready means that the address has been resolved and can be used.
+ ready
+ // failed means that address resolution timed out and the address
+ // could not be resolved.
+ failed
+ // expired means that the cache entry has expired and the address must be
+ // resolved again.
+ expired
+)
+
+// String implements Stringer.
+func (s entryState) String() string {
+ switch s {
+ case incomplete:
+ return "incomplete"
+ case ready:
+ return "ready"
+ case failed:
+ return "failed"
+ case expired:
+ return "expired"
+ default:
+ return fmt.Sprintf("invalid entryState: %d", s)
+ }
+}
+
+// A linkAddrEntry is an entry in the linkAddrCache.
+// This struct is thread-compatible.
+type linkAddrEntry struct {
+ addr tcpip.FullAddress
+ linkAddr tcpip.LinkAddress
+ expiration time.Time
+ s entryState
+
+ // wakers is a set of waiters for address resolution result. Anytime
+ // state transitions out of 'incomplete' these waiters are notified.
+ wakers map[*sleep.Waker]struct{}
+
+ cancel chan struct{}
+}
+
+func (e *linkAddrEntry) state() entryState {
+ if e.s != expired && time.Now().After(e.expiration) {
+ // Force the transition to ensure waiters are notified.
+ e.changeState(expired)
+ }
+ return e.s
+}
+
+func (e *linkAddrEntry) changeState(ns entryState) {
+ if e.s == ns {
+ return
+ }
+
+ // Validate state transition.
+ switch e.s {
+ case incomplete:
+ // All transitions are valid.
+ case ready, failed:
+ if ns != expired {
+ panic(fmt.Sprintf("invalid state transition from %v to %v", e.s, ns))
+ }
+ case expired:
+ // Terminal state.
+ panic(fmt.Sprintf("invalid state transition from %v to %v", e.s, ns))
+ default:
+ panic(fmt.Sprintf("invalid state: %v", e.s))
+ }
+
+ // Notify whoever is waiting on address resolution when transitioning
+ // out of 'incomplete'.
+ if e.s == incomplete {
+ for w := range e.wakers {
+ w.Assert()
+ }
+ e.wakers = nil
+ }
+ e.s = ns
+}
+
+func (e *linkAddrEntry) addWaker(w *sleep.Waker) {
+ e.wakers[w] = struct{}{}
+}
+
+func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
+ delete(e.wakers, w)
+}
+
+// add adds a k -> v mapping to the cache.
+func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ entry := c.cache[k]
+ if entry != nil {
+ s := entry.state()
+ if s != expired && entry.linkAddr == v {
+ // Disregard repeated calls.
+ return
+ }
+ // Check if entry is waiting for address resolution.
+ if s == incomplete {
+ entry.linkAddr = v
+ } else {
+ // Otherwise create a new entry to replace it.
+ entry = c.makeAndAddEntry(k, v)
+ }
+ } else {
+ entry = c.makeAndAddEntry(k, v)
+ }
+
+ entry.changeState(ready)
+}
+
+// makeAndAddEntry is a helper function to create and add a new
+// entry to the cache map and evict older entry as needed.
+func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry {
+ // Take over the next entry.
+ entry := &c.entries[c.next]
+ if c.cache[entry.addr] == entry {
+ delete(c.cache, entry.addr)
+ }
+
+ // Mark the soon-to-be-replaced entry as expired, just in case there is
+ // someone waiting for address resolution on it.
+ entry.changeState(expired)
+ if entry.cancel != nil {
+ entry.cancel <- struct{}{}
+ }
+
+ *entry = linkAddrEntry{
+ addr: k,
+ linkAddr: v,
+ expiration: time.Now().Add(c.ageLimit),
+ wakers: make(map[*sleep.Waker]struct{}),
+ cancel: make(chan struct{}, 1),
+ }
+
+ c.cache[k] = entry
+ c.next++
+ if c.next == len(c.entries) {
+ c.next = 0
+ }
+ return entry
+}
+
+// get reports any known link address for k.
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) {
+ if linkRes != nil {
+ if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
+ return addr, nil
+ }
+ }
+
+ c.mu.Lock()
+ entry := c.cache[k]
+ if entry == nil || entry.state() == expired {
+ c.mu.Unlock()
+ if linkRes == nil {
+ return "", tcpip.ErrNoLinkAddress
+ }
+ c.startAddressResolution(k, linkRes, localAddr, linkEP, waker)
+ return "", tcpip.ErrWouldBlock
+ }
+ defer c.mu.Unlock()
+
+ switch s := entry.state(); s {
+ case expired:
+ // It's possible that entry expired between state() call above and here
+ // in that case it's safe to consider it ready.
+ fallthrough
+ case ready:
+ return entry.linkAddr, nil
+ case failed:
+ return "", tcpip.ErrNoLinkAddress
+ case incomplete:
+ // Address resolution is still in progress.
+ entry.addWaker(waker)
+ return "", tcpip.ErrWouldBlock
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %d", s))
+ }
+}
+
+// removeWaker removes a waker previously added through get().
+func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if entry := c.cache[k]; entry != nil {
+ entry.removeWaker(waker)
+ }
+}
+
+func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ // Look up again with lock held to ensure entry wasn't added by someone else.
+ if e := c.cache[k]; e != nil && e.state() != expired {
+ return
+ }
+
+ // Add 'incomplete' entry in the cache to mark that resolution is in progress.
+ e := c.makeAndAddEntry(k, "")
+ e.addWaker(waker)
+
+ go func() { // S/R-FIXME
+ for i := 0; ; i++ {
+ // Send link request, then wait for the timeout limit and check
+ // whether the request succeeded.
+ linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
+ c.mu.Lock()
+ cancel := e.cancel
+ c.mu.Unlock()
+
+ select {
+ case <-time.After(c.resolutionTimeout):
+ if stop := c.checkLinkRequest(k, i); stop {
+ return
+ }
+ case <-cancel:
+ return
+ }
+ }
+ }()
+}
+
+// checkLinkRequest checks whether previous attempt to resolve address has succeeded
+// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
+// can stop, false if another request should be sent.
+func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ entry, ok := c.cache[k]
+ if !ok {
+ // Entry was evicted from the cache.
+ return true
+ }
+
+ switch s := entry.state(); s {
+ case ready, failed, expired:
+ // Entry was made ready by resolver or failed. Either way we're done.
+ return true
+ case incomplete:
+ if attempt+1 >= c.resolutionAttempts {
+ // Max number of retries reached, mark entry as failed.
+ entry.changeState(failed)
+ return true
+ }
+ // No response yet, need to send another ARP request.
+ return false
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %d", s))
+ }
+}
+
+func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
+ return &linkAddrCache{
+ ageLimit: ageLimit,
+ resolutionTimeout: resolutionTimeout,
+ resolutionAttempts: resolutionAttempts,
+ cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
+ }
+}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
new file mode 100644
index 000000000..e9897b2bd
--- /dev/null
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -0,0 +1,256 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+import (
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+type testaddr struct {
+ addr tcpip.FullAddress
+ linkAddr tcpip.LinkAddress
+}
+
+var testaddrs []testaddr
+
+type testLinkAddressResolver struct {
+ cache *linkAddrCache
+ delay time.Duration
+}
+
+func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+ go func() {
+ if r.delay > 0 {
+ time.Sleep(r.delay)
+ }
+ r.fakeRequest(addr)
+ }()
+ return nil
+}
+
+func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
+ for _, ta := range testaddrs {
+ if ta.addr.Addr == addr {
+ r.cache.add(ta.addr, ta.linkAddr)
+ break
+ }
+ }
+}
+
+func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "broadcast" {
+ return "mac_broadcast", true
+ }
+ return "", false
+}
+
+func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return 1
+}
+
+func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ s.AddWaker(&w, 123)
+ defer s.Done()
+
+ for {
+ if got, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
+ return got, err
+ }
+ s.Fetch(true)
+ }
+}
+
+func init() {
+ for i := 0; i < 4*linkAddrCacheSize; i++ {
+ addr := fmt.Sprintf("Addr%06d", i)
+ testaddrs = append(testaddrs, testaddr{
+ addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
+ linkAddr: tcpip.LinkAddress("Link" + addr),
+ })
+ }
+}
+
+func TestCacheOverflow(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ for i := len(testaddrs) - 1; i >= 0; i-- {
+ e := testaddrs[i]
+ c.add(e.addr, e.linkAddr)
+ got, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
+ }
+ }
+ // Expect to find at least half of the most recent entries.
+ for i := 0; i < linkAddrCacheSize/2; i++ {
+ e := testaddrs[i]
+ got, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
+ }
+ }
+ // The earliest entries should no longer be in the cache.
+ for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
+ e := testaddrs[i]
+ if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
+ }
+ }
+}
+
+func TestCacheConcurrent(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+
+ var wg sync.WaitGroup
+ for r := 0; r < 16; r++ {
+ wg.Add(1)
+ go func() {
+ for _, e := range testaddrs {
+ c.add(e.addr, e.linkAddr)
+ c.get(e.addr, nil, "", nil, nil) // make work for gotsan
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+
+ // All goroutines add in the same order and add more values than
+ // can fit in the cache, so our eviction strategy requires that
+ // the last entry be present and the first be missing.
+ e := testaddrs[len(testaddrs)-1]
+ got, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+
+ e = testaddrs[0]
+ if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+}
+
+func TestCacheAgeLimit(t *testing.T) {
+ c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
+ e := testaddrs[0]
+ c.add(e.addr, e.linkAddr)
+ time.Sleep(50 * time.Millisecond)
+ if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+}
+
+func TestCacheReplace(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ e := testaddrs[0]
+ l2 := e.linkAddr + "2"
+ c.add(e.addr, e.linkAddr)
+ got, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+
+ c.add(e.addr, l2)
+ got, err = c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != l2 {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2)
+ }
+}
+
+func TestCacheResolution(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
+ linkRes := &testLinkAddressResolver{cache: c}
+ for i, ta := range testaddrs {
+ got, err := getBlocking(c, ta.addr, linkRes)
+ if err != nil {
+ t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
+ }
+ if got != ta.linkAddr {
+ t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr)
+ }
+ }
+
+ // Check that after resolved, address stays in the cache and never returns WouldBlock.
+ for i := 0; i < 10; i++ {
+ e := testaddrs[len(testaddrs)-1]
+ got, err := c.get(e.addr, linkRes, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+ }
+}
+
+func TestCacheResolutionFailed(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
+ linkRes := &testLinkAddressResolver{cache: c}
+
+ // First, sanity check that resolution is working...
+ e := testaddrs[0]
+ got, err := getBlocking(c, e.addr, linkRes)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+
+ e.addr.Addr += "2"
+ if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+}
+
+func TestCacheResolutionTimeout(t *testing.T) {
+ resolverDelay := 50 * time.Millisecond
+ expiration := resolverDelay / 2
+ c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
+ linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
+
+ e := testaddrs[0]
+ if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+}
+
+// TestStaticResolution checks that static link addresses are resolved immediately and don't
+// send resolution requests.
+func TestStaticResolution(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, time.Millisecond, 1)
+ linkRes := &testLinkAddressResolver{cache: c, delay: time.Minute}
+
+ addr := tcpip.Address("broadcast")
+ want := tcpip.LinkAddress("mac_broadcast")
+ got, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err)
+ }
+ if got != want {
+ t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
new file mode 100644
index 000000000..8ff4310d5
--- /dev/null
+++ b/pkg/tcpip/stack/nic.go
@@ -0,0 +1,453 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+import (
+ "strings"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/ilist"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+// NIC represents a "network interface card" to which the networking stack is
+// attached.
+type NIC struct {
+ stack *Stack
+ id tcpip.NICID
+ name string
+ linkEP LinkEndpoint
+
+ demux *transportDemuxer
+
+ mu sync.RWMutex
+ spoofing bool
+ promiscuous bool
+ primary map[tcpip.NetworkProtocolNumber]*ilist.List
+ endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
+ subnets []tcpip.Subnet
+}
+
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC {
+ return &NIC{
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ demux: newTransportDemuxer(stack),
+ primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
+ endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
+ }
+}
+
+// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
+// to start delivering packets.
+func (n *NIC) attachLinkEndpoint() {
+ n.linkEP.Attach(n)
+}
+
+// setPromiscuousMode enables or disables promiscuous mode.
+func (n *NIC) setPromiscuousMode(enable bool) {
+ n.mu.Lock()
+ n.promiscuous = enable
+ n.mu.Unlock()
+}
+
+// setSpoofing enables or disables address spoofing.
+func (n *NIC) setSpoofing(enable bool) {
+ n.mu.Lock()
+ n.spoofing = enable
+ n.mu.Unlock()
+}
+
+// primaryEndpoint returns the primary endpoint of n for the given network
+// protocol.
+func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ list := n.primary[protocol]
+ if list == nil {
+ return nil
+ }
+
+ for e := list.Front(); e != nil; e = e.Next() {
+ r := e.(*referencedNetworkEndpoint)
+ if r.tryIncRef() {
+ return r
+ }
+ }
+
+ return nil
+}
+
+// findEndpoint finds the endpoint, if any, with the given address.
+func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address) *referencedNetworkEndpoint {
+ id := NetworkEndpointID{address}
+
+ n.mu.RLock()
+ ref := n.endpoints[id]
+ if ref != nil && !ref.tryIncRef() {
+ ref = nil
+ }
+ spoofing := n.spoofing
+ n.mu.RUnlock()
+
+ if ref != nil || !spoofing {
+ return ref
+ }
+
+ // Try again with the lock in exclusive mode. If we still can't get the
+ // endpoint, create a new "temporary" endpoint. It will only exist while
+ // there's a route through it.
+ n.mu.Lock()
+ ref = n.endpoints[id]
+ if ref == nil || !ref.tryIncRef() {
+ ref, _ = n.addAddressLocked(protocol, address, true)
+ if ref != nil {
+ ref.holdsInsertRef = false
+ }
+ }
+ n.mu.Unlock()
+ return ref
+}
+
+func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ // Create the new network endpoint.
+ ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP)
+ if err != nil {
+ return nil, err
+ }
+
+ id := *ep.ID()
+ if ref, ok := n.endpoints[id]; ok {
+ if !replace {
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
+ n.removeEndpointLocked(ref)
+ }
+
+ ref := &referencedNetworkEndpoint{
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocol,
+ holdsInsertRef: true,
+ }
+
+ // Set up cache if link address resolution exists for this protocol.
+ if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes := n.stack.linkAddrResolvers[protocol]; linkRes != nil {
+ ref.linkCache = n.stack
+ }
+ }
+
+ n.endpoints[id] = ref
+
+ l, ok := n.primary[protocol]
+ if !ok {
+ l = &ilist.List{}
+ n.primary[protocol] = l
+ }
+
+ l.PushBack(ref)
+
+ return ref, nil
+}
+
+// AddAddress adds a new address to n, so that it starts accepting packets
+// targeted at the given address (and network protocol).
+func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ // Add the endpoint.
+ n.mu.Lock()
+ _, err := n.addAddressLocked(protocol, addr, false)
+ n.mu.Unlock()
+
+ return err
+}
+
+// Addresses returns the addresses associated with this NIC.
+func (n *NIC) Addresses() []tcpip.ProtocolAddress {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
+ for nid, ep := range n.endpoints {
+ addrs = append(addrs, tcpip.ProtocolAddress{
+ Protocol: ep.protocol,
+ Address: nid.LocalAddress,
+ })
+ }
+ return addrs
+}
+
+// AddSubnet adds a new subnet to n, so that it starts accepting packets
+// targeted at the given address and network protocol.
+func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
+ n.mu.Lock()
+ n.subnets = append(n.subnets, subnet)
+ n.mu.Unlock()
+}
+
+// Subnets returns the Subnets associated with this NIC.
+func (n *NIC) Subnets() []tcpip.Subnet {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints))
+ for nid := range n.endpoints {
+ sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
+ if err != nil {
+ // This should never happen as the mask has been carefully crafted to
+ // match the address.
+ panic("Invalid endpoint subnet: " + err.Error())
+ }
+ sns = append(sns, sn)
+ }
+ return append(sns, n.subnets...)
+}
+
+func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
+ id := *r.ep.ID()
+
+ // Nothing to do if the reference has already been replaced with a
+ // different one.
+ if n.endpoints[id] != r {
+ return
+ }
+
+ if r.holdsInsertRef {
+ panic("Reference count dropped to zero before being removed")
+ }
+
+ delete(n.endpoints, id)
+ n.primary[r.protocol].Remove(r)
+ r.ep.Close()
+}
+
+func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
+ n.mu.Lock()
+ n.removeEndpointLocked(r)
+ n.mu.Unlock()
+}
+
+// RemoveAddress removes an address from n.
+func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ r := n.endpoints[NetworkEndpointID{addr}]
+ if r == nil || !r.holdsInsertRef {
+ n.mu.Unlock()
+ return tcpip.ErrBadLocalAddress
+ }
+
+ r.holdsInsertRef = false
+ n.mu.Unlock()
+
+ r.decRef()
+
+ return nil
+}
+
+// DeliverNetworkPacket finds the appropriate network protocol endpoint and
+// hands the packet over for further processing. This function is called when
+// the NIC receives a packet from the physical interface.
+// Note that the ownership of the slice backing vv is retained by the caller.
+// This rule applies only to the slice itself, not to the items of the slice;
+// the ownership of the items is not retained by the caller.
+func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ atomic.AddUint64(&n.stack.stats.UnknownProtocolRcvdPackets, 1)
+ return
+ }
+
+ if len(vv.First()) < netProto.MinimumPacketSize() {
+ atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ return
+ }
+
+ src, dst := netProto.ParseAddresses(vv.First())
+ id := NetworkEndpointID{dst}
+
+ n.mu.RLock()
+ ref := n.endpoints[id]
+ if ref != nil && !ref.tryIncRef() {
+ ref = nil
+ }
+ promiscuous := n.promiscuous
+ subnets := n.subnets
+ n.mu.RUnlock()
+
+ if ref == nil {
+ // Check if the packet is for a subnet this NIC cares about.
+ if !promiscuous {
+ for _, sn := range subnets {
+ if sn.Contains(dst) {
+ promiscuous = true
+ break
+ }
+ }
+ }
+ if promiscuous {
+ // Try again with the lock in exclusive mode. If we still can't
+ // get the endpoint, create a new "temporary" one. It will only
+ // exist while there's a route through it.
+ n.mu.Lock()
+ ref = n.endpoints[id]
+ if ref == nil || !ref.tryIncRef() {
+ ref, _ = n.addAddressLocked(protocol, dst, true)
+ if ref != nil {
+ ref.holdsInsertRef = false
+ }
+ }
+ n.mu.Unlock()
+ }
+ }
+
+ if ref == nil {
+ atomic.AddUint64(&n.stack.stats.UnknownNetworkEndpointRcvdPackets, 1)
+ return
+ }
+
+ r := makeRoute(protocol, dst, src, ref)
+ r.LocalLinkAddress = linkEP.LinkAddress()
+ r.RemoteLinkAddress = remoteLinkAddr
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+}
+
+// DeliverTransportPacket delivers the packets to the appropriate transport
+// protocol endpoint.
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView) {
+ state, ok := n.stack.transportProtocols[protocol]
+ if !ok {
+ atomic.AddUint64(&n.stack.stats.UnknownProtocolRcvdPackets, 1)
+ return
+ }
+
+ transProto := state.proto
+ if len(vv.First()) < transProto.MinimumPacketSize() {
+ atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ return
+ }
+
+ srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ if err != nil {
+ atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ return
+ }
+
+ id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
+ if n.demux.deliverPacket(r, protocol, vv, id) {
+ return
+ }
+ if n.stack.demux.deliverPacket(r, protocol, vv, id) {
+ return
+ }
+
+ // Try to deliver to per-stack default handler.
+ if state.defaultHandler != nil {
+ if state.defaultHandler(r, id, vv) {
+ return
+ }
+ }
+
+ // We could not find an appropriate destination for this packet, so
+ // deliver it to the global handler.
+ if !transProto.HandleUnknownDestinationPacket(r, id, vv) {
+ atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ }
+}
+
+// DeliverTransportControlPacket delivers control packets to the appropriate
+// transport protocol endpoint.
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView) {
+ state, ok := n.stack.transportProtocols[trans]
+ if !ok {
+ return
+ }
+
+ transProto := state.proto
+
+ // ICMPv4 only guarantees that 8 bytes of the transport protocol will
+ // be present in the payload. We know that the ports are within the
+ // first 8 bytes for all known transport protocols.
+ if len(vv.First()) < 8 {
+ return
+ }
+
+ srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ if err != nil {
+ return
+ }
+
+ id := TransportEndpointID{srcPort, local, dstPort, remote}
+ if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ return
+ }
+ if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ return
+ }
+}
+
+// ID returns the identifier of n.
+func (n *NIC) ID() tcpip.NICID {
+ return n.id
+}
+
+type referencedNetworkEndpoint struct {
+ ilist.Entry
+ refs int32
+ ep NetworkEndpoint
+ nic *NIC
+ protocol tcpip.NetworkProtocolNumber
+
+ // linkCache is set if link address resolution is enabled for this
+ // protocol. Set to nil otherwise.
+ linkCache LinkAddressCache
+
+ // holdsInsertRef is protected by the NIC's mutex. It indicates whether
+ // the reference count is biased by 1 due to the insertion of the
+ // endpoint. It is reset to false when RemoveAddress is called on the
+ // NIC.
+ holdsInsertRef bool
+}
+
+// decRef decrements the ref count and cleans up the endpoint once it reaches
+// zero.
+func (r *referencedNetworkEndpoint) decRef() {
+ if atomic.AddInt32(&r.refs, -1) == 0 {
+ r.nic.removeEndpoint(r)
+ }
+}
+
+// incRef increments the ref count. It must only be called when the caller is
+// known to be holding a reference to the endpoint, otherwise tryIncRef should
+// be used.
+func (r *referencedNetworkEndpoint) incRef() {
+ atomic.AddInt32(&r.refs, 1)
+}
+
+// tryIncRef attempts to increment the ref count from n to n+1, but only if n is
+// not zero. That is, it will increment the count if the endpoint is still
+// alive, and do nothing if it has already been clean up.
+func (r *referencedNetworkEndpoint) tryIncRef() bool {
+ for {
+ v := atomic.LoadInt32(&r.refs)
+ if v == 0 {
+ return false
+ }
+
+ if atomic.CompareAndSwapInt32(&r.refs, v, v+1) {
+ return true
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
new file mode 100644
index 000000000..e7e6381ac
--- /dev/null
+++ b/pkg/tcpip/stack/registration.go
@@ -0,0 +1,322 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// NetworkEndpointID is the identifier of a network layer protocol endpoint.
+// Currently the local address is sufficient because all supported protocols
+// (i.e., IPv4 and IPv6) have different sizes for their addresses.
+type NetworkEndpointID struct {
+ LocalAddress tcpip.Address
+}
+
+// TransportEndpointID is the identifier of a transport layer protocol endpoint.
+type TransportEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// ControlType is the type of network control message.
+type ControlType int
+
+// The following are the allowed values for ControlType values.
+const (
+ ControlPacketTooBig ControlType = iota
+ ControlPortUnreachable
+ ControlUnknown
+)
+
+// TransportEndpoint is the interface that needs to be implemented by transport
+// protocol (e.g., tcp, udp) endpoints that can handle packets.
+type TransportEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive to
+ // this transport endpoint.
+ HandlePacket(r *Route, id TransportEndpointID, vv *buffer.VectorisedView)
+
+ // HandleControlPacket is called by the stack when new control (e.g.,
+ // ICMP) packets arrive to this transport endpoint.
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv *buffer.VectorisedView)
+}
+
+// TransportProtocol is the interface that needs to be implemented by transport
+// protocols (e.g., tcp, udp) that want to be part of the networking stack.
+type TransportProtocol interface {
+ // Number returns the transport protocol number.
+ Number() tcpip.TransportProtocolNumber
+
+ // NewEndpoint creates a new endpoint of the transport protocol.
+ NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
+ // MinimumPacketSize returns the minimum valid packet size of this
+ // transport protocol. The stack automatically drops any packets smaller
+ // than this targeted at this protocol.
+ MinimumPacketSize() int
+
+ // ParsePorts returns the source and destination ports stored in a
+ // packet of this protocol.
+ ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
+
+ // HandleUnknownDestinationPacket handles packets targeted at this
+ // protocol but that don't match any existing endpoint. For example,
+ // it is targeted at a port that have no listeners.
+ //
+ // The return value indicates whether the packet was well-formed (for
+ // stats purposes only).
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv *buffer.VectorisedView) bool
+
+ // SetOption allows enabling/disabling protocol specific features.
+ // SetOption returns an error if the option is not supported or the
+ // provided option value is invalid.
+ SetOption(option interface{}) *tcpip.Error
+
+ // Option allows retrieving protocol specific option values.
+ // Option returns an error if the option is not supported or the
+ // provided option value is invalid.
+ Option(option interface{}) *tcpip.Error
+}
+
+// TransportDispatcher contains the methods used by the network stack to deliver
+// packets to the appropriate transport endpoint after it has been handled by
+// the network layer.
+type TransportDispatcher interface {
+ // DeliverTransportPacket delivers packets to the appropriate
+ // transport protocol endpoint.
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView)
+
+ // DeliverTransportControlPacket delivers control packets to the
+ // appropriate transport protocol endpoint.
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView)
+}
+
+// NetworkEndpoint is the interface that needs to be implemented by endpoints
+// of network layer protocols (e.g., ipv4, ipv6).
+type NetworkEndpoint interface {
+ // MTU is the maximum transmission unit for this endpoint. This is
+ // generally calculated as the MTU of the underlying data link endpoint
+ // minus the network endpoint max header length.
+ MTU() uint32
+
+ // Capabilities returns the set of capabilities supported by the
+ // underlying link-layer endpoint.
+ Capabilities() LinkEndpointCapabilities
+
+ // MaxHeaderLength returns the maximum size the network (and lower
+ // level layers combined) headers can have. Higher levels use this
+ // information to reserve space in the front of the packets they're
+ // building.
+ MaxHeaderLength() uint16
+
+ // WritePacket writes a packet to the given destination address and
+ // protocol.
+ WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error
+
+ // ID returns the network protocol endpoint ID.
+ ID() *NetworkEndpointID
+
+ // NICID returns the id of the NIC this endpoint belongs to.
+ NICID() tcpip.NICID
+
+ // HandlePacket is called by the link layer when new packets arrive to
+ // this network endpoint.
+ HandlePacket(r *Route, vv *buffer.VectorisedView)
+
+ // Close is called when the endpoint is reomved from a stack.
+ Close()
+}
+
+// NetworkProtocol is the interface that needs to be implemented by network
+// protocols (e.g., ipv4, ipv6) that want to be part of the networking stack.
+type NetworkProtocol interface {
+ // Number returns the network protocol number.
+ Number() tcpip.NetworkProtocolNumber
+
+ // MinimumPacketSize returns the minimum valid packet size of this
+ // network protocol. The stack automatically drops any packets smaller
+ // than this targeted at this protocol.
+ MinimumPacketSize() int
+
+ // ParsePorts returns the source and destination addresses stored in a
+ // packet of this protocol.
+ ParseAddresses(v buffer.View) (src, dst tcpip.Address)
+
+ // NewEndpoint creates a new endpoint of this protocol.
+ NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
+
+ // SetOption allows enabling/disabling protocol specific features.
+ // SetOption returns an error if the option is not supported or the
+ // provided option value is invalid.
+ SetOption(option interface{}) *tcpip.Error
+
+ // Option allows retrieving protocol specific option values.
+ // Option returns an error if the option is not supported or the
+ // provided option value is invalid.
+ Option(option interface{}) *tcpip.Error
+}
+
+// NetworkDispatcher contains the methods used by the network stack to deliver
+// packets to the appropriate network endpoint after it has been handled by
+// the data link layer.
+type NetworkDispatcher interface {
+ // DeliverNetworkPacket finds the appropriate network protocol
+ // endpoint and hands the packet over for further processing.
+ DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView)
+}
+
+// LinkEndpointCapabilities is the type associated with the capabilities
+// supported by a link-layer endpoint. It is a set of bitfields.
+type LinkEndpointCapabilities uint
+
+// The following are the supported link endpoint capabilities.
+const (
+ CapabilityChecksumOffload LinkEndpointCapabilities = 1 << iota
+ CapabilityResolutionRequired
+)
+
+// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
+// ethernet, loopback, raw) and used by network layer protocols to send packets
+// out through the implementer's data link endpoint.
+type LinkEndpoint interface {
+ // MTU is the maximum transmission unit for this endpoint. This is
+ // usually dictated by the backing physical network; when such a
+ // physical network doesn't exist, the limit is generally 64k, which
+ // includes the maximum size of an IP packet.
+ MTU() uint32
+
+ // Capabilities returns the set of capabilities supported by the
+ // endpoint.
+ Capabilities() LinkEndpointCapabilities
+
+ // MaxHeaderLength returns the maximum size the data link (and
+ // lower level layers combined) headers can have. Higher levels use this
+ // information to reserve space in the front of the packets they're
+ // building.
+ MaxHeaderLength() uint16
+
+ // LinkAddress returns the link address (typically a MAC) of the
+ // link endpoint.
+ LinkAddress() tcpip.LinkAddress
+
+ // WritePacket writes a packet with the given protocol through the given
+ // route.
+ WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
+
+ // Attach attaches the data link layer endpoint to the network-layer
+ // dispatcher of the stack.
+ Attach(dispatcher NetworkDispatcher)
+}
+
+// A LinkAddressResolver is an extension to a NetworkProtocol that
+// can resolve link addresses.
+type LinkAddressResolver interface {
+ // LinkAddressRequest sends a request for the LinkAddress of addr.
+ // The request is sent on linkEP with localAddr as the source.
+ //
+ // A valid response will cause the discovery protocol's network
+ // endpoint to call AddLinkAddress.
+ LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
+
+ // ResolveStaticAddress attempts to resolve address without sending
+ // requests. It either resolves the name immediately or returns the
+ // empty LinkAddress.
+ //
+ // It can be used to resolve broadcast addresses for example.
+ ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool)
+
+ // LinkAddressProtocol returns the network protocol of the
+ // addresses this this resolver can resolve.
+ LinkAddressProtocol() tcpip.NetworkProtocolNumber
+}
+
+// A LinkAddressCache caches link addresses.
+type LinkAddressCache interface {
+ // CheckLocalAddress determines if the given local address exists, and if it
+ // does not exist.
+ CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
+
+ // AddLinkAddress adds a link address to the cache.
+ AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
+
+ // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
+ // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
+ // registered with the network protocol, the cache attempts to resolve the address
+ // and returns ErrWouldBlock. Waker is notified when address resolution is
+ // complete (success or not).
+ GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error)
+
+ // RemoveWaker removes a waker that has been added in GetLinkAddress().
+ RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
+}
+
+// TransportProtocolFactory functions are used by the stack to instantiate
+// transport protocols.
+type TransportProtocolFactory func() TransportProtocol
+
+// NetworkProtocolFactory provides methods to be used by the stack to
+// instantiate network protocols.
+type NetworkProtocolFactory func() NetworkProtocol
+
+var (
+ transportProtocols = make(map[string]TransportProtocolFactory)
+ networkProtocols = make(map[string]NetworkProtocolFactory)
+
+ linkEPMu sync.RWMutex
+ nextLinkEndpointID tcpip.LinkEndpointID = 1
+ linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint)
+)
+
+// RegisterTransportProtocolFactory registers a new transport protocol factory
+// with the stack so that it becomes available to users of the stack. This
+// function is intended to be called by init() functions of the protocols.
+func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) {
+ transportProtocols[name] = p
+}
+
+// RegisterNetworkProtocolFactory registers a new network protocol factory with
+// the stack so that it becomes available to users of the stack. This function
+// is intended to be called by init() functions of the protocols.
+func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
+ networkProtocols[name] = p
+}
+
+// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an
+// ID that can be used to refer to it.
+func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID {
+ linkEPMu.Lock()
+ defer linkEPMu.Unlock()
+
+ v := nextLinkEndpointID
+ nextLinkEndpointID++
+
+ linkEndpoints[v] = linkEP
+
+ return v
+}
+
+// FindLinkEndpoint finds the link endpoint associated with the given ID.
+func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint {
+ linkEPMu.RLock()
+ defer linkEPMu.RUnlock()
+
+ return linkEndpoints[id]
+}
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
new file mode 100644
index 000000000..12f5efba5
--- /dev/null
+++ b/pkg/tcpip/stack/route.go
@@ -0,0 +1,133 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+// Route represents a route through the networking stack to a given destination.
+type Route struct {
+ // RemoteAddress is the final destination of the route.
+ RemoteAddress tcpip.Address
+
+ // RemoteLinkAddress is the link-layer (MAC) address of the
+ // final destination of the route.
+ RemoteLinkAddress tcpip.LinkAddress
+
+ // LocalAddress is the local address where the route starts.
+ LocalAddress tcpip.Address
+
+ // LocalLinkAddress is the link-layer (MAC) address of the
+ // where the route starts.
+ LocalLinkAddress tcpip.LinkAddress
+
+ // NextHop is the next node in the path to the destination.
+ NextHop tcpip.Address
+
+ // NetProto is the network-layer protocol.
+ NetProto tcpip.NetworkProtocolNumber
+
+ // ref a reference to the network endpoint through which the route
+ // starts.
+ ref *referencedNetworkEndpoint
+}
+
+// makeRoute initializes a new route. It takes ownership of the provided
+// reference to a network endpoint.
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, ref *referencedNetworkEndpoint) Route {
+ return Route{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ RemoteAddress: remoteAddr,
+ ref: ref,
+ }
+}
+
+// NICID returns the id of the NIC from which this route originates.
+func (r *Route) NICID() tcpip.NICID {
+ return r.ref.ep.NICID()
+}
+
+// MaxHeaderLength forwards the call to the network endpoint's implementation.
+func (r *Route) MaxHeaderLength() uint16 {
+ return r.ref.ep.MaxHeaderLength()
+}
+
+// PseudoHeaderChecksum forwards the call to the network endpoint's
+// implementation.
+func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber) uint16 {
+ return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress)
+}
+
+// Capabilities returns the link-layer capabilities of the route.
+func (r *Route) Capabilities() LinkEndpointCapabilities {
+ return r.ref.ep.Capabilities()
+}
+
+// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
+// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
+// notified when address resolution is complete (success or not).
+func (r *Route) Resolve(waker *sleep.Waker) *tcpip.Error {
+ if !r.IsResolutionRequired() {
+ // Nothing to do if there is no cache (which does the resolution on cache miss) or
+ // link address is already known.
+ return nil
+ }
+
+ nextAddr := r.NextHop
+ if nextAddr == "" {
+ nextAddr = r.RemoteAddress
+ }
+ linkAddr, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ if err != nil {
+ return err
+ }
+ r.RemoteLinkAddress = linkAddr
+ return nil
+}
+
+// RemoveWaker removes a waker that has been added in Resolve().
+func (r *Route) RemoveWaker(waker *sleep.Waker) {
+ nextAddr := r.NextHop
+ if nextAddr == "" {
+ nextAddr = r.RemoteAddress
+ }
+ r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker)
+}
+
+// IsResolutionRequired returns true if Resolve() must be called to resolve
+// the link address before the this route can be written to.
+func (r *Route) IsResolutionRequired() bool {
+ return r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+}
+
+// WritePacket writes the packet through the given route.
+func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+ return r.ref.ep.WritePacket(r, hdr, payload, protocol)
+}
+
+// MTU returns the MTU of the underlying network endpoint.
+func (r *Route) MTU() uint32 {
+ return r.ref.ep.MTU()
+}
+
+// Release frees all resources associated with the route.
+func (r *Route) Release() {
+ if r.ref != nil {
+ r.ref.decRef()
+ r.ref = nil
+ }
+}
+
+// Clone Clone a route such that the original one can be released and the new
+// one will remain valid.
+func (r *Route) Clone() Route {
+ r.ref.incRef()
+ return *r
+}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
new file mode 100644
index 000000000..558ecdb72
--- /dev/null
+++ b/pkg/tcpip/stack/stack.go
@@ -0,0 +1,811 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package stack provides the glue between networking protocols and the
+// consumers of the networking stack.
+//
+// For consumers, the only function of interest is New(), everything else is
+// provided by the tcpip/public package.
+//
+// For protocol implementers, RegisterTransportProtocolFactory() and
+// RegisterNetworkProtocolFactory() are used to register protocol factories with
+// the stack, which will then be used to instantiate protocol objects when
+// consumers interact with the stack.
+package stack
+
+import (
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/ports"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ageLimit is set to the same cache stale time used in Linux.
+ ageLimit = 1 * time.Minute
+ // resolutionTimeout is set to the same ARP timeout used in Linux.
+ resolutionTimeout = 1 * time.Second
+ // resolutionAttempts is set to the same ARP retries used in Linux.
+ resolutionAttempts = 3
+)
+
+type transportProtocolState struct {
+ proto TransportProtocol
+ defaultHandler func(*Route, TransportEndpointID, *buffer.VectorisedView) bool
+}
+
+// TCPProbeFunc is the expected function type for a TCP probe function to be
+// passed to stack.AddTCPProbe.
+type TCPProbeFunc func(s TCPEndpointState)
+
+// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
+type TCPEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
+// TCP endpoint.
+type TCPFastRecoveryState struct {
+ // Active if true indicates the endpoint is in fast recovery.
+ Active bool
+
+ // First is the first unacknowledged sequence number being recovered.
+ First seqnum.Value
+
+ // Last is the 'recover' sequence number that indicates the point at
+ // which we should exit recovery barring any timeouts etc.
+ Last seqnum.Value
+
+ // MaxCwnd is the maximum value we are permitted to grow the congestion
+ // window during recovery. This is set at the time we enter recovery.
+ MaxCwnd int
+}
+
+// TCPReceiverState holds a copy of the internal state of the receiver for
+// a given TCP endpoint.
+type TCPReceiverState struct {
+ // RcvNxt is the TCP variable RCV.NXT.
+ RcvNxt seqnum.Value
+
+ // RcvAcc is the TCP variable RCV.ACC.
+ RcvAcc seqnum.Value
+
+ // RcvWndScale is the window scaling to use for inbound segments.
+ RcvWndScale uint8
+
+ // PendingBufUsed is the number of bytes pending in the receive
+ // queue.
+ PendingBufUsed seqnum.Size
+
+ // PendingBufSize is the size of the socket receive buffer.
+ PendingBufSize seqnum.Size
+}
+
+// TCPSenderState holds a copy of the internal state of the sender for
+// a given TCP Endpoint.
+type TCPSenderState struct {
+ // LastSendTime is the time at which we sent the last segment.
+ LastSendTime time.Time
+
+ // DupAckCount is the number of Duplicate ACK's received.
+ DupAckCount int
+
+ // SndCwnd is the size of the sending congestion window in packets.
+ SndCwnd int
+
+ // Ssthresh is the slow start threshold in packets.
+ Ssthresh int
+
+ // SndCAAckCount is the number of packets consumed in congestion
+ // avoidance mode.
+ SndCAAckCount int
+
+ // Outstanding is the number of packets in flight.
+ Outstanding int
+
+ // SndWnd is the send window size in bytes.
+ SndWnd seqnum.Size
+
+ // SndUna is the next unacknowledged sequence number.
+ SndUna seqnum.Value
+
+ // SndNxt is the sequence number of the next segment to be sent.
+ SndNxt seqnum.Value
+
+ // RTTMeasureSeqNum is the sequence number being used for the latest RTT
+ // measurement.
+ RTTMeasureSeqNum seqnum.Value
+
+ // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
+ RTTMeasureTime time.Time
+
+ // Closed indicates that the caller has closed the endpoint for sending.
+ Closed bool
+
+ // SRTT is the smoothed round-trip time as defined in section 2 of
+ // RFC 6298.
+ SRTT time.Duration
+
+ // RTO is the retransmit timeout as defined in section of 2 of RFC 6298.
+ RTO time.Duration
+
+ // RTTVar is the round-trip time variation as defined in section 2 of
+ // RFC 6298.
+ RTTVar time.Duration
+
+ // SRTTInited if true indicates take a valid RTT measurement has been
+ // completed.
+ SRTTInited bool
+
+ // MaxPayloadSize is the maximum size of the payload of a given segment.
+ // It is initialized on demand.
+ MaxPayloadSize int
+
+ // SndWndScale is the number of bits to shift left when reading the send
+ // window size from a segment.
+ SndWndScale uint8
+
+ // MaxSentAck is the highest acknowledgemnt number sent till now.
+ MaxSentAck seqnum.Value
+
+ // FastRecovery holds the fast recovery state for the endpoint.
+ FastRecovery TCPFastRecoveryState
+}
+
+// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
+type TCPSACKInfo struct {
+ // Blocks is the list of SACK block currently received by the
+ // TCP endpoint.
+ Blocks []header.SACKBlock
+}
+
+// TCPEndpointState is a copy of the internal state of a TCP endpoint.
+type TCPEndpointState struct {
+ // ID is a copy of the TransportEndpointID for the endpoint.
+ ID TCPEndpointID
+
+ // SegTime denotes the absolute time when this segment was received.
+ SegTime time.Time
+
+ // RcvBufSize is the size of the receive socket buffer for the endpoint.
+ RcvBufSize int
+
+ // RcvBufUsed is the amount of bytes actually held in the receive socket
+ // buffer for the endpoint.
+ RcvBufUsed int
+
+ // RcvClosed if true, indicates the endpoint has been closed for reading.
+ RcvClosed bool
+
+ // SendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1.
+ SendTSOk bool
+
+ // RecentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field is
+ // updated if required when a new segment is received by this endpoint.
+ RecentTS uint32
+
+ // TSOffset is a randomized offset added to the value of the TSVal field
+ // in the timestamp option.
+ TSOffset uint32
+
+ // SACKPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ SACKPermitted bool
+
+ // SACK holds TCP SACK related information for this endpoint.
+ SACK TCPSACKInfo
+
+ // SndBufSize is the size of the socket send buffer.
+ SndBufSize int
+
+ // SndBufUsed is the number of bytes held in the socket send buffer.
+ SndBufUsed int
+
+ // SndClosed indicates that the endpoint has been closed for sends.
+ SndClosed bool
+
+ // SndBufInQueue is the number of bytes in the send queue.
+ SndBufInQueue seqnum.Size
+
+ // PacketTooBigCount is used to notify the main protocol routine how
+ // many times a "packet too big" control packet is received.
+ PacketTooBigCount int
+
+ // SndMTU is the smallest MTU seen in the control packets received.
+ SndMTU int
+
+ // Receiver holds variables related to the TCP receiver for the endpoint.
+ Receiver TCPReceiverState
+
+ // Sender holds state related to the TCP Sender for the endpoint.
+ Sender TCPSenderState
+}
+
+// Stack is a networking stack, with all supported protocols, NICs, and route
+// table.
+type Stack struct {
+ transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState
+ networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
+ linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+
+ demux *transportDemuxer
+
+ stats tcpip.Stats
+
+ linkAddrCache *linkAddrCache
+
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
+
+ // route is the route table passed in by the user via SetRouteTable(),
+ // it is used by FindRoute() to build a route for a specific
+ // destination.
+ routeTable []tcpip.Route
+
+ *ports.PortManager
+
+ // If not nil, then any new endpoints will have this probe function
+ // invoked everytime they receive a TCP segment.
+ tcpProbeFunc TCPProbeFunc
+}
+
+// New allocates a new networking stack with only the requested networking and
+// transport protocols configured with default options.
+//
+// Protocol options can be changed by calling the
+// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
+// stack. Please refer to individual protocol implementations as to what options
+// are supported.
+func New(network []string, transport []string) *Stack {
+ s := &Stack{
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
+ nics: make(map[tcpip.NICID]*NIC),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
+ PortManager: ports.NewPortManager(),
+ }
+
+ // Add specified network protocols.
+ for _, name := range network {
+ netProtoFactory, ok := networkProtocols[name]
+ if !ok {
+ continue
+ }
+ netProto := netProtoFactory()
+ s.networkProtocols[netProto.Number()] = netProto
+ if r, ok := netProto.(LinkAddressResolver); ok {
+ s.linkAddrResolvers[r.LinkAddressProtocol()] = r
+ }
+ }
+
+ // Add specified transport protocols.
+ for _, name := range transport {
+ transProtoFactory, ok := transportProtocols[name]
+ if !ok {
+ continue
+ }
+ transProto := transProtoFactory()
+ s.transportProtocols[transProto.Number()] = &transportProtocolState{
+ proto: transProto,
+ }
+ }
+
+ // Create the global transport demuxer.
+ s.demux = newTransportDemuxer(s)
+
+ return s
+}
+
+// SetNetworkProtocolOption allows configuring individual protocol level
+// options. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation or the provided value
+// is incorrect.
+func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+ netProto, ok := s.networkProtocols[network]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return netProto.SetOption(option)
+}
+
+// NetworkProtocolOption allows retrieving individual protocol level option
+// values. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation.
+// e.g.
+// var v ipv4.MyOption
+// err := s.NetworkProtocolOption(tcpip.IPv4ProtocolNumber, &v)
+// if err != nil {
+// ...
+// }
+func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+ netProto, ok := s.networkProtocols[network]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return netProto.Option(option)
+}
+
+// SetTransportProtocolOption allows configuring individual protocol level
+// options. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation or the provided value
+// is incorrect.
+func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+ transProtoState, ok := s.transportProtocols[transport]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return transProtoState.proto.SetOption(option)
+}
+
+// TransportProtocolOption allows retrieving individual protocol level option
+// values. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation.
+// var v tcp.SACKEnabled
+// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
+// ...
+// }
+func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+ transProtoState, ok := s.transportProtocols[transport]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return transProtoState.proto.Option(option)
+}
+
+// SetTransportProtocolHandler sets the per-stack default handler for the given
+// protocol.
+//
+// It must be called only during initialization of the stack. Changing it as the
+// stack is operating is not supported.
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *buffer.VectorisedView) bool) {
+ state := s.transportProtocols[p]
+ if state != nil {
+ state.defaultHandler = h
+ }
+}
+
+// Stats returns a snapshot of the current stats.
+//
+// NOTE: The underlying stats are updated using atomic instructions as a result
+// the snapshot returned does not represent the value of all the stats at any
+// single given point of time.
+// TODO: Make stats available in sentry for debugging/diag.
+func (s *Stack) Stats() tcpip.Stats {
+ return tcpip.Stats{
+ UnknownProtocolRcvdPackets: atomic.LoadUint64(&s.stats.UnknownProtocolRcvdPackets),
+ UnknownNetworkEndpointRcvdPackets: atomic.LoadUint64(&s.stats.UnknownNetworkEndpointRcvdPackets),
+ MalformedRcvdPackets: atomic.LoadUint64(&s.stats.MalformedRcvdPackets),
+ DroppedPackets: atomic.LoadUint64(&s.stats.DroppedPackets),
+ }
+}
+
+// MutableStats returns a mutable copy of the current stats.
+//
+// This is not generally exported via the public interface, but is available
+// internally.
+func (s *Stack) MutableStats() *tcpip.Stats {
+ return &s.stats
+}
+
+// SetRouteTable assigns the route table to be used by this stack. It
+// specifies which NIC to use for given destination address ranges.
+func (s *Stack) SetRouteTable(table []tcpip.Route) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.routeTable = table
+}
+
+// NewEndpoint creates a new transport layer endpoint of the given protocol.
+func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ t, ok := s.transportProtocols[transport]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ return t.proto.NewEndpoint(s, network, waiterQueue)
+}
+
+// createNIC creates a NIC with the provided id and link-layer endpoint, and
+// optionally enable it.
+func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error {
+ ep := FindLinkEndpoint(linkEP)
+ if ep == nil {
+ return tcpip.ErrBadLinkEndpoint
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Make sure id is unique.
+ if _, ok := s.nics[id]; ok {
+ return tcpip.ErrDuplicateNICID
+ }
+
+ n := newNIC(s, id, name, ep)
+
+ s.nics[id] = n
+ if enabled {
+ n.attachLinkEndpoint()
+ }
+
+ return nil
+}
+
+// CreateNIC creates a NIC with the provided id and link-layer endpoint.
+func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, "", linkEP, true)
+}
+
+// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
+// and a human-readable name.
+func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, name, linkEP, true)
+}
+
+// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
+// but leave it disable. Stack.EnableNIC must be called before the link-layer
+// endpoint starts delivering packets to it.
+func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, "", linkEP, false)
+}
+
+// EnableNIC enables the given NIC so that the link-layer endpoint can start
+// delivering packets to it.
+func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.attachLinkEndpoint()
+
+ return nil
+}
+
+// NICSubnets returns a map of NICIDs to their associated subnets.
+func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := map[tcpip.NICID][]tcpip.Subnet{}
+
+ for id, nic := range s.nics {
+ nics[id] = append(nics[id], nic.Subnets()...)
+ }
+ return nics
+}
+
+// NICInfo captures the name and addresses assigned to a NIC.
+type NICInfo struct {
+ Name string
+ LinkAddress tcpip.LinkAddress
+ ProtocolAddresses []tcpip.ProtocolAddress
+}
+
+// NICInfo returns a map of NICIDs to their associated information.
+func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := make(map[tcpip.NICID]NICInfo)
+ for id, nic := range s.nics {
+ nics[id] = NICInfo{
+ Name: nic.name,
+ LinkAddress: nic.linkEP.LinkAddress(),
+ ProtocolAddresses: nic.Addresses(),
+ }
+ }
+ return nics
+}
+
+// AddAddress adds a new network-layer address to the specified NIC.
+func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.AddAddress(protocol, addr)
+}
+
+// AddSubnet adds a subnet range to the specified NIC.
+func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.AddSubnet(protocol, subnet)
+ return nil
+}
+
+// RemoveAddress removes an existing network-layer address from the specified
+// NIC.
+func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.RemoveAddress(addr)
+}
+
+// FindRoute creates a route to the given destination address, leaving through
+// the given nic and local address (if provided).
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ for i := range s.routeTable {
+ if (id != 0 && id != s.routeTable[i].NIC) || (len(remoteAddr) != 0 && !s.routeTable[i].Match(remoteAddr)) {
+ continue
+ }
+
+ nic := s.nics[s.routeTable[i].NIC]
+ if nic == nil {
+ continue
+ }
+
+ var ref *referencedNetworkEndpoint
+ if len(localAddr) != 0 {
+ ref = nic.findEndpoint(netProto, localAddr)
+ } else {
+ ref = nic.primaryEndpoint(netProto)
+ }
+ if ref == nil {
+ continue
+ }
+
+ if len(remoteAddr) == 0 {
+ // If no remote address was provided, then the route
+ // provided will refer to the link local address.
+ remoteAddr = ref.ep.ID().LocalAddress
+ }
+
+ r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, ref)
+ r.NextHop = s.routeTable[i].Gateway
+ return r, nil
+ }
+
+ return Route{}, tcpip.ErrNoRoute
+}
+
+// CheckNetworkProtocol checks if a given network protocol is enabled in the
+// stack.
+func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool {
+ _, ok := s.networkProtocols[protocol]
+ return ok
+}
+
+// CheckLocalAddress determines if the given local address exists, and if it
+// does, returns the id of the NIC it's bound to. Returns 0 if the address
+// does not exist.
+func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ // If a NIC is specified, we try to find the address there only.
+ if nicid != 0 {
+ nic := s.nics[nicid]
+ if nic == nil {
+ return 0
+ }
+
+ ref := nic.findEndpoint(protocol, addr)
+ if ref == nil {
+ return 0
+ }
+
+ ref.decRef()
+
+ return nic.id
+ }
+
+ // Go through all the NICs.
+ for _, nic := range s.nics {
+ ref := nic.findEndpoint(protocol, addr)
+ if ref != nil {
+ ref.decRef()
+ return nic.id
+ }
+ }
+
+ return 0
+}
+
+// SetPromiscuousMode enables or disables promiscuous mode in the given NIC.
+func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setPromiscuousMode(enable)
+
+ return nil
+}
+
+// SetSpoofing enables or disables address spoofing in the given NIC, allowing
+// endpoints to bind to any address in the NIC.
+func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setSpoofing(enable)
+
+ return nil
+}
+
+// AddLinkAddress adds a link address to the stack link cache.
+func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ s.linkAddrCache.add(fullAddr, linkAddr)
+ // TODO: provide a way for a
+ // transport endpoint to receive a signal that AddLinkAddress
+ // for a particular address has been called.
+}
+
+// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
+func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) {
+ s.mu.RLock()
+ nic := s.nics[nicid]
+ if nic == nil {
+ s.mu.RUnlock()
+ return "", tcpip.ErrUnknownNICID
+ }
+ s.mu.RUnlock()
+
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ linkRes := s.linkAddrResolvers[protocol]
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
+}
+
+// RemoveWaker implements LinkAddressCache.RemoveWaker.
+func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic := s.nics[nicid]; nic == nil {
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ s.linkAddrCache.removeWaker(fullAddr, waker)
+ }
+}
+
+// RegisterTransportEndpoint registers the given endpoint with the stack
+// transport dispatcher. Received packets that match the provided id will be
+// delivered to the given endpoint; specifying a nic is optional, but
+// nic-specific IDs have precedence over global ones.
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+ if nicID == 0 {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep)
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.demux.registerEndpoint(netProtos, protocol, id, ep)
+}
+
+// UnregisterTransportEndpoint removes the endpoint with the given id from the
+// stack transport dispatcher.
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
+ if nicID == 0 {
+ s.demux.unregisterEndpoint(netProtos, protocol, id)
+ return
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic != nil {
+ nic.demux.unregisterEndpoint(netProtos, protocol, id)
+ }
+}
+
+// NetworkProtocolInstance returns the protocol instance in the stack for the
+// specified network protocol. This method is public for protocol implementers
+// and tests to use.
+func (s *Stack) NetworkProtocolInstance(num tcpip.NetworkProtocolNumber) NetworkProtocol {
+ if p, ok := s.networkProtocols[num]; ok {
+ return p
+ }
+ return nil
+}
+
+// TransportProtocolInstance returns the protocol instance in the stack for the
+// specified transport protocol. This method is public for protocol implementers
+// and tests to use.
+func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) TransportProtocol {
+ if pState, ok := s.transportProtocols[num]; ok {
+ return pState.proto
+ }
+ return nil
+}
+
+// AddTCPProbe installs a probe function that will be invoked on every segment
+// received by a given TCP endpoint. The probe function is passed a copy of the
+// TCP endpoint state.
+//
+// NOTE: TCPProbe is added only to endpoints created after this call. Endpoints
+// created prior to this call will not call the probe function.
+//
+// Further, installing two different probes back to back can result in some
+// endpoints calling the first one and some the second one. There is no
+// guarantee provided on which probe will be invoked. Ideally this should only
+// be called once per stack.
+func (s *Stack) AddTCPProbe(probe TCPProbeFunc) {
+ s.mu.Lock()
+ s.tcpProbeFunc = probe
+ s.mu.Unlock()
+}
+
+// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil
+// otherwise.
+func (s *Stack) GetTCPProbe() TCPProbeFunc {
+ s.mu.Lock()
+ p := s.tcpProbeFunc
+ s.mu.Unlock()
+ return p
+}
+
+// RemoveTCPProbe removes an installed TCP probe.
+//
+// NOTE: This only ensures that endpoints created after this call do not
+// have a probe attached. Endpoints already created will continue to invoke
+// TCP probe.
+func (s *Stack) RemoveTCPProbe() {
+ s.mu.Lock()
+ s.tcpProbeFunc = nil
+ s.mu.Unlock()
+}
diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go
new file mode 100644
index 000000000..030ae98d1
--- /dev/null
+++ b/pkg/tcpip/stack/stack_global_state.go
@@ -0,0 +1,9 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+// StackFromEnv is the global stack created in restore run.
+// FIXME: remove this variable once tcpip S/R is fully supported.
+var StackFromEnv *Stack
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
new file mode 100644
index 000000000..b416065d7
--- /dev/null
+++ b/pkg/tcpip/stack/stack_test.go
@@ -0,0 +1,760 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package stack_test contains tests for the stack. It is in its own package so
+// that the tests can also validate that all definitions needed to implement
+// transport and network protocols are properly exported by the stack package.
+package stack_test
+
+import (
+ "math"
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+ fakeNetHeaderLen = 12
+
+ // fakeControlProtocol is used for control packets that represent
+ // destination port unreachable.
+ fakeControlProtocol tcpip.TransportProtocolNumber = 2
+
+ // defaultMTU is the MTU, in bytes, used throughout the tests, except
+ // where another value is explicitly used. It is chosen to match the MTU
+ // of loopback interfaces on linux systems.
+ defaultMTU = 65536
+)
+
+// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
+// received packets; the counts of all endpoints are aggregated in the protocol
+// descriptor.
+//
+// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only
+// use the first three: destination address, source address, and transport
+// protocol. They're all one byte fields to simplify parsing.
+type fakeNetworkEndpoint struct {
+ nicid tcpip.NICID
+ id stack.NetworkEndpointID
+ proto *fakeNetworkProtocol
+ dispatcher stack.TransportDispatcher
+ linkEP stack.LinkEndpoint
+}
+
+func (f *fakeNetworkEndpoint) MTU() uint32 {
+ return f.linkEP.MTU() - uint32(f.MaxHeaderLength())
+}
+
+func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
+ return f.nicid
+}
+
+func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
+ return &f.id
+}
+
+func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) {
+ // Increment the received packet count in the protocol descriptor.
+ f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
+
+ // Consume the network header.
+ b := vv.First()
+ vv.TrimFront(fakeNetHeaderLen)
+
+ // Handle control packets.
+ if b[2] == uint8(fakeControlProtocol) {
+ nb := vv.First()
+ if len(nb) < fakeNetHeaderLen {
+ return
+ }
+
+ vv.TrimFront(fakeNetHeaderLen)
+ f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv)
+ return
+ }
+
+ // Dispatch the packet to the transport protocol.
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), vv)
+}
+
+func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
+ return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen
+}
+
+func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
+ return 0
+}
+
+func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return f.linkEP.Capabilities()
+}
+
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+ // Increment the sent packet count in the protocol descriptor.
+ f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
+
+ // Add the protocol's header to the packet and send it to the link
+ // endpoint.
+ b := hdr.Prepend(fakeNetHeaderLen)
+ b[0] = r.RemoteAddress[0]
+ b[1] = f.id.LocalAddress[0]
+ b[2] = byte(protocol)
+ return f.linkEP.WritePacket(r, hdr, payload, fakeNetNumber)
+}
+
+func (*fakeNetworkEndpoint) Close() {}
+
+type fakeNetGoodOption bool
+
+type fakeNetBadOption bool
+
+type fakeNetInvalidValueOption int
+
+type fakeNetOptions struct {
+ good bool
+}
+
+// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
+// number of packets sent and received via endpoints of this protocol. The index
+// where packets are added is given by the packet's destination address MOD 10.
+type fakeNetworkProtocol struct {
+ packetCount [10]int
+ sendPacketCount [10]int
+ opts fakeNetOptions
+}
+
+func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
+ return fakeNetNumber
+}
+
+func (f *fakeNetworkProtocol) MinimumPacketSize() int {
+ return fakeNetHeaderLen
+}
+
+func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
+}
+
+func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ return &fakeNetworkEndpoint{
+ nicid: nicid,
+ id: stack.NetworkEndpointID{addr},
+ proto: f,
+ dispatcher: dispatcher,
+ linkEP: linkEP,
+ }, nil
+}
+
+func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case fakeNetGoodOption:
+ f.opts.good = bool(v)
+ return nil
+ case fakeNetInvalidValueOption:
+ return tcpip.ErrInvalidOptionValue
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *fakeNetGoodOption:
+ *v = fakeNetGoodOption(f.opts.good)
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func TestNetworkReceive(t *testing.T) {
+ // Create a stack with the fake network protocol, one nic, and two
+ // addresses attached to it: 1 & 2.
+ id, linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New([]string{"fakeNet"}, nil)
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ var views [1]buffer.View
+ // Allocate the buffer containing the packet that will be injected into
+ // the stack.
+ buf := buffer.NewView(30)
+
+ // Make sure packet with wrong address is not delivered.
+ buf[0] = 3
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 0 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
+ }
+ if fakeNet.packetCount[2] != 0 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
+ }
+
+ // Make sure packet is delivered to first endpoint.
+ buf[0] = 1
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+ if fakeNet.packetCount[2] != 0 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
+ }
+
+ // Make sure packet is delivered to second endpoint.
+ buf[0] = 2
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+ if fakeNet.packetCount[2] != 1 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
+ }
+
+ // Make sure packet is not delivered if protocol number is wrong.
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber-1, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+ if fakeNet.packetCount[2] != 1 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
+ }
+
+ // Make sure packet that is too small is dropped.
+ buf.CapLength(2)
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+ if fakeNet.packetCount[2] != 1 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
+ }
+}
+
+func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address) {
+ r, err := s.FindRoute(0, "", addr, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("FindRoute failed: %v", err)
+ }
+ defer r.Release()
+
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
+ err = r.WritePacket(&hdr, nil, fakeTransNumber)
+ if err != nil {
+ t.Errorf("WritePacket failed: %v", err)
+ return
+ }
+}
+
+func TestNetworkSend(t *testing.T) {
+ // Create a stack with the fake network protocol, one nic, and one
+ // address: 1. The route table sends all packets through the only
+ // existing nic.
+ id, linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New([]string{"fakeNet"}, nil)
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("NewNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ // Make sure that the link-layer endpoint received the outbound packet.
+ sendTo(t, s, "\x03")
+ if c := linkEP.Drain(); c != 1 {
+ t.Errorf("packetCount = %d, want %d", c, 1)
+ }
+}
+
+func TestNetworkSendMultiRoute(t *testing.T) {
+ // Create a stack with the fake network protocol, two nics, and two
+ // addresses per nic, the first nic has odd address, the second one has
+ // even addresses.
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id1, linkEP1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id1); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ id2, linkEP2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, id2); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ s.SetRouteTable([]tcpip.Route{
+ {"\x01", "\x01", "\x00", 1},
+ {"\x00", "\x01", "\x00", 2},
+ })
+
+ // Send a packet to an odd destination.
+ sendTo(t, s, "\x05")
+
+ if c := linkEP1.Drain(); c != 1 {
+ t.Errorf("packetCount = %d, want %d", c, 1)
+ }
+
+ // Send a packet to an even destination.
+ sendTo(t, s, "\x06")
+
+ if c := linkEP2.Drain(); c != 1 {
+ t.Errorf("packetCount = %d, want %d", c, 1)
+ }
+}
+
+func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
+ r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("FindRoute failed: %v", err)
+ }
+
+ defer r.Release()
+
+ if r.LocalAddress != expectedSrcAddr {
+ t.Fatalf("Bad source address: expected %v, got %v", expectedSrcAddr, r.LocalAddress)
+ }
+
+ if r.RemoteAddress != dstAddr {
+ t.Fatalf("Bad destination address: expected %v, got %v", dstAddr, r.RemoteAddress)
+ }
+}
+
+func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) {
+ _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber)
+ if err != tcpip.ErrNoRoute {
+ t.Fatalf("FindRoute returned unexpected error, expected tcpip.ErrNoRoute, got %v", err)
+ }
+}
+
+func TestRoutes(t *testing.T) {
+ // Create a stack with the fake network protocol, two nics, and two
+ // addresses per nic, the first nic has odd address, the second one has
+ // even addresses.
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id1, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id1); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ id2, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, id2); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ s.SetRouteTable([]tcpip.Route{
+ {"\x01", "\x01", "\x00", 1},
+ {"\x00", "\x01", "\x00", 2},
+ })
+
+ // Test routes to odd address.
+ testRoute(t, s, 0, "", "\x05", "\x01")
+ testRoute(t, s, 0, "\x01", "\x05", "\x01")
+ testRoute(t, s, 1, "\x01", "\x05", "\x01")
+ testRoute(t, s, 0, "\x03", "\x05", "\x03")
+ testRoute(t, s, 1, "\x03", "\x05", "\x03")
+
+ // Test routes to even address.
+ testRoute(t, s, 0, "", "\x06", "\x02")
+ testRoute(t, s, 0, "\x02", "\x06", "\x02")
+ testRoute(t, s, 2, "\x02", "\x06", "\x02")
+ testRoute(t, s, 0, "\x04", "\x06", "\x04")
+ testRoute(t, s, 2, "\x04", "\x06", "\x04")
+
+ // Try to send to odd numbered address from even numbered ones, then
+ // vice-versa.
+ testNoRoute(t, s, 0, "\x02", "\x05")
+ testNoRoute(t, s, 2, "\x02", "\x05")
+ testNoRoute(t, s, 0, "\x04", "\x05")
+ testNoRoute(t, s, 2, "\x04", "\x05")
+
+ testNoRoute(t, s, 0, "\x01", "\x06")
+ testNoRoute(t, s, 1, "\x01", "\x06")
+ testNoRoute(t, s, 0, "\x03", "\x06")
+ testNoRoute(t, s, 1, "\x03", "\x06")
+}
+
+func TestAddressRemoval(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ var views [1]buffer.View
+ buf := buffer.NewView(30)
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ // Write a packet, and check that it gets delivered.
+ fakeNet.packetCount[1] = 0
+ buf[0] = 1
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+
+ // Remove the address, then check that packet doesn't get delivered
+ // anymore.
+ if err := s.RemoveAddress(1, "\x01"); err != nil {
+ t.Fatalf("RemoveAddress failed: %v", err)
+ }
+
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+
+ // Check that removing the same address fails.
+ if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ t.Fatalf("RemoveAddress failed: %v", err)
+ }
+}
+
+func TestDelayedRemovalDueToRoute(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {"\x00", "\x00", "\x00", 1},
+ })
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ var views [1]buffer.View
+ buf := buffer.NewView(30)
+
+ // Write a packet, and check that it gets delivered.
+ fakeNet.packetCount[1] = 0
+ buf[0] = 1
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+
+ // Get a route, check that packet is still deliverable.
+ r, err := s.FindRoute(0, "", "\x02", fakeNetNumber)
+ if err != nil {
+ t.Fatalf("FindRoute failed: %v", err)
+ }
+
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 2 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2)
+ }
+
+ // Remove the address, then check that packet is still deliverable
+ // because the route is keeping the address alive.
+ if err := s.RemoveAddress(1, "\x01"); err != nil {
+ t.Fatalf("RemoveAddress failed: %v", err)
+ }
+
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 3 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
+ }
+
+ // Check that removing the same address fails.
+ if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ t.Fatalf("RemoveAddress failed: %v", err)
+ }
+
+ // Release the route, then check that packet is not deliverable anymore.
+ r.Release()
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 3 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
+ }
+}
+
+func TestPromiscuousMode(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {"\x00", "\x00", "\x00", 1},
+ })
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ var views [1]buffer.View
+ buf := buffer.NewView(30)
+
+ // Write a packet, and check that it doesn't get delivered as we don't
+ // have a matching endpoint.
+ fakeNet.packetCount[1] = 0
+ buf[0] = 1
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 0 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
+ }
+
+ // Set promiscuous mode, then check that packet is delivered.
+ if err := s.SetPromiscuousMode(1, true); err != nil {
+ t.Fatalf("SetPromiscuousMode failed: %v", err)
+ }
+
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+
+ // Check that we can't get a route as there is no local address.
+ _, err := s.FindRoute(0, "", "\x02", fakeNetNumber)
+ if err != tcpip.ErrNoRoute {
+ t.Fatalf("FindRoute returned unexpected status: expected %v, got %v", tcpip.ErrNoRoute, err)
+ }
+
+ // Set promiscuous mode to false, then check that packet can't be
+ // delivered anymore.
+ if err := s.SetPromiscuousMode(1, false); err != nil {
+ t.Fatalf("SetPromiscuousMode failed: %v", err)
+ }
+
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+}
+
+func TestAddressSpoofing(t *testing.T) {
+ srcAddr := tcpip.Address("\x01")
+ dstAddr := tcpip.Address("\x02")
+
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {"\x00", "\x00", "\x00", 1},
+ })
+
+ // With address spoofing disabled, FindRoute does not permit an address
+ // that was not added to the NIC to be used as the source.
+ r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber)
+ if err == nil {
+ t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
+ }
+
+ // With address spoofing enabled, FindRoute permits any address to be used
+ // as the source.
+ if err := s.SetSpoofing(1, true); err != nil {
+ t.Fatalf("SetSpoofing failed: %v", err)
+ }
+ r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("FindRoute failed: %v", err)
+ }
+ if r.LocalAddress != srcAddr {
+ t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ }
+}
+
+// Set the subnet, then check that packet is delivered.
+func TestSubnetAcceptsMatchingPacket(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {"\x00", "\x00", "\x00", 1},
+ })
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ var views [1]buffer.View
+ buf := buffer.NewView(30)
+ buf[0] = 1
+ fakeNet.packetCount[1] = 0
+ subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
+ if err != nil {
+ t.Fatalf("NewSubnet failed: %v", err)
+ }
+ if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
+ t.Fatalf("AddSubnet failed: %v", err)
+ }
+
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+}
+
+// Set destination outside the subnet, then check it doesn't get delivered.
+func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {"\x00", "\x00", "\x00", 1},
+ })
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ var views [1]buffer.View
+ buf := buffer.NewView(30)
+ buf[0] = 1
+ fakeNet.packetCount[1] = 0
+ subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
+ if err != nil {
+ t.Fatalf("NewSubnet failed: %v", err)
+ }
+ if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
+ t.Fatalf("AddSubnet failed: %v", err)
+ }
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 0 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
+ }
+}
+
+func TestNetworkOptions(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, []string{})
+
+ // Try an unsupported network protocol.
+ if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
+ t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
+ }
+
+ testCases := []struct {
+ option interface{}
+ wantErr *tcpip.Error
+ verifier func(t *testing.T, p stack.NetworkProtocol)
+ }{
+ {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) {
+ t.Helper()
+ fakeNet := p.(*fakeNetworkProtocol)
+ if fakeNet.opts.good != true {
+ t.Fatalf("fakeNet.opts.good = false, want = true")
+ }
+ var v fakeNetGoodOption
+ if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil {
+ t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err)
+ }
+ if v != true {
+ t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v)
+ }
+ }},
+ {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
+ {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
+ }
+ for _, tc := range testCases {
+ if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr {
+ t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr)
+ }
+ if tc.verifier != nil {
+ tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber))
+ }
+ }
+}
+
+func init() {
+ stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol {
+ return &fakeNetworkProtocol{}
+ })
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
new file mode 100644
index 000000000..3c0d7aa31
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -0,0 +1,166 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+type protocolIDs struct {
+ network tcpip.NetworkProtocolNumber
+ transport tcpip.TransportProtocolNumber
+}
+
+// transportEndpoints manages all endpoints of a given protocol. It has its own
+// mutex so as to reduce interference between protocols.
+type transportEndpoints struct {
+ mu sync.RWMutex
+ endpoints map[TransportEndpointID]TransportEndpoint
+}
+
+// transportDemuxer demultiplexes packets targeted at a transport endpoint
+// (i.e., after they've been parsed by the network layer). It does two levels
+// of demultiplexing: first based on the network and transport protocols, then
+// based on endpoints IDs.
+type transportDemuxer struct {
+ protocol map[protocolIDs]*transportEndpoints
+}
+
+func newTransportDemuxer(stack *Stack) *transportDemuxer {
+ d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
+
+ // Add each network and transport pair to the demuxer.
+ for netProto := range stack.networkProtocols {
+ for proto := range stack.transportProtocols {
+ d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)}
+ }
+ }
+
+ return d
+}
+
+// registerEndpoint registers the given endpoint with the dispatcher such that
+// packets that match the endpoint ID are delivered to it.
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+ for i, n := range netProtos {
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id)
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return nil
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+
+ if _, ok := eps.endpoints[id]; ok {
+ return tcpip.ErrPortInUse
+ }
+
+ eps.endpoints[id] = ep
+
+ return nil
+}
+
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
+ for _, n := range netProtos {
+ if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
+ eps.mu.Lock()
+ delete(eps.endpoints, id)
+ eps.mu.Unlock()
+ }
+ }
+}
+
+// deliverPacket attempts to deliver the given packet. Returns true if it found
+// an endpoint, false otherwise.
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+ if !ok {
+ return false
+ }
+
+ eps.mu.RLock()
+ ep := d.findEndpointLocked(eps, vv, id)
+ eps.mu.RUnlock()
+
+ // Fail if we didn't find one.
+ if ep == nil {
+ return false
+ }
+
+ // Deliver the packet.
+ ep.HandlePacket(r, id, vv)
+
+ return true
+}
+
+// deliverControlPacket attempts to deliver the given control packet. Returns
+// true if it found an endpoint, false otherwise.
+func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{net, trans}]
+ if !ok {
+ return false
+ }
+
+ // Try to find the endpoint.
+ eps.mu.RLock()
+ ep := d.findEndpointLocked(eps, vv, id)
+ eps.mu.RUnlock()
+
+ // Fail if we didn't find one.
+ if ep == nil {
+ return false
+ }
+
+ // Deliver the packet.
+ ep.HandleControlPacket(id, typ, extra, vv)
+
+ return true
+}
+
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv *buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
+ // Try to find a match with the id as provided.
+ if ep := eps.endpoints[id]; ep != nil {
+ return ep
+ }
+
+ // Try to find a match with the id minus the local address.
+ nid := id
+
+ nid.LocalAddress = ""
+ if ep := eps.endpoints[nid]; ep != nil {
+ return ep
+ }
+
+ // Try to find a match with the id minus the remote part.
+ nid.LocalAddress = id.LocalAddress
+ nid.RemoteAddress = ""
+ nid.RemotePort = 0
+ if ep := eps.endpoints[nid]; ep != nil {
+ return ep
+ }
+
+ // Try to find a match with only the local port.
+ nid.LocalAddress = ""
+ if ep := eps.endpoints[nid]; ep != nil {
+ return ep
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
new file mode 100644
index 000000000..7e072e96e
--- /dev/null
+++ b/pkg/tcpip/stack/transport_test.go
@@ -0,0 +1,420 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack_test
+
+import (
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ fakeTransNumber tcpip.TransportProtocolNumber = 1
+ fakeTransHeaderLen = 3
+)
+
+// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts
+// received packets; the counts of all endpoints are aggregated in the protocol
+// descriptor.
+//
+// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't
+// use it.
+type fakeTransportEndpoint struct {
+ id stack.TransportEndpointID
+ stack *stack.Stack
+ netProto tcpip.NetworkProtocolNumber
+ proto *fakeTransportProtocol
+ peerAddr tcpip.Address
+ route stack.Route
+}
+
+func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
+ return &fakeTransportEndpoint{stack: stack, netProto: netProto, proto: proto}
+}
+
+func (f *fakeTransportEndpoint) Close() {
+ f.route.Release()
+}
+
+func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return mask
+}
+
+func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, *tcpip.Error) {
+ return buffer.View{}, nil
+}
+
+func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) {
+ if len(f.route.RemoteAddress) == 0 {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
+ v, err := p.Get(p.Size())
+ if err != nil {
+ return 0, err
+ }
+ if err := f.route.WritePacket(&hdr, v, fakeTransNumber); err != nil {
+ return 0, err
+ }
+
+ return uintptr(len(v)), nil
+}
+
+func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, *tcpip.Error) {
+ return 0, nil
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+ }
+ return tcpip.ErrInvalidEndpointState
+}
+
+func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ f.peerAddr = addr.Addr
+
+ // Find the route.
+ r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber)
+ if err != nil {
+ return tcpip.ErrNoRoute
+ }
+ defer r.Release()
+
+ // Try to register so that we can start receiving packets.
+ f.id.RemoteAddress = addr.Addr
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f)
+ if err != nil {
+ return err
+ }
+
+ f.route = r.Clone()
+
+ return nil
+}
+
+func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
+ return nil
+}
+
+func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error {
+ return nil
+}
+
+func (*fakeTransportEndpoint) Reset() {
+}
+
+func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
+ return nil
+}
+
+func (*fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, nil
+}
+
+func (*fakeTransportEndpoint) Bind(_ tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ return commit()
+}
+
+func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, nil
+}
+
+func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, nil
+}
+
+func (f *fakeTransportEndpoint) HandlePacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) {
+ // Increment the number of received packets.
+ f.proto.packetCount++
+}
+
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *buffer.VectorisedView) {
+ // Increment the number of received control packets.
+ f.proto.controlCount++
+}
+
+type fakeTransportGoodOption bool
+
+type fakeTransportBadOption bool
+
+type fakeTransportInvalidValueOption int
+
+type fakeTransportProtocolOptions struct {
+ good bool
+}
+
+// fakeTransportProtocol is a transport-layer protocol descriptor. It
+// aggregates the number of packets received via endpoints of this protocol.
+type fakeTransportProtocol struct {
+ packetCount int
+ controlCount int
+ opts fakeTransportProtocolOptions
+}
+
+func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
+ return fakeTransNumber
+}
+
+func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newFakeTransportEndpoint(stack, f, netProto), nil
+}
+
+func (*fakeTransportProtocol) MinimumPacketSize() int {
+ return fakeTransHeaderLen
+}
+
+func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) {
+ return 0, 0, nil
+}
+
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
+ return true
+}
+
+func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case fakeTransportGoodOption:
+ f.opts.good = bool(v)
+ return nil
+ case fakeTransportInvalidValueOption:
+ return tcpip.ErrInvalidOptionValue
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *fakeTransportGoodOption:
+ *v = fakeTransportGoodOption(f.opts.good)
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func TestTransportReceive(t *testing.T) {
+ id, linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"})
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ // Create endpoint and connect to remote address.
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
+ t.Fatalf("Connect failed: %v", err)
+ }
+
+ fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
+
+ var views [1]buffer.View
+ // Create buffer that will hold the packet.
+ buf := buffer.NewView(30)
+
+ // Make sure packet with wrong protocol is not delivered.
+ buf[0] = 1
+ buf[2] = 0
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeTrans.packetCount != 0 {
+ t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
+ }
+
+ // Make sure packet from the wrong source is not delivered.
+ buf[0] = 1
+ buf[1] = 3
+ buf[2] = byte(fakeTransNumber)
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeTrans.packetCount != 0 {
+ t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
+ }
+
+ // Make sure packet is delivered.
+ buf[0] = 1
+ buf[1] = 2
+ buf[2] = byte(fakeTransNumber)
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeTrans.packetCount != 1 {
+ t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1)
+ }
+}
+
+func TestTransportControlReceive(t *testing.T) {
+ id, linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"})
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ // Create endpoint and connect to remote address.
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
+ t.Fatalf("Connect failed: %v", err)
+ }
+
+ fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
+
+ var views [1]buffer.View
+ // Create buffer that will hold the control packet.
+ buf := buffer.NewView(2*fakeNetHeaderLen + 30)
+
+ // Outer packet contains the control protocol number.
+ buf[0] = 1
+ buf[1] = 0xfe
+ buf[2] = uint8(fakeControlProtocol)
+
+ // Make sure packet with wrong protocol is not delivered.
+ buf[fakeNetHeaderLen+0] = 0
+ buf[fakeNetHeaderLen+1] = 1
+ buf[fakeNetHeaderLen+2] = 0
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeTrans.controlCount != 0 {
+ t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
+ }
+
+ // Make sure packet from the wrong source is not delivered.
+ buf[fakeNetHeaderLen+0] = 3
+ buf[fakeNetHeaderLen+1] = 1
+ buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeTrans.controlCount != 0 {
+ t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
+ }
+
+ // Make sure packet is delivered.
+ buf[fakeNetHeaderLen+0] = 2
+ buf[fakeNetHeaderLen+1] = 1
+ buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeTrans.controlCount != 1 {
+ t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1)
+ }
+}
+
+func TestTransportSend(t *testing.T) {
+ id, _ := channel.New(10, defaultMTU, "")
+ s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"})
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+
+ // Create endpoint and bind it.
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
+ t.Fatalf("Connect failed: %v", err)
+ }
+
+ // Create buffer that will hold the payload.
+ view := buffer.NewView(30)
+ _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ if fakeNet.sendPacketCount[2] != 1 {
+ t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1)
+ }
+}
+
+func TestTransportOptions(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"})
+
+ // Try an unsupported transport protocol.
+ if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
+ t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
+ }
+
+ testCases := []struct {
+ option interface{}
+ wantErr *tcpip.Error
+ verifier func(t *testing.T, p stack.TransportProtocol)
+ }{
+ {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) {
+ t.Helper()
+ fakeTrans := p.(*fakeTransportProtocol)
+ if fakeTrans.opts.good != true {
+ t.Fatalf("fakeTrans.opts.good = false, want = true")
+ }
+ var v fakeTransportGoodOption
+ if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
+ t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err)
+ }
+ if v != true {
+ t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v)
+ }
+
+ }},
+ {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
+ {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
+ }
+ for _, tc := range testCases {
+ if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr {
+ t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr)
+ }
+ if tc.verifier != nil {
+ tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber))
+ }
+ }
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol {
+ return &fakeTransportProtocol{}
+ })
+}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
new file mode 100644
index 000000000..f3a94f353
--- /dev/null
+++ b/pkg/tcpip/tcpip.go
@@ -0,0 +1,499 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package tcpip provides the interfaces and related types that users of the
+// tcpip stack will use in order to create endpoints used to send and receive
+// data over the network stack.
+//
+// The starting point is the creation and configuration of a stack. A stack can
+// be created by calling the New() function of the tcpip/stack/stack package;
+// configuring a stack involves creating NICs (via calls to Stack.CreateNIC()),
+// adding network addresses (via calls to Stack.AddAddress()), and
+// setting a route table (via a call to Stack.SetRouteTable()).
+//
+// Once a stack is configured, endpoints can be created by calling
+// Stack.NewEndpoint(). Such endpoints can be used to send/receive data, connect
+// to peers, listen for connections, accept connections, etc., depending on the
+// transport protocol selected.
+package tcpip
+
+import (
+ "errors"
+ "fmt"
+ "strconv"
+ "strings"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Error represents an error in the netstack error space. Using a special type
+// ensures that errors outside of this space are not accidentally introduced.
+type Error struct {
+ string
+}
+
+// String implements fmt.Stringer.String.
+func (e *Error) String() string {
+ return e.string
+}
+
+// Errors that can be returned by the network stack.
+var (
+ ErrUnknownProtocol = &Error{"unknown protocol"}
+ ErrUnknownNICID = &Error{"unknown nic id"}
+ ErrUnknownProtocolOption = &Error{"unknown option for protocol"}
+ ErrDuplicateNICID = &Error{"duplicate nic id"}
+ ErrDuplicateAddress = &Error{"duplicate address"}
+ ErrNoRoute = &Error{"no route"}
+ ErrBadLinkEndpoint = &Error{"bad link layer endpoint"}
+ ErrAlreadyBound = &Error{"endpoint already bound"}
+ ErrInvalidEndpointState = &Error{"endpoint is in invalid state"}
+ ErrAlreadyConnecting = &Error{"endpoint is already connecting"}
+ ErrAlreadyConnected = &Error{"endpoint is already connected"}
+ ErrNoPortAvailable = &Error{"no ports are available"}
+ ErrPortInUse = &Error{"port is in use"}
+ ErrBadLocalAddress = &Error{"bad local address"}
+ ErrClosedForSend = &Error{"endpoint is closed for send"}
+ ErrClosedForReceive = &Error{"endpoint is closed for receive"}
+ ErrWouldBlock = &Error{"operation would block"}
+ ErrConnectionRefused = &Error{"connection was refused"}
+ ErrTimeout = &Error{"operation timed out"}
+ ErrAborted = &Error{"operation aborted"}
+ ErrConnectStarted = &Error{"connection attempt started"}
+ ErrDestinationRequired = &Error{"destination address is required"}
+ ErrNotSupported = &Error{"operation not supported"}
+ ErrQueueSizeNotSupported = &Error{"queue size querying not supported"}
+ ErrNotConnected = &Error{"endpoint not connected"}
+ ErrConnectionReset = &Error{"connection reset by peer"}
+ ErrConnectionAborted = &Error{"connection aborted"}
+ ErrNoSuchFile = &Error{"no such file"}
+ ErrInvalidOptionValue = &Error{"invalid option value specified"}
+ ErrNoLinkAddress = &Error{"no remote link address"}
+ ErrBadAddress = &Error{"bad address"}
+)
+
+// Errors related to Subnet
+var (
+ errSubnetLengthMismatch = errors.New("subnet length of address and mask differ")
+ errSubnetAddressMasked = errors.New("subnet address has bits set outside the mask")
+)
+
+// Address is a byte slice cast as a string that represents the address of a
+// network node. Or, in the case of unix endpoints, it may represent a path.
+type Address string
+
+// AddressMask is a bitmask for an address.
+type AddressMask string
+
+// Subnet is a subnet defined by its address and mask.
+type Subnet struct {
+ address Address
+ mask AddressMask
+}
+
+// NewSubnet creates a new Subnet, checking that the address and mask are the same length.
+func NewSubnet(a Address, m AddressMask) (Subnet, error) {
+ if len(a) != len(m) {
+ return Subnet{}, errSubnetLengthMismatch
+ }
+ for i := 0; i < len(a); i++ {
+ if a[i]&^m[i] != 0 {
+ return Subnet{}, errSubnetAddressMasked
+ }
+ }
+ return Subnet{a, m}, nil
+}
+
+// Contains returns true iff the address is of the same length and matches the
+// subnet address and mask.
+func (s *Subnet) Contains(a Address) bool {
+ if len(a) != len(s.address) {
+ return false
+ }
+ for i := 0; i < len(a); i++ {
+ if a[i]&s.mask[i] != s.address[i] {
+ return false
+ }
+ }
+ return true
+}
+
+// ID returns the subnet ID.
+func (s *Subnet) ID() Address {
+ return s.address
+}
+
+// Bits returns the number of ones (network bits) and zeros (host bits) in the
+// subnet mask.
+func (s *Subnet) Bits() (ones int, zeros int) {
+ for _, b := range []byte(s.mask) {
+ for i := uint(0); i < 8; i++ {
+ if b&(1<<i) == 0 {
+ zeros++
+ } else {
+ ones++
+ }
+ }
+ }
+ return
+}
+
+// Prefix returns the number of bits before the first host bit.
+func (s *Subnet) Prefix() int {
+ for i, b := range []byte(s.mask) {
+ for j := 7; j >= 0; j-- {
+ if b&(1<<uint(j)) == 0 {
+ return i*8 + 7 - j
+ }
+ }
+ }
+ return len(s.mask) * 8
+}
+
+// NICID is a number that uniquely identifies a NIC.
+type NICID int32
+
+// ShutdownFlags represents flags that can be passed to the Shutdown() method
+// of the Endpoint interface.
+type ShutdownFlags int
+
+// Values of the flags that can be passed to the Shutdown() method. They can
+// be OR'ed together.
+const (
+ ShutdownRead ShutdownFlags = 1 << iota
+ ShutdownWrite
+)
+
+// FullAddress represents a full transport node address, as required by the
+// Connect() and Bind() methods.
+type FullAddress struct {
+ // NIC is the ID of the NIC this address refers to.
+ //
+ // This may not be used by all endpoint types.
+ NIC NICID
+
+ // Addr is the network address.
+ Addr Address
+
+ // Port is the transport port.
+ //
+ // This may not be used by all endpoint types.
+ Port uint16
+}
+
+// Payload provides an interface around data that is being sent to an endpoint.
+// This allows the endpoint to request the amount of data it needs based on
+// internal buffers without exposing them. 'p.Get(p.Size())' reads all the data.
+type Payload interface {
+ // Get returns a slice containing exactly 'min(size, p.Size())' bytes.
+ Get(size int) ([]byte, *Error)
+
+ // Size returns the payload size.
+ Size() int
+}
+
+// SlicePayload implements Payload on top of slices for convenience.
+type SlicePayload []byte
+
+// Get implements Payload.
+func (s SlicePayload) Get(size int) ([]byte, *Error) {
+ if size > s.Size() {
+ size = s.Size()
+ }
+ return s[:size], nil
+}
+
+// Size implements Payload.
+func (s SlicePayload) Size() int {
+ return len(s)
+}
+
+// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
+// that exposes functionality like read, write, connect, etc. to users of the
+// networking stack.
+type Endpoint interface {
+ // Close puts the endpoint in a closed state and frees all resources
+ // associated with it.
+ Close()
+
+ // Read reads data from the endpoint and optionally returns the sender.
+ // This method does not block if there is no data pending.
+ // It will also either return an error or data, never both.
+ Read(*FullAddress) (buffer.View, *Error)
+
+ // Write writes data to the endpoint's peer. This method does not block if
+ // the data cannot be written.
+ //
+ // Unlike io.Writer.Write, Endpoint.Write transfers ownership of any bytes
+ // successfully written to the Endpoint. That is, if a call to
+ // Write(SlicePayload{data}) returns (n, err), it may retain data[:n], and
+ // the caller should not use data[:n] after Write returns.
+ //
+ // Note that unlike io.Writer.Write, it is not an error for Write to
+ // perform a partial write.
+ Write(Payload, WriteOptions) (uintptr, *Error)
+
+ // Peek reads data without consuming it from the endpoint.
+ //
+ // This method does not block if there is no data pending.
+ Peek([][]byte) (uintptr, *Error)
+
+ // Connect connects the endpoint to its peer. Specifying a NIC is
+ // optional.
+ //
+ // There are three classes of return values:
+ // nil -- the attempt to connect succeeded.
+ // ErrConnectStarted/ErrAlreadyConnecting -- the connect attempt started
+ // but hasn't completed yet. In this case, the caller must call Connect
+ // or GetSockOpt(ErrorOption) when the endpoint becomes writable to
+ // get the actual result. The first call to Connect after the socket has
+ // connected returns nil. Calling connect again results in ErrAlreadyConnected.
+ // Anything else -- the attempt to connect failed.
+ Connect(address FullAddress) *Error
+
+ // Shutdown closes the read and/or write end of the endpoint connection
+ // to its peer.
+ Shutdown(flags ShutdownFlags) *Error
+
+ // Listen puts the endpoint in "listen" mode, which allows it to accept
+ // new connections.
+ Listen(backlog int) *Error
+
+ // Accept returns a new endpoint if a peer has established a connection
+ // to an endpoint previously set to listen mode. This method does not
+ // block if no new connections are available.
+ //
+ // The returned Queue is the wait queue for the newly created endpoint.
+ Accept() (Endpoint, *waiter.Queue, *Error)
+
+ // Bind binds the endpoint to a specific local address and port.
+ // Specifying a NIC is optional.
+ //
+ // An optional commit function will be executed atomically with respect
+ // to binding the endpoint. If this returns an error, the bind will not
+ // occur and the error will be propagated back to the caller.
+ Bind(address FullAddress, commit func() *Error) *Error
+
+ // GetLocalAddress returns the address to which the endpoint is bound.
+ GetLocalAddress() (FullAddress, *Error)
+
+ // GetRemoteAddress returns the address to which the endpoint is
+ // connected.
+ GetRemoteAddress() (FullAddress, *Error)
+
+ // Readiness returns the current readiness of the endpoint. For example,
+ // if waiter.EventIn is set, the endpoint is immediately readable.
+ Readiness(mask waiter.EventMask) waiter.EventMask
+
+ // SetSockOpt sets a socket option. opt should be one of the *Option types.
+ SetSockOpt(opt interface{}) *Error
+
+ // GetSockOpt gets a socket option. opt should be a pointer to one of the
+ // *Option types.
+ GetSockOpt(opt interface{}) *Error
+}
+
+// WriteOptions contains options for Endpoint.Write.
+type WriteOptions struct {
+ // If To is not nil, write to the given address instead of the endpoint's
+ // peer.
+ To *FullAddress
+
+ // More has the same semantics as Linux's MSG_MORE.
+ More bool
+
+ // EndOfRecord has the same semantics as Linux's MSG_EOR.
+ EndOfRecord bool
+}
+
+// ErrorOption is used in GetSockOpt to specify that the last error reported by
+// the endpoint should be cleared and returned.
+type ErrorOption struct{}
+
+// SendBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the send
+// buffer size option.
+type SendBufferSizeOption int
+
+// ReceiveBufferSizeOption is used by SetSockOpt/GetSockOpt to specify the
+// receive buffer size option.
+type ReceiveBufferSizeOption int
+
+// SendQueueSizeOption is used in GetSockOpt to specify that the number of
+// unread bytes in the output buffer should be returned.
+type SendQueueSizeOption int
+
+// ReceiveQueueSizeOption is used in GetSockOpt to specify that the number of
+// unread bytes in the input buffer should be returned.
+type ReceiveQueueSizeOption int
+
+// V6OnlyOption is used by SetSockOpt/GetSockOpt to specify whether an IPv6
+// socket is to be restricted to sending and receiving IPv6 packets only.
+type V6OnlyOption int
+
+// NoDelayOption is used by SetSockOpt/GetSockOpt to specify if data should be
+// sent out immediately by the transport protocol. For TCP, it determines if the
+// Nagle algorithm is on or off.
+type NoDelayOption int
+
+// ReuseAddressOption is used by SetSockOpt/GetSockOpt to specify whether Bind()
+// should allow reuse of local address.
+type ReuseAddressOption int
+
+// PasscredOption is used by SetSockOpt/GetSockOpt to specify whether
+// SCM_CREDENTIALS socket control messages are enabled.
+//
+// Only supported on Unix sockets.
+type PasscredOption int
+
+// TCPInfoOption is used by GetSockOpt to expose TCP statistics.
+//
+// TODO: Add and populate stat fields.
+type TCPInfoOption struct{}
+
+// Route is a row in the routing table. It specifies through which NIC (and
+// gateway) sets of packets should be routed. A row is considered viable if the
+// masked target address matches the destination adddress in the row.
+type Route struct {
+ // Destination is the address that must be matched against the masked
+ // target address to check if this row is viable.
+ Destination Address
+
+ // Mask specifies which bits of the Destination and the target address
+ // must match for this row to be viable.
+ Mask Address
+
+ // Gateway is the gateway to be used if this row is viable.
+ Gateway Address
+
+ // NIC is the id of the nic to be used if this row is viable.
+ NIC NICID
+}
+
+// Match determines if r is viable for the given destination address.
+func (r *Route) Match(addr Address) bool {
+ if len(addr) != len(r.Destination) {
+ return false
+ }
+
+ for i := 0; i < len(r.Destination); i++ {
+ if (addr[i] & r.Mask[i]) != r.Destination[i] {
+ return false
+ }
+ }
+
+ return true
+}
+
+// LinkEndpointID represents a data link layer endpoint.
+type LinkEndpointID uint64
+
+// TransportProtocolNumber is the number of a transport protocol.
+type TransportProtocolNumber uint32
+
+// NetworkProtocolNumber is the number of a network protocol.
+type NetworkProtocolNumber uint32
+
+// Stats holds statistics about the networking stack.
+type Stats struct {
+ // UnknownProtocolRcvdPackets is the number of packets received by the
+ // stack that were for an unknown or unsupported protocol.
+ UnknownProtocolRcvdPackets uint64
+
+ // UnknownNetworkEndpointRcvdPackets is the number of packets received
+ // by the stack that were for a supported network protocol, but whose
+ // destination address didn't having a matching endpoint.
+ UnknownNetworkEndpointRcvdPackets uint64
+
+ // MalformedRcvPackets is the number of packets received by the stack
+ // that were deemed malformed.
+ MalformedRcvdPackets uint64
+
+ // DroppedPackets is the number of packets dropped due to full queues.
+ DroppedPackets uint64
+}
+
+// String implements the fmt.Stringer interface.
+func (a Address) String() string {
+ switch len(a) {
+ case 4:
+ return fmt.Sprintf("%d.%d.%d.%d", int(a[0]), int(a[1]), int(a[2]), int(a[3]))
+ default:
+ return fmt.Sprintf("%x", []byte(a))
+ }
+}
+
+// To4 converts the IPv4 address to a 4-byte representation.
+// If the address is not an IPv4 address, To4 returns "".
+func (a Address) To4() Address {
+ const (
+ ipv4len = 4
+ ipv6len = 16
+ )
+ if len(a) == ipv4len {
+ return a
+ }
+ if len(a) == ipv6len &&
+ isZeros(a[0:10]) &&
+ a[10] == 0xff &&
+ a[11] == 0xff {
+ return a[12:16]
+ }
+ return ""
+}
+
+// isZeros reports whether a is all zeros.
+func isZeros(a Address) bool {
+ for i := 0; i < len(a); i++ {
+ if a[i] != 0 {
+ return false
+ }
+ }
+ return true
+}
+
+// LinkAddress is a byte slice cast as a string that represents a link address.
+// It is typically a 6-byte MAC address.
+type LinkAddress string
+
+// String implements the fmt.Stringer interface.
+func (a LinkAddress) String() string {
+ switch len(a) {
+ case 6:
+ return fmt.Sprintf("%02x:%02x:%02x:%02x:%02x:%02x", a[0], a[1], a[2], a[3], a[4], a[5])
+ default:
+ return fmt.Sprintf("%x", []byte(a))
+ }
+}
+
+// ParseMACAddress parses an IEEE 802 address.
+//
+// It must be in the format aa:bb:cc:dd:ee:ff or aa-bb-cc-dd-ee-ff.
+func ParseMACAddress(s string) (LinkAddress, error) {
+ parts := strings.FieldsFunc(s, func(c rune) bool {
+ return c == ':' || c == '-'
+ })
+ if len(parts) != 6 {
+ return "", fmt.Errorf("inconsistent parts: %s", s)
+ }
+ addr := make([]byte, 0, len(parts))
+ for _, part := range parts {
+ u, err := strconv.ParseUint(part, 16, 8)
+ if err != nil {
+ return "", fmt.Errorf("invalid hex digits: %s", s)
+ }
+ addr = append(addr, byte(u))
+ }
+ return LinkAddress(addr), nil
+}
+
+// ProtocolAddress is an address and the network protocol it is associated
+// with.
+type ProtocolAddress struct {
+ // Protocol is the protocol of the address.
+ Protocol NetworkProtocolNumber
+
+ // Address is a network address.
+ Address Address
+}
diff --git a/pkg/tcpip/tcpip_test.go b/pkg/tcpip/tcpip_test.go
new file mode 100644
index 000000000..fd4d8346f
--- /dev/null
+++ b/pkg/tcpip/tcpip_test.go
@@ -0,0 +1,130 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcpip
+
+import (
+ "testing"
+)
+
+func TestSubnetContains(t *testing.T) {
+ tests := []struct {
+ s Address
+ m AddressMask
+ a Address
+ want bool
+ }{
+ {"\xa0", "\xf0", "\x90", false},
+ {"\xa0", "\xf0", "\xa0", true},
+ {"\xa0", "\xf0", "\xa5", true},
+ {"\xa0", "\xf0", "\xaf", true},
+ {"\xa0", "\xf0", "\xb0", false},
+ {"\xa0", "\xf0", "", false},
+ {"\xa0", "\xf0", "\xa0\x00", false},
+ {"\xc2\x80", "\xff\xf0", "\xc2\x80", true},
+ {"\xc2\x80", "\xff\xf0", "\xc2\x00", false},
+ {"\xc2\x00", "\xff\xf0", "\xc2\x00", true},
+ {"\xc2\x00", "\xff\xf0", "\xc2\x80", false},
+ }
+ for _, tt := range tests {
+ s, err := NewSubnet(tt.s, tt.m)
+ if err != nil {
+ t.Errorf("NewSubnet(%v, %v) = %v", tt.s, tt.m, err)
+ continue
+ }
+ if got := s.Contains(tt.a); got != tt.want {
+ t.Errorf("Subnet(%v).Contains(%v) = %v, want %v", s, tt.a, got, tt.want)
+ }
+ }
+}
+
+func TestSubnetBits(t *testing.T) {
+ tests := []struct {
+ a AddressMask
+ want1 int
+ want0 int
+ }{
+ {"\x00", 0, 8},
+ {"\x00\x00", 0, 16},
+ {"\x36", 4, 4},
+ {"\x5c", 4, 4},
+ {"\x5c\x5c", 8, 8},
+ {"\x5c\x36", 8, 8},
+ {"\x36\x5c", 8, 8},
+ {"\x36\x36", 8, 8},
+ {"\xff", 8, 0},
+ {"\xff\xff", 16, 0},
+ }
+ for _, tt := range tests {
+ s := &Subnet{mask: tt.a}
+ got1, got0 := s.Bits()
+ if got1 != tt.want1 || got0 != tt.want0 {
+ t.Errorf("Subnet{mask: %x}.Bits() = %d, %d, want %d, %d", tt.a, got1, got0, tt.want1, tt.want0)
+ }
+ }
+}
+
+func TestSubnetPrefix(t *testing.T) {
+ tests := []struct {
+ a AddressMask
+ want int
+ }{
+ {"\x00", 0},
+ {"\x00\x00", 0},
+ {"\x36", 0},
+ {"\x86", 1},
+ {"\xc5", 2},
+ {"\xff\x00", 8},
+ {"\xff\x36", 8},
+ {"\xff\x8c", 9},
+ {"\xff\xc8", 10},
+ {"\xff", 8},
+ {"\xff\xff", 16},
+ }
+ for _, tt := range tests {
+ s := &Subnet{mask: tt.a}
+ got := s.Prefix()
+ if got != tt.want {
+ t.Errorf("Subnet{mask: %x}.Bits() = %d want %d", tt.a, got, tt.want)
+ }
+ }
+}
+
+func TestSubnetCreation(t *testing.T) {
+ tests := []struct {
+ a Address
+ m AddressMask
+ want error
+ }{
+ {"\xa0", "\xf0", nil},
+ {"\xa0\xa0", "\xf0", errSubnetLengthMismatch},
+ {"\xaa", "\xf0", errSubnetAddressMasked},
+ {"", "", nil},
+ }
+ for _, tt := range tests {
+ if _, err := NewSubnet(tt.a, tt.m); err != tt.want {
+ t.Errorf("NewSubnet(%v, %v) = %v, want %v", tt.a, tt.m, err, tt.want)
+ }
+ }
+}
+
+func TestRouteMatch(t *testing.T) {
+ tests := []struct {
+ d Address
+ m Address
+ a Address
+ want bool
+ }{
+ {"\xc2\x80", "\xff\xf0", "\xc2\x80", true},
+ {"\xc2\x80", "\xff\xf0", "\xc2\x00", false},
+ {"\xc2\x00", "\xff\xf0", "\xc2\x00", true},
+ {"\xc2\x00", "\xff\xf0", "\xc2\x80", false},
+ }
+ for _, tt := range tests {
+ r := Route{Destination: tt.d, Mask: tt.m}
+ if got := r.Match(tt.a); got != tt.want {
+ t.Errorf("Route(%v).Match(%v) = %v, want %v", r, tt.a, got, tt.want)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/queue/BUILD b/pkg/tcpip/transport/queue/BUILD
new file mode 100644
index 000000000..162af574c
--- /dev/null
+++ b/pkg/tcpip/transport/queue/BUILD
@@ -0,0 +1,29 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "queue_state",
+ srcs = [
+ "queue.go",
+ ],
+ out = "queue_state.go",
+ package = "queue",
+)
+
+go_library(
+ name = "queue",
+ srcs = [
+ "queue.go",
+ "queue_state.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/queue",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/ilist",
+ "//pkg/state",
+ "//pkg/tcpip",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/queue/queue.go b/pkg/tcpip/transport/queue/queue.go
new file mode 100644
index 000000000..2d2918504
--- /dev/null
+++ b/pkg/tcpip/transport/queue/queue.go
@@ -0,0 +1,166 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package queue provides the implementation of buffer queue
+// and interface of queue entry with Length method.
+package queue
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/ilist"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Entry implements Linker interface and has both Length and Release methods.
+type Entry interface {
+ ilist.Linker
+ Length() int64
+ Release()
+ Peek() Entry
+}
+
+// Queue is a buffer queue.
+type Queue struct {
+ ReaderQueue *waiter.Queue
+ WriterQueue *waiter.Queue
+
+ mu sync.Mutex `state:"nosave"`
+ closed bool
+ used int64
+ limit int64
+ dataList ilist.List
+}
+
+// New allocates and initializes a new queue.
+func New(ReaderQueue *waiter.Queue, WriterQueue *waiter.Queue, limit int64) *Queue {
+ return &Queue{ReaderQueue: ReaderQueue, WriterQueue: WriterQueue, limit: limit}
+}
+
+// Close closes q for reading and writing. It is immediately not writable and
+// will become unreadble will no more data is pending.
+//
+// Both the read and write queues must be notified after closing:
+// q.ReaderQueue.Notify(waiter.EventIn)
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *Queue) Close() {
+ q.mu.Lock()
+ q.closed = true
+ q.mu.Unlock()
+}
+
+// Reset empties the queue and Releases all of the Entries.
+//
+// Both the read and write queues must be notified after resetting:
+// q.ReaderQueue.Notify(waiter.EventIn)
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *Queue) Reset() {
+ q.mu.Lock()
+ for cur := q.dataList.Front(); cur != nil; cur = cur.Next() {
+ cur.(Entry).Release()
+ }
+ q.dataList.Reset()
+ q.used = 0
+ q.mu.Unlock()
+}
+
+// IsReadable determines if q is currently readable.
+func (q *Queue) IsReadable() bool {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ return q.closed || q.dataList.Front() != nil
+}
+
+// IsWritable determines if q is currently writable.
+func (q *Queue) IsWritable() bool {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ return q.closed || q.used < q.limit
+}
+
+// Enqueue adds an entry to the data queue if room is available.
+//
+// If notify is true, ReaderQueue.Notify must be called:
+// q.ReaderQueue.Notify(waiter.EventIn)
+func (q *Queue) Enqueue(e Entry) (notify bool, err *tcpip.Error) {
+ q.mu.Lock()
+
+ if q.closed {
+ q.mu.Unlock()
+ return false, tcpip.ErrClosedForSend
+ }
+
+ if q.used >= q.limit {
+ q.mu.Unlock()
+ return false, tcpip.ErrWouldBlock
+ }
+
+ notify = q.dataList.Front() == nil
+ q.used += e.Length()
+ q.dataList.PushBack(e)
+
+ q.mu.Unlock()
+
+ return notify, nil
+}
+
+// Dequeue removes the first entry in the data queue, if one exists.
+//
+// If notify is true, WriterQueue.Notify must be called:
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *Queue) Dequeue() (e Entry, notify bool, err *tcpip.Error) {
+ q.mu.Lock()
+
+ if q.dataList.Front() == nil {
+ err := tcpip.ErrWouldBlock
+ if q.closed {
+ err = tcpip.ErrClosedForReceive
+ }
+ q.mu.Unlock()
+
+ return nil, false, err
+ }
+
+ notify = q.used >= q.limit
+
+ e = q.dataList.Front().(Entry)
+ q.dataList.Remove(e)
+ q.used -= e.Length()
+
+ notify = notify && q.used < q.limit
+
+ q.mu.Unlock()
+
+ return e, notify, nil
+}
+
+// Peek returns the first entry in the data queue, if one exists.
+func (q *Queue) Peek() (Entry, *tcpip.Error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if q.dataList.Front() == nil {
+ err := tcpip.ErrWouldBlock
+ if q.closed {
+ err = tcpip.ErrClosedForReceive
+ }
+ return nil, err
+ }
+
+ return q.dataList.Front().(Entry).Peek(), nil
+}
+
+// QueuedSize returns the number of bytes currently in the queue, that is, the
+// number of readable bytes.
+func (q *Queue) QueuedSize() int64 {
+ return q.used
+}
+
+// MaxQueueSize returns the maximum number of bytes storable in the queue.
+func (q *Queue) MaxQueueSize() int64 {
+ return q.limit
+}
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
new file mode 100644
index 000000000..d0eb8b8bd
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -0,0 +1,97 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "tcp_state",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ "rcv.go",
+ "segment_heap.go",
+ "snd.go",
+ "tcp_segment_list.go",
+ ],
+ out = "tcp_state.go",
+ package = "tcp",
+)
+
+go_template_instance(
+ name = "tcp_segment_list",
+ out = "tcp_segment_list.go",
+ package = "tcp",
+ prefix = "segment",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Linker": "*segment",
+ },
+)
+
+go_library(
+ name = "tcp",
+ srcs = [
+ "accept.go",
+ "connect.go",
+ "endpoint.go",
+ "endpoint_state.go",
+ "forwarder.go",
+ "protocol.go",
+ "rcv.go",
+ "sack.go",
+ "segment.go",
+ "segment_heap.go",
+ "segment_queue.go",
+ "snd.go",
+ "tcp_segment_list.go",
+ "tcp_state.go",
+ "timer.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sleep",
+ "//pkg/state",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ "//pkg/tcpip/stack",
+ "//pkg/tmutex",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "tcp_test",
+ size = "small",
+ srcs = [
+ "dual_stack_test.go",
+ "tcp_sack_test.go",
+ "tcp_test.go",
+ "tcp_timestamp_test.go",
+ ],
+ deps = [
+ ":tcp",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/loopback",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/seqnum",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp/testing/context",
+ "//pkg/waiter",
+ ],
+)
+
+filegroup(
+ name = "autogen",
+ srcs = [
+ "tcp_segment_list.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go
new file mode 100644
index 000000000..9a5b13066
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/accept.go
@@ -0,0 +1,407 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcp
+
+import (
+ "crypto/rand"
+ "crypto/sha1"
+ "encoding/binary"
+ "hash"
+ "io"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // tsLen is the length, in bits, of the timestamp in the SYN cookie.
+ tsLen = 8
+
+ // tsMask is a mask for timestamp values (i.e., tsLen bits).
+ tsMask = (1 << tsLen) - 1
+
+ // tsOffset is the offset, in bits, of the timestamp in the SYN cookie.
+ tsOffset = 24
+
+ // hashMask is the mask for hash values (i.e., tsOffset bits).
+ hashMask = (1 << tsOffset) - 1
+
+ // maxTSDiff is the maximum allowed difference between a received cookie
+ // timestamp and the current timestamp. If the difference is greater
+ // than maxTSDiff, the cookie is expired.
+ maxTSDiff = 2
+)
+
+var (
+ // SynRcvdCountThreshold is the global maximum number of connections
+ // that are allowed to be in SYN-RCVD state before TCP starts using SYN
+ // cookies to accept connections.
+ //
+ // It is an exported variable only for testing, and should not otherwise
+ // be used by importers of this package.
+ SynRcvdCountThreshold uint64 = 1000
+
+ // mssTable is a slice containing the possible MSS values that we
+ // encode in the SYN cookie with two bits.
+ mssTable = []uint16{536, 1300, 1440, 1460}
+)
+
+func encodeMSS(mss uint16) uint32 {
+ for i := len(mssTable) - 1; i > 0; i-- {
+ if mss >= mssTable[i] {
+ return uint32(i)
+ }
+ }
+ return 0
+}
+
+// syncRcvdCount is the number of endpoints in the SYN-RCVD state. The value is
+// protected by a mutex so that we can increment only when it's guaranteed not
+// to go above a threshold.
+var synRcvdCount struct {
+ sync.Mutex
+ value uint64
+}
+
+// listenContext is used by a listening endpoint to store state used while
+// listening for connections. This struct is allocated by the listen goroutine
+// and must not be accessed or have its methods called concurrently as they
+// may mutate the stored objects.
+type listenContext struct {
+ stack *stack.Stack
+ rcvWnd seqnum.Size
+ nonce [2][sha1.BlockSize]byte
+
+ hasherMu sync.Mutex
+ hasher hash.Hash
+ v6only bool
+ netProto tcpip.NetworkProtocolNumber
+}
+
+// timeStamp returns an 8-bit timestamp with a granularity of 64 seconds.
+func timeStamp() uint32 {
+ return uint32(time.Now().Unix()>>6) & tsMask
+}
+
+// incSynRcvdCount tries to increment the global number of endpoints in SYN-RCVD
+// state. It succeeds if the increment doesn't make the count go beyond the
+// threshold, and fails otherwise.
+func incSynRcvdCount() bool {
+ synRcvdCount.Lock()
+ defer synRcvdCount.Unlock()
+
+ if synRcvdCount.value >= SynRcvdCountThreshold {
+ return false
+ }
+
+ synRcvdCount.value++
+
+ return true
+}
+
+// decSynRcvdCount atomically decrements the global number of endpoints in
+// SYN-RCVD state. It must only be called if a previous call to incSynRcvdCount
+// succeeded.
+func decSynRcvdCount() {
+ synRcvdCount.Lock()
+ defer synRcvdCount.Unlock()
+
+ synRcvdCount.value--
+}
+
+// newListenContext creates a new listen context.
+func newListenContext(stack *stack.Stack, rcvWnd seqnum.Size, v6only bool, netProto tcpip.NetworkProtocolNumber) *listenContext {
+ l := &listenContext{
+ stack: stack,
+ rcvWnd: rcvWnd,
+ hasher: sha1.New(),
+ v6only: v6only,
+ netProto: netProto,
+ }
+
+ rand.Read(l.nonce[0][:])
+ rand.Read(l.nonce[1][:])
+
+ return l
+}
+
+// cookieHash calculates the cookieHash for the given id, timestamp and nonce
+// index. The hash is used to create and validate cookies.
+func (l *listenContext) cookieHash(id stack.TransportEndpointID, ts uint32, nonceIndex int) uint32 {
+
+ // Initialize block with fixed-size data: local ports and v.
+ var payload [8]byte
+ binary.BigEndian.PutUint16(payload[0:], id.LocalPort)
+ binary.BigEndian.PutUint16(payload[2:], id.RemotePort)
+ binary.BigEndian.PutUint32(payload[4:], ts)
+
+ // Feed everything to the hasher.
+ l.hasherMu.Lock()
+ l.hasher.Reset()
+ l.hasher.Write(payload[:])
+ l.hasher.Write(l.nonce[nonceIndex][:])
+ io.WriteString(l.hasher, string(id.LocalAddress))
+ io.WriteString(l.hasher, string(id.RemoteAddress))
+
+ // Finalize the calculation of the hash and return the first 4 bytes.
+ h := make([]byte, 0, sha1.Size)
+ h = l.hasher.Sum(h)
+ l.hasherMu.Unlock()
+
+ return binary.BigEndian.Uint32(h[:])
+}
+
+// createCookie creates a SYN cookie for the given id and incoming sequence
+// number.
+func (l *listenContext) createCookie(id stack.TransportEndpointID, seq seqnum.Value, data uint32) seqnum.Value {
+ ts := timeStamp()
+ v := l.cookieHash(id, 0, 0) + uint32(seq) + (ts << tsOffset)
+ v += (l.cookieHash(id, ts, 1) + data) & hashMask
+ return seqnum.Value(v)
+}
+
+// isCookieValid checks if the supplied cookie is valid for the given id and
+// sequence number. If it is, it also returns the data originally encoded in the
+// cookie when createCookie was called.
+func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnum.Value, seq seqnum.Value) (uint32, bool) {
+ ts := timeStamp()
+ v := uint32(cookie) - l.cookieHash(id, 0, 0) - uint32(seq)
+ cookieTS := v >> tsOffset
+ if ((ts - cookieTS) & tsMask) > maxTSDiff {
+ return 0, false
+ }
+
+ return (v - l.cookieHash(id, cookieTS, 1)) & hashMask, true
+}
+
+// createConnectedEndpoint creates a new connected endpoint, with the connection
+// parameters given by the arguments.
+func (l *listenContext) createConnectedEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+ // Create a new endpoint.
+ netProto := l.netProto
+ if netProto == 0 {
+ netProto = s.route.NetProto
+ }
+ n := newEndpoint(l.stack, netProto, nil)
+ n.v6only = l.v6only
+ n.id = s.id
+ n.boundNICID = s.route.NICID()
+ n.route = s.route.Clone()
+ n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto}
+ n.rcvBufSize = int(l.rcvWnd)
+
+ n.maybeEnableTimestamp(rcvdSynOpts)
+ n.maybeEnableSACKPermitted(rcvdSynOpts)
+
+ // Register new endpoint so that packets are routed to it.
+ if err := n.stack.RegisterTransportEndpoint(n.boundNICID, n.effectiveNetProtos, ProtocolNumber, n.id, n); err != nil {
+ n.Close()
+ return nil, err
+ }
+
+ n.isRegistered = true
+ n.state = stateConnected
+
+ // Create sender and receiver.
+ //
+ // The receiver at least temporarily has a zero receive window scale,
+ // but the caller may change it (before starting the protocol loop).
+ n.snd = newSender(n, iss, irs, s.window, rcvdSynOpts.MSS, rcvdSynOpts.WS)
+ n.rcv = newReceiver(n, irs, l.rcvWnd, 0)
+
+ return n, nil
+}
+
+// createEndpoint creates a new endpoint in connected state and then performs
+// the TCP 3-way handshake.
+func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *header.TCPSynOptions) (*endpoint, *tcpip.Error) {
+ // Create new endpoint.
+ irs := s.sequenceNumber
+ cookie := l.createCookie(s.id, irs, encodeMSS(opts.MSS))
+ ep, err := l.createConnectedEndpoint(s, cookie, irs, opts)
+ if err != nil {
+ return nil, err
+ }
+
+ // Perform the 3-way handshake.
+ h, err := newHandshake(ep, l.rcvWnd)
+ if err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ h.resetToSynRcvd(cookie, irs, opts)
+ if err := h.execute(); err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ // Update the receive window scaling. We can't do it before the
+ // handshake because it's possible that the peer doesn't support window
+ // scaling.
+ ep.rcv.rcvWndScale = h.effectiveRcvWndScale()
+
+ return ep, nil
+}
+
+// deliverAccepted delivers the newly-accepted endpoint to the listener. If the
+// endpoint has transitioned out of the listen state, the new endpoint is closed
+// instead.
+func (e *endpoint) deliverAccepted(n *endpoint) {
+ e.mu.RLock()
+ if e.state == stateListen {
+ e.acceptedChan <- n
+ e.waiterQueue.Notify(waiter.EventIn)
+ } else {
+ n.Close()
+ }
+ e.mu.RUnlock()
+}
+
+// handleSynSegment is called in its own goroutine once the listening endpoint
+// receives a SYN segment. It is responsible for completing the handshake and
+// queueing the new endpoint for acceptance.
+//
+// A limited number of these goroutines are allowed before TCP starts using SYN
+// cookies to accept connections.
+func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) {
+ defer decSynRcvdCount()
+ defer s.decRef()
+
+ n, err := ctx.createEndpointAndPerformHandshake(s, opts)
+ if err != nil {
+ return
+ }
+
+ e.deliverAccepted(n)
+}
+
+// handleListenSegment is called when a listening endpoint receives a segment
+// and needs to handle it.
+func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) {
+ switch s.flags {
+ case flagSyn:
+ opts := parseSynSegmentOptions(s)
+ if incSynRcvdCount() {
+ s.incRef()
+ go e.handleSynSegment(ctx, s, &opts) // S/R-FIXME
+ } else {
+ cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS))
+ // Send SYN with window scaling because we currently
+ // dont't encode this information in the cookie.
+ //
+ // Enable Timestamp option if the original syn did have
+ // the timestamp option specified.
+ synOpts := header.TCPSynOptions{
+ WS: -1,
+ TS: opts.TS,
+ TSVal: tcpTimeStamp(timeStampOffset()),
+ TSEcr: opts.TSVal,
+ }
+ sendSynTCP(&s.route, s.id, flagSyn|flagAck, cookie, s.sequenceNumber+1, ctx.rcvWnd, synOpts)
+ }
+
+ case flagAck:
+ if data, ok := ctx.isCookieValid(s.id, s.ackNumber-1, s.sequenceNumber-1); ok && int(data) < len(mssTable) {
+ // Create newly accepted endpoint and deliver it.
+ rcvdSynOptions := &header.TCPSynOptions{
+ MSS: mssTable[data],
+ // Disable Window scaling as original SYN is
+ // lost.
+ WS: -1,
+ }
+ // When syn cookies are in use we enable timestamp only
+ // if the ack specifies the timestamp option assuming
+ // that the other end did in fact negotiate the
+ // timestamp option in the original SYN.
+ if s.parsedOptions.TS {
+ rcvdSynOptions.TS = true
+ rcvdSynOptions.TSVal = s.parsedOptions.TSVal
+ rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr
+ }
+ n, err := ctx.createConnectedEndpoint(s, s.ackNumber-1, s.sequenceNumber-1, rcvdSynOptions)
+ if err == nil {
+ // clear the tsOffset for the newly created
+ // endpoint as the Timestamp was already
+ // randomly offset when the original SYN-ACK was
+ // sent above.
+ n.tsOffset = 0
+ e.deliverAccepted(n)
+ }
+ }
+ }
+}
+
+// protocolListenLoop is the main loop of a listening TCP endpoint. It runs in
+// its own goroutine and is responsible for handling connection requests.
+func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error {
+ defer func() {
+ // Mark endpoint as closed. This will prevent goroutines running
+ // handleSynSegment() from attempting to queue new connections
+ // to the endpoint.
+ e.mu.Lock()
+ e.state = stateClosed
+ e.mu.Unlock()
+
+ // Notify waiters that the endpoint is shutdown.
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
+
+ // Do cleanup if needed.
+ e.completeWorker()
+ }()
+
+ e.mu.Lock()
+ v6only := e.v6only
+ e.mu.Unlock()
+
+ ctx := newListenContext(e.stack, rcvWnd, v6only, e.netProto)
+
+ s := sleep.Sleeper{}
+ s.AddWaker(&e.notificationWaker, wakerForNotification)
+ s.AddWaker(&e.newSegmentWaker, wakerForNewSegment)
+ for {
+ switch index, _ := s.Fetch(true); index {
+ case wakerForNotification:
+ n := e.fetchNotifications()
+ if n&notifyClose != 0 {
+ return nil
+ }
+ if n&notifyDrain != 0 {
+ for s := e.segmentQueue.dequeue(); s != nil; s = e.segmentQueue.dequeue() {
+ e.handleListenSegment(ctx, s)
+ s.decRef()
+ }
+ e.drainDone <- struct{}{}
+ return nil
+ }
+
+ case wakerForNewSegment:
+ // Process at most maxSegmentsPerWake segments.
+ mayRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ mayRequeue = false
+ break
+ }
+
+ e.handleListenSegment(ctx, s)
+ s.decRef()
+ }
+
+ // If the queue is not empty, make sure we'll wake up
+ // in the next iteration.
+ if mayRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
new file mode 100644
index 000000000..4d20f4d3f
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -0,0 +1,953 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcp
+
+import (
+ "crypto/rand"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// maxSegmentsPerWake is the maximum number of segments to process in the main
+// protocol goroutine per wake-up. Yielding [after this number of segments are
+// processed] allows other events to be processed as well (e.g., timeouts,
+// resets, etc.).
+const maxSegmentsPerWake = 100
+
+type handshakeState int
+
+// The following are the possible states of the TCP connection during a 3-way
+// handshake. A depiction of the states and transitions can be found in RFC 793,
+// page 23.
+const (
+ handshakeSynSent handshakeState = iota
+ handshakeSynRcvd
+ handshakeCompleted
+)
+
+// The following are used to set up sleepers.
+const (
+ wakerForNotification = iota
+ wakerForNewSegment
+ wakerForResend
+ wakerForResolution
+)
+
+const (
+ // Maximum space available for options.
+ maxOptionSize = 40
+)
+
+// handshake holds the state used during a TCP 3-way handshake.
+type handshake struct {
+ ep *endpoint
+ state handshakeState
+ active bool
+ flags uint8
+ ackNum seqnum.Value
+
+ // iss is the initial send sequence number, as defined in RFC 793.
+ iss seqnum.Value
+
+ // rcvWnd is the receive window, as defined in RFC 793.
+ rcvWnd seqnum.Size
+
+ // sndWnd is the send window, as defined in RFC 793.
+ sndWnd seqnum.Size
+
+ // mss is the maximum segment size received from the peer.
+ mss uint16
+
+ // sndWndScale is the send window scale, as defined in RFC 1323. A
+ // negative value means no scaling is supported by the peer.
+ sndWndScale int
+
+ // rcvWndScale is the receive window scale, as defined in RFC 1323.
+ rcvWndScale int
+}
+
+func newHandshake(ep *endpoint, rcvWnd seqnum.Size) (handshake, *tcpip.Error) {
+ h := handshake{
+ ep: ep,
+ active: true,
+ rcvWnd: rcvWnd,
+ rcvWndScale: FindWndScale(rcvWnd),
+ }
+ if err := h.resetState(); err != nil {
+ return handshake{}, err
+ }
+
+ return h, nil
+}
+
+// FindWndScale determines the window scale to use for the given maximum window
+// size.
+func FindWndScale(wnd seqnum.Size) int {
+ if wnd < 0x10000 {
+ return 0
+ }
+
+ max := seqnum.Size(0xffff)
+ s := 0
+ for wnd > max && s < header.MaxWndScale {
+ s++
+ max <<= 1
+ }
+
+ return s
+}
+
+// resetState resets the state of the handshake object such that it becomes
+// ready for a new 3-way handshake.
+func (h *handshake) resetState() *tcpip.Error {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+
+ h.state = handshakeSynSent
+ h.flags = flagSyn
+ h.ackNum = 0
+ h.mss = 0
+ h.iss = seqnum.Value(uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24)
+
+ return nil
+}
+
+// effectiveRcvWndScale returns the effective receive window scale to be used.
+// If the peer doesn't support window scaling, the effective rcv wnd scale is
+// zero; otherwise it's the value calculated based on the initial rcv wnd.
+func (h *handshake) effectiveRcvWndScale() uint8 {
+ if h.sndWndScale < 0 {
+ return 0
+ }
+ return uint8(h.rcvWndScale)
+}
+
+// resetToSynRcvd resets the state of the handshake object to the SYN-RCVD
+// state.
+func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *header.TCPSynOptions) {
+ h.active = false
+ h.state = handshakeSynRcvd
+ h.flags = flagSyn | flagAck
+ h.iss = iss
+ h.ackNum = irs + 1
+ h.mss = opts.MSS
+ h.sndWndScale = opts.WS
+}
+
+// checkAck checks if the ACK number, if present, of a segment received during
+// a TCP 3-way handshake is valid. If it's not, a RST segment is sent back in
+// response.
+func (h *handshake) checkAck(s *segment) bool {
+ if s.flagIsSet(flagAck) && s.ackNumber != h.iss+1 {
+ // RFC 793, page 36, states that a reset must be generated when
+ // the connection is in any non-synchronized state and an
+ // incoming segment acknowledges something not yet sent. The
+ // connection remains in the same state.
+ ack := s.sequenceNumber.Add(s.logicalLen())
+ h.ep.sendRaw(nil, flagRst|flagAck, s.ackNumber, ack, 0)
+ return false
+ }
+
+ return true
+}
+
+// synSentState handles a segment received when the TCP 3-way handshake is in
+// the SYN-SENT state.
+func (h *handshake) synSentState(s *segment) *tcpip.Error {
+ // RFC 793, page 37, states that in the SYN-SENT state, a reset is
+ // acceptable if the ack field acknowledges the SYN.
+ if s.flagIsSet(flagRst) {
+ if s.flagIsSet(flagAck) && s.ackNumber == h.iss+1 {
+ return tcpip.ErrConnectionRefused
+ }
+ return nil
+ }
+
+ if !h.checkAck(s) {
+ return nil
+ }
+
+ // We are in the SYN-SENT state. We only care about segments that have
+ // the SYN flag.
+ if !s.flagIsSet(flagSyn) {
+ return nil
+ }
+
+ // Parse the SYN options.
+ rcvSynOpts := parseSynSegmentOptions(s)
+
+ // Remember if the Timestamp option was negotiated.
+ h.ep.maybeEnableTimestamp(&rcvSynOpts)
+
+ // Remember if the SACKPermitted option was negotiated.
+ h.ep.maybeEnableSACKPermitted(&rcvSynOpts)
+
+ // Remember the sequence we'll ack from now on.
+ h.ackNum = s.sequenceNumber + 1
+ h.flags |= flagAck
+ h.mss = rcvSynOpts.MSS
+ h.sndWndScale = rcvSynOpts.WS
+
+ // If this is a SYN ACK response, we only need to acknowledge the SYN
+ // and the handshake is completed.
+ if s.flagIsSet(flagAck) {
+ h.state = handshakeCompleted
+ h.ep.sendRaw(nil, flagAck, h.iss+1, h.ackNum, h.rcvWnd>>h.effectiveRcvWndScale())
+ return nil
+ }
+
+ // A SYN segment was received, but no ACK in it. We acknowledge the SYN
+ // but resend our own SYN and wait for it to be acknowledged in the
+ // SYN-RCVD state.
+ h.state = handshakeSynRcvd
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: rcvSynOpts.TS,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+
+ // We only send SACKPermitted if the other side indicated it
+ // permits SACK. This is not explicitly defined in the RFC but
+ // this is the behaviour implemented by Linux.
+ SACKPermitted: rcvSynOpts.SACKPermitted,
+ }
+ sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
+ return nil
+}
+
+// synRcvdState handles a segment received when the TCP 3-way handshake is in
+// the SYN-RCVD state.
+func (h *handshake) synRcvdState(s *segment) *tcpip.Error {
+ if s.flagIsSet(flagRst) {
+ // RFC 793, page 37, states that in the SYN-RCVD state, a reset
+ // is acceptable if the sequence number is in the window.
+ if s.sequenceNumber.InWindow(h.ackNum, h.rcvWnd) {
+ return tcpip.ErrConnectionRefused
+ }
+ return nil
+ }
+
+ if !h.checkAck(s) {
+ return nil
+ }
+
+ if s.flagIsSet(flagSyn) && s.sequenceNumber != h.ackNum-1 {
+ // We received two SYN segments with different sequence
+ // numbers, so we reset this and restart the whole
+ // process, except that we don't reset the timer.
+ ack := s.sequenceNumber.Add(s.logicalLen())
+ seq := seqnum.Value(0)
+ if s.flagIsSet(flagAck) {
+ seq = s.ackNumber
+ }
+ h.ep.sendRaw(nil, flagRst|flagAck, seq, ack, 0)
+
+ if !h.active {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ if err := h.resetState(); err != nil {
+ return err
+ }
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: h.ep.sendTSOk,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+ SACKPermitted: h.ep.sackPermitted,
+ }
+ sendSynTCP(&s.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ return nil
+ }
+
+ // We have previously received (and acknowledged) the peer's SYN. If the
+ // peer acknowledges our SYN, the handshake is completed.
+ if s.flagIsSet(flagAck) {
+
+ // If the timestamp option is negotiated and the segment does
+ // not carry a timestamp option then the segment must be dropped
+ // as per https://tools.ietf.org/html/rfc7323#section-3.2.
+ if h.ep.sendTSOk && !s.parsedOptions.TS {
+ atomic.AddUint64(&h.ep.stack.MutableStats().DroppedPackets, 1)
+ return nil
+ }
+
+ // Update timestamp if required. See RFC7323, section-4.3.
+ h.ep.updateRecentTimestamp(s.parsedOptions.TSVal, h.ackNum, s.sequenceNumber)
+
+ h.state = handshakeCompleted
+ return nil
+ }
+
+ return nil
+}
+
+// processSegments goes through the segment queue and processes up to
+// maxSegmentsPerWake (if they're available).
+func (h *handshake) processSegments() *tcpip.Error {
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := h.ep.segmentQueue.dequeue()
+ if s == nil {
+ return nil
+ }
+
+ h.sndWnd = s.window
+ if !s.flagIsSet(flagSyn) && h.sndWndScale > 0 {
+ h.sndWnd <<= uint8(h.sndWndScale)
+ }
+
+ var err *tcpip.Error
+ switch h.state {
+ case handshakeSynRcvd:
+ err = h.synRcvdState(s)
+ case handshakeSynSent:
+ err = h.synSentState(s)
+ }
+ s.decRef()
+ if err != nil {
+ return err
+ }
+
+ // We stop processing packets once the handshake is completed,
+ // otherwise we may process packets meant to be processed by
+ // the main protocol goroutine.
+ if h.state == handshakeCompleted {
+ break
+ }
+ }
+
+ // If the queue is not empty, make sure we'll wake up in the next
+ // iteration.
+ if !h.ep.segmentQueue.empty() {
+ h.ep.newSegmentWaker.Assert()
+ }
+
+ return nil
+}
+
+func (h *handshake) resolveRoute() *tcpip.Error {
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ resolutionWaker := &sleep.Waker{}
+ s.AddWaker(resolutionWaker, wakerForResolution)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ defer s.Done()
+
+ // Initial action is to resolve route.
+ index := wakerForResolution
+ for {
+ switch index {
+ case wakerForResolution:
+ if err := h.ep.route.Resolve(resolutionWaker); err != tcpip.ErrWouldBlock {
+ // Either success (err == nil) or failure.
+ return err
+ }
+ // Resolution not completed. Keep trying...
+
+ case wakerForNotification:
+ n := h.ep.fetchNotifications()
+ if n&notifyClose != 0 {
+ h.ep.route.RemoveWaker(resolutionWaker)
+ return tcpip.ErrAborted
+ }
+ }
+
+ // Wait for notification.
+ index, _ = s.Fetch(true)
+ }
+}
+
+// execute executes the TCP 3-way handshake.
+func (h *handshake) execute() *tcpip.Error {
+ if h.ep.route.IsResolutionRequired() {
+ if err := h.resolveRoute(); err != nil {
+ return err
+ }
+ }
+
+ // Initialize the resend timer.
+ resendWaker := sleep.Waker{}
+ timeOut := time.Duration(time.Second)
+ rt := time.AfterFunc(timeOut, func() {
+ resendWaker.Assert()
+ })
+ defer rt.Stop()
+
+ // Set up the wakers.
+ s := sleep.Sleeper{}
+ s.AddWaker(&resendWaker, wakerForResend)
+ s.AddWaker(&h.ep.notificationWaker, wakerForNotification)
+ s.AddWaker(&h.ep.newSegmentWaker, wakerForNewSegment)
+ defer s.Done()
+
+ var sackEnabled SACKEnabled
+ if err := h.ep.stack.TransportProtocolOption(ProtocolNumber, &sackEnabled); err != nil {
+ // If stack returned an error when checking for SACKEnabled
+ // status then just default to switching off SACK negotiation.
+ sackEnabled = false
+ }
+
+ // Send the initial SYN segment and loop until the handshake is
+ // completed.
+ synOpts := header.TCPSynOptions{
+ WS: h.rcvWndScale,
+ TS: true,
+ TSVal: h.ep.timestamp(),
+ TSEcr: h.ep.recentTS,
+ SACKPermitted: bool(sackEnabled),
+ }
+
+ // Execute is also called in a listen context so we want to make sure we
+ // only send the TS/SACK option when we received the TS/SACK in the
+ // initial SYN.
+ if h.state == handshakeSynRcvd {
+ synOpts.TS = h.ep.sendTSOk
+ synOpts.SACKPermitted = h.ep.sackPermitted && bool(sackEnabled)
+ }
+ sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+ for h.state != handshakeCompleted {
+ switch index, _ := s.Fetch(true); index {
+ case wakerForResend:
+ timeOut *= 2
+ if timeOut > 60*time.Second {
+ return tcpip.ErrTimeout
+ }
+ rt.Reset(timeOut)
+ sendSynTCP(&h.ep.route, h.ep.id, h.flags, h.iss, h.ackNum, h.rcvWnd, synOpts)
+
+ case wakerForNotification:
+ n := h.ep.fetchNotifications()
+ if n&notifyClose != 0 {
+ return tcpip.ErrAborted
+ }
+
+ case wakerForNewSegment:
+ if err := h.processSegments(); err != nil {
+ return err
+ }
+ }
+ }
+
+ return nil
+}
+
+func parseSynSegmentOptions(s *segment) header.TCPSynOptions {
+ synOpts := header.ParseSynOptions(s.options, s.flagIsSet(flagAck))
+ if synOpts.TS {
+ s.parsedOptions.TSVal = synOpts.TSVal
+ s.parsedOptions.TSEcr = synOpts.TSEcr
+ }
+ return synOpts
+}
+
+var optionPool = sync.Pool{
+ New: func() interface{} {
+ return make([]byte, maxOptionSize)
+ },
+}
+
+func getOptions() []byte {
+ return optionPool.Get().([]byte)
+}
+
+func putOptions(options []byte) {
+ // Reslice to full capacity.
+ optionPool.Put(options[0:cap(options)])
+}
+
+func makeSynOptions(opts header.TCPSynOptions) []byte {
+ // Emulate linux option order. This is as follows:
+ //
+ // if md5: NOP NOP MD5SIG 18 md5sig(16)
+ // if mss: MSS 4 mss(2)
+ // if ts and sack_advertise:
+ // SACK 2 TIMESTAMP 2 timestamp(8)
+ // elif ts: NOP NOP TIMESTAMP 10 timestamp(8)
+ // elif sack: NOP NOP SACK 2
+ // if wscale: NOP WINDOW 3 ws(1)
+ // if sack_blocks: NOP NOP SACK ((2 + (#blocks * 8))
+ // [for each block] start_seq(4) end_seq(4)
+ // if fastopen_cookie:
+ // if exp: EXP (4 + len(cookie)) FASTOPEN_MAGIC(2)
+ // else: FASTOPEN (2 + len(cookie))
+ // cookie(variable) [padding to four bytes]
+ //
+ options := getOptions()
+
+ // Always encode the mss.
+ offset := header.EncodeMSSOption(uint32(opts.MSS), options)
+
+ // Special ordering is required here. If both TS and SACK are enabled,
+ // then the SACK option precedes TS, with no padding. If they are
+ // enabled individually, then we see padding before the option.
+ if opts.TS && opts.SACKPermitted {
+ offset += header.EncodeSACKPermittedOption(options[offset:])
+ offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
+ } else if opts.TS {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeTSOption(opts.TSVal, opts.TSEcr, options[offset:])
+ } else if opts.SACKPermitted {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKPermittedOption(options[offset:])
+ }
+
+ // Initialize the WS option.
+ if opts.WS >= 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeWSOption(opts.WS, options[offset:])
+ }
+
+ // Padding to the end; note that this never apply unless we add a
+ // fastopen option, we always expect the offset to remain the same.
+ if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
+ panic("unexpected option encoding")
+ }
+
+ return options[:offset]
+}
+
+func sendSynTCP(r *stack.Route, id stack.TransportEndpointID, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts header.TCPSynOptions) *tcpip.Error {
+ // The MSS in opts is automatically calculated as this function is
+ // called from many places and we don't want every call point being
+ // embedded with the MSS calculation.
+ if opts.MSS == 0 {
+ opts.MSS = uint16(r.MTU() - header.TCPMinimumSize)
+ }
+
+ options := makeSynOptions(opts)
+ err := sendTCPWithOptions(r, id, nil, flags, seq, ack, rcvWnd, options)
+ putOptions(options)
+ return err
+}
+
+// sendTCPWithOptions sends a TCP segment with the provided options via the
+// provided network endpoint and under the provided identity.
+func sendTCPWithOptions(r *stack.Route, id stack.TransportEndpointID, data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size, opts []byte) *tcpip.Error {
+ optLen := len(opts)
+ // Allocate a buffer for the TCP header.
+ hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()) + optLen)
+
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ // Initialize the header.
+ tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize + optLen))
+ tcp.Encode(&header.TCPFields{
+ SrcPort: id.LocalPort,
+ DstPort: id.RemotePort,
+ SeqNum: uint32(seq),
+ AckNum: uint32(ack),
+ DataOffset: uint8(header.TCPMinimumSize + optLen),
+ Flags: flags,
+ WindowSize: uint16(rcvWnd),
+ })
+ copy(tcp[header.TCPMinimumSize:], opts)
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
+ length := uint16(hdr.UsedLength())
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber)
+ if data != nil {
+ length += uint16(len(data))
+ xsum = header.Checksum(data, xsum)
+ }
+
+ tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
+ }
+
+ return r.WritePacket(&hdr, data, ProtocolNumber)
+}
+
+// sendTCP sends a TCP segment via the provided network endpoint and under the
+// provided identity.
+func sendTCP(r *stack.Route, id stack.TransportEndpointID, data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
+ // Allocate a buffer for the TCP header.
+ hdr := buffer.NewPrependable(header.TCPMinimumSize + int(r.MaxHeaderLength()))
+
+ if rcvWnd > 0xffff {
+ rcvWnd = 0xffff
+ }
+
+ // Initialize the header.
+ tcp := header.TCP(hdr.Prepend(header.TCPMinimumSize))
+ tcp.Encode(&header.TCPFields{
+ SrcPort: id.LocalPort,
+ DstPort: id.RemotePort,
+ SeqNum: uint32(seq),
+ AckNum: uint32(ack),
+ DataOffset: header.TCPMinimumSize,
+ Flags: flags,
+ WindowSize: uint16(rcvWnd),
+ })
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
+ length := uint16(hdr.UsedLength())
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber)
+ if data != nil {
+ length += uint16(len(data))
+ xsum = header.Checksum(data, xsum)
+ }
+
+ tcp.SetChecksum(^tcp.CalculateChecksum(xsum, length))
+ }
+
+ return r.WritePacket(&hdr, data, ProtocolNumber)
+}
+
+// makeOptions makes an options slice.
+func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte {
+ options := getOptions()
+ offset := 0
+
+ // N.B. the ordering here matches the ordering used by Linux internally
+ // and described in the raw makeOptions function. We don't include
+ // unnecessary cases here (post connection.)
+ if e.sendTSOk {
+ // Embed the timestamp if timestamp has been enabled.
+ //
+ // We only use the lower 32 bits of the unix time in
+ // milliseconds. This is similar to what Linux does where it
+ // uses the lower 32 bits of the jiffies value in the tsVal
+ // field of the timestamp option.
+ //
+ // Further, RFC7323 section-5.4 recommends millisecond
+ // resolution as the lowest recommended resolution for the
+ // timestamp clock.
+ //
+ // Ref: https://tools.ietf.org/html/rfc7323#section-5.4.
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeTSOption(e.timestamp(), uint32(e.recentTS), options[offset:])
+ }
+ if e.sackPermitted && len(sackBlocks) > 0 {
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeNOP(options[offset:])
+ offset += header.EncodeSACKBlocks(sackBlocks, options[offset:])
+ }
+
+ // We expect the above to produce an aligned offset.
+ if delta := header.AddTCPOptionPadding(options, offset); delta != 0 {
+ panic("unexpected option encoding")
+ }
+
+ return options[:offset]
+}
+
+// sendRaw sends a TCP segment to the endpoint's peer.
+func (e *endpoint) sendRaw(data buffer.View, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error {
+ var sackBlocks []header.SACKBlock
+ if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&flagAck != 0) {
+ sackBlocks = e.sack.Blocks[:e.sack.NumBlocks]
+ }
+ options := e.makeOptions(sackBlocks)
+ if len(options) > 0 {
+ err := sendTCPWithOptions(&e.route, e.id, data, flags, seq, ack, rcvWnd, options)
+ putOptions(options)
+ return err
+ }
+ err := sendTCP(&e.route, e.id, data, flags, seq, ack, rcvWnd)
+ putOptions(options)
+ return err
+}
+
+func (e *endpoint) handleWrite() bool {
+ // Move packets from send queue to send list. The queue is accessible
+ // from other goroutines and protected by the send mutex, while the send
+ // list is only accessible from the handler goroutine, so it needs no
+ // mutexes.
+ e.sndBufMu.Lock()
+
+ first := e.sndQueue.Front()
+ if first != nil {
+ e.snd.writeList.PushBackList(&e.sndQueue)
+ e.snd.sndNxtList.UpdateForward(e.sndBufInQueue)
+ e.sndBufInQueue = 0
+ }
+
+ e.sndBufMu.Unlock()
+
+ // Initialize the next segment to write if it's currently nil.
+ if e.snd.writeNext == nil {
+ e.snd.writeNext = first
+ }
+
+ // Push out any new packets.
+ e.snd.sendData()
+
+ return true
+}
+
+func (e *endpoint) handleClose() bool {
+ // Drain the send queue.
+ e.handleWrite()
+
+ // Mark send side as closed.
+ e.snd.closed = true
+
+ return true
+}
+
+// resetConnection sends a RST segment and puts the endpoint in an error state
+// with the given error code.
+// This method must only be called from the protocol goroutine.
+func (e *endpoint) resetConnection(err *tcpip.Error) {
+ e.sendRaw(nil, flagAck|flagRst, e.snd.sndUna, e.rcv.rcvNxt, 0)
+
+ e.mu.Lock()
+ e.state = stateError
+ e.hardError = err
+ e.mu.Unlock()
+}
+
+// completeWorker is called by the worker goroutine when it's about to exit. It
+// marks the worker as completed and performs cleanup work if requested by
+// Close().
+func (e *endpoint) completeWorker() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ e.workerRunning = false
+ if e.workerCleanup {
+ e.cleanup()
+ }
+}
+
+// handleSegments pulls segments from the queue and processes them. It returns
+// true if the protocol loop should continue, false otherwise.
+func (e *endpoint) handleSegments() bool {
+ checkRequeue := true
+ for i := 0; i < maxSegmentsPerWake; i++ {
+ s := e.segmentQueue.dequeue()
+ if s == nil {
+ checkRequeue = false
+ break
+ }
+
+ // Invoke the tcp probe if installed.
+ if e.probe != nil {
+ e.probe(e.completeState())
+ }
+
+ if s.flagIsSet(flagRst) {
+ if e.rcv.acceptable(s.sequenceNumber, 0) {
+ // RFC 793, page 37 states that "in all states
+ // except SYN-SENT, all reset (RST) segments are
+ // validated by checking their SEQ-fields." So
+ // we only process it if it's acceptable.
+ s.decRef()
+ e.mu.Lock()
+ e.state = stateError
+ e.hardError = tcpip.ErrConnectionReset
+ e.mu.Unlock()
+ return false
+ }
+ } else if s.flagIsSet(flagAck) {
+ // Patch the window size in the segment according to the
+ // send window scale.
+ s.window <<= e.snd.sndWndScale
+
+ // If the timestamp option is negotiated and the segment
+ // does not carry a timestamp option then the segment
+ // must be dropped as per
+ // https://tools.ietf.org/html/rfc7323#section-3.2.
+ if e.sendTSOk && !s.parsedOptions.TS {
+ atomic.AddUint64(&e.stack.MutableStats().DroppedPackets, 1)
+ s.decRef()
+ continue
+ }
+
+ // RFC 793, page 41 states that "once in the ESTABLISHED
+ // state all segments must carry current acknowledgment
+ // information."
+ e.rcv.handleRcvdSegment(s)
+ e.snd.handleRcvdSegment(s)
+ }
+ s.decRef()
+ }
+
+ // If the queue is not empty, make sure we'll wake up in the next
+ // iteration.
+ if checkRequeue && !e.segmentQueue.empty() {
+ e.newSegmentWaker.Assert()
+ }
+
+ // Send an ACK for all processed packets if needed.
+ if e.rcv.rcvNxt != e.snd.maxSentAck {
+ e.snd.sendAck()
+ }
+
+ return true
+}
+
+// protocolMainLoop is the main loop of the TCP protocol. It runs in its own
+// goroutine and is responsible for sending segments and handling received
+// segments.
+func (e *endpoint) protocolMainLoop(passive bool) *tcpip.Error {
+ var closeTimer *time.Timer
+ var closeWaker sleep.Waker
+
+ defer func() {
+ e.waiterQueue.Notify(waiter.EventIn | waiter.EventOut)
+ e.completeWorker()
+
+ if e.snd != nil {
+ e.snd.resendTimer.cleanup()
+ }
+
+ if closeTimer != nil {
+ closeTimer.Stop()
+ }
+ }()
+
+ if !passive {
+ // This is an active connection, so we must initiate the 3-way
+ // handshake, and then inform potential waiters about its
+ // completion.
+ h, err := newHandshake(e, seqnum.Size(e.receiveBufferAvailable()))
+ if err == nil {
+ err = h.execute()
+ }
+ if err != nil {
+ e.lastErrorMu.Lock()
+ e.lastError = err
+ e.lastErrorMu.Unlock()
+
+ e.mu.Lock()
+ e.state = stateError
+ e.hardError = err
+ e.mu.Unlock()
+
+ return err
+ }
+
+ // Transfer handshake state to TCP connection. We disable
+ // receive window scaling if the peer doesn't support it
+ // (indicated by a negative send window scale).
+ e.snd = newSender(e, h.iss, h.ackNum-1, h.sndWnd, h.mss, h.sndWndScale)
+
+ e.rcvListMu.Lock()
+ e.rcv = newReceiver(e, h.ackNum-1, h.rcvWnd, h.effectiveRcvWndScale())
+ e.rcvListMu.Unlock()
+ }
+
+ // Tell waiters that the endpoint is connected and writable.
+ e.mu.Lock()
+ e.state = stateConnected
+ e.mu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventOut)
+
+ // When the protocol loop exits we should wake up our waiters with EventHUp.
+ defer e.waiterQueue.Notify(waiter.EventHUp)
+
+ // Set up the functions that will be called when the main protocol loop
+ // wakes up.
+ funcs := []struct {
+ w *sleep.Waker
+ f func() bool
+ }{
+ {
+ w: &e.sndWaker,
+ f: e.handleWrite,
+ },
+ {
+ w: &e.sndCloseWaker,
+ f: e.handleClose,
+ },
+ {
+ w: &e.newSegmentWaker,
+ f: e.handleSegments,
+ },
+ {
+ w: &closeWaker,
+ f: func() bool {
+ e.resetConnection(tcpip.ErrConnectionAborted)
+ return false
+ },
+ },
+ {
+ w: &e.snd.resendWaker,
+ f: func() bool {
+ if !e.snd.retransmitTimerExpired() {
+ e.resetConnection(tcpip.ErrTimeout)
+ return false
+ }
+ return true
+ },
+ },
+ {
+ w: &e.notificationWaker,
+ f: func() bool {
+ n := e.fetchNotifications()
+ if n&notifyNonZeroReceiveWindow != 0 {
+ e.rcv.nonZeroWindow()
+ }
+
+ if n&notifyReceiveWindowChanged != 0 {
+ e.rcv.pendingBufSize = seqnum.Size(e.receiveBufferSize())
+ }
+
+ if n&notifyMTUChanged != 0 {
+ e.sndBufMu.Lock()
+ count := e.packetTooBigCount
+ e.packetTooBigCount = 0
+ mtu := e.sndMTU
+ e.sndBufMu.Unlock()
+
+ e.snd.updateMaxPayloadSize(mtu, count)
+ }
+
+ if n&notifyClose != 0 && closeTimer == nil {
+ // Reset the connection 3 seconds after the
+ // endpoint has been closed.
+ closeTimer = time.AfterFunc(3*time.Second, func() {
+ closeWaker.Assert()
+ })
+ }
+ return true
+ },
+ },
+ }
+
+ // Initialize the sleeper based on the wakers in funcs.
+ s := sleep.Sleeper{}
+ for i := range funcs {
+ s.AddWaker(funcs[i].w, i)
+ }
+
+ // Main loop. Handle segments until both send and receive ends of the
+ // connection have completed.
+ for !e.rcv.closed || !e.snd.closed || e.snd.sndUna != e.snd.sndNxtList {
+ e.workMu.Unlock()
+ v, _ := s.Fetch(true)
+ e.workMu.Lock()
+ if !funcs[v].f() {
+ return nil
+ }
+ }
+
+ // Mark endpoint as closed.
+ e.mu.Lock()
+ e.state = stateClosed
+ e.mu.Unlock()
+
+ return nil
+}
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
new file mode 100644
index 000000000..a89af4559
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -0,0 +1,550 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcp_test
+
+import (
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/checker"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+func TestV4MappedConnectOnV6Only(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ // Start connection attempt, it must fail.
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
+ if err != tcpip.ErrNoRoute {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+}
+
+func testV4Connect(t *testing.T, c *context.Context) {
+ // Start connection attempt.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventOut)
+ defer c.WQ.EventUnregister(&we)
+
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV4MappedAddr, Port: context.TestPort})
+ if err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+
+ tcp := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcp.SequenceNumber())
+
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcp.DestinationPort(),
+ DstPort: tcp.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Receive ACK packet.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ),
+ )
+
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ err = c.EP.GetSockOpt(tcpip.ErrorOption{})
+ if err != nil {
+ t.Fatalf("Unexpected error when connecting: %v", err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for connection")
+ }
+}
+
+func TestV4MappedConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Test the connection request.
+ testV4Connect(t, c)
+}
+
+func TestV4ConnectWhenBoundToWildcard(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test the connection request.
+ testV4Connect(t, c)
+}
+
+func TestV4ConnectWhenBoundToV4MappedWildcard(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to v4 mapped wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test the connection request.
+ testV4Connect(t, c)
+}
+
+func TestV4ConnectWhenBoundToV4Mapped(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to v4 mapped address.
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test the connection request.
+ testV4Connect(t, c)
+}
+
+func testV6Connect(t *testing.T, c *context.Context) {
+ // Start connection attempt to IPv6 address.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventOut)
+ defer c.WQ.EventUnregister(&we)
+
+ err := c.EP.Connect(tcpip.FullAddress{Addr: context.TestV6Addr, Port: context.TestPort})
+ if err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetV6Packet()
+ checker.IPv6(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+
+ tcp := header.TCP(header.IPv6(b).Payload())
+ c.IRS = seqnum.Value(tcp.SequenceNumber())
+
+ iss := seqnum.Value(789)
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: tcp.DestinationPort(),
+ DstPort: tcp.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Receive ACK packet.
+ checker.IPv6(t, c.GetV6Packet(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ),
+ )
+
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ err = c.EP.GetSockOpt(tcpip.ErrorOption{})
+ if err != nil {
+ t.Fatalf("Unexpected error when connecting: %v", err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for connection")
+ }
+}
+
+func TestV6Connect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Test the connection request.
+ testV6Connect(t, c)
+}
+
+func TestV6ConnectV6Only(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ // Test the connection request.
+ testV6Connect(t, c)
+}
+
+func TestV6ConnectWhenBoundToWildcard(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test the connection request.
+ testV6Connect(t, c)
+}
+
+func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to local address.
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV6Addr}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test the connection request.
+ testV6Connect(t, c)
+}
+
+func TestV4RefuseOnV6Only(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(true)
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Start listening.
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Send a SYN request.
+ irs := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the RST reply.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.AckNum(uint32(irs)+1),
+ ),
+ )
+}
+
+func TestV6RefuseOnBoundToV4Mapped(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind and listen.
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Send a SYN request.
+ irs := seqnum.Value(789)
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the RST reply.
+ checker.IPv6(t, c.GetV6Packet(),
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagRst|header.TCPFlagAck),
+ checker.AckNum(uint32(irs)+1),
+ ),
+ )
+}
+
+func testV4Accept(t *testing.T, c *context.Context) {
+ // Start listening.
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Send a SYN request.
+ irs := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ iss := seqnum.Value(tcp.SequenceNumber())
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1),
+ ),
+ )
+
+ // Send ACK.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ nep, _, err := c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ nep, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Make sure we get the same error when calling the original ep and the
+ // new one. This validates that v4-mapped endpoints are still able to
+ // query the V6Only flag, whereas pure v4 endpoints are not.
+ var v tcpip.V6OnlyOption
+ expected := c.EP.GetSockOpt(&v)
+ if err := nep.GetSockOpt(&v); err != expected {
+ t.Fatalf("GetSockOpt returned unexpected value: got %v, want %v", err, expected)
+ }
+
+ // Check the peer address.
+ addr, err := nep.GetRemoteAddress()
+ if err != nil {
+ t.Fatalf("GetRemoteAddress failed failed: %v", err)
+ }
+
+ if addr.Addr != context.TestAddr {
+ t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestAddr)
+ }
+}
+
+func TestV4AcceptOnV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Accept(t, c)
+}
+
+func TestV4AcceptOnBoundToV4MappedWildcard(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind to v4 mapped wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.V4MappedWildcardAddr, Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Accept(t, c)
+}
+
+func TestV4AcceptOnBoundToV4Mapped(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind and listen.
+ if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackV4MappedAddr, Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Accept(t, c)
+}
+
+func TestV6AcceptOnV6(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateV6Endpoint(false)
+
+ // Bind and listen.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := c.EP.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Send a SYN request.
+ irs := seqnum.Value(789)
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: irs,
+ RcvWnd: 30000,
+ })
+
+ // Receive the SYN-ACK reply.
+ b := c.GetV6Packet()
+ tcp := header.TCP(header.IPv6(b).Payload())
+ iss := seqnum.Value(tcp.SequenceNumber())
+ checker.IPv6(t, b,
+ checker.TCP(
+ checker.SrcPort(context.StackPort),
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn),
+ checker.AckNum(uint32(irs)+1),
+ ),
+ )
+
+ // Send ACK.
+ c.SendV6Packet(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: context.StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: irs + 1,
+ AckNum: iss + 1,
+ RcvWnd: 30000,
+ })
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ nep, _, err := c.EP.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ nep, _, err = c.EP.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Make sure we can still query the v6 only status of the new endpoint,
+ // that is, that it is in fact a v6 socket.
+ var v tcpip.V6OnlyOption
+ if err := nep.GetSockOpt(&v); err != nil {
+ t.Fatalf("GetSockOpt failed failed: %v", err)
+ }
+
+ // Check the peer address.
+ addr, err := nep.GetRemoteAddress()
+ if err != nil {
+ t.Fatalf("GetRemoteAddress failed failed: %v", err)
+ }
+
+ if addr.Addr != context.TestV6Addr {
+ t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, context.TestV6Addr)
+ }
+}
+
+func TestV4AcceptOnV4(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Accept(t, c)
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
new file mode 100644
index 000000000..5d62589d8
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -0,0 +1,1371 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcp
+
+import (
+ "crypto/rand"
+ "math"
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tmutex"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateListen
+ stateConnecting
+ stateConnected
+ stateClosed
+ stateError
+)
+
+// Reasons for notifying the protocol goroutine.
+const (
+ notifyNonZeroReceiveWindow = 1 << iota
+ notifyReceiveWindowChanged
+ notifyClose
+ notifyMTUChanged
+ notifyDrain
+)
+
+// SACKInfo holds TCP SACK related information for a given endpoint.
+type SACKInfo struct {
+ // Blocks is the maximum number of SACK blocks we track
+ // per endpoint.
+ Blocks [MaxSACKBlocks]header.SACKBlock
+
+ // NumBlocks is the number of valid SACK blocks stored in the
+ // blocks array above.
+ NumBlocks int
+}
+
+// endpoint represents a TCP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized. The protocol implementation, however, runs in a single
+// goroutine.
+type endpoint struct {
+ // workMu is used to arbitrate which goroutine may perform protocol
+ // work. Only the main protocol goroutine is expected to call Lock() on
+ // it, but other goroutines (e.g., send) may call TryLock() to eagerly
+ // perform work without having to wait for the main one to wake up.
+ workMu tmutex.Mutex `state:"nosave"`
+
+ // The following fields are initialized at creation time and do not
+ // change throughout the lifetime of the endpoint.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // lastError represents the last error that the endpoint reported;
+ // access to it is protected by the following mutex.
+ lastErrorMu sync.Mutex `state:"nosave"`
+ lastError *tcpip.Error
+
+ // The following fields are used to manage the receive queue. The
+ // protocol goroutine adds ready-for-delivery segments to rcvList,
+ // which are returned by Read() calls to users.
+ //
+ // Once the peer has closed its send side, rcvClosed is set to true
+ // to indicate to users that no more data is coming.
+ rcvListMu sync.Mutex `state:"nosave"`
+ rcvList segmentList
+ rcvClosed bool
+ rcvBufSize int
+ rcvBufUsed int
+
+ // The following fields are protected by the mutex.
+ mu sync.RWMutex `state:"nosave"`
+ id stack.TransportEndpointID
+ state endpointState
+ isPortReserved bool
+ isRegistered bool
+ boundNICID tcpip.NICID
+ route stack.Route `state:"manual"`
+ v6only bool
+ isConnectNotified bool
+
+ // effectiveNetProtos contains the network protocols actually in use. In
+ // most cases it will only contain "netProto", but in cases like IPv6
+ // endpoints with v6only set to false, this could include multiple
+ // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
+ // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
+ // address).
+ effectiveNetProtos []tcpip.NetworkProtocolNumber
+
+ // hardError is meaningful only when state is stateError, it stores the
+ // error to be returned when read/write syscalls are called and the
+ // endpoint is in this state.
+ hardError *tcpip.Error
+
+ // workerRunning specifies if a worker goroutine is running.
+ workerRunning bool
+
+ // workerCleanup specifies if the worker goroutine must perform cleanup
+ // before exitting. This can only be set to true when workerRunning is
+ // also true, and they're both protected by the mutex.
+ workerCleanup bool
+
+ // sendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1
+ sendTSOk bool
+
+ // recentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field is
+ // updated if required when a new segment is received by this endpoint.
+ recentTS uint32
+
+ // tsOffset is a randomized offset added to the value of the
+ // TSVal field in the timestamp option.
+ tsOffset uint32
+
+ // shutdownFlags represent the current shutdown state of the endpoint.
+ shutdownFlags tcpip.ShutdownFlags
+
+ // sackPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ sackPermitted bool
+
+ // sack holds TCP SACK related information for this endpoint.
+ sack SACKInfo
+
+ // The options below aren't implemented, but we remember the user
+ // settings because applications expect to be able to set/query these
+ // options.
+ noDelay bool
+ reuseAddr bool
+
+ // segmentQueue is used to hand received segments to the protocol
+ // goroutine. Segments are queued as long as the queue is not full,
+ // and dropped when it is.
+ segmentQueue segmentQueue `state:"zerovalue"`
+
+ // The following fields are used to manage the send buffer. When
+ // segments are ready to be sent, they are added to sndQueue and the
+ // protocol goroutine is signaled via sndWaker.
+ //
+ // When the send side is closed, the protocol goroutine is notified via
+ // sndCloseWaker, and sndClosed is set to true.
+ sndBufMu sync.Mutex `state:"nosave"`
+ sndBufSize int
+ sndBufUsed int
+ sndClosed bool
+ sndBufInQueue seqnum.Size
+ sndQueue segmentList
+ sndWaker sleep.Waker `state:"manual"`
+ sndCloseWaker sleep.Waker `state:"manual"`
+
+ // The following are used when a "packet too big" control packet is
+ // received. They are protected by sndBufMu. They are used to
+ // communicate to the main protocol goroutine how many such control
+ // messages have been received since the last notification was processed
+ // and what was the smallest MTU seen.
+ packetTooBigCount int
+ sndMTU int
+
+ // newSegmentWaker is used to indicate to the protocol goroutine that
+ // it needs to wake up and handle new segments queued to it.
+ newSegmentWaker sleep.Waker `state:"manual"`
+
+ // notificationWaker is used to indicate to the protocol goroutine that
+ // it needs to wake up and check for notifications.
+ notificationWaker sleep.Waker `state:"manual"`
+
+ // notifyFlags is a bitmask of flags used to indicate to the protocol
+ // goroutine what it was notified; this is only accessed atomically.
+ notifyFlags uint32 `state:"zerovalue"`
+
+ // acceptedChan is used by a listening endpoint protocol goroutine to
+ // send newly accepted connections to the endpoint so that they can be
+ // read by Accept() calls.
+ acceptedChan chan *endpoint `state:".(endpointChan)"`
+
+ // The following are only used from the protocol goroutine, and
+ // therefore don't need locks to protect them.
+ rcv *receiver
+ snd *sender
+
+ // The goroutine drain completion notification channel.
+ drainDone chan struct{} `state:"nosave"`
+
+ // probe if not nil is invoked on every received segment. It is passed
+ // a copy of the current state of the endpoint.
+ probe stack.TCPProbeFunc `state:"nosave"`
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+ e := &endpoint{
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ rcvBufSize: DefaultBufferSize,
+ sndBufSize: DefaultBufferSize,
+ sndMTU: int(math.MaxInt32),
+ noDelay: false,
+ reuseAddr: true,
+ }
+
+ var ss SendBufferSizeOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ e.sndBufSize = ss.Default
+ }
+
+ var rs ReceiveBufferSizeOption
+ if err := stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ e.rcvBufSize = rs.Default
+ }
+
+ if p := stack.GetTCPProbe(); p != nil {
+ e.probe = p
+ }
+
+ e.segmentQueue.setLimit(2 * e.rcvBufSize)
+ e.workMu.Init()
+ e.workMu.Lock()
+ e.tsOffset = timeStampOffset()
+ return e
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ result := waiter.EventMask(0)
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ switch e.state {
+ case stateInitial, stateBound, stateConnecting:
+ // Ready for nothing.
+
+ case stateClosed, stateError:
+ // Ready for anything.
+ result = mask
+
+ case stateListen:
+ // Check if there's anything in the accepted channel.
+ if (mask & waiter.EventIn) != 0 {
+ if len(e.acceptedChan) > 0 {
+ result |= waiter.EventIn
+ }
+ }
+
+ case stateConnected:
+ // Determine if the endpoint is writable if requested.
+ if (mask & waiter.EventOut) != 0 {
+ e.sndBufMu.Lock()
+ if e.sndClosed || e.sndBufUsed < e.sndBufSize {
+ result |= waiter.EventOut
+ }
+ e.sndBufMu.Unlock()
+ }
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvListMu.Lock()
+ if e.rcvBufUsed > 0 || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvListMu.Unlock()
+ }
+ }
+
+ return result
+}
+
+func (e *endpoint) fetchNotifications() uint32 {
+ return atomic.SwapUint32(&e.notifyFlags, 0)
+}
+
+func (e *endpoint) notifyProtocolGoroutine(n uint32) {
+ for {
+ v := atomic.LoadUint32(&e.notifyFlags)
+ if v&n == n {
+ // The flags are already set.
+ return
+ }
+
+ if atomic.CompareAndSwapUint32(&e.notifyFlags, v, v|n) {
+ if v == 0 {
+ // We are causing a transition from no flags to
+ // at least one flag set, so we must cause the
+ // protocol goroutine to wake up.
+ e.notificationWaker.Assert()
+ }
+ return
+ }
+ }
+}
+
+// Close puts the endpoint in a closed state and frees all resources associated
+// with it. It must be called only once and with no other concurrent calls to
+// the endpoint.
+func (e *endpoint) Close() {
+ // Issue a shutdown so that the peer knows we won't send any more data
+ // if we're connected, or stop accepting if we're listening.
+ e.Shutdown(tcpip.ShutdownWrite | tcpip.ShutdownRead)
+
+ // While we hold the lock, determine if the cleanup should happen
+ // inline or if we should tell the worker (if any) to do the cleanup.
+ e.mu.Lock()
+ worker := e.workerRunning
+ if worker {
+ e.workerCleanup = true
+ }
+
+ // We always release ports inline so that they are immediately available
+ // for reuse after Close() is called. If also registered, it means this
+ // is a listening socket, so we must unregister as well otherwise the
+ // next user would fail in Listen() when trying to register.
+ if e.isPortReserved {
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, e.id.LocalAddress, e.id.LocalPort)
+ e.isPortReserved = false
+
+ if e.isRegistered {
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ e.isRegistered = false
+ }
+ }
+
+ e.mu.Unlock()
+
+ // Now that we don't hold the lock anymore, either perform the local
+ // cleanup or kick the worker to make sure it knows it needs to cleanup.
+ if !worker {
+ e.cleanup()
+ } else {
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+}
+
+// cleanup frees all resources associated with the endpoint. It is called after
+// Close() is called and the worker goroutine (if any) is done with its work.
+func (e *endpoint) cleanup() {
+ // Close all endpoints that might have been accepted by TCP but not by
+ // the client.
+ if e.acceptedChan != nil {
+ close(e.acceptedChan)
+ for n := range e.acceptedChan {
+ n.resetConnection(tcpip.ErrConnectionAborted)
+ n.Close()
+ }
+ }
+
+ if e.isRegistered {
+ e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ }
+
+ e.route.Release()
+}
+
+// Read reads data from the endpoint.
+func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, *tcpip.Error) {
+ e.mu.RLock()
+ // The endpoint can be read if it's connected, or if it's already closed
+ // but has some pending unread data. Also note that a RST being received
+ // would cause the state to become stateError so we should allow the
+ // reads to proceed before returning a ECONNRESET.
+ if s := e.state; s != stateConnected && s != stateClosed && e.rcvBufUsed == 0 {
+ e.mu.RUnlock()
+ if s == stateError {
+ return buffer.View{}, e.hardError
+ }
+ return buffer.View{}, tcpip.ErrInvalidEndpointState
+ }
+
+ e.rcvListMu.Lock()
+ v, err := e.readLocked()
+ e.rcvListMu.Unlock()
+
+ e.mu.RUnlock()
+
+ return v, err
+}
+
+func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
+ if e.rcvBufUsed == 0 {
+ if e.rcvClosed || e.state != stateConnected {
+ return buffer.View{}, tcpip.ErrClosedForReceive
+ }
+ return buffer.View{}, tcpip.ErrWouldBlock
+ }
+
+ s := e.rcvList.Front()
+ views := s.data.Views()
+ v := views[s.viewToDeliver]
+ s.viewToDeliver++
+
+ if s.viewToDeliver >= len(views) {
+ e.rcvList.Remove(s)
+ s.decRef()
+ }
+
+ scale := e.rcv.rcvWndScale
+ wasZero := e.zeroReceiveWindow(scale)
+ e.rcvBufUsed -= len(v)
+ if wasZero && !e.zeroReceiveWindow(scale) {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
+
+ return v, nil
+}
+
+// Write writes data to the endpoint's peer.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) {
+ // Linux completely ignores any address passed to sendto(2) for TCP sockets
+ // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More
+ // and opts.EndOfRecord are also ignored.
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // The endpoint cannot be written to if it's not connected.
+ if e.state != stateConnected {
+ switch e.state {
+ case stateError:
+ return 0, e.hardError
+ default:
+ return 0, tcpip.ErrClosedForSend
+ }
+ }
+
+ // Nothing to do if the buffer is empty.
+ if p.Size() == 0 {
+ return 0, nil
+ }
+
+ e.sndBufMu.Lock()
+
+ // Check if the connection has already been closed for sends.
+ if e.sndClosed {
+ e.sndBufMu.Unlock()
+ return 0, tcpip.ErrClosedForSend
+ }
+
+ // Check against the limit.
+ avail := e.sndBufSize - e.sndBufUsed
+ if avail <= 0 {
+ e.sndBufMu.Unlock()
+ return 0, tcpip.ErrWouldBlock
+ }
+
+ v, perr := p.Get(avail)
+ if perr != nil {
+ e.sndBufMu.Unlock()
+ return 0, perr
+ }
+
+ var err *tcpip.Error
+ if p.Size() > avail {
+ err = tcpip.ErrWouldBlock
+ }
+ l := len(v)
+ s := newSegmentFromView(&e.route, e.id, v)
+
+ // Add data to the send queue.
+ e.sndBufUsed += l
+ e.sndBufInQueue += seqnum.Size(l)
+ e.sndQueue.PushBack(s)
+
+ e.sndBufMu.Unlock()
+
+ if e.workMu.TryLock() {
+ // Do the work inline.
+ e.handleWrite()
+ e.workMu.Unlock()
+ } else {
+ // Let the protocol goroutine do the work.
+ e.sndWaker.Assert()
+ }
+ return uintptr(l), err
+}
+
+// Peek reads data without consuming it from the endpoint.
+//
+// This method does not block if there is no data pending.
+func (e *endpoint) Peek(vec [][]byte) (uintptr, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // The endpoint can be read if it's connected, or if it's already closed
+ // but has some pending unread data.
+ if s := e.state; s != stateConnected && s != stateClosed {
+ if s == stateError {
+ return 0, e.hardError
+ }
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
+ if e.rcvBufUsed == 0 {
+ if e.rcvClosed || e.state != stateConnected {
+ return 0, tcpip.ErrClosedForReceive
+ }
+ return 0, tcpip.ErrWouldBlock
+ }
+
+ // Make a copy of vec so we can modify the slide headers.
+ vec = append([][]byte(nil), vec...)
+
+ var num uintptr
+
+ for s := e.rcvList.Front(); s != nil; s = s.Next() {
+ views := s.data.Views()
+
+ for i := s.viewToDeliver; i < len(views); i++ {
+ v := views[i]
+
+ for len(v) > 0 {
+ if len(vec) == 0 {
+ return num, nil
+ }
+ if len(vec[0]) == 0 {
+ vec = vec[1:]
+ continue
+ }
+
+ n := copy(vec[0], v)
+ v = v[n:]
+ vec[0] = vec[0][n:]
+ num += uintptr(n)
+ }
+ }
+ }
+
+ return num, nil
+}
+
+// zeroReceiveWindow checks if the receive window to be announced now would be
+// zero, based on the amount of available buffer and the receive window scaling.
+//
+// It must be called with rcvListMu held.
+func (e *endpoint) zeroReceiveWindow(scale uint8) bool {
+ if e.rcvBufUsed >= e.rcvBufSize {
+ return true
+ }
+
+ return ((e.rcvBufSize - e.rcvBufUsed) >> scale) == 0
+}
+
+// SetSockOpt sets a socket option.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.NoDelayOption:
+ e.mu.Lock()
+ e.noDelay = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReuseAddressOption:
+ e.mu.Lock()
+ e.reuseAddr = v != 0
+ e.mu.Unlock()
+ return nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ // Make sure the receive buffer size is within the min and max
+ // allowed.
+ var rs ReceiveBufferSizeOption
+ size := int(v)
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &rs); err == nil {
+ if size < rs.Min {
+ size = rs.Min
+ }
+ if size > rs.Max {
+ size = rs.Max
+ }
+ }
+
+ mask := uint32(notifyReceiveWindowChanged)
+
+ e.rcvListMu.Lock()
+
+ // Make sure the receive buffer size allows us to send a
+ // non-zero window size.
+ scale := uint8(0)
+ if e.rcv != nil {
+ scale = e.rcv.rcvWndScale
+ }
+ if size>>scale == 0 {
+ size = 1 << scale
+ }
+
+ // Make sure 2*size doesn't overflow.
+ if size > math.MaxInt32/2 {
+ size = math.MaxInt32 / 2
+ }
+
+ wasZero := e.zeroReceiveWindow(scale)
+ e.rcvBufSize = size
+ if wasZero && !e.zeroReceiveWindow(scale) {
+ mask |= notifyNonZeroReceiveWindow
+ }
+ e.rcvListMu.Unlock()
+
+ e.segmentQueue.setLimit(2 * size)
+
+ e.notifyProtocolGoroutine(mask)
+ return nil
+
+ case tcpip.SendBufferSizeOption:
+ // Make sure the send buffer size is within the min and max
+ // allowed.
+ size := int(v)
+ var ss SendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if size < ss.Min {
+ size = ss.Min
+ }
+ if size > ss.Max {
+ size = ss.Max
+ }
+ }
+
+ e.sndBufMu.Lock()
+ e.sndBufSize = size
+ e.sndBufMu.Unlock()
+
+ return nil
+
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // We only allow this to be set when we're in the initial state.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.v6only = v != 0
+ }
+
+ return nil
+}
+
+// readyReceiveSize returns the number of bytes ready to be received.
+func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // The endpoint cannot be in listen state.
+ if e.state == stateListen {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
+ return e.rcvBufUsed, nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ e.lastErrorMu.Lock()
+ err := e.lastError
+ e.lastError = nil
+ e.lastErrorMu.Unlock()
+ return err
+
+ case *tcpip.SendBufferSizeOption:
+ e.sndBufMu.Lock()
+ *o = tcpip.SendBufferSizeOption(e.sndBufSize)
+ e.sndBufMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.rcvListMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSize)
+ e.rcvListMu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ v, err := e.readyReceiveSize()
+ if err != nil {
+ return err
+ }
+
+ *o = tcpip.ReceiveQueueSizeOption(v)
+ return nil
+
+ case *tcpip.NoDelayOption:
+ e.mu.RLock()
+ v := e.noDelay
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.ReuseAddressOption:
+ e.mu.RLock()
+ v := e.reuseAddr
+ e.mu.RUnlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.TCPInfoOption:
+ *o = tcpip.TCPInfoOption{}
+ return nil
+ }
+
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.netProto
+ if header.IsV4MappedAddress(addr.Addr) {
+ // Fail if using a v4 mapped address on a v6only endpoint.
+ if e.v6only {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ netProto = header.IPv4ProtocolNumber
+ addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
+ if addr.Addr == "\x00\x00\x00\x00" {
+ addr.Addr = ""
+ }
+ }
+
+ // Fail if we're bound to an address length different from the one we're
+ // checking.
+ if l := len(e.id.LocalAddress); l != 0 && len(addr.Addr) != 0 && l != len(addr.Addr) {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return netProto, nil
+}
+
+// Connect connects the endpoint to its peer.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ netProto, err := e.checkV4Mapped(&addr)
+ if err != nil {
+ return err
+ }
+
+ nicid := addr.NIC
+ switch e.state {
+ case stateBound:
+ // If we're already bound to a NIC but the caller is requesting
+ // that we use a different one now, we cannot proceed.
+ if e.boundNICID == 0 {
+ break
+ }
+
+ if nicid != 0 && nicid != e.boundNICID {
+ return tcpip.ErrNoRoute
+ }
+
+ nicid = e.boundNICID
+
+ case stateInitial:
+ // Nothing to do. We'll eventually fill-in the gaps in the ID
+ // (if any) when we find a route.
+
+ case stateConnecting:
+ // A connection request has already been issued but hasn't
+ // completed yet.
+ return tcpip.ErrAlreadyConnecting
+
+ case stateConnected:
+ // The endpoint is already connected. If caller hasn't been notified yet, return success.
+ if !e.isConnectNotified {
+ e.isConnectNotified = true
+ return nil
+ }
+ // Otherwise return that it's already connected.
+ return tcpip.ErrAlreadyConnected
+
+ case stateError:
+ return e.hardError
+
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicid, e.id.LocalAddress, addr.Addr, netProto)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ origID := e.id
+
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ e.id.LocalAddress = r.LocalAddress
+ e.id.RemoteAddress = r.RemoteAddress
+ e.id.RemotePort = addr.Port
+
+ if e.id.LocalPort != 0 {
+ // The endpoint is bound to a port, attempt to register it.
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e)
+ if err != nil {
+ return err
+ }
+ } else {
+ // The endpoint doesn't have a local port yet, so try to get
+ // one. Make sure that it isn't one that will result in the same
+ // address/port for both local and remote (otherwise this
+ // endpoint would be trying to connect to itself).
+ sameAddr := e.id.LocalAddress == e.id.RemoteAddress
+ _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ if sameAddr && p == e.id.RemotePort {
+ return false, nil
+ }
+
+ e.id.LocalPort = p
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, e.id, e)
+ switch err {
+ case nil:
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ })
+ if err != nil {
+ return err
+ }
+ }
+
+ // Remove the port reservation. This can happen when Bind is called
+ // before Connect: in such a case we don't want to hold on to
+ // reservations anymore.
+ if e.isPortReserved {
+ e.stack.ReleasePort(e.effectiveNetProtos, ProtocolNumber, origID.LocalAddress, origID.LocalPort)
+ e.isPortReserved = false
+ }
+
+ e.isRegistered = true
+ e.state = stateConnecting
+ e.route = r.Clone()
+ e.boundNICID = nicid
+ e.effectiveNetProtos = netProtos
+ e.workerRunning = true
+
+ go e.protocolMainLoop(false) // S/R-FIXME
+
+ return tcpip.ErrConnectStarted
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection to its
+// peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+ e.shutdownFlags |= flags
+
+ switch e.state {
+ case stateConnected:
+ // Close for write.
+ if (flags & tcpip.ShutdownWrite) != 0 {
+ e.sndBufMu.Lock()
+
+ if e.sndClosed {
+ // Already closed.
+ e.sndBufMu.Unlock()
+ break
+ }
+
+ // Queue fin segment.
+ s := newSegmentFromView(&e.route, e.id, nil)
+ e.sndQueue.PushBack(s)
+ e.sndBufInQueue++
+
+ // Mark endpoint as closed.
+ e.sndClosed = true
+
+ e.sndBufMu.Unlock()
+
+ // Tell protocol goroutine to close.
+ e.sndCloseWaker.Assert()
+ }
+
+ case stateListen:
+ // Tell protocolListenLoop to stop.
+ if flags&tcpip.ShutdownRead != 0 {
+ e.notifyProtocolGoroutine(notifyClose)
+ }
+
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ return nil
+}
+
+// Listen puts the endpoint in "listen" mode, which allows it to accept
+// new connections.
+func (e *endpoint) Listen(backlog int) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Allow the backlog to be adjusted if the endpoint is not shutting down.
+ // When the endpoint shuts down, it sets workerCleanup to true, and from
+ // that point onward, acceptedChan is the responsibility of the cleanup()
+ // method (and should not be touched anywhere else, including here).
+ if e.state == stateListen && !e.workerCleanup {
+ // Adjust the size of the channel iff we can fix existing
+ // pending connections into the new one.
+ if len(e.acceptedChan) > backlog {
+ return tcpip.ErrInvalidEndpointState
+ }
+ origChan := e.acceptedChan
+ e.acceptedChan = make(chan *endpoint, backlog)
+ close(origChan)
+ for ep := range origChan {
+ e.acceptedChan <- ep
+ }
+ return nil
+ }
+
+ // Endpoint must be bound before it can transition to listen mode.
+ if e.state != stateBound {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Register the endpoint.
+ if err := e.stack.RegisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e); err != nil {
+ return err
+ }
+
+ e.isRegistered = true
+ e.state = stateListen
+ if e.acceptedChan == nil {
+ e.acceptedChan = make(chan *endpoint, backlog)
+ }
+ e.workerRunning = true
+
+ go e.protocolListenLoop( // S/R-SAFE: drained on save.
+ seqnum.Size(e.receiveBufferAvailable()))
+
+ return nil
+}
+
+// startAcceptedLoop sets up required state and starts a goroutine with the
+// main loop for accepted connections.
+func (e *endpoint) startAcceptedLoop(waiterQueue *waiter.Queue) {
+ e.waiterQueue = waiterQueue
+ e.workerRunning = true
+ go e.protocolMainLoop(true) // S/R-FIXME
+}
+
+// Accept returns a new endpoint if a peer has established a connection
+// to an endpoint previously set to listen mode.
+func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // Endpoint must be in listen state before it can accept connections.
+ if e.state != stateListen {
+ return nil, nil, tcpip.ErrInvalidEndpointState
+ }
+
+ // Get the new accepted endpoint.
+ var n *endpoint
+ select {
+ case n = <-e.acceptedChan:
+ default:
+ return nil, nil, tcpip.ErrWouldBlock
+ }
+
+ // Start the protocol goroutine.
+ wq := &waiter.Queue{}
+ n.startAcceptedLoop(wq)
+
+ return n, wq, nil
+}
+
+// Bind binds the endpoint to a specific local port and optionally address.
+func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) (retErr *tcpip.Error) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore. This is because once the endpoint goes into a connected or
+ // listen state, it is already bound.
+ if e.state != stateInitial {
+ return tcpip.ErrAlreadyBound
+ }
+
+ netProto, err := e.checkV4Mapped(&addr)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
+ }
+
+ // Reserve the port.
+ port, err := e.stack.ReservePort(netProtos, ProtocolNumber, addr.Addr, addr.Port)
+ if err != nil {
+ return err
+ }
+
+ e.isPortReserved = true
+ e.effectiveNetProtos = netProtos
+ e.id.LocalPort = port
+
+ // Any failures beyond this point must remove the port registration.
+ defer func() {
+ if retErr != nil {
+ e.stack.ReleasePort(netProtos, ProtocolNumber, addr.Addr, port)
+ e.isPortReserved = false
+ e.effectiveNetProtos = nil
+ e.id.LocalPort = 0
+ e.id.LocalAddress = ""
+ e.boundNICID = 0
+ }
+ }()
+
+ // If an address is specified, we must ensure that it's one of our
+ // local addresses.
+ if len(addr.Addr) != 0 {
+ nic := e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr)
+ if nic == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+
+ e.boundNICID = nic
+ e.id.LocalAddress = addr.Addr
+ }
+
+ // Check the commit function.
+ if commit != nil {
+ if err := commit(); err != nil {
+ // The defer takes care of unwind.
+ return err
+ }
+ }
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ Addr: e.id.LocalAddress,
+ Port: e.id.LocalPort,
+ NIC: e.boundNICID,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ NIC: e.boundNICID,
+ }, nil
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
+ s := newSegment(r, id, vv)
+ if !s.parse() {
+ atomic.AddUint64(&e.stack.MutableStats().MalformedRcvdPackets, 1)
+ s.decRef()
+ return
+ }
+
+ // Send packet to worker goroutine.
+ if e.segmentQueue.enqueue(s) {
+ e.newSegmentWaker.Assert()
+ } else {
+ // The queue is full, so we drop the segment.
+ atomic.AddUint64(&e.stack.MutableStats().DroppedPackets, 1)
+ s.decRef()
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) {
+ switch typ {
+ case stack.ControlPacketTooBig:
+ e.sndBufMu.Lock()
+ e.packetTooBigCount++
+ if v := int(extra); v < e.sndMTU {
+ e.sndMTU = v
+ }
+ e.sndBufMu.Unlock()
+
+ e.notifyProtocolGoroutine(notifyMTUChanged)
+ }
+}
+
+// updateSndBufferUsage is called by the protocol goroutine when room opens up
+// in the send buffer. The number of newly available bytes is v.
+func (e *endpoint) updateSndBufferUsage(v int) {
+ e.sndBufMu.Lock()
+ notify := e.sndBufUsed >= e.sndBufSize>>1
+ e.sndBufUsed -= v
+ // We only notify when there is half the sndBufSize available after
+ // a full buffer event occurs. This ensures that we don't wake up
+ // writers to queue just 1-2 segments and go back to sleep.
+ notify = notify && e.sndBufUsed < e.sndBufSize>>1
+ e.sndBufMu.Unlock()
+
+ if notify {
+ e.waiterQueue.Notify(waiter.EventOut)
+ }
+}
+
+// readyToRead is called by the protocol goroutine when a new segment is ready
+// to be read, or when the connection is closed for receiving (in which case
+// s will be nil).
+func (e *endpoint) readyToRead(s *segment) {
+ e.rcvListMu.Lock()
+ if s != nil {
+ s.incRef()
+ e.rcvBufUsed += s.data.Size()
+ e.rcvList.PushBack(s)
+ } else {
+ e.rcvClosed = true
+ }
+ e.rcvListMu.Unlock()
+
+ e.waiterQueue.Notify(waiter.EventIn)
+}
+
+// receiveBufferAvailable calculates how many bytes are still available in the
+// receive buffer.
+func (e *endpoint) receiveBufferAvailable() int {
+ e.rcvListMu.Lock()
+ size := e.rcvBufSize
+ used := e.rcvBufUsed
+ e.rcvListMu.Unlock()
+
+ // We may use more bytes than the buffer size when the receive buffer
+ // shrinks.
+ if used >= size {
+ return 0
+ }
+
+ return size - used
+}
+
+func (e *endpoint) receiveBufferSize() int {
+ e.rcvListMu.Lock()
+ size := e.rcvBufSize
+ e.rcvListMu.Unlock()
+
+ return size
+}
+
+// updateRecentTimestamp updates the recent timestamp using the algorithm
+// described in https://tools.ietf.org/html/rfc7323#section-4.3
+func (e *endpoint) updateRecentTimestamp(tsVal uint32, maxSentAck seqnum.Value, segSeq seqnum.Value) {
+ if e.sendTSOk && seqnum.Value(e.recentTS).LessThan(seqnum.Value(tsVal)) && segSeq.LessThanEq(maxSentAck) {
+ e.recentTS = tsVal
+ }
+}
+
+// maybeEnableTimestamp marks the timestamp option enabled for this endpoint if
+// the SYN options indicate that timestamp option was negotiated. It also
+// initializes the recentTS with the value provided in synOpts.TSval.
+func (e *endpoint) maybeEnableTimestamp(synOpts *header.TCPSynOptions) {
+ if synOpts.TS {
+ e.sendTSOk = true
+ e.recentTS = synOpts.TSVal
+ }
+}
+
+// timestamp returns the timestamp value to be used in the TSVal field of the
+// timestamp option for outgoing TCP segments for a given endpoint.
+func (e *endpoint) timestamp() uint32 {
+ return tcpTimeStamp(e.tsOffset)
+}
+
+// tcpTimeStamp returns a timestamp offset by the provided offset. This is
+// not inlined above as it's used when SYN cookies are in use and endpoint
+// is not created at the time when the SYN cookie is sent.
+func tcpTimeStamp(offset uint32) uint32 {
+ now := time.Now()
+ return uint32(now.Unix()*1000+int64(now.Nanosecond()/1e6)) + offset
+}
+
+// timeStampOffset returns a randomized timestamp offset to be used when sending
+// timestamp values in a timestamp option for a TCP segment.
+func timeStampOffset() uint32 {
+ b := make([]byte, 4)
+ if _, err := rand.Read(b); err != nil {
+ panic(err)
+ }
+ // Initialize a random tsOffset that will be added to the recentTS
+ // everytime the timestamp is sent when the Timestamp option is enabled.
+ //
+ // See https://tools.ietf.org/html/rfc7323#section-5.4 for details on
+ // why this is required.
+ //
+ // NOTE: This is not completely to spec as normally this should be
+ // initialized in a manner analogous to how sequence numbers are
+ // randomized per connection basis. But for now this is sufficient.
+ return uint32(b[0]) | uint32(b[1])<<8 | uint32(b[2])<<16 | uint32(b[3])<<24
+}
+
+// maybeEnableSACKPermitted marks the SACKPermitted option enabled for this endpoint
+// if the SYN options indicate that the SACK option was negotiated and the TCP
+// stack is configured to enable TCP SACK option.
+func (e *endpoint) maybeEnableSACKPermitted(synOpts *header.TCPSynOptions) {
+ var v SACKEnabled
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &v); err != nil {
+ // Stack doesn't support SACK. So just return.
+ return
+ }
+ if bool(v) && synOpts.SACKPermitted {
+ e.sackPermitted = true
+ }
+}
+
+// completeState makes a full copy of the endpoint and returns it. This is used
+// before invoking the probe. The state returned may not be fully consistent if
+// there are intervening syscalls when the state is being copied.
+func (e *endpoint) completeState() stack.TCPEndpointState {
+ var s stack.TCPEndpointState
+ s.SegTime = time.Now()
+
+ // Copy EndpointID.
+ e.mu.Lock()
+ s.ID = stack.TCPEndpointID(e.id)
+ e.mu.Unlock()
+
+ // Copy endpoint rcv state.
+ e.rcvListMu.Lock()
+ s.RcvBufSize = e.rcvBufSize
+ s.RcvBufUsed = e.rcvBufUsed
+ s.RcvClosed = e.rcvClosed
+ e.rcvListMu.Unlock()
+
+ // Endpoint TCP Option state.
+ s.SendTSOk = e.sendTSOk
+ s.RecentTS = e.recentTS
+ s.TSOffset = e.tsOffset
+ s.SACKPermitted = e.sackPermitted
+ s.SACK.Blocks = make([]header.SACKBlock, e.sack.NumBlocks)
+ copy(s.SACK.Blocks, e.sack.Blocks[:e.sack.NumBlocks])
+
+ // Copy endpoint send state.
+ e.sndBufMu.Lock()
+ s.SndBufSize = e.sndBufSize
+ s.SndBufUsed = e.sndBufUsed
+ s.SndClosed = e.sndClosed
+ s.SndBufInQueue = e.sndBufInQueue
+ s.PacketTooBigCount = e.packetTooBigCount
+ s.SndMTU = e.sndMTU
+ e.sndBufMu.Unlock()
+
+ // Copy receiver state.
+ s.Receiver = stack.TCPReceiverState{
+ RcvNxt: e.rcv.rcvNxt,
+ RcvAcc: e.rcv.rcvAcc,
+ RcvWndScale: e.rcv.rcvWndScale,
+ PendingBufUsed: e.rcv.pendingBufUsed,
+ PendingBufSize: e.rcv.pendingBufSize,
+ }
+
+ // Copy sender state.
+ s.Sender = stack.TCPSenderState{
+ LastSendTime: e.snd.lastSendTime,
+ DupAckCount: e.snd.dupAckCount,
+ FastRecovery: stack.TCPFastRecoveryState{
+ Active: e.snd.fr.active,
+ First: e.snd.fr.first,
+ Last: e.snd.fr.last,
+ MaxCwnd: e.snd.fr.maxCwnd,
+ },
+ SndCwnd: e.snd.sndCwnd,
+ Ssthresh: e.snd.sndSsthresh,
+ SndCAAckCount: e.snd.sndCAAckCount,
+ Outstanding: e.snd.outstanding,
+ SndWnd: e.snd.sndWnd,
+ SndUna: e.snd.sndUna,
+ SndNxt: e.snd.sndNxt,
+ RTTMeasureSeqNum: e.snd.rttMeasureSeqNum,
+ RTTMeasureTime: e.snd.rttMeasureTime,
+ Closed: e.snd.closed,
+ SRTT: e.snd.srtt,
+ RTO: e.snd.rto,
+ SRTTInited: e.snd.srttInited,
+ MaxPayloadSize: e.snd.maxPayloadSize,
+ SndWndScale: e.snd.sndWndScale,
+ MaxSentAck: e.snd.maxSentAck,
+ }
+ return s
+}
diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go
new file mode 100644
index 000000000..dbb70ff21
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/endpoint_state.go
@@ -0,0 +1,128 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcp
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// ErrSaveRejection indicates a failed save due to unsupported tcp endpoint
+// state.
+type ErrSaveRejection struct {
+ Err error
+}
+
+// Error returns a sensible description of the save rejection error.
+func (e ErrSaveRejection) Error() string {
+ return "save rejected due to unsupported endpoint state: " + e.Err.Error()
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets.
+ e.segmentQueue.setLimit(0)
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ switch e.state {
+ case stateInitial:
+ case stateBound:
+ case stateListen:
+ if !e.segmentQueue.empty() {
+ e.mu.RUnlock()
+ e.drainDone = make(chan struct{}, 1)
+ e.notificationWaker.Assert()
+ <-e.drainDone
+ e.mu.RLock()
+ }
+ case stateConnecting:
+ panic(ErrSaveRejection{fmt.Errorf("endpoint in connecting state upon save: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
+ case stateConnected:
+ // FIXME
+ panic(ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%v, remote %v:%v", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)})
+ case stateClosed:
+ case stateError:
+ default:
+ panic(fmt.Sprintf("endpoint in unknown state %v", e.state))
+ }
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ e.stack = stack.StackFromEnv
+
+ if e.state == stateListen {
+ e.state = stateBound
+ backlog := cap(e.acceptedChan)
+ e.acceptedChan = nil
+ defer func() {
+ if err := e.Listen(backlog); err != nil {
+ panic("endpoint listening failed: " + err.String())
+ }
+ }()
+ }
+
+ if e.state == stateBound {
+ e.state = stateInitial
+ defer func() {
+ if err := e.Bind(tcpip.FullAddress{Addr: e.id.LocalAddress, Port: e.id.LocalPort}, nil); err != nil {
+ panic("endpoint binding failed: " + err.String())
+ }
+ }()
+ }
+
+ if e.state == stateInitial {
+ var ss SendBufferSizeOption
+ if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil {
+ if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max {
+ panic(fmt.Sprintf("endpoint.sndBufSize %d is outside the min and max allowed [%d, %d]", e.sndBufSize, ss.Min, ss.Max))
+ }
+ if e.rcvBufSize < ss.Min || e.rcvBufSize > ss.Max {
+ panic(fmt.Sprintf("endpoint.rcvBufSize %d is outside the min and max allowed [%d, %d]", e.rcvBufSize, ss.Min, ss.Max))
+ }
+ }
+ }
+
+ e.segmentQueue.setLimit(2 * e.rcvBufSize)
+ e.workMu.Init()
+}
+
+// saveAcceptedChan is invoked by stateify.
+func (e *endpoint) saveAcceptedChan() endpointChan {
+ if e.acceptedChan == nil {
+ return endpointChan{}
+ }
+ close(e.acceptedChan)
+ buffer := make([]*endpoint, 0, len(e.acceptedChan))
+ for ep := range e.acceptedChan {
+ buffer = append(buffer, ep)
+ }
+ if len(buffer) != cap(buffer) {
+ panic("endpoint.acceptedChan buffer got consumed by background context")
+ }
+ c := cap(e.acceptedChan)
+ e.acceptedChan = nil
+ return endpointChan{buffer: buffer, cap: c}
+}
+
+// loadAcceptedChan is invoked by stateify.
+func (e *endpoint) loadAcceptedChan(c endpointChan) {
+ if c.cap == 0 {
+ return
+ }
+ e.acceptedChan = make(chan *endpoint, c.cap)
+ for _, ep := range c.buffer {
+ e.acceptedChan <- ep
+ }
+}
+
+type endpointChan struct {
+ buffer []*endpoint
+ cap int
+}
diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go
new file mode 100644
index 000000000..657ac524f
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/forwarder.go
@@ -0,0 +1,161 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcp
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// Forwarder is a connection request forwarder, which allows clients to decide
+// what to do with a connection request, for example: ignore it, send a RST, or
+// attempt to complete the 3-way handshake.
+//
+// The canonical way of using it is to pass the Forwarder.HandlePacket function
+// to stack.SetTransportProtocolHandler.
+type Forwarder struct {
+ maxInFlight int
+ handler func(*ForwarderRequest)
+
+ mu sync.Mutex
+ inFlight map[stack.TransportEndpointID]struct{}
+ listen *listenContext
+}
+
+// NewForwarder allocates and initializes a new forwarder with the given
+// maximum number of in-flight connection attempts. Once the maximum is reached
+// new incoming connection requests will be ignored.
+//
+// If rcvWnd is set to zero, the default buffer size is used instead.
+func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*ForwarderRequest)) *Forwarder {
+ if rcvWnd == 0 {
+ rcvWnd = DefaultBufferSize
+ }
+ return &Forwarder{
+ maxInFlight: maxInFlight,
+ handler: handler,
+ inFlight: make(map[stack.TransportEndpointID]struct{}),
+ listen: newListenContext(s, seqnum.Size(rcvWnd), true, 0),
+ }
+}
+
+// HandlePacket handles a packet if it is of interest to the forwarder (i.e., if
+// it's a SYN packet), returning true if it's the case. Otherwise the packet
+// is not handled and false is returned.
+//
+// This function is expected to be passed as an argument to the
+// stack.SetTransportProtocolHandler function.
+func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) bool {
+ s := newSegment(r, id, vv)
+ defer s.decRef()
+
+ // We only care about well-formed SYN packets.
+ if !s.parse() || s.flags != flagSyn {
+ return false
+ }
+
+ opts := parseSynSegmentOptions(s)
+
+ f.mu.Lock()
+ defer f.mu.Unlock()
+
+ // We have an inflight request for this id, ignore this one for now.
+ if _, ok := f.inFlight[id]; ok {
+ return true
+ }
+
+ // Ignore the segment if we're beyond the limit.
+ if len(f.inFlight) >= f.maxInFlight {
+ return true
+ }
+
+ // Launch a new goroutine to handle the request.
+ f.inFlight[id] = struct{}{}
+ s.incRef()
+ go f.handler(&ForwarderRequest{ // S/R-FIXME
+ forwarder: f,
+ segment: s,
+ synOptions: opts,
+ })
+
+ return true
+}
+
+// ForwarderRequest represents a connection request received by the forwarder
+// and passed to the client. Clients must eventually call Complete() on it, and
+// may optionally create an endpoint to represent it via CreateEndpoint.
+type ForwarderRequest struct {
+ mu sync.Mutex
+ forwarder *Forwarder
+ segment *segment
+ synOptions header.TCPSynOptions
+}
+
+// ID returns the 4-tuple (src address, src port, dst address, dst port) that
+// represents the connection request.
+func (r *ForwarderRequest) ID() stack.TransportEndpointID {
+ return r.segment.id
+}
+
+// Complete completes the request, and optionally sends a RST segment back to the
+// sender.
+func (r *ForwarderRequest) Complete(sendReset bool) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.segment == nil {
+ panic("Completing already completed forwarder request")
+ }
+
+ // Remove request from the forwarder.
+ r.forwarder.mu.Lock()
+ delete(r.forwarder.inFlight, r.segment.id)
+ r.forwarder.mu.Unlock()
+
+ // If the caller requested, send a reset.
+ if sendReset {
+ replyWithReset(r.segment)
+ }
+
+ // Release all resources.
+ r.segment.decRef()
+ r.segment = nil
+ r.forwarder = nil
+}
+
+// CreateEndpoint creates a TCP endpoint for the connection request, performing
+// the 3-way handshake in the process.
+func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ if r.segment == nil {
+ return nil, tcpip.ErrInvalidEndpointState
+ }
+
+ f := r.forwarder
+ ep, err := f.listen.createEndpointAndPerformHandshake(r.segment, &header.TCPSynOptions{
+ MSS: r.synOptions.MSS,
+ WS: r.synOptions.WS,
+ TS: r.synOptions.TS,
+ TSVal: r.synOptions.TSVal,
+ TSEcr: r.synOptions.TSEcr,
+ SACKPermitted: r.synOptions.SACKPermitted,
+ })
+ if err != nil {
+ return nil, err
+ }
+
+ // Start the protocol goroutine.
+ ep.startAcceptedLoop(queue)
+
+ return ep, nil
+}
diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go
new file mode 100644
index 000000000..d81a1dd9b
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/protocol.go
@@ -0,0 +1,192 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package tcp contains the implementation of the TCP transport protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing tcp.ProtocolName (or "tcp") as one of the
+// transport protocols when calling stack.New(). Then endpoints can be created
+// by passing tcp.ProtocolNumber as the transport protocol number when calling
+// Stack.NewEndpoint().
+package tcp
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolName is the string representation of the tcp protocol name.
+ ProtocolName = "tcp"
+
+ // ProtocolNumber is the tcp protocol number.
+ ProtocolNumber = header.TCPProtocolNumber
+
+ // MinBufferSize is the smallest size of a receive or send buffer.
+ minBufferSize = 4 << 10 // 4096 bytes.
+
+ // DefaultBufferSize is the default size of the receive and send buffers.
+ DefaultBufferSize = 1 << 20 // 1MB
+
+ // MaxBufferSize is the largest size a receive and send buffer can grow to.
+ maxBufferSize = 4 << 20 // 4MB
+)
+
+// SACKEnabled option can be used to enable SACK support in the TCP
+// protocol. See: https://tools.ietf.org/html/rfc2018.
+type SACKEnabled bool
+
+// SendBufferSizeOption allows the default, min and max send buffer sizes for
+// TCP endpoints to be queried or configured.
+type SendBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+// ReceiveBufferSizeOption allows the default, min and max receive buffer size
+// for TCP endpoints to be queried or configured.
+type ReceiveBufferSizeOption struct {
+ Min int
+ Default int
+ Max int
+}
+
+type protocol struct {
+ mu sync.Mutex
+ sackEnabled bool
+ sendBufferSize SendBufferSizeOption
+ recvBufferSize ReceiveBufferSizeOption
+}
+
+// Number returns the tcp protocol number.
+func (*protocol) Number() tcpip.TransportProtocolNumber {
+ return ProtocolNumber
+}
+
+// NewEndpoint creates a new tcp endpoint.
+func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, waiterQueue), nil
+}
+
+// MinimumPacketSize returns the minimum valid tcp packet size.
+func (*protocol) MinimumPacketSize() int {
+ return header.TCPMinimumSize
+}
+
+// ParsePorts returns the source and destination ports stored in the given tcp
+// packet.
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ h := header.TCP(v)
+ return h.SourcePort(), h.DestinationPort(), nil
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+//
+// RFC 793, page 36, states that "If the connection does not exist (CLOSED) then
+// a reset is sent in response to any incoming segment except another reset. In
+// particular, SYNs addressed to a non-existent connection are rejected by this
+// means."
+func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) bool {
+ s := newSegment(r, id, vv)
+ defer s.decRef()
+
+ if !s.parse() {
+ return false
+ }
+
+ // There's nothing to do if this is already a reset packet.
+ if s.flagIsSet(flagRst) {
+ return true
+ }
+
+ replyWithReset(s)
+ return true
+}
+
+// replyWithReset replies to the given segment with a reset segment.
+func replyWithReset(s *segment) {
+ // Get the seqnum from the packet if the ack flag is set.
+ seq := seqnum.Value(0)
+ if s.flagIsSet(flagAck) {
+ seq = s.ackNumber
+ }
+
+ ack := s.sequenceNumber.Add(s.logicalLen())
+
+ sendTCP(&s.route, s.id, nil, flagRst|flagAck, seq, ack, 0)
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case SACKEnabled:
+ p.mu.Lock()
+ p.sackEnabled = bool(v)
+ p.mu.Unlock()
+ return nil
+
+ case SendBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.sendBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ case ReceiveBufferSizeOption:
+ if v.Min <= 0 || v.Default < v.Min || v.Default > v.Max {
+ return tcpip.ErrInvalidOptionValue
+ }
+ p.mu.Lock()
+ p.recvBufferSize = v
+ p.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *SACKEnabled:
+ p.mu.Lock()
+ *v = SACKEnabled(p.sackEnabled)
+ p.mu.Unlock()
+ return nil
+
+ case *SendBufferSizeOption:
+ p.mu.Lock()
+ *v = p.sendBufferSize
+ p.mu.Unlock()
+ return nil
+
+ case *ReceiveBufferSizeOption:
+ p.mu.Lock()
+ *v = p.recvBufferSize
+ p.mu.Unlock()
+ return nil
+
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
+ return &protocol{
+ sendBufferSize: SendBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ recvBufferSize: ReceiveBufferSizeOption{minBufferSize, DefaultBufferSize, maxBufferSize},
+ }
+ })
+}
diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go
new file mode 100644
index 000000000..574602105
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/rcv.go
@@ -0,0 +1,208 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcp
+
+import (
+ "container/heap"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+// receiver holds the state necessary to receive TCP segments and turn them
+// into a stream of bytes.
+type receiver struct {
+ ep *endpoint
+
+ rcvNxt seqnum.Value
+
+ // rcvAcc is one beyond the last acceptable sequence number. That is,
+ // the "largest" sequence value that the receiver has announced to the
+ // its peer that it's willing to accept. This may be different than
+ // rcvNxt + rcvWnd if the receive window is reduced; in that case we
+ // have to reduce the window as we receive more data instead of
+ // shrinking it.
+ rcvAcc seqnum.Value
+
+ rcvWndScale uint8
+
+ closed bool
+
+ pendingRcvdSegments segmentHeap
+ pendingBufUsed seqnum.Size
+ pendingBufSize seqnum.Size
+}
+
+func newReceiver(ep *endpoint, irs seqnum.Value, rcvWnd seqnum.Size, rcvWndScale uint8) *receiver {
+ return &receiver{
+ ep: ep,
+ rcvNxt: irs + 1,
+ rcvAcc: irs.Add(rcvWnd + 1),
+ rcvWndScale: rcvWndScale,
+ pendingBufSize: rcvWnd,
+ }
+}
+
+// acceptable checks if the segment sequence number range is acceptable
+// according to the table on page 26 of RFC 793.
+func (r *receiver) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
+ rcvWnd := r.rcvNxt.Size(r.rcvAcc)
+ if rcvWnd == 0 {
+ return segLen == 0 && segSeq == r.rcvNxt
+ }
+
+ return segSeq.InWindow(r.rcvNxt, rcvWnd) ||
+ seqnum.Overlap(r.rcvNxt, rcvWnd, segSeq, segLen)
+}
+
+// getSendParams returns the parameters needed by the sender when building
+// segments to send.
+func (r *receiver) getSendParams() (rcvNxt seqnum.Value, rcvWnd seqnum.Size) {
+ // Calculate the window size based on the current buffer size.
+ n := r.ep.receiveBufferAvailable()
+ acc := r.rcvNxt.Add(seqnum.Size(n))
+ if r.rcvAcc.LessThan(acc) {
+ r.rcvAcc = acc
+ }
+
+ return r.rcvNxt, r.rcvNxt.Size(r.rcvAcc) >> r.rcvWndScale
+}
+
+// nonZeroWindow is called when the receive window grows from zero to nonzero;
+// in such cases we may need to send an ack to indicate to our peer that it can
+// resume sending data.
+func (r *receiver) nonZeroWindow() {
+ if (r.rcvAcc-r.rcvNxt)>>r.rcvWndScale != 0 {
+ // We never got around to announcing a zero window size, so we
+ // don't need to immediately announce a nonzero one.
+ return
+ }
+
+ // Immediately send an ack.
+ r.ep.snd.sendAck()
+}
+
+// consumeSegment attempts to consume a segment that was received by r. The
+// segment may have just been received or may have been received earlier but
+// wasn't ready to be consumed then.
+//
+// Returns true if the segment was consumed, false if it cannot be consumed
+// yet because of a missing segment.
+func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum.Size) bool {
+ if segLen > 0 {
+ // If the segment doesn't include the seqnum we're expecting to
+ // consume now, we're missing a segment. We cannot proceed until
+ // we receive that segment though.
+ if !r.rcvNxt.InWindow(segSeq, segLen) {
+ return false
+ }
+
+ // Trim segment to eliminate already acknowledged data.
+ if segSeq.LessThan(r.rcvNxt) {
+ diff := segSeq.Size(r.rcvNxt)
+ segLen -= diff
+ segSeq.UpdateForward(diff)
+ s.sequenceNumber.UpdateForward(diff)
+ s.data.TrimFront(int(diff))
+ }
+
+ // Move segment to ready-to-deliver list. Wakeup any waiters.
+ r.ep.readyToRead(s)
+
+ } else if segSeq != r.rcvNxt {
+ return false
+ }
+
+ // Update the segment that we're expecting to consume.
+ r.rcvNxt = segSeq.Add(segLen)
+
+ // Trim SACK Blocks to remove any SACK information that covers
+ // sequence numbers that have been consumed.
+ TrimSACKBlockList(&r.ep.sack, r.rcvNxt)
+
+ if s.flagIsSet(flagFin) {
+ r.rcvNxt++
+
+ // Send ACK immediately.
+ r.ep.snd.sendAck()
+
+ // Tell any readers that no more data will come.
+ r.closed = true
+ r.ep.readyToRead(nil)
+
+ // Flush out any pending segments, except the very first one if
+ // it happens to be the one we're handling now because the
+ // caller is using it.
+ first := 0
+ if len(r.pendingRcvdSegments) != 0 && r.pendingRcvdSegments[0] == s {
+ first = 1
+ }
+
+ for i := first; i < len(r.pendingRcvdSegments); i++ {
+ r.pendingRcvdSegments[i].decRef()
+ }
+ r.pendingRcvdSegments = r.pendingRcvdSegments[:first]
+ }
+
+ return true
+}
+
+// handleRcvdSegment handles TCP segments directed at the connection managed by
+// r as they arrive. It is called by the protocol main loop.
+func (r *receiver) handleRcvdSegment(s *segment) {
+ // We don't care about receive processing anymore if the receive side
+ // is closed.
+ if r.closed {
+ return
+ }
+
+ segLen := seqnum.Size(s.data.Size())
+ segSeq := s.sequenceNumber
+
+ // If the sequence number range is outside the acceptable range, just
+ // send an ACK. This is according to RFC 793, page 37.
+ if !r.acceptable(segSeq, segLen) {
+ r.ep.snd.sendAck()
+ return
+ }
+
+ // Defer segment processing if it can't be consumed now.
+ if !r.consumeSegment(s, segSeq, segLen) {
+ if segLen > 0 || s.flagIsSet(flagFin) {
+ // We only store the segment if it's within our buffer
+ // size limit.
+ if r.pendingBufUsed < r.pendingBufSize {
+ r.pendingBufUsed += s.logicalLen()
+ s.incRef()
+ heap.Push(&r.pendingRcvdSegments, s)
+ }
+
+ UpdateSACKBlocks(&r.ep.sack, segSeq, segSeq.Add(segLen), r.rcvNxt)
+
+ // Immediately send an ack so that the peer knows it may
+ // have to retransmit.
+ r.ep.snd.sendAck()
+ }
+ return
+ }
+
+ // By consuming the current segment, we may have filled a gap in the
+ // sequence number domain that allows pending segments to be consumed
+ // now. So try to do it.
+ for !r.closed && r.pendingRcvdSegments.Len() > 0 {
+ s := r.pendingRcvdSegments[0]
+ segLen := seqnum.Size(s.data.Size())
+ segSeq := s.sequenceNumber
+
+ // Skip segment altogether if it has already been acknowledged.
+ if !segSeq.Add(segLen-1).LessThan(r.rcvNxt) &&
+ !r.consumeSegment(s, segSeq, segLen) {
+ break
+ }
+
+ heap.Pop(&r.pendingRcvdSegments)
+ r.pendingBufUsed -= s.logicalLen()
+ s.decRef()
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/sack.go b/pkg/tcpip/transport/tcp/sack.go
new file mode 100644
index 000000000..0b66305a5
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/sack.go
@@ -0,0 +1,85 @@
+package tcp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // MaxSACKBlocks is the maximum number of SACK blocks stored
+ // at receiver side.
+ MaxSACKBlocks = 6
+)
+
+// UpdateSACKBlocks updates the list of SACK blocks to include the segment
+// specified by segStart->segEnd. If the segment happens to be an out of order
+// delivery then the first block in the sack.blocks always includes the
+// segment identified by segStart->segEnd.
+func UpdateSACKBlocks(sack *SACKInfo, segStart seqnum.Value, segEnd seqnum.Value, rcvNxt seqnum.Value) {
+ newSB := header.SACKBlock{Start: segStart, End: segEnd}
+ if sack.NumBlocks == 0 {
+ sack.Blocks[0] = newSB
+ sack.NumBlocks = 1
+ return
+ }
+ var n = 0
+ for i := 0; i < sack.NumBlocks; i++ {
+ start, end := sack.Blocks[i].Start, sack.Blocks[i].End
+ if end.LessThanEq(start) || start.LessThanEq(rcvNxt) {
+ // Discard any invalid blocks where end is before start
+ // and discard any sack blocks that are before rcvNxt as
+ // those have already been acked.
+ continue
+ }
+ if newSB.Start.LessThanEq(end) && start.LessThanEq(newSB.End) {
+ // Merge this SACK block into newSB and discard this SACK
+ // block.
+ if start.LessThan(newSB.Start) {
+ newSB.Start = start
+ }
+ if newSB.End.LessThan(end) {
+ newSB.End = end
+ }
+ } else {
+ // Save this block.
+ sack.Blocks[n] = sack.Blocks[i]
+ n++
+ }
+ }
+ if rcvNxt.LessThan(newSB.Start) {
+ // If this was an out of order segment then make sure that the
+ // first SACK block is the one that includes the segment.
+ //
+ // See the first bullet point in
+ // https://tools.ietf.org/html/rfc2018#section-4
+ if n == MaxSACKBlocks {
+ // If the number of SACK blocks is equal to
+ // MaxSACKBlocks then discard the last SACK block.
+ n--
+ }
+ for i := n - 1; i >= 0; i-- {
+ sack.Blocks[i+1] = sack.Blocks[i]
+ }
+ sack.Blocks[0] = newSB
+ n++
+ }
+ sack.NumBlocks = n
+}
+
+// TrimSACKBlockList updates the sack block list by removing/modifying any block
+// where start is < rcvNxt.
+func TrimSACKBlockList(sack *SACKInfo, rcvNxt seqnum.Value) {
+ n := 0
+ for i := 0; i < sack.NumBlocks; i++ {
+ if sack.Blocks[i].End.LessThanEq(rcvNxt) {
+ continue
+ }
+ if sack.Blocks[i].Start.LessThan(rcvNxt) {
+ // Shrink this SACK block.
+ sack.Blocks[i].Start = rcvNxt
+ }
+ sack.Blocks[n] = sack.Blocks[i]
+ n++
+ }
+ sack.NumBlocks = n
+}
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
new file mode 100644
index 000000000..c742fc394
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -0,0 +1,145 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcp
+
+import (
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// Flags that may be set in a TCP segment.
+const (
+ flagFin = 1 << iota
+ flagSyn
+ flagRst
+ flagPsh
+ flagAck
+ flagUrg
+)
+
+// segment represents a TCP segment. It holds the payload and parsed TCP segment
+// information, and can be added to intrusive lists.
+// segment is mostly immutable, the only field allowed to change is viewToDeliver.
+type segment struct {
+ segmentEntry
+ refCnt int32
+ id stack.TransportEndpointID
+ route stack.Route
+ data buffer.VectorisedView
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View
+ // viewToDeliver keeps track of the next View that should be
+ // delivered by the Read endpoint.
+ viewToDeliver int
+ sequenceNumber seqnum.Value
+ ackNumber seqnum.Value
+ flags uint8
+ window seqnum.Size
+
+ // parsedOptions stores the parsed values from the options in the segment.
+ parsedOptions header.TCPOptions
+ options []byte
+}
+
+func newSegment(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) *segment {
+ s := &segment{
+ refCnt: 1,
+ id: id,
+ route: r.Clone(),
+ }
+ s.data = vv.Clone(s.views[:])
+ return s
+}
+
+func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment {
+ s := &segment{
+ refCnt: 1,
+ id: id,
+ route: r.Clone(),
+ }
+ s.views[0] = v
+ s.data = buffer.NewVectorisedView(len(v), s.views[:1])
+ return s
+}
+
+func (s *segment) clone() *segment {
+ t := &segment{
+ refCnt: 1,
+ id: s.id,
+ sequenceNumber: s.sequenceNumber,
+ ackNumber: s.ackNumber,
+ flags: s.flags,
+ window: s.window,
+ route: s.route.Clone(),
+ viewToDeliver: s.viewToDeliver,
+ }
+ t.data = s.data.Clone(t.views[:])
+ return t
+}
+
+func (s *segment) flagIsSet(flag uint8) bool {
+ return (s.flags & flag) != 0
+}
+
+func (s *segment) decRef() {
+ if atomic.AddInt32(&s.refCnt, -1) == 0 {
+ s.route.Release()
+ }
+}
+
+func (s *segment) incRef() {
+ atomic.AddInt32(&s.refCnt, 1)
+}
+
+// logicalLen is the segment length in the sequence number space. It's defined
+// as the data length plus one for each of the SYN and FIN bits set.
+func (s *segment) logicalLen() seqnum.Size {
+ l := seqnum.Size(s.data.Size())
+ if s.flagIsSet(flagSyn) {
+ l++
+ }
+ if s.flagIsSet(flagFin) {
+ l++
+ }
+ return l
+}
+
+// parse populates the sequence & ack numbers, flags, and window fields of the
+// segment from the TCP header stored in the data. It then updates the view to
+// skip the data. Returns boolean indicating if the parsing was successful.
+func (s *segment) parse() bool {
+ h := header.TCP(s.data.First())
+
+ // h is the header followed by the payload. We check that the offset to
+ // the data respects the following constraints:
+ // 1. That it's at least the minimum header size; if we don't do this
+ // then part of the header would be delivered to user.
+ // 2. That the header fits within the buffer; if we don't do this, we
+ // would panic when we tried to access data beyond the buffer.
+ //
+ // N.B. The segment has already been validated as having at least the
+ // minimum TCP size before reaching here, so it's safe to read the
+ // fields.
+ offset := int(h.DataOffset())
+ if offset < header.TCPMinimumSize || offset > len(h) {
+ return false
+ }
+
+ s.options = []byte(h[header.TCPMinimumSize:offset])
+ s.parsedOptions = header.ParseTCPOptions(s.options)
+ s.data.TrimFront(offset)
+
+ s.sequenceNumber = seqnum.Value(h.SequenceNumber())
+ s.ackNumber = seqnum.Value(h.AckNumber())
+ s.flags = h.Flags()
+ s.window = seqnum.Size(h.WindowSize())
+
+ return true
+}
diff --git a/pkg/tcpip/transport/tcp/segment_heap.go b/pkg/tcpip/transport/tcp/segment_heap.go
new file mode 100644
index 000000000..137ddbdd2
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_heap.go
@@ -0,0 +1,36 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcp
+
+type segmentHeap []*segment
+
+// Len returns the length of h.
+func (h segmentHeap) Len() int {
+ return len(h)
+}
+
+// Less determines whether the i-th element of h is less than the j-th element.
+func (h segmentHeap) Less(i, j int) bool {
+ return h[i].sequenceNumber.LessThan(h[j].sequenceNumber)
+}
+
+// Swap swaps the i-th and j-th elements of h.
+func (h segmentHeap) Swap(i, j int) {
+ h[i], h[j] = h[j], h[i]
+}
+
+// Push adds x as the last element of h.
+func (h *segmentHeap) Push(x interface{}) {
+ *h = append(*h, x.(*segment))
+}
+
+// Pop removes the last element of h and returns it.
+func (h *segmentHeap) Pop() interface{} {
+ old := *h
+ n := len(old)
+ x := old[n-1]
+ *h = old[:n-1]
+ return x
+}
diff --git a/pkg/tcpip/transport/tcp/segment_queue.go b/pkg/tcpip/transport/tcp/segment_queue.go
new file mode 100644
index 000000000..c4a7f7d5b
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/segment_queue.go
@@ -0,0 +1,69 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcp
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+// segmentQueue is a bounded, thread-safe queue of TCP segments.
+type segmentQueue struct {
+ mu sync.Mutex
+ list segmentList
+ limit int
+ used int
+}
+
+// empty determines if the queue is empty.
+func (q *segmentQueue) empty() bool {
+ q.mu.Lock()
+ r := q.used == 0
+ q.mu.Unlock()
+
+ return r
+}
+
+// setLimit updates the limit. No segments are immediately dropped in case the
+// queue becomes full due to the new limit.
+func (q *segmentQueue) setLimit(limit int) {
+ q.mu.Lock()
+ q.limit = limit
+ q.mu.Unlock()
+}
+
+// enqueue adds the given segment to the queue.
+//
+// Returns true when the segment is successfully added to the queue, in which
+// case ownership of the reference is transferred to the queue. And returns
+// false if the queue is full, in which case ownership is retained by the
+// caller.
+func (q *segmentQueue) enqueue(s *segment) bool {
+ q.mu.Lock()
+ r := q.used < q.limit
+ if r {
+ q.list.PushBack(s)
+ q.used += s.data.Size() + header.TCPMinimumSize
+ }
+ q.mu.Unlock()
+
+ return r
+}
+
+// dequeue removes and returns the next segment from queue, if one exists.
+// Ownership is transferred to the caller, who is responsible for decrementing
+// the ref count when done.
+func (q *segmentQueue) dequeue() *segment {
+ q.mu.Lock()
+ s := q.list.Front()
+ if s != nil {
+ q.list.Remove(s)
+ q.used -= s.data.Size() + header.TCPMinimumSize
+ }
+ q.mu.Unlock()
+
+ return s
+}
diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go
new file mode 100644
index 000000000..ad94aecd8
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/snd.go
@@ -0,0 +1,628 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcp
+
+import (
+ "math"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+const (
+ // minRTO is the minimum allowed value for the retransmit timeout.
+ minRTO = 200 * time.Millisecond
+
+ // InitialCwnd is the initial congestion window.
+ InitialCwnd = 10
+)
+
+// sender holds the state necessary to send TCP segments.
+type sender struct {
+ ep *endpoint
+
+ // lastSendTime is the timestamp when the last packet was sent.
+ lastSendTime time.Time
+
+ // dupAckCount is the number of duplicated acks received. It is used for
+ // fast retransmit.
+ dupAckCount int
+
+ // fr holds state related to fast recovery.
+ fr fastRecovery
+
+ // sndCwnd is the congestion window, in packets.
+ sndCwnd int
+
+ // sndSsthresh is the threshold between slow start and congestion
+ // avoidance.
+ sndSsthresh int
+
+ // sndCAAckCount is the number of packets acknowledged during congestion
+ // avoidance. When enough packets have been ack'd (typically cwnd
+ // packets), the congestion window is incremented by one.
+ sndCAAckCount int
+
+ // outstanding is the number of outstanding packets, that is, packets
+ // that have been sent but not yet acknowledged.
+ outstanding int
+
+ // sndWnd is the send window size.
+ sndWnd seqnum.Size
+
+ // sndUna is the next unacknowledged sequence number.
+ sndUna seqnum.Value
+
+ // sndNxt is the sequence number of the next segment to be sent.
+ sndNxt seqnum.Value
+
+ // sndNxtList is the sequence number of the next segment to be added to
+ // the send list.
+ sndNxtList seqnum.Value
+
+ // rttMeasureSeqNum is the sequence number being used for the latest RTT
+ // measurement.
+ rttMeasureSeqNum seqnum.Value
+
+ // rttMeasureTime is the time when the rttMeasureSeqNum was sent.
+ rttMeasureTime time.Time
+
+ closed bool
+ writeNext *segment
+ writeList segmentList
+ resendTimer timer `state:"nosave"`
+ resendWaker sleep.Waker `state:"nosave"`
+
+ // srtt, rttvar & rto are the "smoothed round-trip time", "round-trip
+ // time variation" and "retransmit timeout", as defined in section 2 of
+ // RFC 6298.
+ srtt time.Duration
+ rttvar time.Duration
+ rto time.Duration
+ srttInited bool
+
+ // maxPayloadSize is the maximum size of the payload of a given segment.
+ // It is initialized on demand.
+ maxPayloadSize int
+
+ // sndWndScale is the number of bits to shift left when reading the send
+ // window size from a segment.
+ sndWndScale uint8
+
+ // maxSentAck is the maxium acknowledgement actually sent.
+ maxSentAck seqnum.Value
+}
+
+// fastRecovery holds information related to fast recovery from a packet loss.
+type fastRecovery struct {
+ // active whether the endpoint is in fast recovery. The following fields
+ // are only meaningful when active is true.
+ active bool
+
+ // first and last represent the inclusive sequence number range being
+ // recovered.
+ first seqnum.Value
+ last seqnum.Value
+
+ // maxCwnd is the maximum value the congestion window may be inflated to
+ // due to duplicate acks. This exists to avoid attacks where the
+ // receiver intentionally sends duplicate acks to artificially inflate
+ // the sender's cwnd.
+ maxCwnd int
+}
+
+func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint16, sndWndScale int) *sender {
+ s := &sender{
+ ep: ep,
+ sndCwnd: InitialCwnd,
+ sndSsthresh: math.MaxInt64,
+ sndWnd: sndWnd,
+ sndUna: iss + 1,
+ sndNxt: iss + 1,
+ sndNxtList: iss + 1,
+ rto: 1 * time.Second,
+ rttMeasureSeqNum: iss + 1,
+ lastSendTime: time.Now(),
+ maxPayloadSize: int(mss),
+ maxSentAck: irs + 1,
+ fr: fastRecovery{
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 1.
+ last: iss,
+ },
+ }
+
+ // A negative sndWndScale means that no scaling is in use, otherwise we
+ // store the scaling value.
+ if sndWndScale > 0 {
+ s.sndWndScale = uint8(sndWndScale)
+ }
+
+ s.updateMaxPayloadSize(int(ep.route.MTU()), 0)
+
+ s.resendTimer.init(&s.resendWaker)
+
+ return s
+}
+
+// updateMaxPayloadSize updates the maximum payload size based on the given
+// MTU. If this is in response to "packet too big" control packets (indicated
+// by the count argument), it also reduces the number of oustanding packets and
+// attempts to retransmit the first packet above the MTU size.
+func (s *sender) updateMaxPayloadSize(mtu, count int) {
+ m := mtu - header.TCPMinimumSize
+
+ // Calculate the maximum option size.
+ var maxSackBlocks [header.TCPMaxSACKBlocks]header.SACKBlock
+ options := s.ep.makeOptions(maxSackBlocks[:])
+ m -= len(options)
+ putOptions(options)
+
+ // We don't adjust up for now.
+ if m >= s.maxPayloadSize {
+ return
+ }
+
+ // Make sure we can transmit at least one byte.
+ if m <= 0 {
+ m = 1
+ }
+
+ s.maxPayloadSize = m
+
+ s.outstanding -= count
+ if s.outstanding < 0 {
+ s.outstanding = 0
+ }
+
+ // Rewind writeNext to the first segment exceeding the MTU. Do nothing
+ // if it is already before such a packet.
+ for seg := s.writeList.Front(); seg != nil; seg = seg.Next() {
+ if seg == s.writeNext {
+ // We got to writeNext before we could find a segment
+ // exceeding the MTU.
+ break
+ }
+
+ if seg.data.Size() > m {
+ // We found a segment exceeding the MTU. Rewind
+ // writeNext and try to retransmit it.
+ s.writeNext = seg
+ break
+ }
+ }
+
+ // Since we likely reduced the number of outstanding packets, we may be
+ // ready to send some more.
+ s.sendData()
+}
+
+// sendAck sends an ACK segment.
+func (s *sender) sendAck() {
+ s.sendSegment(nil, flagAck, s.sndNxt)
+}
+
+// updateRTO updates the retransmit timeout when a new roud-trip time is
+// available. This is done in accordance with section 2 of RFC 6298.
+func (s *sender) updateRTO(rtt time.Duration) {
+ if !s.srttInited {
+ s.rttvar = rtt / 2
+ s.srtt = rtt
+ s.srttInited = true
+ } else {
+ diff := s.srtt - rtt
+ if diff < 0 {
+ diff = -diff
+ }
+ s.rttvar = (3*s.rttvar + diff) / 4
+ s.srtt = (7*s.srtt + rtt) / 8
+ }
+
+ s.rto = s.srtt + 4*s.rttvar
+ if s.rto < minRTO {
+ s.rto = minRTO
+ }
+}
+
+// resendSegment resends the first unacknowledged segment.
+func (s *sender) resendSegment() {
+ // Don't use any segments we already sent to measure RTT as they may
+ // have been affected by packets being lost.
+ s.rttMeasureSeqNum = s.sndNxt
+
+ // Resend the segment.
+ if seg := s.writeList.Front(); seg != nil {
+ s.sendSegment(&seg.data, seg.flags, seg.sequenceNumber)
+ }
+}
+
+// reduceSlowStartThreshold reduces the slow-start threshold per RFC 5681,
+// page 6, eq. 4. It is called when we detect congestion in the network.
+func (s *sender) reduceSlowStartThreshold() {
+ s.sndSsthresh = s.outstanding / 2
+ if s.sndSsthresh < 2 {
+ s.sndSsthresh = 2
+ }
+}
+
+// retransmitTimerExpired is called when the retransmit timer expires, and
+// unacknowledged segments are assumed lost, and thus need to be resent.
+// Returns true if the connection is still usable, or false if the connection
+// is deemed lost.
+func (s *sender) retransmitTimerExpired() bool {
+ // Check if the timer actually expired or if it's a spurious wake due
+ // to a previously orphaned runtime timer.
+ if !s.resendTimer.checkExpiration() {
+ return true
+ }
+
+ // Give up if we've waited more than a minute since the last resend.
+ if s.rto >= 60*time.Second {
+ return false
+ }
+
+ // Set new timeout. The timer will be restarted by the call to sendData
+ // below.
+ s.rto *= 2
+
+ if s.fr.active {
+ // We were attempting fast recovery but were not successful.
+ // Leave the state. We don't need to update ssthresh because it
+ // has already been updated when entered fast-recovery.
+ s.leaveFastRecovery()
+ }
+
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 4.
+ // We store the highest sequence number transmitted in cases where
+ // we were not in fast recovery.
+ s.fr.last = s.sndNxt - 1
+
+ // We lost a packet, so reduce ssthresh.
+ s.reduceSlowStartThreshold()
+
+ // Reduce the congestion window to 1, i.e., enter slow-start. Per
+ // RFC 5681, page 7, we must use 1 regardless of the value of the
+ // initial congestion window.
+ s.sndCwnd = 1
+
+ // Mark the next segment to be sent as the first unacknowledged one and
+ // start sending again. Set the number of outstanding packets to 0 so
+ // that we'll be able to retransmit.
+ //
+ // We'll keep on transmitting (or retransmitting) as we get acks for
+ // the data we transmit.
+ s.outstanding = 0
+ s.writeNext = s.writeList.Front()
+ s.sendData()
+
+ return true
+}
+
+// sendData sends new data segments. It is called when data becomes available or
+// when the send window opens up.
+func (s *sender) sendData() {
+ limit := s.maxPayloadSize
+
+ // Reduce the congestion window to min(IW, cwnd) per RFC 5681, page 10.
+ // "A TCP SHOULD set cwnd to no more than RW before beginning
+ // transmission if the TCP has not sent data in the interval exceeding
+ // the retrasmission timeout."
+ if !s.fr.active && time.Now().Sub(s.lastSendTime) > s.rto {
+ if s.sndCwnd > InitialCwnd {
+ s.sndCwnd = InitialCwnd
+ }
+ }
+
+ // TODO: We currently don't merge multiple send buffers
+ // into one segment if they happen to fit. We should do that
+ // eventually.
+ var seg *segment
+ end := s.sndUna.Add(s.sndWnd)
+ for seg = s.writeNext; seg != nil && s.outstanding < s.sndCwnd; seg = seg.Next() {
+ // We abuse the flags field to determine if we have already
+ // assigned a sequence number to this segment.
+ if seg.flags == 0 {
+ seg.sequenceNumber = s.sndNxt
+ seg.flags = flagAck | flagPsh
+ }
+
+ var segEnd seqnum.Value
+ if seg.data.Size() == 0 {
+ seg.flags = flagAck
+
+ s.ep.rcvListMu.Lock()
+ rcvBufUsed := s.ep.rcvBufUsed
+ s.ep.rcvListMu.Unlock()
+
+ s.ep.mu.Lock()
+ // We're sending a FIN by default
+ fl := flagFin
+ if (s.ep.shutdownFlags&tcpip.ShutdownRead) != 0 && rcvBufUsed > 0 {
+ // If there is unread data we must send a RST.
+ // For more information see RFC 2525 section 2.17.
+ fl = flagRst
+ }
+ s.ep.mu.Unlock()
+ seg.flags |= uint8(fl)
+
+ segEnd = seg.sequenceNumber.Add(1)
+ } else {
+ // We're sending a non-FIN segment.
+ if !seg.sequenceNumber.LessThan(end) {
+ break
+ }
+
+ available := int(seg.sequenceNumber.Size(end))
+ if available > limit {
+ available = limit
+ }
+
+ if seg.data.Size() > available {
+ // Split this segment up.
+ nSeg := seg.clone()
+ nSeg.data.TrimFront(available)
+ nSeg.sequenceNumber.UpdateForward(seqnum.Size(available))
+ s.writeList.InsertAfter(seg, nSeg)
+ seg.data.CapLength(available)
+ }
+
+ s.outstanding++
+ segEnd = seg.sequenceNumber.Add(seqnum.Size(seg.data.Size()))
+ }
+
+ s.sendSegment(&seg.data, seg.flags, seg.sequenceNumber)
+
+ // Update sndNxt if we actually sent new data (as opposed to
+ // retransmitting some previously sent data).
+ if s.sndNxt.LessThan(segEnd) {
+ s.sndNxt = segEnd
+ }
+ }
+
+ // Remember the next segment we'll write.
+ s.writeNext = seg
+
+ // Enable the timer if we have pending data and it's not enabled yet.
+ if !s.resendTimer.enabled() && s.sndUna != s.sndNxt {
+ s.resendTimer.enable(s.rto)
+ }
+}
+
+func (s *sender) enterFastRecovery() {
+ // Save state to reflect we're now in fast recovery.
+ s.reduceSlowStartThreshold()
+ // Save state to reflect we're now in fast recovery.
+ // See : https://tools.ietf.org/html/rfc5681#section-3.2 Step 3.
+ // We inflat the cwnd by 3 to account for the 3 packets which triggered
+ // the 3 duplicate ACKs and are now not in flight.
+ s.sndCwnd = s.sndSsthresh + 3
+ s.fr.first = s.sndUna
+ s.fr.last = s.sndNxt - 1
+ s.fr.maxCwnd = s.sndCwnd + s.outstanding
+ s.fr.active = true
+}
+
+func (s *sender) leaveFastRecovery() {
+ s.fr.active = false
+ s.fr.first = 0
+ s.fr.last = s.sndNxt - 1
+ s.fr.maxCwnd = 0
+ s.dupAckCount = 0
+
+ // Deflate cwnd. It had been artificially inflated when new dups arrived.
+ s.sndCwnd = s.sndSsthresh
+}
+
+// checkDuplicateAck is called when an ack is received. It manages the state
+// related to duplicate acks and determines if a retransmit is needed according
+// to the rules in RFC 6582 (NewReno).
+func (s *sender) checkDuplicateAck(seg *segment) bool {
+ ack := seg.ackNumber
+ if s.fr.active {
+ // We are in fast recovery mode. Ignore the ack if it's out of
+ // range.
+ if !ack.InRange(s.sndUna, s.sndNxt+1) {
+ return false
+ }
+
+ // Leave fast recovery if it acknowledges all the data covered by
+ // this fast recovery session.
+ if s.fr.last.LessThan(ack) {
+ s.leaveFastRecovery()
+ return false
+ }
+
+ // Don't count this as a duplicate if it is carrying data or
+ // updating the window.
+ if seg.logicalLen() != 0 || s.sndWnd != seg.window {
+ return false
+ }
+
+ // Inflate the congestion window if we're getting duplicate acks
+ // for the packet we retransmitted.
+ if ack == s.fr.first {
+ // We received a dup, inflate the congestion window by 1
+ // packet if we're not at the max yet.
+ if s.sndCwnd < s.fr.maxCwnd {
+ s.sndCwnd++
+ }
+ return false
+ }
+
+ // A partial ack was received. Retransmit this packet and
+ // remember it so that we don't retransmit it again. We don't
+ // inflate the window because we're putting the same packet back
+ // onto the wire.
+ //
+ // N.B. The retransmit timer will be reset by the caller.
+ s.fr.first = ack
+ return true
+ }
+
+ // We're not in fast recovery yet. A segment is considered a duplicate
+ // only if it doesn't carry any data and doesn't update the send window,
+ // because if it does, it wasn't sent in response to an out-of-order
+ // segment.
+ if ack != s.sndUna || seg.logicalLen() != 0 || s.sndWnd != seg.window || ack == s.sndNxt {
+ s.dupAckCount = 0
+ return false
+ }
+
+ // Enter fast recovery when we reach 3 dups.
+ s.dupAckCount++
+ if s.dupAckCount != 3 {
+ return false
+ }
+
+ // See: https://tools.ietf.org/html/rfc6582#section-3.2 Step 2
+ //
+ // We only do the check here, the incrementing of last to the highest
+ // sequence number transmitted till now is done when enterFastRecovery
+ // is invoked.
+ if !s.fr.last.LessThan(seg.ackNumber) {
+ s.dupAckCount = 0
+ return false
+ }
+ s.enterFastRecovery()
+ s.dupAckCount = 0
+ return true
+}
+
+// updateCwnd updates the congestion window based on the number of packets that
+// were acknowledged.
+func (s *sender) updateCwnd(packetsAcked int) {
+ if s.sndCwnd < s.sndSsthresh {
+ // Don't let the congestion window cross into the congestion
+ // avoidance range.
+ newcwnd := s.sndCwnd + packetsAcked
+ if newcwnd >= s.sndSsthresh {
+ newcwnd = s.sndSsthresh
+ s.sndCAAckCount = 0
+ }
+
+ packetsAcked -= newcwnd - s.sndCwnd
+ s.sndCwnd = newcwnd
+ if packetsAcked == 0 {
+ // We've consumed all ack'd packets.
+ return
+ }
+ }
+
+ // Consume the packets in congestion avoidance mode.
+ s.sndCAAckCount += packetsAcked
+ if s.sndCAAckCount >= s.sndCwnd {
+ s.sndCwnd += s.sndCAAckCount / s.sndCwnd
+ s.sndCAAckCount = s.sndCAAckCount % s.sndCwnd
+ }
+}
+
+// handleRcvdSegment is called when a segment is received; it is responsible for
+// updating the send-related state.
+func (s *sender) handleRcvdSegment(seg *segment) {
+ // Check if we can extract an RTT measurement from this ack.
+ if s.rttMeasureSeqNum.LessThan(seg.ackNumber) {
+ s.updateRTO(time.Now().Sub(s.rttMeasureTime))
+ s.rttMeasureSeqNum = s.sndNxt
+ }
+
+ // Update Timestamp if required. See RFC7323, section-4.3.
+ s.ep.updateRecentTimestamp(seg.parsedOptions.TSVal, s.maxSentAck, seg.sequenceNumber)
+
+ // Count the duplicates and do the fast retransmit if needed.
+ rtx := s.checkDuplicateAck(seg)
+
+ // Stash away the current window size.
+ s.sndWnd = seg.window
+
+ // Ignore ack if it doesn't acknowledge any new data.
+ ack := seg.ackNumber
+ if (ack - 1).InRange(s.sndUna, s.sndNxt) {
+ // When an ack is received we must reset the timer. We stop it
+ // here and it will be restarted later if needed.
+ s.resendTimer.disable()
+
+ // Remove all acknowledged data from the write list.
+ acked := s.sndUna.Size(ack)
+ s.sndUna = ack
+
+ ackLeft := acked
+ originalOutstanding := s.outstanding
+ for ackLeft > 0 {
+ // We use logicalLen here because we can have FIN
+ // segments (which are always at the end of list) that
+ // have no data, but do consume a sequence number.
+ seg := s.writeList.Front()
+ datalen := seg.logicalLen()
+
+ if datalen > ackLeft {
+ seg.data.TrimFront(int(ackLeft))
+ break
+ }
+
+ if s.writeNext == seg {
+ s.writeNext = seg.Next()
+ }
+ s.writeList.Remove(seg)
+ s.outstanding--
+ seg.decRef()
+ ackLeft -= datalen
+ }
+
+ // Update the send buffer usage and notify potential waiters.
+ s.ep.updateSndBufferUsage(int(acked))
+
+ // If we are not in fast recovery then update the congestion
+ // window based on the number of acknowledged packets.
+ if !s.fr.active {
+ s.updateCwnd(originalOutstanding - s.outstanding)
+ }
+
+ // It is possible for s.outstanding to drop below zero if we get
+ // a retransmit timeout, reset outstanding to zero but later
+ // get an ack that cover previously sent data.
+ if s.outstanding < 0 {
+ s.outstanding = 0
+ }
+ }
+
+ // Now that we've popped all acknowledged data from the retransmit
+ // queue, retransmit if needed.
+ if rtx {
+ s.resendSegment()
+ }
+
+ // Send more data now that some of the pending data has been ack'd, or
+ // that the window opened up, or the congestion window was inflated due
+ // to a duplicate ack during fast recovery. This will also re-enable
+ // the retransmit timer if needed.
+ s.sendData()
+}
+
+// sendSegment sends a new segment containing the given payload, flags and
+// sequence number.
+func (s *sender) sendSegment(data *buffer.VectorisedView, flags byte, seq seqnum.Value) *tcpip.Error {
+ s.lastSendTime = time.Now()
+ if seq == s.rttMeasureSeqNum {
+ s.rttMeasureTime = s.lastSendTime
+ }
+
+ rcvNxt, rcvWnd := s.ep.rcv.getSendParams()
+
+ // Remember the max sent ack.
+ s.maxSentAck = rcvNxt
+
+ if data == nil {
+ return s.ep.sendRaw(nil, flags, seq, rcvNxt, rcvWnd)
+ }
+
+ if len(data.Views()) > 1 {
+ panic("send path does not support views with multiple buffers")
+ }
+
+ return s.ep.sendRaw(data.First(), flags, seq, rcvNxt, rcvWnd)
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_sack_test.go b/pkg/tcpip/transport/tcp/tcp_sack_test.go
new file mode 100644
index 000000000..19de96dcb
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_sack_test.go
@@ -0,0 +1,336 @@
+package tcp_test
+
+import (
+ "fmt"
+ "reflect"
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context"
+)
+
+// createConnectWithSACKPermittedOption creates and connects c.ep with the
+// SACKPermitted option enabled if the stack in the context has the SACK support
+// enabled.
+func createConnectedWithSACKPermittedOption(c *context.Context) *context.RawEndpoint {
+ return c.CreateConnectedWithOptions(header.TCPSynOptions{SACKPermitted: c.SACKEnabled()})
+}
+
+func setStackSACKPermitted(t *testing.T, c *context.Context, enable bool) {
+ t.Helper()
+ if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enable)); err != nil {
+ t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, SACKEnabled(%v) = %v", enable, err)
+ }
+}
+
+// TestSackPermittedConnect establishes a connection with the SACK option
+// enabled.
+func TestSackPermittedConnect(t *testing.T) {
+ for _, sackEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ setStackSACKPermitted(t, c, sackEnabled)
+ rep := createConnectedWithSACKPermittedOption(c)
+ data := []byte{1, 2, 3}
+
+ rep.SendPacket(data, nil)
+ savedSeqNum := rep.NextSeqNum
+ rep.VerifyACKNoSACK()
+
+ // Make an out of order packet and send it.
+ rep.NextSeqNum += 3
+ sackBlocks := []header.SACKBlock{
+ {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
+ }
+ rep.SendPacket(data, nil)
+
+ // Restore the saved sequence number so that the
+ // VerifyXXX calls use the right sequence number for
+ // checking ACK numbers.
+ rep.NextSeqNum = savedSeqNum
+ if sackEnabled {
+ rep.VerifyACKHasSACK(sackBlocks)
+ } else {
+ rep.VerifyACKNoSACK()
+ }
+
+ // Send the missing segment.
+ rep.SendPacket(data, nil)
+ // The ACK should contain the cumulative ACK for all 9
+ // bytes sent and no SACK blocks.
+ rep.NextSeqNum += 3
+ // Check that no SACK block is returned in the ACK.
+ rep.VerifyACKNoSACK()
+ })
+ }
+}
+
+// TestSackDisabledConnect establishes a connection with the SACK option
+// disabled and verifies that no SACKs are sent for out of order segments.
+func TestSackDisabledConnect(t *testing.T) {
+ for _, sackEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("sackEnabled: %v", sackEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ setStackSACKPermitted(t, c, sackEnabled)
+
+ rep := c.CreateConnectedWithOptions(header.TCPSynOptions{})
+
+ data := []byte{1, 2, 3}
+
+ rep.SendPacket(data, nil)
+ savedSeqNum := rep.NextSeqNum
+ rep.VerifyACKNoSACK()
+
+ // Make an out of order packet and send it.
+ rep.NextSeqNum += 3
+ rep.SendPacket(data, nil)
+
+ // The ACK should contain the older sequence number and
+ // no SACK blocks.
+ rep.NextSeqNum = savedSeqNum
+ rep.VerifyACKNoSACK()
+
+ // Send the missing segment.
+ rep.SendPacket(data, nil)
+ // The ACK should contain the cumulative ACK for all 9
+ // bytes sent and no SACK blocks.
+ rep.NextSeqNum += 3
+ // Check that no SACK block is returned in the ACK.
+ rep.VerifyACKNoSACK()
+ })
+ }
+}
+
+// TestSackPermittedAccept accepts and establishes a connection with the
+// SACKPermitted option enabled if the connection request specifies the
+// SACKPermitted option. In case of SYN cookies SACK should be disabled as we
+// don't encode the SACK information in the cookie.
+func TestSackPermittedAccept(t *testing.T) {
+ type testCase struct {
+ cookieEnabled bool
+ sackPermitted bool
+ wndScale int
+ wndSize uint16
+ }
+
+ testCases := []testCase{
+ // When cookie is used window scaling is disabled.
+ {true, false, -1, 0xffff}, // When cookie is used window scaling is disabled.
+ {false, true, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
+ }
+ savedSynCountThreshold := tcp.SynRcvdCountThreshold
+ defer func() {
+ tcp.SynRcvdCountThreshold = savedSynCountThreshold
+ }()
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
+ if tc.cookieEnabled {
+ tcp.SynRcvdCountThreshold = 0
+ } else {
+ tcp.SynRcvdCountThreshold = savedSynCountThreshold
+ }
+ for _, sackEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("test stack.sackEnabled: %v", sackEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ setStackSACKPermitted(t, c, sackEnabled)
+
+ rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, SACKPermitted: tc.sackPermitted})
+ // Now verify no SACK blocks are
+ // received when sack is disabled.
+ data := []byte{1, 2, 3}
+ rep.SendPacket(data, nil)
+ rep.VerifyACKNoSACK()
+
+ savedSeqNum := rep.NextSeqNum
+
+ // Make an out of order packet and send
+ // it.
+ rep.NextSeqNum += 3
+ sackBlocks := []header.SACKBlock{
+ {rep.NextSeqNum, rep.NextSeqNum.Add(seqnum.Size(len(data)))},
+ }
+ rep.SendPacket(data, nil)
+
+ // The ACK should contain the older
+ // sequence number.
+ rep.NextSeqNum = savedSeqNum
+ if sackEnabled && tc.sackPermitted {
+ rep.VerifyACKHasSACK(sackBlocks)
+ } else {
+ rep.VerifyACKNoSACK()
+ }
+
+ // Send the missing segment.
+ rep.SendPacket(data, nil)
+ // The ACK should contain the cumulative
+ // ACK for all 9 bytes sent and no SACK
+ // blocks.
+ rep.NextSeqNum += 3
+ // Check that no SACK block is returned
+ // in the ACK.
+ rep.VerifyACKNoSACK()
+ })
+ }
+ })
+ }
+}
+
+// TestSackDisabledAccept accepts and establishes a connection with
+// the SACKPermitted option disabled and verifies that no SACKs are
+// sent for out of order packets.
+func TestSackDisabledAccept(t *testing.T) {
+ type testCase struct {
+ cookieEnabled bool
+ wndScale int
+ wndSize uint16
+ }
+
+ testCases := []testCase{
+ // When cookie is used window scaling is disabled.
+ {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
+ {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
+ }
+ savedSynCountThreshold := tcp.SynRcvdCountThreshold
+ defer func() {
+ tcp.SynRcvdCountThreshold = savedSynCountThreshold
+ }()
+ for _, tc := range testCases {
+ t.Run(fmt.Sprintf("test: %#v", tc), func(t *testing.T) {
+ if tc.cookieEnabled {
+ tcp.SynRcvdCountThreshold = 0
+ } else {
+ tcp.SynRcvdCountThreshold = savedSynCountThreshold
+ }
+ for _, sackEnabled := range []bool{false, true} {
+ t.Run(fmt.Sprintf("test: sackEnabled: %v", sackEnabled), func(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+ setStackSACKPermitted(t, c, sackEnabled)
+
+ rep := c.AcceptWithOptions(tc.wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Now verify no SACK blocks are
+ // received when sack is disabled.
+ data := []byte{1, 2, 3}
+ rep.SendPacket(data, nil)
+ rep.VerifyACKNoSACK()
+ savedSeqNum := rep.NextSeqNum
+
+ // Make an out of order packet and send
+ // it.
+ rep.NextSeqNum += 3
+ rep.SendPacket(data, nil)
+
+ // The ACK should contain the older
+ // sequence number and no SACK blocks.
+ rep.NextSeqNum = savedSeqNum
+ rep.VerifyACKNoSACK()
+
+ // Send the missing segment.
+ rep.SendPacket(data, nil)
+ // The ACK should contain the cumulative
+ // ACK for all 9 bytes sent and no SACK
+ // blocks.
+ rep.NextSeqNum += 3
+ // Check that no SACK block is returned
+ // in the ACK.
+ rep.VerifyACKNoSACK()
+ })
+ }
+ })
+ }
+}
+
+func TestUpdateSACKBlocks(t *testing.T) {
+ testCases := []struct {
+ segStart seqnum.Value
+ segEnd seqnum.Value
+ rcvNxt seqnum.Value
+ sackBlocks []header.SACKBlock
+ updated []header.SACKBlock
+ }{
+ // Trivial cases where current SACK block list is empty and we
+ // have an out of order delivery.
+ {10, 11, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 11}}},
+ {10, 12, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 12}}},
+ {10, 20, 2, []header.SACKBlock{}, []header.SACKBlock{{10, 20}}},
+
+ // Cases where current SACK block list is not empty and we have
+ // an out of order delivery. Tests that the updated SACK block
+ // list has the first block as the one that contains the new
+ // SACK block representing the segment that was just delivered.
+ {10, 11, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 11}, {12, 20}}},
+ {24, 30, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{24, 30}, {12, 20}}},
+ {24, 30, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}}},
+
+ // Ensure that we only retain header.MaxSACKBlocks and drop the
+ // oldest one if adding a new block exceeds
+ // header.MaxSACKBlocks.
+ {24, 30, 9,
+ []header.SACKBlock{{12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}, {72, 80}},
+ []header.SACKBlock{{24, 30}, {12, 20}, {32, 40}, {42, 50}, {52, 60}, {62, 70}}},
+
+ // Cases where segment extends an existing SACK block.
+ {10, 12, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 20}}},
+ {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
+ {10, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{10, 22}}},
+ {15, 22, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 22}}},
+ {15, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{12, 25}}},
+ {11, 25, 9, []header.SACKBlock{{12, 20}}, []header.SACKBlock{{11, 25}}},
+ {10, 12, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 20}, {32, 40}}},
+ {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
+ {10, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{10, 22}, {32, 40}}},
+ {15, 22, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 22}, {32, 40}}},
+ {15, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{12, 25}, {32, 40}}},
+ {11, 25, 9, []header.SACKBlock{{12, 20}, {32, 40}}, []header.SACKBlock{{11, 25}, {32, 40}}},
+
+ // Cases where segment contains rcvNxt.
+ {10, 20, 15, []header.SACKBlock{{20, 30}, {40, 50}}, []header.SACKBlock{{40, 50}}},
+ }
+
+ for _, tc := range testCases {
+ var sack tcp.SACKInfo
+ copy(sack.Blocks[:], tc.sackBlocks)
+ sack.NumBlocks = len(tc.sackBlocks)
+ tcp.UpdateSACKBlocks(&sack, tc.segStart, tc.segEnd, tc.rcvNxt)
+ if got, want := sack.Blocks[:sack.NumBlocks], tc.updated; !reflect.DeepEqual(got, want) {
+ t.Errorf("UpdateSACKBlocks(%v, %v, %v, %v), got: %v, want: %v", tc.sackBlocks, tc.segStart, tc.segEnd, tc.rcvNxt, got, want)
+ }
+
+ }
+}
+
+func TestTrimSackBlockList(t *testing.T) {
+ testCases := []struct {
+ rcvNxt seqnum.Value
+ sackBlocks []header.SACKBlock
+ trimmed []header.SACKBlock
+ }{
+ // Simple cases where we trim whole entries.
+ {2, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}},
+ {21, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{22, 30}, {32, 40}}},
+ {31, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{32, 40}}},
+ {40, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
+ // Cases where we need to update a block.
+ {12, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{12, 20}, {22, 30}, {32, 40}}},
+ {23, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{23, 30}, {32, 40}}},
+ {33, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{{33, 40}}},
+ {41, []header.SACKBlock{{10, 20}, {22, 30}, {32, 40}}, []header.SACKBlock{}},
+ }
+ for _, tc := range testCases {
+ var sack tcp.SACKInfo
+ copy(sack.Blocks[:], tc.sackBlocks)
+ sack.NumBlocks = len(tc.sackBlocks)
+ tcp.TrimSACKBlockList(&sack, tc.rcvNxt)
+ if got, want := sack.Blocks[:sack.NumBlocks], tc.trimmed; !reflect.DeepEqual(got, want) {
+ t.Errorf("TrimSackBlockList(%v, %v), got: %v, want: %v", tc.sackBlocks, tc.rcvNxt, got, want)
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
new file mode 100644
index 000000000..118d861ba
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -0,0 +1,2759 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcp_test
+
+import (
+ "bytes"
+ "fmt"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/checker"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/loopback"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // defaultMTU is the MTU, in bytes, used throughout the tests, except
+ // where another value is explicitly used. It is chosen to match the MTU
+ // of loopback interfaces on linux systems.
+ defaultMTU = 65535
+
+ // defaultIPv4MSS is the MSS sent by the network stack in SYN/SYN-ACK for an
+ // IPv4 endpoint when the MTU is set to defaultMTU in the test.
+ defaultIPv4MSS = defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
+)
+
+func TestGiveUpConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ var wq waiter.Queue
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Register for notification, then start connection attempt.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer wq.EventUnregister(&waitEntry)
+
+ if err := ep.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort}); err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Close the connection, wait for completion.
+ ep.Close()
+
+ // Wait for ep to become writable.
+ <-notifyCh
+ if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
+ t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %v, want = %v", err, tcpip.ErrAborted)
+ }
+}
+
+func TestActiveHandshake(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+}
+
+func TestNonBlockingClose(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+ ep := c.EP
+ c.EP = nil
+
+ // Close the endpoint and measure how long it takes.
+ t0 := time.Now()
+ ep.Close()
+ if diff := time.Now().Sub(t0); diff > 3*time.Second {
+ t.Fatalf("Took too long to close: %v", diff)
+ }
+}
+
+func TestConnectResetAfterClose(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+ ep := c.EP
+ c.EP = nil
+
+ // Close the endpoint, make sure we get a FIN segment, then acknowledge
+ // to complete closure of sender, but don't send our own FIN.
+ ep.Close()
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Wait for the ep to give up waiting for a FIN, and send a RST.
+ time.Sleep(3 * time.Second)
+ for {
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ if tcp.Flags() == header.TCPFlagAck|header.TCPFlagFin {
+ // This is a retransmit of the FIN, ignore it.
+ continue
+ }
+
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ),
+ )
+ break
+ }
+}
+
+func TestSimpleReceive(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ data := []byte{1, 2, 3}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Receive data.
+ v, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ if bytes.Compare(data, v) != 0 {
+ t.Fatalf("Data is different: expected %v, got %v", data, v)
+ }
+
+ // Check that ACK is received.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestOutOfOrderReceive(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ // Send second half of data first, with seqnum 3 ahead of expected.
+ data := []byte{1, 2, 3, 4, 5, 6}
+ c.SendPacket(data[3:], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 793,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Check that we get an ACK specifying which seqnum is expected.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Wait 200ms and check that no data has been received.
+ time.Sleep(200 * time.Millisecond)
+ if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ // Send the first 3 bytes now.
+ c.SendPacket(data[:3], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Receive data.
+ read := make([]byte, 0, 6)
+ for len(read) < len(data) {
+ v, err := c.EP.Read(nil)
+ if err != nil {
+ if err == tcpip.ErrWouldBlock {
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+ continue
+ }
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ read = append(read, v...)
+ }
+
+ // Check that we received the data in proper order.
+ if bytes.Compare(data, read) != 0 {
+ t.Fatalf("Data is different: expected %v, got %v", data, read)
+ }
+
+ // Check that the whole data is acknowledged.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestOutOfOrderFlood(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create a new connection with initial window size of 10.
+ opt := tcpip.ReceiveBufferSizeOption(10)
+ c.CreateConnected(789, 30000, &opt)
+
+ if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ // Send 100 packets before the actual one that is expected.
+ data := []byte{1, 2, 3, 4, 5, 6}
+ for i := 0; i < 100; i++ {
+ c.SendPacket(data[3:], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 796,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ // Send packet with seqnum 793. It must be discarded because the
+ // out-of-order buffer was filled by the previous packets.
+ c.SendPacket(data[3:], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 793,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Now send the expected packet, seqnum 790.
+ c.SendPacket(data[:3], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Check that only packet 790 is acknowledged.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(793),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestRstOnCloseWithUnreadData(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ data := []byte{1, 2, 3}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(3 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Check that ACK is received, this happens regardless of the read.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Now that we know we have unread data, let's just close the connection
+ // and verify that netstack sends an RST rather than a FIN.
+ c.EP.Close()
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagRst),
+ ))
+}
+
+func TestFullWindowReceive(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ opt := tcpip.ReceiveBufferSizeOption(10)
+ c.CreateConnected(789, 30000, &opt)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ _, err := c.EP.Read(nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ // Fill up the window.
+ data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Check that data is acknowledged, and window goes to zero.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.Window(0),
+ ),
+ )
+
+ // Receive data and check it.
+ v, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ if bytes.Compare(data, v) != 0 {
+ t.Fatalf("Data is different: expected %v, got %v", data, v)
+ }
+
+ // Check that we get an ACK for the newly non-zero window.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.Window(10),
+ ),
+ )
+}
+
+func TestNoWindowShrinking(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Start off with a window size of 10, then shrink it to 5.
+ opt := tcpip.ReceiveBufferSizeOption(10)
+ c.CreateConnected(789, 30000, &opt)
+
+ opt = 5
+ if err := c.EP.SetSockOpt(opt); err != nil {
+ t.Fatalf("SetSockOpt failed: %v", err)
+ }
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ _, err := c.EP.Read(nil)
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ // Send 3 bytes, check that the peer acknowledges them.
+ data := []byte{1, 2, 3, 4, 5, 6, 7, 8, 9, 10}
+ c.SendPacket(data[:3], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Check that data is acknowledged, and that window doesn't go to zero
+ // just yet because it was previously set to 10. It must go to 7 now.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(793),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.Window(7),
+ ),
+ )
+
+ // Send 7 more bytes, check that the window fills up.
+ c.SendPacket(data[3:], &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 793,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ select {
+ case <-ch:
+ case <-time.After(5 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.Window(0),
+ ),
+ )
+
+ // Receive data and check it.
+ read := make([]byte, 0, 10)
+ for len(read) < len(data) {
+ v, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ read = append(read, v...)
+ }
+
+ if bytes.Compare(data, read) != 0 {
+ t.Fatalf("Data is different: expected %v, got %v", data, read)
+ }
+
+ // Check that we get an ACK for the newly non-zero window, which is the
+ // new size.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.Window(5),
+ ),
+ )
+}
+
+func TestSimpleSend(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Check that data is received.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(data, p) != 0 {
+ t.Fatalf("Data is different: expected %v, got %v", data, p)
+ }
+
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
+ RcvWnd: 30000,
+ })
+}
+
+func TestZeroWindowSend(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 0, nil)
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Since the window is currently zero, check that no packet is received.
+ c.CheckNoPacket("Packet received when window is zero")
+
+ // Open up the window. Data should be received now.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Check that data is received.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(data, p) != 0 {
+ t.Fatalf("Data is different: expected %v, got %v", data, p)
+ }
+
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1 + seqnum.Size(len(data))),
+ RcvWnd: 30000,
+ })
+}
+
+func TestScaledWindowConnect(t *testing.T) {
+ // This test ensures that window scaling is used when the peer
+ // does advertise it and connection is established with Connect().
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set the window size greater than the maximum non-scaled window.
+ opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
+ c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
+ })
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Check that data is received, and that advertised window is 0xbfff,
+ // that is, that it is scaled.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.Window(0xbfff),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+}
+
+func TestNonScaledWindowConnect(t *testing.T) {
+ // This test ensures that window scaling is not used when the peer
+ // doesn't advertise it and connection is established with Connect().
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set the window size greater than the maximum non-scaled window.
+ opt := tcpip.ReceiveBufferSizeOption(65535 * 3)
+ c.CreateConnected(789, 30000, &opt)
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Check that data is received, and that advertised window is 0xffff,
+ // that is, that it's not scaled.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.Window(0xffff),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+}
+
+func TestScaledWindowAccept(t *testing.T) {
+ // This test ensures that window scaling is used when the peer
+ // does advertise it and connection is established with Accept().
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create EP and start listening.
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ // Set the window size greater than the maximum non-scaled window.
+ if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Do 3-way handshake.
+ c.PassiveConnectWithOptions(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Check that data is received, and that advertised window is 0xbfff,
+ // that is, that it is scaled.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.Window(0xbfff),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+}
+
+func TestNonScaledWindowAccept(t *testing.T) {
+ // This test ensures that window scaling is not used when the peer
+ // doesn't advertise it and connection is established with Accept().
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create EP and start listening.
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ // Set the window size greater than the maximum non-scaled window.
+ if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(65535 * 3)); err != nil {
+ t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Do 3-way handshake.
+ c.PassiveConnect(100, 2, header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Check that data is received, and that advertised window is 0xffff,
+ // that is, that it's not scaled.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.Window(0xffff),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+}
+
+func TestZeroScaledWindowReceive(t *testing.T) {
+ // This test ensures that the endpoint sends a non-zero window size
+ // advertisement when the scaled window transitions from 0 to non-zero,
+ // but the actual window (not scaled) hasn't gotten to zero.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Set the window size such that a window scale of 4 will be used.
+ const wnd = 65535 * 10
+ const ws = uint32(4)
+ opt := tcpip.ReceiveBufferSizeOption(wnd)
+ c.CreateConnectedWithRawOptions(789, 30000, &opt, []byte{
+ header.TCPOptionWS, 3, 0, header.TCPOptionNOP,
+ })
+
+ // Write chunks of 50000 bytes.
+ remain := wnd
+ sent := 0
+ data := make([]byte, 50000)
+ for remain > len(data) {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(remain>>ws)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ // Make the window non-zero, but the scaled window zero.
+ if remain >= 16 {
+ data = data[:remain-15]
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + sent),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ sent += len(data)
+ remain -= len(data)
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(0),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ }
+
+ // Read some data. An ack should be sent in response to that.
+ v, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(790+sent)),
+ checker.Window(uint16(len(v)>>ws)),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) {
+ payloadMultiplier := 10
+ dataLen := payloadMultiplier * maxPayload
+ data := make([]byte, dataLen)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Check that data is received in chunks.
+ bytesReceived := 0
+ numPackets := 0
+ for bytesReceived != dataLen {
+ b := c.GetPacket()
+ numPackets++
+ tcp := header.TCP(header.IPv4(b).Payload())
+ payloadLen := len(tcp.Payload())
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1+uint32(bytesReceived)),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ pdata := data[bytesReceived : bytesReceived+payloadLen]
+ if p := tcp.Payload(); bytes.Compare(pdata, p) != 0 {
+ t.Fatalf("Data is different: expected %v, got %v", pdata, p)
+ }
+ bytesReceived += payloadLen
+ var options []byte
+ if c.TimeStampEnabled {
+ // If timestamp option is enabled, echo back the timestamp and increment
+ // the TSEcr value included in the packet and send that back as the TSVal.
+ parsedOpts := tcp.ParsedOptions()
+ tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:])
+ options = tsOpt[:]
+ }
+ // Acknowledge the data.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
+ RcvWnd: 30000,
+ TCPOpts: options,
+ })
+ }
+ if numPackets == 1 {
+ t.Fatalf("expected write to be broken up into multiple packets, but got 1 packet")
+ }
+}
+
+func TestSendGreaterThanMTU(t *testing.T) {
+ const maxPayload = 100
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+ testBrokenUpWrite(t, c, maxPayload)
+}
+
+func TestActiveSendMSSLessThanMTU(t *testing.T) {
+ const maxPayload = 100
+ c := context.New(t, 65535)
+ defer c.Cleanup()
+
+ c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
+ })
+ testBrokenUpWrite(t, c, maxPayload)
+}
+
+func TestPassiveSendMSSLessThanMTU(t *testing.T) {
+ const maxPayload = 100
+ const mtu = 1200
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ // Create EP and start listening.
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ // Set the buffer size to a deterministic size so that we can check the
+ // window scaling option.
+ const rcvBufferSize = 0x20000
+ const wndScale = 2
+ if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Do 3-way handshake.
+ c.PassiveConnect(maxPayload, wndScale, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Check that data gets properly segmented.
+ testBrokenUpWrite(t, c, maxPayload)
+}
+
+func TestSynCookiePassiveSendMSSLessThanMTU(t *testing.T) {
+ const maxPayload = 536
+ const mtu = 2000
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ // Set the SynRcvd threshold to zero to force a syn cookie based accept
+ // to happen.
+ saved := tcp.SynRcvdCountThreshold
+ defer func() {
+ tcp.SynRcvdCountThreshold = saved
+ }()
+ tcp.SynRcvdCountThreshold = 0
+
+ // Create EP and start listening.
+ wq := &waiter.Queue{}
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Do 3-way handshake.
+ c.PassiveConnect(maxPayload, -1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for accept")
+ }
+ }
+
+ // Check that data gets properly segmented.
+ testBrokenUpWrite(t, c, maxPayload)
+}
+
+func TestForwarderSendMSSLessThanMTU(t *testing.T) {
+ const maxPayload = 100
+ const mtu = 1200
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ s := c.Stack()
+ ch := make(chan *tcpip.Error, 1)
+ f := tcp.NewForwarder(s, 65536, 10, func(r *tcp.ForwarderRequest) {
+ var err *tcpip.Error
+ c.EP, err = r.CreateEndpoint(&c.WQ)
+ ch <- err
+ })
+ s.SetTransportProtocolHandler(tcp.ProtocolNumber, f.HandlePacket)
+
+ // Do 3-way handshake.
+ c.PassiveConnect(maxPayload, 1, header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize})
+
+ // Wait for connection to be available.
+ select {
+ case err := <-ch:
+ if err != nil {
+ t.Fatalf("Error creating endpoint: %v", err)
+ }
+ case <-time.After(2 * time.Second):
+ t.Fatalf("Timed out waiting for connection")
+ }
+
+ // Check that data gets properly segmented.
+ testBrokenUpWrite(t, c, maxPayload)
+}
+
+func TestSynOptionsOnActiveConnect(t *testing.T) {
+ const mtu = 1400
+ c := context.New(t, mtu)
+ defer c.Cleanup()
+
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Set the buffer size to a deterministic size so that we can check the
+ // window scaling option.
+ const rcvBufferSize = 0x20000
+ const wndScale = 2
+ if err := c.EP.SetSockOpt(tcpip.ReceiveBufferSizeOption(rcvBufferSize)); err != nil {
+ t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+
+ // Start connection attempt.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventOut)
+ defer c.WQ.EventUnregister(&we)
+
+ err = c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, WS: wndScale}),
+ ),
+ )
+
+ tcp := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcp.SequenceNumber())
+
+ // Wait for retransmit.
+ time.Sleep(1 * time.Second)
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.SrcPort(tcp.SourcePort()),
+ checker.SeqNum(tcp.SequenceNumber()),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: mtu - header.IPv4MinimumSize - header.TCPMinimumSize, WS: wndScale}),
+ ),
+ )
+
+ // Send SYN-ACK.
+ iss := seqnum.Value(789)
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: tcp.DestinationPort(),
+ DstPort: tcp.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ // Receive ACK packet.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ),
+ )
+
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ err = c.EP.GetSockOpt(tcpip.ErrorOption{})
+ if err != nil {
+ t.Fatalf("Unexpected error when connecting: %v", err)
+ }
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for connection")
+ }
+}
+
+func TestCloseListener(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create listener.
+ var wq waiter.Queue
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Bind(tcpip.FullAddress{}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Close the listener and measure how long it takes.
+ t0 := time.Now()
+ ep.Close()
+ if diff := time.Now().Sub(t0); diff > 3*time.Second {
+ t.Fatalf("Took too long to close: %v", diff)
+ }
+}
+
+func TestReceiveOnResetConnection(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ // Send RST segment.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagRst,
+ SeqNum: 790,
+ RcvWnd: 30000,
+ })
+
+ // Try to read.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+loop:
+ for {
+ switch _, err := c.EP.Read(nil); err {
+ case nil:
+ t.Fatalf("Unexpected success.")
+ case tcpip.ErrWouldBlock:
+ select {
+ case <-ch:
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for reset to arrive")
+ }
+ case tcpip.ErrConnectionReset:
+ break loop
+ default:
+ t.Fatalf("Unexpected error: want %v, got %v", tcpip.ErrConnectionReset, err)
+ }
+ }
+}
+
+func TestSendOnResetConnection(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ // Send RST segment.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagRst,
+ SeqNum: 790,
+ RcvWnd: 30000,
+ })
+
+ // Wait for the RST to be received.
+ time.Sleep(1 * time.Second)
+
+ // Try to write.
+ view := buffer.NewView(10)
+ _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
+ if err != tcpip.ErrConnectionReset {
+ t.Fatalf("Unexpected error from Write: want %v, got %v", tcpip.ErrConnectionReset, err)
+ }
+}
+
+func TestFinImmediately(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ // Shutdown immediately, check that we get a FIN.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Unexpected error from Shutdown: %v", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ // Ack and send FIN as well.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Check that the stack acks the FIN.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestFinRetransmit(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ // Shutdown immediately, check that we get a FIN.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Unexpected error from Shutdown: %v", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ // Don't acknowledge yet. We should get a retransmit of the FIN.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ // Ack and send FIN as well.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Check that the stack acks the FIN.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestFinWithNoPendingData(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ // Write something out, and have it acknowledged.
+ view := buffer.NewView(10)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ next := uint32(c.IRS) + 1
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ // Shutdown, check that we get a FIN.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Unexpected error from Shutdown: %v", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ next++
+
+ // Ack and send FIN as well.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ // Check that the stack acks the FIN.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestFinWithPendingDataCwndFull(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ // Write enough segments to fill the congestion window before ACK'ing
+ // any of them.
+ view := buffer.NewView(10)
+ for i := tcp.InitialCwnd; i > 0; i-- {
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+ }
+
+ next := uint32(c.IRS) + 1
+ for i := tcp.InitialCwnd; i > 0; i-- {
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+ }
+
+ // Shutdown the connection, check that the FIN segment isn't sent
+ // because the congestion window doesn't allow it. Wait until a
+ // retransmit is received.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Unexpected error from Shutdown: %v", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Send the ACK that will allow the FIN to be sent as well.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ next++
+
+ // Send a FIN that acknowledges everything. Get an ACK back.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestFinWithPendingData(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ // Write something out, and acknowledge it to get cwnd to 2.
+ view := buffer.NewView(10)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ next := uint32(c.IRS) + 1
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ // Write new data, but don't acknowledge it.
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+
+ // Shutdown the connection, check that we do get a FIN.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Unexpected error from Shutdown: %v", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ next++
+
+ // Send a FIN that acknowledges everything. Get an ACK back.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+}
+
+func TestFinWithPartialAck(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ // Write something out, and acknowledge it to get cwnd to 2. Also send
+ // FIN from the test side.
+ view := buffer.NewView(10)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ next := uint32(c.IRS) + 1
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+
+ // Check that we get an ACK for the fin.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(791),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Write new data, but don't acknowledge it.
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(len(view)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(791),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ next += uint32(len(view))
+
+ // Shutdown the connection, check that we do get a FIN.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Unexpected error from Shutdown: %v", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(next),
+ checker.AckNum(791),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+ next++
+
+ // Send an ACK for the data, but not for the FIN yet.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 791,
+ AckNum: seqnum.Value(next - 1),
+ RcvWnd: 30000,
+ })
+
+ // Check that we don't get a retransmit of the FIN.
+ c.CheckNoPacketTimeout("FIN retransmitted when data was ack'd", 100*time.Millisecond)
+
+ // Ack the FIN.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 791,
+ AckNum: seqnum.Value(next),
+ RcvWnd: 30000,
+ })
+}
+
+func TestExponentialIncreaseDuringSlowStart(t *testing.T) {
+ maxPayload := 10
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ const iterations = 7
+ data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // Double the number of expected packets for the next iteration.
+ expected *= 2
+ }
+}
+
+func TestCongestionAvoidance(t *testing.T) {
+ maxPayload := 10
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ const iterations = 7
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ }
+
+ // Don't acknowledge the first packet of the last packet train. Let's
+ // wait for them to time out, which will trigger a restart of slow
+ // start, and initialization of ssthresh to cwnd/2.
+ rtxOffset := bytesRead - maxPayload*expected
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // This part is tricky: when the timeout happened, we had "expected"
+ // packets pending, cwnd reset to 1, and ssthresh set to expected/2.
+ // By acknowledging "expected" packets, the slow-start part will
+ // increase cwnd to expected/2 (which "consumes" expected/2-1 of the
+ // acknowledgements), then the congestion avoidance part will consume
+ // an extra expected/2 acks to take cwnd to expected/2 + 1. One ack
+ // remains in the "ack count" (which will cause cwnd to be incremented
+ // once it reaches cwnd acks).
+ //
+ // So we're straight into congestion avoidance with cwnd set to
+ // expected/2 + 1.
+ //
+ // Check that packets trains of cwnd packets are sent, and that cwnd is
+ // incremented by 1 after we acknowledge each packet.
+ expected = expected/2 + 1
+ for i := 0; i < iterations; i++ {
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // In cogestion avoidance, the packets trains increase by 1 in
+ // each iteration.
+ expected++
+ }
+}
+
+func TestFastRecovery(t *testing.T) {
+ maxPayload := 10
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ const iterations = 7
+ data := buffer.NewView(2 * maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in one shot. Packets will only be written at the
+ // MTU size though.
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ }
+
+ // Send 3 duplicate acks. This should force an immediate retransmit of
+ // the pending packet and put the sender into fast recovery.
+ rtxOffset := bytesRead - maxPayload*expected
+ for i := 0; i < 3; i++ {
+ c.SendAck(790, rtxOffset)
+ }
+
+ // Receive the retransmitted packet.
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ // Now send 7 mode duplicate acks. Each of these should cause a window
+ // inflation by 1 and cause the sender to send an extra packet.
+ for i := 0; i < 7; i++ {
+ c.SendAck(790, rtxOffset)
+ }
+
+ recover := bytesRead
+
+ // Ensure no new packets arrive.
+ c.CheckNoPacketTimeout("More packets received than expected during recovery after dupacks for this cwnd.",
+ 50*time.Millisecond)
+
+ // Acknowledge half of the pending data.
+ rtxOffset = bytesRead - expected*maxPayload/2
+ c.SendAck(790, rtxOffset)
+
+ // Receive the retransmit due to partial ack.
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ // Receive the 10 extra packets that should have been released due to
+ // the congestion window inflation in recovery.
+ for i := 0; i < 10; i++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // A partial ACK during recovery should reduce congestion window by the
+ // number acked. Since we had "expected" packets outstanding before sending
+ // partial ack and we acked expected/2 , the cwnd and outstanding should
+ // be expected/2 + 7. Which means the sender should not send any more packets
+ // till we ack this one.
+ c.CheckNoPacketTimeout("More packets received than expected during recovery after partial ack for this cwnd.",
+ 50*time.Millisecond)
+
+ // Acknowledge all pending data to recover point.
+ c.SendAck(790, recover)
+
+ // At this point, the cwnd should reset to expected/2 and there are 10
+ // packets outstanding.
+ //
+ // NOTE: Technically netstack is incorrect in that we adjust the cwnd on
+ // the same segment that takes us out of recovery. But because of that
+ // the actual cwnd at exit of recovery will be expected/2 + 1 as we
+ // acked a cwnd worth of packets which will increase the cwnd further by
+ // 1 in congestion avoidance.
+ //
+ // Now in the first iteration since there are 10 packets outstanding.
+ // We would expect to get expected/2 +1 - 10 packets. But subsequent
+ // iterations will send us expected/2 + 1 + 1 (per iteration).
+ expected = expected/2 + 1 - 10
+ for i := 0; i < iterations; i++ {
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout(fmt.Sprintf("More packets received(after deflation) than expected %d for this cwnd.", expected), 50*time.Millisecond)
+
+ // Acknowledge all the data received so far.
+ c.SendAck(790, bytesRead)
+
+ // In cogestion avoidance, the packets trains increase by 1 in
+ // each iteration.
+ if i == 0 {
+ // After the first iteration we expect to get the full
+ // congestion window worth of packets in every
+ // iteration.
+ expected += 10
+ }
+ expected++
+ }
+}
+
+func TestRetransmit(t *testing.T) {
+ maxPayload := 10
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ const iterations = 7
+ data := buffer.NewView(maxPayload * (tcp.InitialCwnd << (iterations + 1)))
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ // Write all the data in two shots. Packets will only be written at the
+ // MTU size though.
+ half := data[:len(data)/2]
+ if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+ half = data[len(data)/2:]
+ if _, err := c.EP.Write(tcpip.SlicePayload(half), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Do slow start for a few iterations.
+ expected := tcp.InitialCwnd
+ bytesRead := 0
+ for i := 0; i < iterations; i++ {
+ expected = tcp.InitialCwnd << uint(i)
+ if i > 0 {
+ // Acknowledge all the data received so far if not on
+ // first iteration.
+ c.SendAck(790, bytesRead)
+ }
+
+ // Read all packets expected on this iteration. Don't
+ // acknowledge any of them just yet, so that we can measure the
+ // congestion window.
+ for j := 0; j < expected; j++ {
+ c.ReceiveAndCheckPacket(data, bytesRead, maxPayload)
+ bytesRead += maxPayload
+ }
+
+ // Check we don't receive any more packets on this iteration.
+ // The timeout can't be too high or we'll trigger a timeout.
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+ }
+
+ // Wait for a timeout and retransmit.
+ rtxOffset := bytesRead - maxPayload*expected
+ c.ReceiveAndCheckPacket(data, rtxOffset, maxPayload)
+
+ // Acknowledge half of the pending data.
+ rtxOffset = bytesRead - expected*maxPayload/2
+ c.SendAck(790, rtxOffset)
+
+ // Receive the remaining data, making sure that acknowledged data is not
+ // retransmitted.
+ for offset := rtxOffset; offset < len(data); offset += maxPayload {
+ c.ReceiveAndCheckPacket(data, offset, maxPayload)
+ c.SendAck(790, offset+maxPayload)
+ }
+
+ c.CheckNoPacketTimeout("More packets received than expected for this cwnd.", 50*time.Millisecond)
+}
+
+func TestUpdateListenBacklog(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // Create listener.
+ var wq waiter.Queue
+ ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Bind(tcpip.FullAddress{}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+
+ // Update the backlog with another Listen() on the same endpoint.
+ if err := ep.Listen(20); err != nil {
+ t.Fatalf("Listen failed to update backlog: %v", err)
+ }
+
+ ep.Close()
+}
+
+func scaledSendWindow(t *testing.T, scale uint8) {
+ // This test ensures that the endpoint is using the right scaling by
+ // sending a buffer that is larger than the window size, and ensuring
+ // that the endpoint doesn't send more than allowed.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ maxPayload := defaultMTU - header.IPv4MinimumSize - header.TCPMinimumSize
+ c.CreateConnectedWithRawOptions(789, 0, nil, []byte{
+ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
+ header.TCPOptionWS, 3, scale, header.TCPOptionNOP,
+ })
+
+ // Open up the window with a scaled value.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 1,
+ })
+
+ // Send some data. Check that it's capped by the window size.
+ view := buffer.NewView(65535)
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Check that only data that fits in the scaled window is sent.
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen((1<<scale)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ // Reset the connection to free resources.
+ c.SendPacket(nil, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagRst,
+ SeqNum: 790,
+ })
+}
+
+func TestScaledSendWindow(t *testing.T) {
+ for scale := uint8(0); scale <= 14; scale++ {
+ scaledSendWindow(t, scale)
+ }
+}
+
+func TestReceivedSegmentQueuing(t *testing.T) {
+ // This test sends 200 segments containing a few bytes each to an
+ // endpoint and checks that they're all received and acknowledged by
+ // the endpoint, that is, that none of the segments are dropped by
+ // internal queues.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ // Send 200 segments.
+ data := []byte{1, 2, 3}
+ for i := 0; i < 200; i++ {
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(790 + i*len(data)),
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+ }
+
+ // Receive ACKs for all segments.
+ last := seqnum.Value(790 + 200*len(data))
+ for {
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+ tcp := header.TCP(header.IPv4(b).Payload())
+ ack := seqnum.Value(tcp.AckNumber())
+ if ack == last {
+ break
+ }
+
+ if last.LessThan(ack) {
+ t.Fatalf("Acknowledge (%v) beyond the expected (%v)", ack, last)
+ }
+ }
+}
+
+func TestReadAfterClosedState(t *testing.T) {
+ // This test ensures that calling Read() or Peek() after the endpoint
+ // has transitioned to closedState still works if there is pending
+ // data. To transition to stateClosed without calling Close(), we must
+ // shutdown the send path and the peer must send its own FIN.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnected(789, 30000, nil)
+
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ if _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ // Shutdown immediately for write, check that we get a FIN.
+ if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
+ t.Fatalf("Unexpected error from Shutdown: %v", err)
+ }
+
+ checker.IPv4(t, c.GetPacket(),
+ checker.PayloadLen(header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.TCPFlags(header.TCPFlagAck|header.TCPFlagFin),
+ ),
+ )
+
+ // Send some data and acknowledge the FIN.
+ data := []byte{1, 2, 3}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(2),
+ RcvWnd: 30000,
+ })
+
+ // Check that ACK is received.
+ checker.IPv4(t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+2),
+ checker.AckNum(uint32(791+len(data))),
+ checker.TCPFlags(header.TCPFlagAck),
+ ),
+ )
+
+ // Give the stack the chance to transition to closed state.
+ time.Sleep(1 * time.Second)
+
+ // Wait for receive to be notified.
+ select {
+ case <-ch:
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Check that peek works.
+ peekBuf := make([]byte, 10)
+ n, err := c.EP.Peek([][]byte{peekBuf})
+ if err != nil {
+ t.Fatalf("Unexpected error from Peek: %v", err)
+ }
+
+ peekBuf = peekBuf[:n]
+ if bytes.Compare(data, peekBuf) != 0 {
+ t.Fatalf("Data is different: expected %v, got %v", data, peekBuf)
+ }
+
+ // Receive data.
+ v, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+
+ if bytes.Compare(data, v) != 0 {
+ t.Fatalf("Data is different: expected %v, got %v", data, v)
+ }
+
+ // Now that we drained the queue, check that functions fail with the
+ // right error code.
+ if _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("Unexpected return from Read: got %v, want %v", err, tcpip.ErrClosedForReceive)
+ }
+
+ if _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("Unexpected return from Peek: got %v, want %v", err, tcpip.ErrClosedForReceive)
+ }
+}
+
+func TestReusePort(t *testing.T) {
+ // This test ensures that ports are immediately available for reuse
+ // after Close on the endpoints using them returns.
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ // First case, just an endpoint that was bound.
+ var err *tcpip.Error
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ c.EP.Close()
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+ c.EP.Close()
+
+ // Second case, an endpoint that was bound and is connecting..
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+ err = c.EP.Connect(tcpip.FullAddress{Addr: context.TestAddr, Port: context.TestPort})
+ if err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+ c.EP.Close()
+
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+ c.EP.Close()
+
+ // Third case, an endpoint that was bound and is listening.
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+ err = c.EP.Listen(10)
+ if err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+ c.EP.Close()
+
+ c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ if err := c.EP.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+ err = c.EP.Listen(10)
+ if err != nil {
+ t.Fatalf("Listen failed: %v", err)
+ }
+}
+
+func checkRecvBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
+ t.Helper()
+
+ var s tcpip.ReceiveBufferSizeOption
+ if err := ep.GetSockOpt(&s); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+
+ if int(s) != v {
+ t.Fatalf("Bad receive buffer size: want=%v, got=%v", v, s)
+ }
+}
+
+func checkSendBufferSize(t *testing.T, ep tcpip.Endpoint, v int) {
+ t.Helper()
+
+ var s tcpip.SendBufferSizeOption
+ if err := ep.GetSockOpt(&s); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+
+ if int(s) != v {
+ t.Fatalf("Bad send buffer size: want=%v, got=%v", v, s)
+ }
+}
+
+func TestDefaultBufferSizes(t *testing.T) {
+ s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName})
+
+ // Check the default values.
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ defer func() {
+ if ep != nil {
+ ep.Close()
+ }
+ }()
+
+ checkSendBufferSize(t, ep, tcp.DefaultBufferSize)
+ checkRecvBufferSize(t, ep, tcp.DefaultBufferSize)
+
+ // Change the default send buffer size.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize * 2, tcp.DefaultBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ ep.Close()
+ ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+
+ checkSendBufferSize(t, ep, tcp.DefaultBufferSize*2)
+ checkRecvBufferSize(t, ep, tcp.DefaultBufferSize)
+
+ // Change the default receive buffer size.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize * 3, tcp.DefaultBufferSize * 30}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ ep.Close()
+ ep, err = s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+
+ checkSendBufferSize(t, ep, tcp.DefaultBufferSize*2)
+ checkRecvBufferSize(t, ep, tcp.DefaultBufferSize*3)
+}
+
+func TestMinMaxBufferSizes(t *testing.T) {
+ s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName})
+
+ // Check the default values.
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &waiter.Queue{})
+ if err != nil {
+ t.Fatalf("NewEndpoint failed; %v", err)
+ }
+ defer ep.Close()
+
+ // Change the min/max values for send/receive
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{200, tcp.DefaultBufferSize * 2, tcp.DefaultBufferSize * 20}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{300, tcp.DefaultBufferSize * 3, tcp.DefaultBufferSize * 30}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ // Set values below the min.
+ if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(199)); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+
+ checkRecvBufferSize(t, ep, 200)
+
+ if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(299)); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+
+ checkSendBufferSize(t, ep, 300)
+
+ // Set values above the max.
+ if err := ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(1 + tcp.DefaultBufferSize*20)); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+
+ checkRecvBufferSize(t, ep, tcp.DefaultBufferSize*20)
+
+ if err := ep.SetSockOpt(tcpip.SendBufferSizeOption(1 + tcp.DefaultBufferSize*30)); err != nil {
+ t.Fatalf("GetSockOpt failed: %v", err)
+ }
+
+ checkSendBufferSize(t, ep, tcp.DefaultBufferSize*30)
+}
+
+func TestSelfConnect(t *testing.T) {
+ // This test ensures that intentional self-connects work. In particular,
+ // it checks that if an endpoint binds to say 127.0.0.1:1000 then
+ // connects to 127.0.0.1:1000, then it will be connected to itself, and
+ // is able to send and receive data through the same endpoint.
+ s := stack.New([]string{ipv4.ProtocolName}, []string{tcp.ProtocolName})
+
+ id := loopback.New()
+ if testing.Verbose() {
+ id = sniffer.New(id)
+ }
+
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, context.StackAddr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: "\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ },
+ })
+
+ var wq waiter.Queue
+ ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}, nil); err != nil {
+ t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Register for notification, then start connection attempt.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&waitEntry, waiter.EventOut)
+ defer wq.EventUnregister(&waitEntry)
+
+ err = ep.Connect(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort})
+ if err != tcpip.ErrConnectStarted {
+ t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ <-notifyCh
+ err = ep.GetSockOpt(tcpip.ErrorOption{})
+ if err != nil {
+ t.Fatalf("Connect failed: %v", err)
+ }
+
+ // Write something.
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+ if _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ // Read back what was written.
+ wq.EventUnregister(&waitEntry)
+ wq.EventRegister(&waitEntry, waiter.EventIn)
+ rd, err := ep.Read(nil)
+ if err != nil {
+ if err != tcpip.ErrWouldBlock {
+ t.Fatalf("Read failed: %v", err)
+ }
+ <-notifyCh
+ rd, err = ep.Read(nil)
+ if err != nil {
+ t.Fatalf("Read failed: %v", err)
+ }
+ }
+
+ if bytes.Compare(data, rd) != 0 {
+ t.Fatalf("Data is different: want=%v, got=%v", data, rd)
+ }
+}
+
+func TestPathMTUDiscovery(t *testing.T) {
+ // This test verifies the stack retransmits packets after it receives an
+ // ICMP packet indicating that the path MTU has been exceeded.
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ // Create new connection with MSS of 1460.
+ const maxPayload = 1500 - header.TCPMinimumSize - header.IPv4MinimumSize
+ c.CreateConnectedWithRawOptions(789, 30000, nil, []byte{
+ header.TCPOptionMSS, 4, byte(maxPayload / 256), byte(maxPayload % 256),
+ })
+
+ // Send 3200 bytes of data.
+ const writeSize = 3200
+ data := buffer.NewView(writeSize)
+ for i := range data {
+ data[i] = byte(i)
+ }
+
+ if _, err := c.EP.Write(tcpip.SlicePayload(data), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Write failed: %v", err)
+ }
+
+ receivePackets := func(c *context.Context, sizes []int, which int, seqNum uint32) []byte {
+ var ret []byte
+ for i, size := range sizes {
+ p := c.GetPacket()
+ if i == which {
+ ret = p
+ }
+ checker.IPv4(t, p,
+ checker.PayloadLen(size+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(seqNum),
+ checker.AckNum(790),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+ seqNum += uint32(size)
+ }
+ return ret
+ }
+
+ // Receive three packets.
+ sizes := []int{maxPayload, maxPayload, writeSize - 2*maxPayload}
+ first := receivePackets(c, sizes, 0, uint32(c.IRS)+1)
+
+ // Send "packet too big" messages back to netstack.
+ const newMTU = 1200
+ const newMaxPayload = newMTU - header.IPv4MinimumSize - header.TCPMinimumSize
+ mtu := []byte{0, 0, newMTU / 256, newMTU % 256}
+ c.SendICMPPacket(header.ICMPv4DstUnreachable, header.ICMPv4FragmentationNeeded, mtu, first, newMTU)
+
+ // See retransmitted packets. None exceeding the new max.
+ sizes = []int{newMaxPayload, maxPayload - newMaxPayload, newMaxPayload, maxPayload - newMaxPayload, writeSize - 2*maxPayload}
+ receivePackets(c, sizes, -1, uint32(c.IRS)+1)
+}
+
+func TestTCPEndpointProbe(t *testing.T) {
+ c := context.New(t, 1500)
+ defer c.Cleanup()
+
+ invoked := make(chan struct{})
+ c.Stack().AddTCPProbe(func(state stack.TCPEndpointState) {
+ // Validate that the endpoint ID is what we expect.
+ //
+ // We don't do an extensive validation of every field but a
+ // basic sanity test.
+ if got, want := state.ID.LocalAddress, tcpip.Address(context.StackAddr); got != want {
+ t.Fatalf("unexpected LocalAddress got: %d, want: %d", got, want)
+ }
+ if got, want := state.ID.LocalPort, c.Port; got != want {
+ t.Fatalf("unexpected LocalPort got: %d, want: %d", got, want)
+ }
+ if got, want := state.ID.RemoteAddress, tcpip.Address(context.TestAddr); got != want {
+ t.Fatalf("unexpected RemoteAddress got: %d, want: %d", got, want)
+ }
+ if got, want := state.ID.RemotePort, uint16(context.TestPort); got != want {
+ t.Fatalf("unexpected RemotePort got: %d, want: %d", got, want)
+ }
+
+ invoked <- struct{}{}
+ })
+
+ c.CreateConnected(789, 30000, nil)
+
+ data := []byte{1, 2, 3}
+ c.SendPacket(data, &context.Headers{
+ SrcPort: context.TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: 790,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ })
+
+ select {
+ case <-invoked:
+ case <-time.After(100 * time.Millisecond):
+ t.Fatalf("TCP Probe function was not called")
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
new file mode 100644
index 000000000..ae1b578c5
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -0,0 +1,302 @@
+package tcp_test
+
+import (
+ "bytes"
+ "math/rand"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/checker"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// createConnectedWithTimestampOption creates and connects c.ep with the
+// timestamp option enabled.
+func createConnectedWithTimestampOption(c *context.Context) *context.RawEndpoint {
+ return c.CreateConnectedWithOptions(header.TCPSynOptions{TS: true, TSVal: 1})
+}
+
+// TestTimeStampEnabledConnect tests that netstack sends the timestamp option on
+// an active connect and sets the TS Echo Reply fields correctly when the
+// SYN-ACK also indicates support for the TS option and provides a TSVal.
+func TestTimeStampEnabledConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ rep := createConnectedWithTimestampOption(c)
+
+ // Register for read and validate that we have data to read.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ // The following tests ensure that TS option once enabled behaves
+ // correctly as described in
+ // https://tools.ietf.org/html/rfc7323#section-4.3.
+ //
+ // We are not testing delayed ACKs here, but we do test out of order
+ // packet delivery and filling the sequence number hole created due to
+ // the out of order packet.
+ //
+ // The test also verifies that the sequence numbers and timestamps are
+ // as expected.
+ data := []byte{1, 2, 3}
+
+ // First we increment tsVal by a small amount.
+ tsVal := rep.TSVal + 100
+ rep.SendPacketWithTS(data, tsVal)
+ rep.VerifyACKWithTS(tsVal)
+
+ // Next we send an out of order packet.
+ rep.NextSeqNum += 3
+ tsVal += 200
+ rep.SendPacketWithTS(data, tsVal)
+
+ // The ACK should contain the original sequenceNumber and an older TS.
+ rep.NextSeqNum -= 6
+ rep.VerifyACKWithTS(tsVal - 200)
+
+ // Next we fill the hole and the returned ACK should contain the
+ // cumulative sequence number acking all data sent till now and have the
+ // latest timestamp sent below in its TSEcr field.
+ tsVal -= 100
+ rep.SendPacketWithTS(data, tsVal)
+ rep.NextSeqNum += 3
+ rep.VerifyACKWithTS(tsVal)
+
+ // Increment tsVal by a large value that doesn't result in a wrap around.
+ tsVal += 0x7fffffff
+ rep.SendPacketWithTS(data, tsVal)
+ rep.VerifyACKWithTS(tsVal)
+
+ // Increment tsVal again by a large value which should cause the
+ // timestamp value to wrap around. The returned ACK should contain the
+ // wrapped around timestamp in its tsEcr field and not the tsVal from
+ // the previous packet sent above.
+ tsVal += 0x7fffffff
+ rep.SendPacketWithTS(data, tsVal)
+ rep.VerifyACKWithTS(tsVal)
+
+ select {
+ case <-ch:
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // There should be 5 views to read and each of them should
+ // contain the same data.
+ for i := 0; i < 5; i++ {
+ got, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+ if want := data; bytes.Compare(got, want) != 0 {
+ t.Fatalf("Data is different: got: %v, want: %v", got, want)
+ }
+ }
+}
+
+// TestTimeStampDisabledConnect tests that netstack sends timestamp option on an
+// active connect but if the SYN-ACK doesn't specify the TS option then
+// timestamp option is not enabled and future packets do not contain a
+// timestamp.
+func TestTimeStampDisabledConnect(t *testing.T) {
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ c.CreateConnectedWithOptions(header.TCPSynOptions{})
+}
+
+func timeStampEnabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
+ savedSynCountThreshold := tcp.SynRcvdCountThreshold
+ defer func() {
+ tcp.SynRcvdCountThreshold = savedSynCountThreshold
+ }()
+
+ if cookieEnabled {
+ tcp.SynRcvdCountThreshold = 0
+ }
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
+ tsVal := rand.Uint32()
+ c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS, TS: true, TSVal: tsVal})
+
+ // Now send some data and validate that timestamp is echoed correctly in the ACK.
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Check that data is received and that the timestamp option TSEcr field
+ // matches the expected value.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ // Add 12 bytes for the timestamp option + 2 NOPs to align at 4
+ // byte boundary.
+ checker.PayloadLen(len(data)+header.TCPMinimumSize+12),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.Window(wndSize),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPTimestampChecker(true, 0, tsVal+1),
+ ),
+ )
+}
+
+// TestTimeStampEnabledAccept tests that if the SYN on a passive connect
+// specifies the Timestamp option then the Timestamp option is sent on a SYN-ACK
+// and echoes the tsVal field of the original SYN in the tcEcr field of the
+// SYN-ACK. We cover the cases where SYN cookies are enabled/disabled and verify
+// that Timestamp option is enabled in both cases if requested in the original
+// SYN.
+func TestTimeStampEnabledAccept(t *testing.T) {
+ testCases := []struct {
+ cookieEnabled bool
+ wndScale int
+ wndSize uint16
+ }{
+ {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
+ {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
+ }
+ for _, tc := range testCases {
+ timeStampEnabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
+ }
+}
+
+func timeStampDisabledAccept(t *testing.T, cookieEnabled bool, wndScale int, wndSize uint16) {
+ savedSynCountThreshold := tcp.SynRcvdCountThreshold
+ defer func() {
+ tcp.SynRcvdCountThreshold = savedSynCountThreshold
+ }()
+ if cookieEnabled {
+ tcp.SynRcvdCountThreshold = 0
+ }
+
+ c := context.New(t, defaultMTU)
+ defer c.Cleanup()
+
+ t.Logf("Test w/ CookieEnabled = %v", cookieEnabled)
+ c.AcceptWithOptions(wndScale, header.TCPSynOptions{MSS: defaultIPv4MSS})
+
+ // Now send some data with the accepted connection endpoint and validate
+ // that no timestamp option is sent in the TCP segment.
+ data := []byte{1, 2, 3}
+ view := buffer.NewView(len(data))
+ copy(view, data)
+
+ if _, err := c.EP.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ t.Fatalf("Unexpected error from Write: %v", err)
+ }
+
+ // Check that data is received and that the timestamp option is disabled
+ // when SYN cookies are enabled/disabled.
+ b := c.GetPacket()
+ checker.IPv4(t, b,
+ checker.PayloadLen(len(data)+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(context.TestPort),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(790),
+ checker.Window(wndSize),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ checker.TCPTimestampChecker(false, 0, 0),
+ ),
+ )
+}
+
+// TestTimeStampDisabledAccept tests that Timestamp option is not used when the
+// peer doesn't advertise it and connection is established with Accept().
+func TestTimeStampDisabledAccept(t *testing.T) {
+ testCases := []struct {
+ cookieEnabled bool
+ wndScale int
+ wndSize uint16
+ }{
+ {true, -1, 0xffff}, // When cookie is used window scaling is disabled.
+ {false, 5, 0x8000}, // 0x8000 * 2^5 = 1<<20 = 1MB window (the default).
+ }
+ for _, tc := range testCases {
+ timeStampDisabledAccept(t, tc.cookieEnabled, tc.wndScale, tc.wndSize)
+ }
+}
+
+func TestSendGreaterThanMTUWithOptions(t *testing.T) {
+ const maxPayload = 100
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ createConnectedWithTimestampOption(c)
+ testBrokenUpWrite(t, c, maxPayload)
+}
+
+func TestSegmentDropWhenTimestampMissing(t *testing.T) {
+ const maxPayload = 100
+ c := context.New(t, uint32(header.TCPMinimumSize+header.IPv4MinimumSize+maxPayload))
+ defer c.Cleanup()
+
+ rep := createConnectedWithTimestampOption(c)
+
+ // Register for read.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&we, waiter.EventIn)
+ defer c.WQ.EventUnregister(&we)
+
+ stk := c.Stack()
+ droppedPackets := stk.Stats().DroppedPackets
+ data := []byte{1, 2, 3}
+ // Save the sequence number as we will reset it later down
+ // in the test.
+ savedSeqNum := rep.NextSeqNum
+ rep.SendPacket(data, nil)
+
+ select {
+ case <-ch:
+ t.Fatalf("Got data to read when we expect packet to be dropped")
+ case <-time.After(1 * time.Second):
+ // We expect that no data will be available to read.
+ }
+
+ // Assert that DroppedPackets was incremented by 1.
+ if got, want := stk.Stats().DroppedPackets, droppedPackets+1; got != want {
+ t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want)
+ }
+
+ droppedPackets = stk.Stats().DroppedPackets
+ // Reset the sequence number so that the other endpoint accepts
+ // this segment and does not treat it like an out of order delivery.
+ rep.NextSeqNum = savedSeqNum
+ // Now send a packet with timestamp and we should get the data.
+ rep.SendPacketWithTS(data, rep.TSVal+1)
+
+ select {
+ case <-ch:
+ case <-time.After(1 * time.Second):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+
+ // Assert that DroppedPackets was not incremented by 1.
+ if got, want := stk.Stats().DroppedPackets, droppedPackets; got != want {
+ t.Fatalf("incorrect number of dropped packets, got: %v, want: %v", got, want)
+ }
+
+ // Issue a read and we should data.
+ got, err := c.EP.Read(nil)
+ if err != nil {
+ t.Fatalf("Unexpected error from Read: %v", err)
+ }
+ if want := data; bytes.Compare(got, want) != 0 {
+ t.Fatalf("Data is different: got: %v, want: %v", got, want)
+ }
+}
diff --git a/pkg/tcpip/transport/tcp/testing/context/BUILD b/pkg/tcpip/transport/tcp/testing/context/BUILD
new file mode 100644
index 000000000..40850c3e7
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/testing/context/BUILD
@@ -0,0 +1,27 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "context",
+ testonly = 1,
+ srcs = ["context.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp/testing/context",
+ visibility = [
+ "//:sandbox",
+ ],
+ deps = [
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/seqnum",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
new file mode 100644
index 000000000..6a402d150
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -0,0 +1,900 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package context provides a test context for use in tcp tests. It also
+// provides helper methods to assert/check certain behaviours.
+package context
+
+import (
+ "bytes"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/checker"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // StackAddr is the IPv4 address assigned to the stack.
+ StackAddr = "\x0a\x00\x00\x01"
+
+ // StackPort is used as the listening port in tests for passive
+ // connects.
+ StackPort = 1234
+
+ // TestAddr is the source address for packets sent to the stack via the
+ // link layer endpoint.
+ TestAddr = "\x0a\x00\x00\x02"
+
+ // TestPort is the TCP port used for packets sent to the stack
+ // via the link layer endpoint.
+ TestPort = 4096
+
+ // StackV6Addr is the IPv6 address assigned to the stack.
+ StackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+
+ // TestV6Addr is the source address for packets sent to the stack via
+ // the link layer endpoint.
+ TestV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+
+ // StackV4MappedAddr is StackAddr as a mapped v6 address.
+ StackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + StackAddr
+
+ // TestV4MappedAddr is TestAddr as a mapped v6 address.
+ TestV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + TestAddr
+
+ // V4MappedWildcardAddr is the mapped v6 representation of 0.0.0.0.
+ V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+
+ // testInitialSequenceNumber is the initial sequence number sent in packets that
+ // are sent in response to a SYN or in the initial SYN sent to the stack.
+ testInitialSequenceNumber = 789
+)
+
+// defaultWindowScale value specified here depends on the tcp.DefaultBufferSize
+// constant defined in the tcp/endpoint.go because the tcp.DefaultBufferSize is
+// used in tcp.newHandshake to determine the window scale to use when sending a
+// SYN/SYN-ACK.
+var defaultWindowScale = tcp.FindWndScale(tcp.DefaultBufferSize)
+
+// Headers is used to represent the TCP header fields when building a
+// new packet.
+type Headers struct {
+ // SrcPort holds the src port value to be used in the packet.
+ SrcPort uint16
+
+ // DstPort holds the destination port value to be used in the packet.
+ DstPort uint16
+
+ // SeqNum is the value of the sequence number field in the TCP header.
+ SeqNum seqnum.Value
+
+ // AckNum represents the acknowledgement number field in the TCP header.
+ AckNum seqnum.Value
+
+ // Flags are the TCP flags in the TCP header.
+ Flags int
+
+ // RcvWnd is the window to be advertised in the ReceiveWindow field of
+ // the TCP header.
+ RcvWnd seqnum.Size
+
+ // TCPOpts holds the options to be sent in the option field of the TCP
+ // header.
+ TCPOpts []byte
+}
+
+// Context provides an initialized Network stack and a link layer endpoint
+// for use in TCP tests.
+type Context struct {
+ t *testing.T
+ linkEP *channel.Endpoint
+ s *stack.Stack
+
+ // IRS holds the initial sequence number in the SYN sent by endpoint in
+ // case of an active connect or the sequence number sent by the endpoint
+ // in the SYN-ACK sent in response to a SYN when listening in passive
+ // mode.
+ IRS seqnum.Value
+
+ // Port holds the port bound by EP below in case of an active connect or
+ // the listening port number in case of a passive connect.
+ Port uint16
+
+ // EP is the test endpoint in the stack owned by this context. This endpoint
+ // is used in various tests to either initiate an active connect or is used
+ // as a passive listening endpoint to accept inbound connections.
+ EP tcpip.Endpoint
+
+ // Wq is the wait queue associated with EP and is used to block for events
+ // on EP.
+ WQ waiter.Queue
+
+ // TimeStampEnabled is true if ep is connected with the timestamp option
+ // enabled.
+ TimeStampEnabled bool
+}
+
+// New allocates and initializes a test context containing a new
+// stack and a link-layer endpoint.
+func New(t *testing.T, mtu uint32) *Context {
+ s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{tcp.ProtocolName})
+
+ // Allow minimum send/receive buffer sizes to be 1 during tests.
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SendBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.ReceiveBufferSizeOption{1, tcp.DefaultBufferSize, tcp.DefaultBufferSize * 10}); err != nil {
+ t.Fatalf("SetTransportProtocolOption failed: %v", err)
+ }
+
+ // Some of the congestion control tests send up to 640 packets, we so
+ // set the channel size to 1000.
+ id, linkEP := channel.New(1000, mtu, "")
+ if testing.Verbose() {
+ id = sniffer.New(id)
+ }
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, StackAddr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv6.ProtocolNumber, StackV6Addr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: "\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ },
+ {
+ Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ },
+ })
+
+ return &Context{
+ t: t,
+ s: s,
+ linkEP: linkEP,
+ }
+}
+
+// Cleanup closes the context endpoint if required.
+func (c *Context) Cleanup() {
+ if c.EP != nil {
+ c.EP.Close()
+ }
+}
+
+// Stack returns a reference to the stack in the Context.
+func (c *Context) Stack() *stack.Stack {
+ return c.s
+}
+
+// CheckNoPacketTimeout verifies that no packet is received during the time
+// specified by wait.
+func (c *Context) CheckNoPacketTimeout(errMsg string, wait time.Duration) {
+ select {
+ case <-c.linkEP.C:
+ c.t.Fatalf(errMsg)
+
+ case <-time.After(wait):
+ }
+}
+
+// CheckNoPacket verifies that no packet is received for 1 second.
+func (c *Context) CheckNoPacket(errMsg string) {
+ c.CheckNoPacketTimeout(errMsg, 1*time.Second)
+}
+
+// GetPacket reads a packet from the link layer endpoint and verifies
+// that it is an IPv4 packet with the expected source and destination
+// addresses. It will fail with an error if no packet is received for
+// 2 seconds.
+func (c *Context) GetPacket() []byte {
+ select {
+ case p := <-c.linkEP.C:
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
+ b := make([]byte, len(p.Header)+len(p.Payload))
+ copy(b, p.Header)
+ copy(b[len(p.Header):], p.Payload)
+
+ checker.IPv4(c.t, b, checker.SrcAddr(StackAddr), checker.DstAddr(TestAddr))
+ return b
+
+ case <-time.After(2 * time.Second):
+ c.t.Fatalf("Packet wasn't written out")
+ }
+
+ return nil
+}
+
+// SendICMPPacket builds and sends an ICMPv4 packet via the link layer endpoint.
+func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code uint8, p1, p2 []byte, maxTotalSize int) {
+ // Allocate a buffer data and headers.
+ buf := buffer.NewView(header.IPv4MinimumSize + header.ICMPv4MinimumSize + len(p1) + len(p2))
+ if len(buf) > maxTotalSize {
+ buf = buf[:maxTotalSize]
+ }
+
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(len(buf)),
+ TTL: 65,
+ Protocol: uint8(header.ICMPv4ProtocolNumber),
+ SrcAddr: TestAddr,
+ DstAddr: StackAddr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ icmp := header.ICMPv4(buf[header.IPv4MinimumSize:])
+ icmp.SetType(typ)
+ icmp.SetCode(code)
+
+ copy(icmp[header.ICMPv4MinimumSize:], p1)
+ copy(icmp[header.ICMPv4MinimumSize+len(p1):], p2)
+
+ // Inject packet.
+ var views [1]buffer.View
+ vv := buf.ToVectorisedView(views)
+ c.linkEP.Inject(ipv4.ProtocolNumber, &vv)
+}
+
+// SendPacket builds and sends a TCP segment(with the provided payload & TCP
+// headers) in an IPv4 packet via the link layer endpoint.
+func (c *Context) SendPacket(payload []byte, h *Headers) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.TCPMinimumSize + header.IPv4MinimumSize + len(h.TCPOpts) + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+ copy(buf[len(buf)-len(payload)-len(h.TCPOpts):], h.TCPOpts)
+
+ // Initialize the IP header.
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(len(buf)),
+ TTL: 65,
+ Protocol: uint8(tcp.ProtocolNumber),
+ SrcAddr: TestAddr,
+ DstAddr: StackAddr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Initialize the TCP header.
+ t := header.TCP(buf[header.IPv4MinimumSize:])
+ t.Encode(&header.TCPFields{
+ SrcPort: h.SrcPort,
+ DstPort: h.DstPort,
+ SeqNum: uint32(h.SeqNum),
+ AckNum: uint32(h.AckNum),
+ DataOffset: uint8(header.TCPMinimumSize + len(h.TCPOpts)),
+ Flags: uint8(h.Flags),
+ WindowSize: uint16(h.RcvWnd),
+ })
+
+ // Calculate the TCP pseudo-header checksum.
+ xsum := header.Checksum([]byte(TestAddr), 0)
+ xsum = header.Checksum([]byte(StackAddr), xsum)
+ xsum = header.Checksum([]byte{0, uint8(tcp.ProtocolNumber)}, xsum)
+
+ // Calculate the TCP checksum and set it.
+ length := uint16(header.TCPMinimumSize + len(h.TCPOpts) + len(payload))
+ xsum = header.Checksum(payload, xsum)
+ t.SetChecksum(^t.CalculateChecksum(xsum, length))
+
+ // Inject packet.
+ var views [1]buffer.View
+ vv := buf.ToVectorisedView(views)
+ c.linkEP.Inject(ipv4.ProtocolNumber, &vv)
+}
+
+// SendAck sends an ACK packet.
+func (c *Context) SendAck(seq seqnum.Value, bytesReceived int) {
+ c.SendPacket(nil, &Headers{
+ SrcPort: TestPort,
+ DstPort: c.Port,
+ Flags: header.TCPFlagAck,
+ SeqNum: seqnum.Value(testInitialSequenceNumber).Add(1),
+ AckNum: c.IRS.Add(1 + seqnum.Size(bytesReceived)),
+ RcvWnd: 30000,
+ })
+}
+
+// ReceiveAndCheckPacket reads a packet from the link layer endpoint and
+// verifies that the packet packet payload of packet matches the slice
+// of data indicated by offset & size.
+func (c *Context) ReceiveAndCheckPacket(data []byte, offset, size int) {
+ b := c.GetPacket()
+ checker.IPv4(c.t, b,
+ checker.PayloadLen(size+header.TCPMinimumSize),
+ checker.TCP(
+ checker.DstPort(TestPort),
+ checker.SeqNum(uint32(c.IRS.Add(seqnum.Size(1+offset)))),
+ checker.AckNum(uint32(seqnum.Value(testInitialSequenceNumber).Add(1))),
+ checker.TCPFlagsMatch(header.TCPFlagAck, ^uint8(header.TCPFlagPsh)),
+ ),
+ )
+
+ pdata := data[offset:][:size]
+ if p := b[header.IPv4MinimumSize+header.TCPMinimumSize:]; bytes.Compare(pdata, p) != 0 {
+ c.t.Fatalf("Data is different: expected %v, got %v", pdata, p)
+ }
+}
+
+// CreateV6Endpoint creates and initializes c.ep as a IPv6 Endpoint. If v6Only
+// is true then it sets the IP_V6ONLY option on the socket to make it a IPv6
+// only endpoint instead of a default dual stack socket.
+func (c *Context) CreateV6Endpoint(v6only bool) {
+ var err *tcpip.Error
+ c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv6.ProtocolNumber, &c.WQ)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var v tcpip.V6OnlyOption
+ if v6only {
+ v = 1
+ }
+ if err := c.EP.SetSockOpt(v); err != nil {
+ c.t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+}
+
+// GetV6Packet reads a single packet from the link layer endpoint of the context
+// and asserts that it is an IPv6 Packet with the expected src/dest addresses.
+func (c *Context) GetV6Packet() []byte {
+ select {
+ case p := <-c.linkEP.C:
+ if p.Proto != ipv6.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
+ }
+ b := make([]byte, len(p.Header)+len(p.Payload))
+ copy(b, p.Header)
+ copy(b[len(p.Header):], p.Payload)
+
+ checker.IPv6(c.t, b, checker.SrcAddr(StackV6Addr), checker.DstAddr(TestV6Addr))
+ return b
+
+ case <-time.After(2 * time.Second):
+ c.t.Fatalf("Packet wasn't written out")
+ }
+
+ return nil
+}
+
+// SendV6Packet builds and sends an IPv6 Packet via the link layer endpoint of
+// the context.
+func (c *Context) SendV6Packet(payload []byte, h *Headers) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.TCPMinimumSize + header.IPv6MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(header.TCPMinimumSize + len(payload)),
+ NextHeader: uint8(tcp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: TestV6Addr,
+ DstAddr: StackV6Addr,
+ })
+
+ // Initialize the TCP header.
+ t := header.TCP(buf[header.IPv6MinimumSize:])
+ t.Encode(&header.TCPFields{
+ SrcPort: h.SrcPort,
+ DstPort: h.DstPort,
+ SeqNum: uint32(h.SeqNum),
+ AckNum: uint32(h.AckNum),
+ DataOffset: header.TCPMinimumSize,
+ Flags: uint8(h.Flags),
+ WindowSize: uint16(h.RcvWnd),
+ })
+
+ // Calculate the TCP pseudo-header checksum.
+ xsum := header.Checksum([]byte(TestV6Addr), 0)
+ xsum = header.Checksum([]byte(StackV6Addr), xsum)
+ xsum = header.Checksum([]byte{0, uint8(tcp.ProtocolNumber)}, xsum)
+
+ // Calculate the TCP checksum and set it.
+ length := uint16(header.TCPMinimumSize + len(payload))
+ xsum = header.Checksum(payload, xsum)
+ t.SetChecksum(^t.CalculateChecksum(xsum, length))
+
+ // Inject packet.
+ var views [1]buffer.View
+ vv := buf.ToVectorisedView(views)
+ c.linkEP.Inject(ipv6.ProtocolNumber, &vv)
+}
+
+// CreateConnected creates a connected TCP endpoint.
+func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption) {
+ c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil)
+}
+
+// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends
+// the specified option bytes as the Option field in the initial SYN packet.
+//
+// It also sets the receive buffer for the endpoint to the specified
+// value in epRcvBuf.
+func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) {
+ // Create TCP endpoint.
+ var err *tcpip.Error
+ c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if epRcvBuf != nil {
+ if err := c.EP.SetSockOpt(*epRcvBuf); err != nil {
+ c.t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+ }
+
+ // Start connection attempt.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort})
+ if err != tcpip.ErrConnectStarted {
+ c.t.Fatalf("Unexpected return value from Connect: %v", err)
+ }
+
+ // Receive SYN packet.
+ b := c.GetPacket()
+ checker.IPv4(c.t, b,
+ checker.TCP(
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ ),
+ )
+
+ tcp := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcp.SequenceNumber())
+
+ c.SendPacket(nil, &Headers{
+ SrcPort: tcp.DestinationPort(),
+ DstPort: tcp.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: rcvWnd,
+ TCPOpts: options,
+ })
+
+ // Receive ACK packet.
+ checker.IPv4(c.t, c.GetPacket(),
+ checker.TCP(
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS)+1),
+ checker.AckNum(uint32(iss)+1),
+ ),
+ )
+
+ // Wait for connection to be established.
+ select {
+ case <-notifyCh:
+ err = c.EP.GetSockOpt(tcpip.ErrorOption{})
+ if err != nil {
+ c.t.Fatalf("Unexpected error when connecting: %v", err)
+ }
+ case <-time.After(1 * time.Second):
+ c.t.Fatalf("Timed out waiting for connection")
+ }
+
+ c.Port = tcp.SourcePort()
+}
+
+// RawEndpoint is just a small wrapper around a TCP endpoint's state to make
+// sending data and ACK packets easy while being able to manipulate the sequence
+// numbers and timestamp values as needed.
+type RawEndpoint struct {
+ C *Context
+ SrcPort uint16
+ DstPort uint16
+ Flags int
+ NextSeqNum seqnum.Value
+ AckNum seqnum.Value
+ WndSize seqnum.Size
+ RecentTS uint32 // Stores the latest timestamp to echo back.
+ TSVal uint32 // TSVal stores the last timestamp sent by this endpoint.
+
+ // SackPermitted is true if SACKPermitted option was negotiated for this endpoint.
+ SACKPermitted bool
+}
+
+// SendPacketWithTS embeds the provided tsVal in the Timestamp option
+// for the packet to be sent out.
+func (r *RawEndpoint) SendPacketWithTS(payload []byte, tsVal uint32) {
+ r.TSVal = tsVal
+ tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(r.TSVal, r.RecentTS, tsOpt[2:])
+ r.SendPacket(payload, tsOpt[:])
+}
+
+// SendPacket is a small wrapper function to build and send packets.
+func (r *RawEndpoint) SendPacket(payload []byte, opts []byte) {
+ packetHeaders := &Headers{
+ SrcPort: r.SrcPort,
+ DstPort: r.DstPort,
+ Flags: r.Flags,
+ SeqNum: r.NextSeqNum,
+ AckNum: r.AckNum,
+ RcvWnd: r.WndSize,
+ TCPOpts: opts,
+ }
+ r.C.SendPacket(payload, packetHeaders)
+ r.NextSeqNum = r.NextSeqNum.Add(seqnum.Size(len(payload)))
+}
+
+// VerifyACKWithTS verifies that the tsEcr field in the ack matches the provided
+// tsVal.
+func (r *RawEndpoint) VerifyACKWithTS(tsVal uint32) {
+ // Read ACK and verify that tsEcr of ACK packet is [1,2,3,4]
+ ackPacket := r.C.GetPacket()
+ checker.IPv4(r.C.t, ackPacket,
+ checker.TCP(
+ checker.DstPort(r.SrcPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(r.AckNum)),
+ checker.AckNum(uint32(r.NextSeqNum)),
+ checker.TCPTimestampChecker(true, 0, tsVal),
+ ),
+ )
+ // Store the parsed TSVal from the ack as recentTS.
+ tcpSeg := header.TCP(header.IPv4(ackPacket).Payload())
+ opts := tcpSeg.ParsedOptions()
+ r.RecentTS = opts.TSVal
+}
+
+// VerifyACKNoSACK verifies that the ACK does not contain a SACK block.
+func (r *RawEndpoint) VerifyACKNoSACK() {
+ r.VerifyACKHasSACK(nil)
+}
+
+// VerifyACKHasSACK verifies that the ACK contains the specified SACKBlocks.
+func (r *RawEndpoint) VerifyACKHasSACK(sackBlocks []header.SACKBlock) {
+ // Read ACK and verify that the TCP options in the segment do
+ // not contain a SACK block.
+ ackPacket := r.C.GetPacket()
+ checker.IPv4(r.C.t, ackPacket,
+ checker.TCP(
+ checker.DstPort(r.SrcPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(r.AckNum)),
+ checker.AckNum(uint32(r.NextSeqNum)),
+ checker.TCPSACKBlockChecker(sackBlocks),
+ ),
+ )
+}
+
+// CreateConnectedWithOptions creates and connects c.ep with the specified TCP
+// options enabled and returns a RawEndpoint which represents the other end of
+// the connection.
+//
+// It also verifies where required(eg.Timestamp) that the ACK to the SYN-ACK
+// does not carry an option that was not requested.
+func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *RawEndpoint {
+ var err *tcpip.Error
+ c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ)
+ if err != nil {
+ c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err)
+ }
+
+ // Start connection attempt.
+ waitEntry, notifyCh := waiter.NewChannelEntry(nil)
+ c.WQ.EventRegister(&waitEntry, waiter.EventOut)
+ defer c.WQ.EventUnregister(&waitEntry)
+
+ testFullAddr := tcpip.FullAddress{Addr: TestAddr, Port: TestPort}
+ err = c.EP.Connect(testFullAddr)
+ if err != tcpip.ErrConnectStarted {
+ c.t.Fatalf("c.ep.Connect(%v) = %v", testFullAddr, err)
+ }
+ // Receive SYN packet.
+ b := c.GetPacket()
+ // Validate that the syn has the timestamp option and a valid
+ // TS value.
+ checker.IPv4(c.t, b,
+ checker.TCP(
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagSyn),
+ checker.TCPSynOptions(header.TCPSynOptions{
+ MSS: uint16(c.linkEP.MTU() - header.IPv4MinimumSize - header.TCPMinimumSize),
+ TS: true,
+ WS: defaultWindowScale,
+ SACKPermitted: c.SACKEnabled(),
+ }),
+ ),
+ )
+ tcpSeg := header.TCP(header.IPv4(b).Payload())
+ synOptions := header.ParseSynOptions(tcpSeg.Options(), false)
+
+ // Build options w/ tsVal to be sent in the SYN-ACK.
+ synAckOptions := make([]byte, 40)
+ offset := 0
+ if wantOptions.TS {
+ offset += header.EncodeTSOption(wantOptions.TSVal, synOptions.TSVal, synAckOptions[offset:])
+ }
+ if wantOptions.SACKPermitted {
+ offset += header.EncodeSACKPermittedOption(synAckOptions[offset:])
+ }
+
+ offset += header.AddTCPOptionPadding(synAckOptions, offset)
+
+ // Build SYN-ACK.
+ c.IRS = seqnum.Value(tcpSeg.SequenceNumber())
+ iss := seqnum.Value(testInitialSequenceNumber)
+ c.SendPacket(nil, &Headers{
+ SrcPort: tcpSeg.DestinationPort(),
+ DstPort: tcpSeg.SourcePort(),
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ SeqNum: iss,
+ AckNum: c.IRS.Add(1),
+ RcvWnd: 30000,
+ TCPOpts: synAckOptions[:offset],
+ })
+
+ // Read ACK.
+ ackPacket := c.GetPacket()
+
+ // Verify TCP header fields.
+ tcpCheckers := []checker.TransportChecker{
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagAck),
+ checker.SeqNum(uint32(c.IRS) + 1),
+ checker.AckNum(uint32(iss) + 1),
+ }
+
+ // Verify that tsEcr of ACK packet is wantOptions.TSVal if the
+ // timestamp option was enabled, if not then we verify that
+ // there is no timestamp in the ACK packet.
+ if wantOptions.TS {
+ tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(true, 0, wantOptions.TSVal))
+ } else {
+ tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
+ }
+
+ checker.IPv4(c.t, ackPacket, checker.TCP(tcpCheckers...))
+
+ ackSeg := header.TCP(header.IPv4(ackPacket).Payload())
+ ackOptions := ackSeg.ParsedOptions()
+
+ // Wait for connection to be established.
+ select {
+ case <-notifyCh:
+ err = c.EP.GetSockOpt(tcpip.ErrorOption{})
+ if err != nil {
+ c.t.Fatalf("Unexpected error when connecting: %v", err)
+ }
+ case <-time.After(1 * time.Second):
+ c.t.Fatalf("Timed out waiting for connection")
+ }
+
+ // Store the source port in use by the endpoint.
+ c.Port = tcpSeg.SourcePort()
+
+ // Mark in context that timestamp option is enabled for this endpoint.
+ c.TimeStampEnabled = true
+
+ return &RawEndpoint{
+ C: c,
+ SrcPort: tcpSeg.DestinationPort(),
+ DstPort: tcpSeg.SourcePort(),
+ Flags: header.TCPFlagAck | header.TCPFlagPsh,
+ NextSeqNum: iss + 1,
+ AckNum: c.IRS.Add(1),
+ WndSize: 30000,
+ RecentTS: ackOptions.TSVal,
+ TSVal: wantOptions.TSVal,
+ SACKPermitted: wantOptions.SACKPermitted,
+ }
+}
+
+// AcceptWithOptions initializes a listening endpoint and connects to it with the
+// provided options enabled. It also verifies that the SYN-ACK has the expected
+// values for the provided options.
+//
+// The function returns a RawEndpoint representing the other end of the accepted
+// endpoint.
+func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ // Create EP and start listening.
+ wq := &waiter.Queue{}
+ ep, err := c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+ defer ep.Close()
+
+ if err := ep.Bind(tcpip.FullAddress{Port: StackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ if err := ep.Listen(10); err != nil {
+ c.t.Fatalf("Listen failed: %v", err)
+ }
+
+ rep := c.PassiveConnectWithOptions(100, wndScale, synOptions)
+
+ // Try to accept the connection.
+ we, ch := waiter.NewChannelEntry(nil)
+ wq.EventRegister(&we, waiter.EventIn)
+ defer wq.EventUnregister(&we)
+
+ c.EP, _, err = ep.Accept()
+ if err == tcpip.ErrWouldBlock {
+ // Wait for connection to be established.
+ select {
+ case <-ch:
+ c.EP, _, err = ep.Accept()
+ if err != nil {
+ c.t.Fatalf("Accept failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ c.t.Fatalf("Timed out waiting for accept")
+ }
+ }
+ return rep
+}
+
+// PassiveConnect just disables WindowScaling and delegates the call to
+// PassiveConnectWithOptions.
+func (c *Context) PassiveConnect(maxPayload, wndScale int, synOptions header.TCPSynOptions) {
+ synOptions.WS = -1
+ c.PassiveConnectWithOptions(maxPayload, wndScale, synOptions)
+}
+
+// PassiveConnectWithOptions initiates a new connection (with the specified TCP
+// options enabled) to the port on which the Context.ep is listening for new
+// connections. It also validates that the SYN-ACK has the expected values for
+// the enabled options.
+//
+// NOTE: MSS is not a negotiated option and it can be asymmetric
+// in each direction. This function uses the maxPayload to set the MSS to be
+// sent to the peer on a connect and validates that the MSS in the SYN-ACK
+// response is equal to the MTU - (tcphdr len + iphdr len).
+//
+// wndScale is the expected window scale in the SYN-ACK and synOptions.WS is the
+// value of the window scaling option to be sent in the SYN. If synOptions.WS >
+// 0 then we send the WindowScale option.
+func (c *Context) PassiveConnectWithOptions(maxPayload, wndScale int, synOptions header.TCPSynOptions) *RawEndpoint {
+ opts := make([]byte, 40)
+ offset := 0
+ offset += header.EncodeMSSOption(uint32(maxPayload), opts)
+
+ if synOptions.WS >= 0 {
+ offset += header.EncodeWSOption(3, opts[offset:])
+ }
+ if synOptions.TS {
+ offset += header.EncodeTSOption(synOptions.TSVal, synOptions.TSEcr, opts[offset:])
+ }
+
+ if synOptions.SACKPermitted {
+ offset += header.EncodeSACKPermittedOption(opts[offset:])
+ }
+
+ paddingToAdd := 4 - offset%4
+ // Now add any padding bytes that might be required to quad align the
+ // options.
+ for i := offset; i < offset+paddingToAdd; i++ {
+ opts[i] = header.TCPOptionNOP
+ }
+ offset += paddingToAdd
+
+ // Send a SYN request.
+ iss := seqnum.Value(testInitialSequenceNumber)
+ c.SendPacket(nil, &Headers{
+ SrcPort: TestPort,
+ DstPort: StackPort,
+ Flags: header.TCPFlagSyn,
+ SeqNum: iss,
+ RcvWnd: 30000,
+ TCPOpts: opts[:offset],
+ })
+
+ // Receive the SYN-ACK reply. Make sure MSS and other expected options
+ // are present.
+ b := c.GetPacket()
+ tcp := header.TCP(header.IPv4(b).Payload())
+ c.IRS = seqnum.Value(tcp.SequenceNumber())
+
+ tcpCheckers := []checker.TransportChecker{
+ checker.SrcPort(StackPort),
+ checker.DstPort(TestPort),
+ checker.TCPFlags(header.TCPFlagAck | header.TCPFlagSyn),
+ checker.AckNum(uint32(iss) + 1),
+ checker.TCPSynOptions(header.TCPSynOptions{MSS: synOptions.MSS, WS: wndScale, SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled()}),
+ }
+
+ // If TS option was enabled in the original SYN then add a checker to
+ // validate the Timestamp option in the SYN-ACK.
+ if synOptions.TS {
+ tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(synOptions.TS, 0, synOptions.TSVal))
+ } else {
+ tcpCheckers = append(tcpCheckers, checker.TCPTimestampChecker(false, 0, 0))
+ }
+
+ checker.IPv4(c.t, b, checker.TCP(tcpCheckers...))
+ rcvWnd := seqnum.Size(30000)
+ ackHeaders := &Headers{
+ SrcPort: TestPort,
+ DstPort: StackPort,
+ Flags: header.TCPFlagAck,
+ SeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ RcvWnd: rcvWnd,
+ }
+
+ // If WS was expected to be in effect then scale the advertised window
+ // correspondingly.
+ if synOptions.WS > 0 {
+ ackHeaders.RcvWnd = rcvWnd >> byte(synOptions.WS)
+ }
+
+ parsedOpts := tcp.ParsedOptions()
+ if synOptions.TS {
+ // Echo the tsVal back to the peer in the tsEcr field of the
+ // timestamp option.
+ // Increment TSVal by 1 from the value sent in the SYN and echo
+ // the TSVal in the SYN-ACK in the TSEcr field.
+ opts := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP}
+ header.EncodeTSOption(synOptions.TSVal+1, parsedOpts.TSVal, opts[2:])
+ ackHeaders.TCPOpts = opts[:]
+ }
+
+ // Send ACK.
+ c.SendPacket(nil, ackHeaders)
+
+ c.Port = StackPort
+
+ return &RawEndpoint{
+ C: c,
+ SrcPort: TestPort,
+ DstPort: StackPort,
+ Flags: header.TCPFlagPsh | header.TCPFlagAck,
+ NextSeqNum: iss + 1,
+ AckNum: c.IRS + 1,
+ WndSize: rcvWnd,
+ SACKPermitted: synOptions.SACKPermitted && c.SACKEnabled(),
+ RecentTS: parsedOpts.TSVal,
+ TSVal: synOptions.TSVal + 1,
+ }
+}
+
+// SACKEnabled returns true if the TCP Protocol option SACKEnabled is set to true
+// for the Stack in the context.
+func (c *Context) SACKEnabled() bool {
+ var v tcp.SACKEnabled
+ if err := c.Stack().TransportProtocolOption(tcp.ProtocolNumber, &v); err != nil {
+ // Stack doesn't support SACK. So just return.
+ return false
+ }
+ return bool(v)
+}
diff --git a/pkg/tcpip/transport/tcp/timer.go b/pkg/tcpip/transport/tcp/timer.go
new file mode 100644
index 000000000..7aa824d8f
--- /dev/null
+++ b/pkg/tcpip/transport/tcp/timer.go
@@ -0,0 +1,131 @@
+// Copyright 2017 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcp
+
+import (
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+)
+
+type timerState int
+
+const (
+ timerStateDisabled timerState = iota
+ timerStateEnabled
+ timerStateOrphaned
+)
+
+// timer is a timer implementation that reduces the interactions with the
+// runtime timer infrastructure by letting timers run (and potentially
+// eventually expire) even if they are stopped. It makes it cheaper to
+// disable/reenable timers at the expense of spurious wakes. This is useful for
+// cases when the same timer is disabled/reenabled repeatedly with relatively
+// long timeouts farther into the future.
+//
+// TCP retransmit timers benefit from this because they the timeouts are long
+// (currently at least 200ms), and get disabled when acks are received, and
+// reenabled when new pending segments are sent.
+//
+// It is advantageous to avoid interacting with the runtime because it acquires
+// a global mutex and performs O(log n) operations, where n is the global number
+// of timers, whenever a timer is enabled or disabled, and may make a syscall.
+//
+// This struct is thread-compatible.
+type timer struct {
+ // state is the current state of the timer, it can be one of the
+ // following values:
+ // disabled - the timer is disabled.
+ // orphaned - the timer is disabled, but the runtime timer is
+ // enabled, which means that it will evetually cause a
+ // spurious wake (unless it gets enabled again before
+ // then).
+ // enabled - the timer is enabled, but the runtime timer may be set
+ // to an earlier expiration time due to a previous
+ // orphaned state.
+ state timerState
+
+ // target is the expiration time of the current timer. It is only
+ // meaningful in the enabled state.
+ target time.Time
+
+ // runtimeTarget is the expiration time of the runtime timer. It is
+ // meaningful in the enabled and orphaned states.
+ runtimeTarget time.Time
+
+ // timer is the runtime timer used to wait on.
+ timer *time.Timer
+}
+
+// init initializes the timer. Once it expires, it the given waker will be
+// asserted.
+func (t *timer) init(w *sleep.Waker) {
+ t.state = timerStateDisabled
+
+ // Initialize a runtime timer that will assert the waker, then
+ // immediately stop it.
+ t.timer = time.AfterFunc(time.Hour, func() {
+ w.Assert()
+ })
+ t.timer.Stop()
+}
+
+// cleanup frees all resources associated with the timer.
+func (t *timer) cleanup() {
+ t.timer.Stop()
+}
+
+// checkExpiration checks if the given timer has actually expired, it should be
+// called whenever a sleeper wakes up due to the waker being asserted, and is
+// used to check if it's a supurious wake (due to a previously orphaned timer)
+// or a legitimate one.
+func (t *timer) checkExpiration() bool {
+ // Transition to fully disabled state if we're just consuming an
+ // orphaned timer.
+ if t.state == timerStateOrphaned {
+ t.state = timerStateDisabled
+ return false
+ }
+
+ // The timer is enabled, but it may have expired early. Check if that's
+ // the case, and if so, reset the runtime timer to the correct time.
+ now := time.Now()
+ if now.Before(t.target) {
+ t.runtimeTarget = t.target
+ t.timer.Reset(t.target.Sub(now))
+ return false
+ }
+
+ // The timer has actually expired, disable it for now and inform the
+ // caller.
+ t.state = timerStateDisabled
+ return true
+}
+
+// disable disables the timer, leaving it in an orphaned state if it wasn't
+// already disabled.
+func (t *timer) disable() {
+ if t.state != timerStateDisabled {
+ t.state = timerStateOrphaned
+ }
+}
+
+// enabled returns true if the timer is currently enabled, false otherwise.
+func (t *timer) enabled() bool {
+ return t.state == timerStateEnabled
+}
+
+// enable enables the timer, programming the runtime timer if necessary.
+func (t *timer) enable(d time.Duration) {
+ t.target = time.Now().Add(d)
+
+ // Check if we need to set the runtime timer.
+ if t.state == timerStateDisabled || t.target.Before(t.runtimeTarget) {
+ t.runtimeTarget = t.target
+ t.timer.Reset(d)
+ }
+
+ t.state = timerStateEnabled
+}
diff --git a/pkg/tcpip/transport/tcpconntrack/BUILD b/pkg/tcpip/transport/tcpconntrack/BUILD
new file mode 100644
index 000000000..cf83ca134
--- /dev/null
+++ b/pkg/tcpip/transport/tcpconntrack/BUILD
@@ -0,0 +1,24 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+
+go_library(
+ name = "tcpconntrack",
+ srcs = ["tcp_conntrack.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcpconntrack",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/seqnum",
+ ],
+)
+
+go_test(
+ name = "tcpconntrack_test",
+ size = "small",
+ srcs = ["tcp_conntrack_test.go"],
+ deps = [
+ ":tcpconntrack",
+ "//pkg/tcpip/header",
+ ],
+)
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
new file mode 100644
index 000000000..487f2572d
--- /dev/null
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack.go
@@ -0,0 +1,333 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package tcpconntrack implements a TCP connection tracking object. It allows
+// users with access to a segment stream to figure out when a connection is
+// established, reset, and closed (and in the last case, who closed first).
+package tcpconntrack
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+)
+
+// Result is returned when the state of a TCB is updated in response to an
+// inbound or outbound segment.
+type Result int
+
+const (
+ // ResultDrop indicates that the segment should be dropped.
+ ResultDrop Result = iota
+
+ // ResultConnecting indicates that the connection remains in a
+ // connecting state.
+ ResultConnecting
+
+ // ResultAlive indicates that the connection remains alive (connected).
+ ResultAlive
+
+ // ResultReset indicates that the connection was reset.
+ ResultReset
+
+ // ResultClosedByPeer indicates that the connection was gracefully
+ // closed, and the inbound stream was closed first.
+ ResultClosedByPeer
+
+ // ResultClosedBySelf indicates that the connection was gracefully
+ // closed, and the outbound stream was closed first.
+ ResultClosedBySelf
+)
+
+// TCB is a TCP Control Block. It holds state necessary to keep track of a TCP
+// connection and inform the caller when the connection has been closed.
+type TCB struct {
+ inbound stream
+ outbound stream
+
+ // State handlers.
+ handlerInbound func(*TCB, header.TCP) Result
+ handlerOutbound func(*TCB, header.TCP) Result
+
+ // firstFin holds a pointer to the first stream to send a FIN.
+ firstFin *stream
+
+ // state is the current state of the stream.
+ state Result
+}
+
+// Init initializes the state of the TCB according to the initial SYN.
+func (t *TCB) Init(initialSyn header.TCP) {
+ t.handlerInbound = synSentStateInbound
+ t.handlerOutbound = synSentStateOutbound
+
+ iss := seqnum.Value(initialSyn.SequenceNumber())
+ t.outbound.una = iss
+ t.outbound.nxt = iss.Add(logicalLen(initialSyn))
+ t.outbound.end = t.outbound.nxt
+
+ // Even though "end" is a sequence number, we don't know the initial
+ // receive sequence number yet, so we store the window size until we get
+ // a SYN from the peer.
+ t.inbound.una = 0
+ t.inbound.nxt = 0
+ t.inbound.end = seqnum.Value(initialSyn.WindowSize())
+ t.state = ResultConnecting
+}
+
+// UpdateStateInbound updates the state of the TCB based on the supplied inbound
+// segment.
+func (t *TCB) UpdateStateInbound(tcp header.TCP) Result {
+ st := t.handlerInbound(t, tcp)
+ if st != ResultDrop {
+ t.state = st
+ }
+ return st
+}
+
+// UpdateStateOutbound updates the state of the TCB based on the supplied
+// outbound segment.
+func (t *TCB) UpdateStateOutbound(tcp header.TCP) Result {
+ st := t.handlerOutbound(t, tcp)
+ if st != ResultDrop {
+ t.state = st
+ }
+ return st
+}
+
+// IsAlive returns true as long as the connection is established(Alive)
+// or connecting state.
+func (t *TCB) IsAlive() bool {
+ return !t.inbound.rstSeen && !t.outbound.rstSeen && (!t.inbound.closed() || !t.outbound.closed())
+}
+
+// OutboundSendSequenceNumber returns the snd.NXT for the outbound stream.
+func (t *TCB) OutboundSendSequenceNumber() seqnum.Value {
+ return t.outbound.nxt
+}
+
+// adapResult modifies the supplied "Result" according to the state of the TCB;
+// if r is anything other than "Alive", or if one of the streams isn't closed
+// yet, it is returned unmodified. Otherwise it's converted to either
+// ClosedBySelf or ClosedByPeer depending on which stream was closed first.
+func (t *TCB) adaptResult(r Result) Result {
+ // Check the unmodified case.
+ if r != ResultAlive || !t.inbound.closed() || !t.outbound.closed() {
+ return r
+ }
+
+ // Find out which was closed first.
+ if t.firstFin == &t.outbound {
+ return ResultClosedBySelf
+ }
+
+ return ResultClosedByPeer
+}
+
+// synSentStateInbound is the state handler for inbound segments when the
+// connection is in SYN-SENT state.
+func synSentStateInbound(t *TCB, tcp header.TCP) Result {
+ flags := tcp.Flags()
+ ackPresent := flags&header.TCPFlagAck != 0
+ ack := seqnum.Value(tcp.AckNumber())
+
+ // Ignore segment if ack is present but not acceptable.
+ if ackPresent && !(ack-1).InRange(t.outbound.una, t.outbound.nxt) {
+ return ResultConnecting
+ }
+
+ // If reset is specified, we will let the packet through no matter what
+ // but we will also destroy the connection if the ACK is present (and
+ // implicitly acceptable).
+ if flags&header.TCPFlagRst != 0 {
+ if ackPresent {
+ t.inbound.rstSeen = true
+ return ResultReset
+ }
+ return ResultConnecting
+ }
+
+ // Ignore segment if SYN is not set.
+ if flags&header.TCPFlagSyn == 0 {
+ return ResultConnecting
+ }
+
+ // Update state informed by this SYN.
+ irs := seqnum.Value(tcp.SequenceNumber())
+ t.inbound.una = irs
+ t.inbound.nxt = irs.Add(logicalLen(tcp))
+ t.inbound.end += irs
+
+ t.outbound.end = t.outbound.una.Add(seqnum.Size(tcp.WindowSize()))
+
+ // If the ACK was set (it is acceptable), update our unacknowledgement
+ // tracking.
+ if ackPresent {
+ // Advance the "una" and "end" indices of the outbound stream.
+ if t.outbound.una.LessThan(ack) {
+ t.outbound.una = ack
+ }
+
+ if end := ack.Add(seqnum.Size(tcp.WindowSize())); t.outbound.end.LessThan(end) {
+ t.outbound.end = end
+ }
+ }
+
+ // Update handlers so that new calls will be handled by new state.
+ t.handlerInbound = allOtherInbound
+ t.handlerOutbound = allOtherOutbound
+
+ return ResultAlive
+}
+
+// synSentStateOutbound is the state handler for outbound segments when the
+// connection is in SYN-SENT state.
+func synSentStateOutbound(t *TCB, tcp header.TCP) Result {
+ // Drop outbound segments that aren't retransmits of the original one.
+ if tcp.Flags() != header.TCPFlagSyn ||
+ tcp.SequenceNumber() != uint32(t.outbound.una) {
+ return ResultDrop
+ }
+
+ // Update the receive window. We only remember the largest value seen.
+ if wnd := seqnum.Value(tcp.WindowSize()); wnd > t.inbound.end {
+ t.inbound.end = wnd
+ }
+
+ return ResultConnecting
+}
+
+// update updates the state of inbound and outbound streams, given the supplied
+// inbound segment. For outbound segments, this same function can be called with
+// swapped inbound/outbound streams.
+func update(tcp header.TCP, inbound, outbound *stream, firstFin **stream) Result {
+ // Ignore segments out of the window.
+ s := seqnum.Value(tcp.SequenceNumber())
+ if !inbound.acceptable(s, dataLen(tcp)) {
+ return ResultAlive
+ }
+
+ flags := tcp.Flags()
+ if flags&header.TCPFlagRst != 0 {
+ inbound.rstSeen = true
+ return ResultReset
+ }
+
+ // Ignore segments that don't have the ACK flag, and those with the SYN
+ // flag.
+ if flags&header.TCPFlagAck == 0 || flags&header.TCPFlagSyn != 0 {
+ return ResultAlive
+ }
+
+ // Ignore segments that acknowledge not yet sent data.
+ ack := seqnum.Value(tcp.AckNumber())
+ if outbound.nxt.LessThan(ack) {
+ return ResultAlive
+ }
+
+ // Advance the "una" and "end" indices of the outbound stream.
+ if outbound.una.LessThan(ack) {
+ outbound.una = ack
+ }
+
+ if end := ack.Add(seqnum.Size(tcp.WindowSize())); outbound.end.LessThan(end) {
+ outbound.end = end
+ }
+
+ // Advance the "nxt" index of the inbound stream.
+ end := s.Add(logicalLen(tcp))
+ if inbound.nxt.LessThan(end) {
+ inbound.nxt = end
+ }
+
+ // Note the index of the FIN segment. And stash away a pointer to the
+ // first stream to see a FIN.
+ if flags&header.TCPFlagFin != 0 && !inbound.finSeen {
+ inbound.finSeen = true
+ inbound.fin = end - 1
+
+ if *firstFin == nil {
+ *firstFin = inbound
+ }
+ }
+
+ return ResultAlive
+}
+
+// allOtherInbound is the state handler for inbound segments in all states
+// except SYN-SENT.
+func allOtherInbound(t *TCB, tcp header.TCP) Result {
+ return t.adaptResult(update(tcp, &t.inbound, &t.outbound, &t.firstFin))
+}
+
+// allOtherOutbound is the state handler for outbound segments in all states
+// except SYN-SENT.
+func allOtherOutbound(t *TCB, tcp header.TCP) Result {
+ return t.adaptResult(update(tcp, &t.outbound, &t.inbound, &t.firstFin))
+}
+
+// streams holds the state of a TCP unidirectional stream.
+type stream struct {
+ // The interval [una, end) is the allowed interval as defined by the
+ // receiver, i.e., anything less than una has already been acknowledged
+ // and anything greater than or equal to end is beyond the receiver
+ // window. The interval [una, nxt) is the acknowledgable range, whose
+ // right edge indicates the sequence number of the next byte to be sent
+ // by the sender, i.e., anything greater than or equal to nxt hasn't
+ // been sent yet.
+ una seqnum.Value
+ nxt seqnum.Value
+ end seqnum.Value
+
+ // finSeen indicates if a FIN has already been sent on this stream.
+ finSeen bool
+
+ // fin is the sequence number of the FIN. It is only valid after finSeen
+ // is set to true.
+ fin seqnum.Value
+
+ // rstSeen indicates if a RST has already been sent on this stream.
+ rstSeen bool
+}
+
+// acceptable determines if the segment with the given sequence number and data
+// length is acceptable, i.e., if it's within the [una, end) window or, in case
+// the window is zero, if it's a packet with no payload and sequence number
+// equal to una.
+func (s *stream) acceptable(segSeq seqnum.Value, segLen seqnum.Size) bool {
+ wnd := s.una.Size(s.end)
+ if wnd == 0 {
+ return segLen == 0 && segSeq == s.una
+ }
+
+ // Make sure [segSeq, seqSeq+segLen) is non-empty.
+ if segLen == 0 {
+ segLen = 1
+ }
+
+ return seqnum.Overlap(s.una, wnd, segSeq, segLen)
+}
+
+// closed determines if the stream has already been closed. This happens when
+// a FIN has been set by the sender and acknowledged by the receiver.
+func (s *stream) closed() bool {
+ return s.finSeen && s.fin.LessThan(s.una)
+}
+
+// dataLen returns the length of the TCP segment payload.
+func dataLen(tcp header.TCP) seqnum.Size {
+ return seqnum.Size(len(tcp) - int(tcp.DataOffset()))
+}
+
+// logicalLen calculates the logical length of the TCP segment.
+func logicalLen(tcp header.TCP) seqnum.Size {
+ l := dataLen(tcp)
+ flags := tcp.Flags()
+ if flags&header.TCPFlagSyn != 0 {
+ l++
+ }
+ if flags&header.TCPFlagFin != 0 {
+ l++
+ }
+ return l
+}
diff --git a/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
new file mode 100644
index 000000000..50cab3132
--- /dev/null
+++ b/pkg/tcpip/transport/tcpconntrack/tcp_conntrack_test.go
@@ -0,0 +1,501 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package tcpconntrack_test
+
+import (
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcpconntrack"
+)
+
+// connected creates a connection tracker TCB and sets it to a connected state
+// by performing a 3-way handshake.
+func connected(t *testing.T, iss, irs uint32, isw, irw uint16) *tcpconntrack.TCB {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: iss,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: irw,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive SYN-ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: irs,
+ AckNum: iss + 1,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ WindowSize: isw,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: iss + 1,
+ AckNum: irs + 1,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: irw,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ return &tcb
+}
+
+func TestConnectionRefused(t *testing.T) {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive RST.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 1235,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagRst | header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
+ }
+}
+
+func TestConnectionRefusedInSynRcvd(t *testing.T) {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive SYN.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive RST with no ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 790,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagRst,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultReset {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
+ }
+}
+
+func TestConnectionResetInSynRcvd(t *testing.T) {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive SYN.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send RST with no ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1235,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagRst,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultReset {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultReset)
+ }
+}
+
+func TestRetransmitOnSynSent(t *testing.T) {
+ // Send initial SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Retransmit the same SYN.
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultConnecting {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultConnecting)
+ }
+}
+
+func TestRetransmitOnSynRcvd(t *testing.T) {
+ // Send initial SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive SYN. This will cause the state to go to SYN-RCVD.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Retransmit the original SYN.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Transmit a SYN-ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 790,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+}
+
+func TestClosedBySelf(t *testing.T) {
+ tcb := connected(t, 1234, 789, 30000, 50000)
+
+ // Send FIN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1235,
+ AckNum: 790,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive FIN/ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 790,
+ AckNum: 1236,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1236,
+ AckNum: 791,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
+ }
+}
+
+func TestClosedByPeer(t *testing.T) {
+ tcb := connected(t, 1234, 789, 30000, 50000)
+
+ // Receive FIN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 790,
+ AckNum: 1235,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send FIN/ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1235,
+ AckNum: 791,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 791,
+ AckNum: 1236,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultClosedByPeer {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedByPeer)
+ }
+}
+
+func TestSendAndReceiveDataClosedBySelf(t *testing.T) {
+ sseq := uint32(1234)
+ rseq := uint32(789)
+ tcb := connected(t, sseq, rseq, 30000, 50000)
+ sseq++
+ rseq++
+
+ // Send some data.
+ tcp := make(header.TCP, header.TCPMinimumSize+1024)
+
+ for i := uint32(0); i < 10; i++ {
+ // Send some data.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: sseq,
+ AckNum: rseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+ sseq += uint32(len(tcp)) - header.TCPMinimumSize
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive ack for data.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: rseq,
+ AckNum: sseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+ }
+
+ for i := uint32(0); i < 10; i++ {
+ // Receive some data.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: rseq,
+ AckNum: sseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+ rseq += uint32(len(tcp)) - header.TCPMinimumSize
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ack for data.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: sseq,
+ AckNum: rseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp[:header.TCPMinimumSize]); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+ }
+
+ // Send FIN.
+ tcp = tcp[:header.TCPMinimumSize]
+ tcp.Encode(&header.TCPFields{
+ SeqNum: sseq,
+ AckNum: rseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 30000,
+ })
+ sseq++
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Receive FIN/ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: rseq,
+ AckNum: sseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck | header.TCPFlagFin,
+ WindowSize: 50000,
+ })
+ rseq++
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: sseq,
+ AckNum: rseq,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultClosedBySelf {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultClosedBySelf)
+ }
+}
+
+func TestIgnoreBadResetOnSynSent(t *testing.T) {
+ // Send SYN.
+ tcp := make(header.TCP, header.TCPMinimumSize)
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1234,
+ AckNum: 0,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn,
+ WindowSize: 30000,
+ })
+
+ tcb := tcpconntrack.TCB{}
+ tcb.Init(tcp)
+
+ // Receive a RST with a bad ACK, it should not cause the connection to
+ // be reset.
+ acks := []uint32{1234, 1236, 1000, 5000}
+ flags := []uint8{header.TCPFlagRst, header.TCPFlagRst | header.TCPFlagAck}
+ for _, a := range acks {
+ for _, f := range flags {
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: a,
+ DataOffset: header.TCPMinimumSize,
+ Flags: f,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultConnecting {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+ }
+ }
+
+ // Complete the handshake.
+ // Receive SYN-ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 789,
+ AckNum: 1235,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagSyn | header.TCPFlagAck,
+ WindowSize: 50000,
+ })
+
+ if r := tcb.UpdateStateInbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+
+ // Send ACK.
+ tcp.Encode(&header.TCPFields{
+ SeqNum: 1235,
+ AckNum: 790,
+ DataOffset: header.TCPMinimumSize,
+ Flags: header.TCPFlagAck,
+ WindowSize: 30000,
+ })
+
+ if r := tcb.UpdateStateOutbound(tcp); r != tcpconntrack.ResultAlive {
+ t.Fatalf("Bad result: got %v, want %v", r, tcpconntrack.ResultAlive)
+ }
+}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
new file mode 100644
index 000000000..ac34a932e
--- /dev/null
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -0,0 +1,77 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "udp_state",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ "udp_packet_list.go",
+ ],
+ out = "udp_state.go",
+ imports = ["gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"],
+ package = "udp",
+)
+
+go_template_instance(
+ name = "udp_packet_list",
+ out = "udp_packet_list.go",
+ package = "udp",
+ prefix = "udpPacket",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Linker": "*udpPacket",
+ },
+)
+
+go_library(
+ name = "udp",
+ srcs = [
+ "endpoint.go",
+ "endpoint_state.go",
+ "protocol.go",
+ "udp_packet_list.go",
+ "udp_state.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp",
+ visibility = ["//visibility:public"],
+ deps = [
+ "//pkg/sleep",
+ "//pkg/state",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "udp_x_test",
+ size = "small",
+ srcs = ["udp_test.go"],
+ deps = [
+ ":udp",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/checker",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/link/channel",
+ "//pkg/tcpip/link/sniffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/waiter",
+ ],
+)
+
+filegroup(
+ name = "autogen",
+ srcs = [
+ "udp_packet_list.go",
+ ],
+ visibility = ["//:sandbox"],
+)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
new file mode 100644
index 000000000..80fa88c4c
--- /dev/null
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -0,0 +1,746 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package udp
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+type udpPacket struct {
+ udpPacketEntry
+ senderAddress tcpip.FullAddress
+ data buffer.VectorisedView `state:".(buffer.VectorisedView)"`
+ // views is used as buffer for data when its length is large
+ // enough to store a VectorisedView.
+ views [8]buffer.View `state:"nosave"`
+}
+
+type endpointState int
+
+const (
+ stateInitial endpointState = iota
+ stateBound
+ stateConnected
+ stateClosed
+)
+
+// endpoint represents a UDP endpoint. This struct serves as the interface
+// between users of the endpoint and the protocol implementation; it is legal to
+// have concurrent goroutines make calls into the endpoint, they are properly
+// synchronized.
+type endpoint struct {
+ // The following fields are initialized at creation time and do not
+ // change throughout the lifetime of the endpoint.
+ stack *stack.Stack `state:"manual"`
+ netProto tcpip.NetworkProtocolNumber
+ waiterQueue *waiter.Queue
+
+ // The following fields are used to manage the receive queue, and are
+ // protected by rcvMu.
+ rcvMu sync.Mutex `state:"nosave"`
+ rcvReady bool
+ rcvList udpPacketList
+ rcvBufSizeMax int `state:".(int)"`
+ rcvBufSize int
+ rcvClosed bool
+
+ // The following fields are protected by the mu mutex.
+ mu sync.RWMutex `state:"nosave"`
+ sndBufSize int
+ id stack.TransportEndpointID
+ state endpointState
+ bindNICID tcpip.NICID
+ bindAddr tcpip.Address
+ regNICID tcpip.NICID
+ route stack.Route `state:"manual"`
+ dstPort uint16
+ v6only bool
+
+ // effectiveNetProtos contains the network protocols actually in use. In
+ // most cases it will only contain "netProto", but in cases like IPv6
+ // endpoints with v6only set to false, this could include multiple
+ // protocols (e.g., IPv6 and IPv4) or a single different protocol (e.g.,
+ // IPv4 when IPv6 endpoint is bound or connected to an IPv4 mapped
+ // address).
+ effectiveNetProtos []tcpip.NetworkProtocolNumber
+}
+
+func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) *endpoint {
+ return &endpoint{
+ stack: stack,
+ netProto: netProto,
+ waiterQueue: waiterQueue,
+ rcvBufSizeMax: 32 * 1024,
+ sndBufSize: 32 * 1024,
+ }
+}
+
+// NewConnectedEndpoint creates a new endpoint in the connected state using the
+// provided route.
+func NewConnectedEndpoint(stack *stack.Stack, r *stack.Route, id stack.TransportEndpointID, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ ep := newEndpoint(stack, r.NetProto, waiterQueue)
+
+ // Register new endpoint so that packets are routed to it.
+ if err := stack.RegisterTransportEndpoint(r.NICID(), []tcpip.NetworkProtocolNumber{r.NetProto}, ProtocolNumber, id, ep); err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ ep.id = id
+ ep.route = r.Clone()
+ ep.dstPort = id.RemotePort
+ ep.regNICID = r.NICID()
+
+ ep.state = stateConnected
+
+ return ep, nil
+}
+
+// Close puts the endpoint in a closed state and frees all resources
+// associated with it.
+func (e *endpoint) Close() {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ switch e.state {
+ case stateBound, stateConnected:
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ }
+
+ // Close the receive list and drain it.
+ e.rcvMu.Lock()
+ e.rcvClosed = true
+ e.rcvBufSize = 0
+ for !e.rcvList.Empty() {
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ }
+ e.rcvMu.Unlock()
+
+ e.route.Release()
+
+ // Update the state.
+ e.state = stateClosed
+}
+
+// Read reads data from the endpoint. This method does not block if
+// there is no data pending.
+func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, *tcpip.Error) {
+ e.rcvMu.Lock()
+
+ if e.rcvList.Empty() {
+ err := tcpip.ErrWouldBlock
+ if e.rcvClosed {
+ err = tcpip.ErrClosedForReceive
+ }
+ e.rcvMu.Unlock()
+ return buffer.View{}, err
+ }
+
+ p := e.rcvList.Front()
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+
+ e.rcvMu.Unlock()
+
+ if addr != nil {
+ *addr = p.senderAddress
+ }
+
+ return p.data.ToView(), nil
+}
+
+// prepareForWrite prepares the endpoint for sending data. In particular, it
+// binds it if it's still in the initial state. To do so, it must first
+// reacquire the mutex in exclusive mode.
+//
+// Returns true for retry if preparation should be retried.
+func (e *endpoint) prepareForWrite(to *tcpip.FullAddress) (retry bool, err *tcpip.Error) {
+ switch e.state {
+ case stateInitial:
+ case stateConnected:
+ return false, nil
+
+ case stateBound:
+ if to == nil {
+ return false, tcpip.ErrDestinationRequired
+ }
+ return false, nil
+ default:
+ return false, tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // The state changed when we released the shared locked and re-acquired
+ // it in exclusive mode. Try again.
+ if e.state != stateInitial {
+ return true, nil
+ }
+
+ // The state is still 'initial', so try to bind the endpoint.
+ if err := e.bindLocked(tcpip.FullAddress{}, nil); err != nil {
+ return false, err
+ }
+
+ return true, nil
+}
+
+// Write writes data to the endpoint's peer. This method does not block
+// if the data cannot be written.
+func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) {
+ // MSG_MORE is unimplemented. (This also means that MSG_EOR is a no-op.)
+ if opts.More {
+ return 0, tcpip.ErrInvalidOptionValue
+ }
+
+ to := opts.To
+
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ // Prepare for write.
+ for {
+ retry, err := e.prepareForWrite(to)
+ if err != nil {
+ return 0, err
+ }
+
+ if !retry {
+ break
+ }
+ }
+
+ var route *stack.Route
+ var dstPort uint16
+ if to == nil {
+ route = &e.route
+ dstPort = e.dstPort
+
+ if route.IsResolutionRequired() {
+ // Promote lock to exclusive if using a shared route, given that it may need to
+ // change in Route.Resolve() call below.
+ e.mu.RUnlock()
+ defer e.mu.RLock()
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // Recheck state after lock was re-acquired.
+ if e.state != stateConnected {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+ }
+ } else {
+ // Reject destination address if it goes through a different
+ // NIC than the endpoint was bound to.
+ nicid := to.NIC
+ if e.bindNICID != 0 {
+ if nicid != 0 && nicid != e.bindNICID {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ nicid = e.bindNICID
+ }
+
+ toCopy := *to
+ to = &toCopy
+ netProto, err := e.checkV4Mapped(to, true)
+ if err != nil {
+ return 0, err
+ }
+
+ // Find the enpoint.
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, to.Addr, netProto)
+ if err != nil {
+ return 0, err
+ }
+ defer r.Release()
+
+ route = &r
+ dstPort = to.Port
+ }
+
+ if route.IsResolutionRequired() {
+ waker := &sleep.Waker{}
+ if err := route.Resolve(waker); err != nil {
+ if err == tcpip.ErrWouldBlock {
+ // Link address needs to be resolved. Resolution was triggered the background.
+ // Better luck next time.
+ //
+ // TODO: queue up the request and send after link address
+ // is resolved.
+ route.RemoveWaker(waker)
+ return 0, tcpip.ErrNoLinkAddress
+ }
+ return 0, err
+ }
+ }
+
+ v, err := p.Get(p.Size())
+ if err != nil {
+ return 0, err
+ }
+ sendUDP(route, v, e.id.LocalPort, dstPort)
+ return uintptr(len(v)), nil
+}
+
+// Peek only returns data from a single datagram, so do nothing here.
+func (e *endpoint) Peek([][]byte) (uintptr, *tcpip.Error) {
+ return 0, nil
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ // TODO: Actually implement this.
+ switch v := opt.(type) {
+ case tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ // We only allow this to be set when we're in the initial state.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.v6only = v != 0
+ }
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.SendBufferSizeOption:
+ e.mu.Lock()
+ *o = tcpip.SendBufferSizeOption(e.sndBufSize)
+ e.mu.Unlock()
+ return nil
+
+ case *tcpip.ReceiveBufferSizeOption:
+ e.rcvMu.Lock()
+ *o = tcpip.ReceiveBufferSizeOption(e.rcvBufSizeMax)
+ e.rcvMu.Unlock()
+ return nil
+
+ case *tcpip.V6OnlyOption:
+ // We only recognize this option on v6 endpoints.
+ if e.netProto != header.IPv6ProtocolNumber {
+ return tcpip.ErrUnknownProtocolOption
+ }
+
+ e.mu.Lock()
+ v := e.v6only
+ e.mu.Unlock()
+
+ *o = 0
+ if v {
+ *o = 1
+ }
+ return nil
+
+ case *tcpip.ReceiveQueueSizeOption:
+ e.rcvMu.Lock()
+ if e.rcvList.Empty() {
+ *o = 0
+ } else {
+ p := e.rcvList.Front()
+ *o = tcpip.ReceiveQueueSizeOption(p.data.Size())
+ }
+ e.rcvMu.Unlock()
+ return nil
+ }
+
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// sendUDP sends a UDP segment via the provided network endpoint and under the
+// provided identity.
+func sendUDP(r *stack.Route, data buffer.View, localPort, remotePort uint16) *tcpip.Error {
+ // Allocate a buffer for the UDP header.
+ hdr := buffer.NewPrependable(header.UDPMinimumSize + int(r.MaxHeaderLength()))
+
+ // Initialize the header.
+ udp := header.UDP(hdr.Prepend(header.UDPMinimumSize))
+
+ length := uint16(hdr.UsedLength()) + uint16(len(data))
+ udp.Encode(&header.UDPFields{
+ SrcPort: localPort,
+ DstPort: remotePort,
+ Length: length,
+ })
+
+ // Only calculate the checksum if offloading isn't supported.
+ if r.Capabilities()&stack.CapabilityChecksumOffload == 0 {
+ xsum := r.PseudoHeaderChecksum(ProtocolNumber)
+ if data != nil {
+ xsum = header.Checksum(data, xsum)
+ }
+
+ udp.SetChecksum(^udp.CalculateChecksum(xsum, length))
+ }
+
+ return r.WritePacket(&hdr, data, ProtocolNumber)
+}
+
+func (e *endpoint) checkV4Mapped(addr *tcpip.FullAddress, allowMismatch bool) (tcpip.NetworkProtocolNumber, *tcpip.Error) {
+ netProto := e.netProto
+ if header.IsV4MappedAddress(addr.Addr) {
+ // Fail if using a v4 mapped address on a v6only endpoint.
+ if e.v6only {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ netProto = header.IPv4ProtocolNumber
+ addr.Addr = addr.Addr[header.IPv6AddressSize-header.IPv4AddressSize:]
+ if addr.Addr == "\x00\x00\x00\x00" {
+ addr.Addr = ""
+ }
+ }
+
+ // Fail if we're bound to an address length different from the one we're
+ // checking.
+ if l := len(e.id.LocalAddress); !allowMismatch && l != 0 && l != len(addr.Addr) {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+
+ return netProto, nil
+}
+
+// Connect connects the endpoint to its peer. Specifying a NIC is optional.
+func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ if addr.Port == 0 {
+ // We don't support connecting to port zero.
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ nicid := addr.NIC
+ localPort := uint16(0)
+ switch e.state {
+ case stateInitial:
+ case stateBound, stateConnected:
+ localPort = e.id.LocalPort
+ if e.bindNICID == 0 {
+ break
+ }
+
+ if nicid != 0 && nicid != e.bindNICID {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ nicid = e.bindNICID
+ default:
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
+ }
+
+ // Find a route to the desired destination.
+ r, err := e.stack.FindRoute(nicid, e.bindAddr, addr.Addr, netProto)
+ if err != nil {
+ return err
+ }
+ defer r.Release()
+
+ id := stack.TransportEndpointID{
+ LocalAddress: r.LocalAddress,
+ LocalPort: localPort,
+ RemotePort: addr.Port,
+ RemoteAddress: r.RemoteAddress,
+ }
+
+ // Even if we're connected, this endpoint can still be used to send
+ // packets on a different network protocol, so we register both even if
+ // v6only is set to false and this is an ipv6 endpoint.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if e.netProto == header.IPv6ProtocolNumber && !e.v6only {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv4ProtocolNumber,
+ header.IPv6ProtocolNumber,
+ }
+ }
+
+ id, err = e.registerWithStack(nicid, netProtos, id)
+ if err != nil {
+ return err
+ }
+
+ // Remove the old registration.
+ if e.id.LocalPort != 0 {
+ e.stack.UnregisterTransportEndpoint(e.regNICID, e.effectiveNetProtos, ProtocolNumber, e.id)
+ }
+
+ e.id = id
+ e.route = r.Clone()
+ e.dstPort = addr.Port
+ e.regNICID = nicid
+ e.effectiveNetProtos = netProtos
+
+ e.state = stateConnected
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// ConnectEndpoint is not supported.
+func (*endpoint) ConnectEndpoint(tcpip.Endpoint) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection
+// to its peer.
+func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.ErrNotConnected
+ }
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.rcvMu.Lock()
+ wasClosed := e.rcvClosed
+ e.rcvClosed = true
+ e.rcvMu.Unlock()
+
+ if !wasClosed {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+ }
+
+ return nil
+}
+
+// Listen is not supported by UDP, it just fails.
+func (*endpoint) Listen(int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept is not supported by UDP, it just fails.
+func (*endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, tcpip.ErrNotSupported
+}
+
+func (e *endpoint) registerWithStack(nicid tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, id stack.TransportEndpointID) (stack.TransportEndpointID, *tcpip.Error) {
+ if id.LocalPort != 0 {
+ // The endpoint already has a local port, just attempt to
+ // register it.
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
+ return id, err
+ }
+
+ // We need to find a port for the endpoint.
+ _, err := e.stack.PickEphemeralPort(func(p uint16) (bool, *tcpip.Error) {
+ id.LocalPort = p
+ err := e.stack.RegisterTransportEndpoint(nicid, netProtos, ProtocolNumber, id, e)
+ switch err {
+ case nil:
+ return true, nil
+ case tcpip.ErrPortInUse:
+ return false, nil
+ default:
+ return false, err
+ }
+ })
+
+ return id, err
+}
+
+func (e *endpoint) bindLocked(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ // Don't allow binding once endpoint is not in the initial state
+ // anymore.
+ if e.state != stateInitial {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ netProto, err := e.checkV4Mapped(&addr, false)
+ if err != nil {
+ return err
+ }
+
+ // Expand netProtos to include v4 and v6 if the caller is binding to a
+ // wildcard (empty) address, and this is an IPv6 endpoint with v6only
+ // set to false.
+ netProtos := []tcpip.NetworkProtocolNumber{netProto}
+ if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" {
+ netProtos = []tcpip.NetworkProtocolNumber{
+ header.IPv6ProtocolNumber,
+ header.IPv4ProtocolNumber,
+ }
+ }
+
+ if len(addr.Addr) != 0 {
+ // A local address was specified, verify that it's valid.
+ if e.stack.CheckLocalAddress(addr.NIC, netProto, addr.Addr) == 0 {
+ return tcpip.ErrBadLocalAddress
+ }
+ }
+
+ id := stack.TransportEndpointID{
+ LocalPort: addr.Port,
+ LocalAddress: addr.Addr,
+ }
+ id, err = e.registerWithStack(addr.NIC, netProtos, id)
+ if err != nil {
+ return err
+ }
+ if commit != nil {
+ if err := commit(); err != nil {
+ // Unregister, the commit failed.
+ e.stack.UnregisterTransportEndpoint(addr.NIC, netProtos, ProtocolNumber, id)
+ return err
+ }
+ }
+
+ e.id = id
+ e.regNICID = addr.NIC
+ e.effectiveNetProtos = netProtos
+
+ // Mark endpoint as bound.
+ e.state = stateBound
+
+ e.rcvMu.Lock()
+ e.rcvReady = true
+ e.rcvMu.Unlock()
+
+ return nil
+}
+
+// Bind binds the endpoint to a specific local address and port.
+// Specifying a NIC is optional.
+func (e *endpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ e.mu.Lock()
+ defer e.mu.Unlock()
+
+ err := e.bindLocked(addr, commit)
+ if err != nil {
+ return err
+ }
+
+ e.bindNICID = addr.NIC
+ e.bindAddr = addr.Addr
+
+ return nil
+}
+
+// GetLocalAddress returns the address to which the endpoint is bound.
+func (e *endpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.LocalAddress,
+ Port: e.id.LocalPort,
+ }, nil
+}
+
+// GetRemoteAddress returns the address to which the endpoint is connected.
+func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+
+ if e.state != stateConnected {
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+ }
+
+ return tcpip.FullAddress{
+ NIC: e.regNICID,
+ Addr: e.id.RemoteAddress,
+ Port: e.id.RemotePort,
+ }, nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // The endpoint is always writable.
+ result := waiter.EventOut & mask
+
+ // Determine if the endpoint is readable if requested.
+ if (mask & waiter.EventIn) != 0 {
+ e.rcvMu.Lock()
+ if !e.rcvList.Empty() || e.rcvClosed {
+ result |= waiter.EventIn
+ }
+ e.rcvMu.Unlock()
+ }
+
+ return result
+}
+
+// HandlePacket is called by the stack when new packets arrive to this transport
+// endpoint.
+func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv *buffer.VectorisedView) {
+ // Get the header then trim it from the view.
+ hdr := header.UDP(vv.First())
+ if int(hdr.Length()) > vv.Size() {
+ // Malformed packet.
+ return
+ }
+
+ vv.TrimFront(header.UDPMinimumSize)
+
+ e.rcvMu.Lock()
+
+ // Drop the packet if our buffer is currently full.
+ if !e.rcvReady || e.rcvClosed || e.rcvBufSize >= e.rcvBufSizeMax {
+ e.rcvMu.Unlock()
+ return
+ }
+
+ wasEmpty := e.rcvBufSize == 0
+
+ // Push new packet into receive list and increment the buffer size.
+ pkt := &udpPacket{
+ senderAddress: tcpip.FullAddress{
+ NIC: r.NICID(),
+ Addr: id.RemoteAddress,
+ Port: hdr.SourcePort(),
+ },
+ }
+ pkt.data = vv.Clone(pkt.views[:])
+ e.rcvList.PushBack(pkt)
+ e.rcvBufSize += vv.Size()
+
+ e.rcvMu.Unlock()
+
+ // Notify any waiters that there's data to be read now.
+ if wasEmpty {
+ e.waiterQueue.Notify(waiter.EventIn)
+ }
+}
+
+// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket.
+func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv *buffer.VectorisedView) {
+}
diff --git a/pkg/tcpip/transport/udp/endpoint_state.go b/pkg/tcpip/transport/udp/endpoint_state.go
new file mode 100644
index 000000000..41b98424a
--- /dev/null
+++ b/pkg/tcpip/transport/udp/endpoint_state.go
@@ -0,0 +1,91 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// saveData saves udpPacket.data field.
+func (u *udpPacket) saveData() buffer.VectorisedView {
+ // We canoot save u.data directly as u.data.views may alias to u.views,
+ // which is not allowed by state framework (in-struct pointer).
+ return u.data.Clone(nil)
+}
+
+// loadData loads udpPacket.data field.
+func (u *udpPacket) loadData(data buffer.VectorisedView) {
+ // NOTE: We cannot do the u.data = data.Clone(u.views[:]) optimization
+ // here because data.views is not guaranteed to be loaded by now. Plus,
+ // data.views will be allocated anyway so there really is little point
+ // of utilizing u.views for data.views.
+ u.data = data
+}
+
+// beforeSave is invoked by stateify.
+func (e *endpoint) beforeSave() {
+ // Stop incoming packets from being handled (and mutate endpoint state).
+ // The lock will be released after savercvBufSizeMax(), which would have
+ // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming
+ // packets.
+ e.rcvMu.Lock()
+}
+
+// saveRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) saveRcvBufSizeMax() int {
+ max := e.rcvBufSizeMax
+ // Make sure no new packets will be handled regardless of the lock.
+ e.rcvBufSizeMax = 0
+ // Release the lock acquired in beforeSave() so regular endpoint closing
+ // logic can proceed after save.
+ e.rcvMu.Unlock()
+ return max
+}
+
+// loadRcvBufSizeMax is invoked by stateify.
+func (e *endpoint) loadRcvBufSizeMax(max int) {
+ e.rcvBufSizeMax = max
+}
+
+// afterLoad is invoked by stateify.
+func (e *endpoint) afterLoad() {
+ e.stack = stack.StackFromEnv
+
+ if e.state != stateBound && e.state != stateConnected {
+ return
+ }
+
+ netProto := e.effectiveNetProtos[0]
+ // Connect() and bindLocked() both assert
+ //
+ // netProto == header.IPv6ProtocolNumber
+ //
+ // before creating a multi-entry effectiveNetProtos.
+ if len(e.effectiveNetProtos) > 1 {
+ netProto = header.IPv6ProtocolNumber
+ }
+
+ var err *tcpip.Error
+ if e.state == stateConnected {
+ e.route, err = e.stack.FindRoute(e.regNICID, e.bindAddr, e.id.RemoteAddress, netProto)
+ if err != nil {
+ panic(*err)
+ }
+
+ e.id.LocalAddress = e.route.LocalAddress
+ } else if len(e.id.LocalAddress) != 0 { // stateBound
+ if e.stack.CheckLocalAddress(e.regNICID, netProto, e.id.LocalAddress) == 0 {
+ panic(tcpip.ErrBadLocalAddress)
+ }
+ }
+
+ e.id, err = e.registerWithStack(e.regNICID, e.effectiveNetProtos, e.id)
+ if err != nil {
+ panic(*err)
+ }
+}
diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go
new file mode 100644
index 000000000..fa30e7201
--- /dev/null
+++ b/pkg/tcpip/transport/udp/protocol.go
@@ -0,0 +1,73 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package udp contains the implementation of the UDP transport protocol. To use
+// it in the networking stack, this package must be added to the project, and
+// activated on the stack by passing udp.ProtocolName (or "udp") as one of the
+// transport protocols when calling stack.New(). Then endpoints can be created
+// by passing udp.ProtocolNumber as the transport protocol number when calling
+// Stack.NewEndpoint().
+package udp
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ProtocolName is the string representation of the udp protocol name.
+ ProtocolName = "udp"
+
+ // ProtocolNumber is the udp protocol number.
+ ProtocolNumber = header.UDPProtocolNumber
+)
+
+type protocol struct{}
+
+// Number returns the udp protocol number.
+func (*protocol) Number() tcpip.TransportProtocolNumber {
+ return ProtocolNumber
+}
+
+// NewEndpoint creates a new udp endpoint.
+func (*protocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newEndpoint(stack, netProto, waiterQueue), nil
+}
+
+// MinimumPacketSize returns the minimum valid udp packet size.
+func (*protocol) MinimumPacketSize() int {
+ return header.UDPMinimumSize
+}
+
+// ParsePorts returns the source and destination ports stored in the given udp
+// packet.
+func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) {
+ h := header.UDP(v)
+ return h.SourcePort(), h.DestinationPort(), nil
+}
+
+// HandleUnknownDestinationPacket handles packets targeted at this protocol but
+// that don't match any existing endpoint.
+func (p *protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
+ return true
+}
+
+// SetOption implements TransportProtocol.SetOption.
+func (p *protocol) SetOption(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Option implements TransportProtocol.Option.
+func (p *protocol) Option(option interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory(ProtocolName, func() stack.TransportProtocol {
+ return &protocol{}
+ })
+}
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
new file mode 100644
index 000000000..65c567952
--- /dev/null
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -0,0 +1,625 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package udp_test
+
+import (
+ "bytes"
+ "math/rand"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/checker"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/sniffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01"
+ testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02"
+ stackV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + stackAddr
+ testV4MappedAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff" + testAddr
+ V4MappedWildcardAddr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+
+ stackAddr = "\x0a\x00\x00\x01"
+ stackPort = 1234
+ testAddr = "\x0a\x00\x00\x02"
+ testPort = 4096
+
+ // defaultMTU is the MTU, in bytes, used throughout the tests, except
+ // where another value is explicitly used. It is chosen to match the MTU
+ // of loopback interfaces on linux systems.
+ defaultMTU = 65536
+)
+
+type testContext struct {
+ t *testing.T
+ linkEP *channel.Endpoint
+ s *stack.Stack
+
+ ep tcpip.Endpoint
+ wq waiter.Queue
+}
+
+type headers struct {
+ srcPort uint16
+ dstPort uint16
+}
+
+func newDualTestContext(t *testing.T, mtu uint32) *testContext {
+ s := stack.New([]string{ipv4.ProtocolName, ipv6.ProtocolName}, []string{udp.ProtocolName})
+
+ id, linkEP := channel.New(256, mtu, "")
+ if testing.Verbose() {
+ id = sniffer.New(id)
+ }
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv4.ProtocolNumber, stackAddr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, ipv6.ProtocolNumber, stackV6Addr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {
+ Destination: "\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ },
+ {
+ Destination: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ Mask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ Gateway: "",
+ NIC: 1,
+ },
+ })
+
+ return &testContext{
+ t: t,
+ s: s,
+ linkEP: linkEP,
+ }
+}
+
+func (c *testContext) cleanup() {
+ if c.ep != nil {
+ c.ep.Close()
+ }
+}
+
+func (c *testContext) createV6Endpoint(v4only bool) {
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ var v tcpip.V6OnlyOption
+ if v4only {
+ v = 1
+ }
+ if err := c.ep.SetSockOpt(v); err != nil {
+ c.t.Fatalf("SetSockOpt failed failed: %v", err)
+ }
+}
+
+func (c *testContext) getV6Packet() []byte {
+ select {
+ case p := <-c.linkEP.C:
+ if p.Proto != ipv6.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv6.ProtocolNumber)
+ }
+ b := make([]byte, len(p.Header)+len(p.Payload))
+ copy(b, p.Header)
+ copy(b[len(p.Header):], p.Payload)
+
+ checker.IPv6(c.t, b, checker.SrcAddr(stackV6Addr), checker.DstAddr(testV6Addr))
+ return b
+
+ case <-time.After(2 * time.Second):
+ c.t.Fatalf("Packet wasn't written out")
+ }
+
+ return nil
+}
+
+func (c *testContext) getPacket() []byte {
+ select {
+ case p := <-c.linkEP.C:
+ if p.Proto != ipv4.ProtocolNumber {
+ c.t.Fatalf("Bad network protocol: got %v, wanted %v", p.Proto, ipv4.ProtocolNumber)
+ }
+ b := make([]byte, len(p.Header)+len(p.Payload))
+ copy(b, p.Header)
+ copy(b[len(p.Header):], p.Payload)
+
+ checker.IPv4(c.t, b, checker.SrcAddr(stackAddr), checker.DstAddr(testAddr))
+ return b
+
+ case <-time.After(2 * time.Second):
+ c.t.Fatalf("Packet wasn't written out")
+ }
+
+ return nil
+}
+
+func (c *testContext) sendV6Packet(payload []byte, h *headers) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv6(buf)
+ ip.Encode(&header.IPv6Fields{
+ PayloadLength: uint16(header.UDPMinimumSize + len(payload)),
+ NextHeader: uint8(udp.ProtocolNumber),
+ HopLimit: 65,
+ SrcAddr: testV6Addr,
+ DstAddr: stackV6Addr,
+ })
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv6MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.Checksum([]byte(testV6Addr), 0)
+ xsum = header.Checksum([]byte(stackV6Addr), xsum)
+ xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum)
+
+ // Calculate the UDP checksum and set it.
+ length := uint16(header.UDPMinimumSize + len(payload))
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum, length))
+
+ // Inject packet.
+ var views [1]buffer.View
+ vv := buf.ToVectorisedView(views)
+ c.linkEP.Inject(ipv6.ProtocolNumber, &vv)
+}
+
+func (c *testContext) sendPacket(payload []byte, h *headers) {
+ // Allocate a buffer for data and headers.
+ buf := buffer.NewView(header.UDPMinimumSize + header.IPv4MinimumSize + len(payload))
+ copy(buf[len(buf)-len(payload):], payload)
+
+ // Initialize the IP header.
+ ip := header.IPv4(buf)
+ ip.Encode(&header.IPv4Fields{
+ IHL: header.IPv4MinimumSize,
+ TotalLength: uint16(len(buf)),
+ TTL: 65,
+ Protocol: uint8(udp.ProtocolNumber),
+ SrcAddr: testAddr,
+ DstAddr: stackAddr,
+ })
+ ip.SetChecksum(^ip.CalculateChecksum())
+
+ // Initialize the UDP header.
+ u := header.UDP(buf[header.IPv4MinimumSize:])
+ u.Encode(&header.UDPFields{
+ SrcPort: h.srcPort,
+ DstPort: h.dstPort,
+ Length: uint16(header.UDPMinimumSize + len(payload)),
+ })
+
+ // Calculate the UDP pseudo-header checksum.
+ xsum := header.Checksum([]byte(testAddr), 0)
+ xsum = header.Checksum([]byte(stackAddr), xsum)
+ xsum = header.Checksum([]byte{0, uint8(udp.ProtocolNumber)}, xsum)
+
+ // Calculate the UDP checksum and set it.
+ length := uint16(header.UDPMinimumSize + len(payload))
+ xsum = header.Checksum(payload, xsum)
+ u.SetChecksum(^u.CalculateChecksum(xsum, length))
+
+ // Inject packet.
+ var views [1]buffer.View
+ vv := buf.ToVectorisedView(views)
+ c.linkEP.Inject(ipv4.ProtocolNumber, &vv)
+}
+
+func newPayload() []byte {
+ b := make([]byte, 30+rand.Intn(100))
+ for i := range b {
+ b[i] = byte(rand.Intn(256))
+ }
+ return b
+}
+
+func testV4Read(c *testContext) {
+ // Send a packet.
+ payload := newPayload()
+ c.sendPacket(payload, &headers{
+ srcPort: testPort,
+ dstPort: stackPort,
+ })
+
+ // Try to receive the data.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.wq.EventRegister(&we, waiter.EventIn)
+ defer c.wq.EventUnregister(&we)
+
+ var addr tcpip.FullAddress
+ v, err := c.ep.Read(&addr)
+ if err == tcpip.ErrWouldBlock {
+ // Wait for data to become available.
+ select {
+ case <-ch:
+ v, err = c.ep.Read(&addr)
+ if err != nil {
+ c.t.Fatalf("Read failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ c.t.Fatalf("Timed out waiting for data")
+ }
+ }
+
+ // Check the peer address.
+ if addr.Addr != testAddr {
+ c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
+ }
+
+ // Check the payload.
+ if !bytes.Equal(payload, v) {
+ c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
+ }
+}
+
+func TestV4ReadOnV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Read(c)
+}
+
+func TestV4ReadOnBoundToV4MappedWildcard(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Bind to v4 mapped wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: V4MappedWildcardAddr, Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Read(c)
+}
+
+func TestV4ReadOnBoundToV4Mapped(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Bind to local address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Read(c)
+}
+
+func TestV6ReadOnV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Send a packet.
+ payload := newPayload()
+ c.sendV6Packet(payload, &headers{
+ srcPort: testPort,
+ dstPort: stackPort,
+ })
+
+ // Try to receive the data.
+ we, ch := waiter.NewChannelEntry(nil)
+ c.wq.EventRegister(&we, waiter.EventIn)
+ defer c.wq.EventUnregister(&we)
+
+ var addr tcpip.FullAddress
+ v, err := c.ep.Read(&addr)
+ if err == tcpip.ErrWouldBlock {
+ // Wait for data to become available.
+ select {
+ case <-ch:
+ v, err = c.ep.Read(&addr)
+ if err != nil {
+ c.t.Fatalf("Read failed: %v", err)
+ }
+
+ case <-time.After(1 * time.Second):
+ c.t.Fatalf("Timed out waiting for data")
+ }
+ }
+
+ // Check the peer address.
+ if addr.Addr != testV6Addr {
+ c.t.Fatalf("Unexpected remote address: got %v, want %v", addr.Addr, testAddr)
+ }
+
+ // Check the payload.
+ if !bytes.Equal(payload, v) {
+ c.t.Fatalf("Bad payload: got %x, want %x", v, payload)
+ }
+}
+
+func TestV4ReadOnV4(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ // Create v4 UDP endpoint.
+ var err *tcpip.Error
+ c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &c.wq)
+ if err != nil {
+ c.t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Test acceptance.
+ testV4Read(c)
+}
+
+func testDualWrite(c *testContext) uint16 {
+ // Write to V4 mapped address.
+ payload := buffer.View(newPayload())
+ n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
+ })
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+ }
+
+ // Check that we received the packet.
+ b := c.getPacket()
+ udp := header.UDP(header.IPv4(b).Payload())
+ checker.IPv4(c.t, b,
+ checker.UDP(
+ checker.DstPort(testPort),
+ ),
+ )
+
+ port := udp.SourcePort()
+
+ // Check the payload.
+ if !bytes.Equal(payload, udp.Payload()) {
+ c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ }
+
+ // Write to v6 address.
+ payload = buffer.View(newPayload())
+ n, err = c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ })
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+ }
+
+ // Check that we received the packet, and that the source port is the
+ // same as the one used in ipv4.
+ b = c.getV6Packet()
+ udp = header.UDP(header.IPv6(b).Payload())
+ checker.IPv6(c.t, b,
+ checker.UDP(
+ checker.DstPort(testPort),
+ checker.SrcPort(port),
+ ),
+ )
+
+ // Check the payload.
+ if !bytes.Equal(payload, udp.Payload()) {
+ c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ }
+
+ return port
+}
+
+func TestDualWriteUnbound(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ testDualWrite(c)
+}
+
+func TestDualWriteBoundToWildcard(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Bind to wildcard.
+ if err := c.ep.Bind(tcpip.FullAddress{Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ p := testDualWrite(c)
+ if p != stackPort {
+ c.t.Fatalf("Bad port: got %v, want %v", p, stackPort)
+ }
+}
+
+func TestDualWriteConnectedToV6(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Connect to v6 address.
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ testDualWrite(c)
+}
+
+func TestDualWriteConnectedToV4Mapped(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Connect to v4 mapped address.
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ testDualWrite(c)
+}
+
+func TestV4WriteOnV6Only(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(true)
+
+ // Write to V4 mapped address.
+ payload := buffer.View(newPayload())
+ _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort},
+ })
+ if err != tcpip.ErrNoRoute {
+ c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute)
+ }
+}
+
+func TestV6WriteOnBoundToV4Mapped(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Bind to v4 mapped address.
+ if err := c.ep.Bind(tcpip.FullAddress{Addr: stackV4MappedAddr, Port: stackPort}, nil); err != nil {
+ c.t.Fatalf("Bind failed: %v", err)
+ }
+
+ // Write to v6 address.
+ payload := buffer.View(newPayload())
+ _, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{
+ To: &tcpip.FullAddress{Addr: testV6Addr, Port: testPort},
+ })
+ if err != tcpip.ErrNoRoute {
+ c.t.Fatalf("Write returned unexpected error: got %v, want %v", err, tcpip.ErrNoRoute)
+ }
+}
+
+func TestV6WriteOnConnected(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Connect to v6 address.
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV6Addr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+
+ // Write without destination.
+ payload := buffer.View(newPayload())
+ n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+ }
+
+ // Check that we received the packet.
+ b := c.getV6Packet()
+ udp := header.UDP(header.IPv6(b).Payload())
+ checker.IPv6(c.t, b,
+ checker.UDP(
+ checker.DstPort(testPort),
+ ),
+ )
+
+ // Check the payload.
+ if !bytes.Equal(payload, udp.Payload()) {
+ c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ }
+}
+
+func TestV4WriteOnConnected(t *testing.T) {
+ c := newDualTestContext(t, defaultMTU)
+ defer c.cleanup()
+
+ c.createV6Endpoint(false)
+
+ // Connect to v4 mapped address.
+ if err := c.ep.Connect(tcpip.FullAddress{Addr: testV4MappedAddr, Port: testPort}); err != nil {
+ c.t.Fatalf("Connect failed: %v", err)
+ }
+
+ // Write without destination.
+ payload := buffer.View(newPayload())
+ n, err := c.ep.Write(tcpip.SlicePayload(payload), tcpip.WriteOptions{})
+ if err != nil {
+ c.t.Fatalf("Write failed: %v", err)
+ }
+ if n != uintptr(len(payload)) {
+ c.t.Fatalf("Bad number of bytes written: got %v, want %v", n, len(payload))
+ }
+
+ // Check that we received the packet.
+ b := c.getPacket()
+ udp := header.UDP(header.IPv4(b).Payload())
+ checker.IPv4(c.t, b,
+ checker.UDP(
+ checker.DstPort(testPort),
+ ),
+ )
+
+ // Check the payload.
+ if !bytes.Equal(payload, udp.Payload()) {
+ c.t.Fatalf("Bad payload: got %x, want %x", udp.Payload(), payload)
+ }
+}
diff --git a/pkg/tcpip/transport/unix/BUILD b/pkg/tcpip/transport/unix/BUILD
new file mode 100644
index 000000000..47bc7a649
--- /dev/null
+++ b/pkg/tcpip/transport/unix/BUILD
@@ -0,0 +1,37 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "unix_state",
+ srcs = [
+ "connectioned.go",
+ "connectionless.go",
+ "unix.go",
+ ],
+ out = "unix_state.go",
+ package = "unix",
+)
+
+go_library(
+ name = "unix",
+ srcs = [
+ "connectioned.go",
+ "connectioned_state.go",
+ "connectionless.go",
+ "unix.go",
+ "unix_state.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix",
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/ilist",
+ "//pkg/log",
+ "//pkg/state",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/transport/queue",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/tcpip/transport/unix/connectioned.go b/pkg/tcpip/transport/unix/connectioned.go
new file mode 100644
index 000000000..def1b2c99
--- /dev/null
+++ b/pkg/tcpip/transport/unix/connectioned.go
@@ -0,0 +1,431 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unix
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/queue"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// UniqueIDProvider generates a sequence of unique identifiers useful for,
+// among other things, lock ordering.
+type UniqueIDProvider interface {
+ // UniqueID returns a new unique identifier.
+ UniqueID() uint64
+}
+
+// A ConnectingEndpoint is a connectioned unix endpoint that is attempting to
+// establish a bidirectional connection with a BoundEndpoint.
+type ConnectingEndpoint interface {
+ // ID returns the endpoint's globally unique identifier. This identifier
+ // must be used to determine locking order if more than one endpoint is
+ // to be locked in the same codepath. The endpoint with the smaller
+ // identifier must be locked before endpoints with larger identifiers.
+ ID() uint64
+
+ // Passcred implements socket.Credentialer.Passcred.
+ Passcred() bool
+
+ // Type returns the socket type, typically either SockStream or
+ // SockSeqpacket. The connection attempt must be aborted if this
+ // value doesn't match the ConnectableEndpoint's type.
+ Type() SockType
+
+ // GetLocalAddress returns the bound path.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Locker protects the following methods. While locked, only the holder of
+ // the lock can change the return value of the protected methods.
+ sync.Locker
+
+ // Connected returns true iff the ConnectingEndpoint is in the connected
+ // state. ConnectingEndpoints can only be connected to a single endpoint,
+ // so the connection attempt must be aborted if this returns true.
+ Connected() bool
+
+ // Listening returns true iff the ConnectingEndpoint is in the listening
+ // state. ConnectingEndpoints cannot make connections while listening, so
+ // the connection attempt must be aborted if this returns true.
+ Listening() bool
+
+ // WaiterQueue returns a pointer to the endpoint's waiter queue.
+ WaiterQueue() *waiter.Queue
+}
+
+// connectionedEndpoint is a Unix-domain connected or connectable endpoint and implements
+// ConnectingEndpoint, ConnectableEndpoint and tcpip.Endpoint.
+//
+// connectionedEndpoints must be in connected state in order to transfer data.
+//
+// This implementation includes STREAM and SEQPACKET Unix sockets created with
+// socket(2), accept(2) or socketpair(2) and dgram unix sockets created with
+// socketpair(2). See unix_connectionless.go for the implementation of DGRAM
+// Unix sockets created with socket(2).
+//
+// The state is much simpler than a TCP endpoint, so it is not encoded
+// explicitly. Instead we enforce the following invariants:
+//
+// receiver != nil, connected != nil => connected.
+// path != "" && acceptedChan == nil => bound, not listening.
+// path != "" && acceptedChan != nil => bound and listening.
+//
+// Only one of these will be true at any moment.
+type connectionedEndpoint struct {
+ baseEndpoint
+
+ // id is the unique endpoint identifier. This is used exclusively for
+ // lock ordering within connect.
+ id uint64
+
+ // idGenerator is used to generate new unique endpoint identifiers.
+ idGenerator UniqueIDProvider
+
+ // stype is used by connecting sockets to ensure that they are the
+ // same type. The value is typically either tcpip.SockSeqpacket or
+ // tcpip.SockStream.
+ stype SockType
+
+ // acceptedChan is per the TCP endpoint implementation. Note that the
+ // sockets in this channel are _already in the connected state_, and
+ // have another associated connectionedEndpoint.
+ //
+ // If nil, then no listen call has been made.
+ acceptedChan chan *connectionedEndpoint `state:".([]*connectionedEndpoint)"`
+}
+
+// NewConnectioned creates a new unbound connectionedEndpoint.
+func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint {
+ return &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+}
+
+// NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
+func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
+ a := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+ b := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+
+ q1 := queue.New(a.Queue, b.Queue, initialLimit)
+ q2 := queue.New(b.Queue, a.Queue, initialLimit)
+
+ if stype == SockStream {
+ a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
+ b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}}
+ } else {
+ a.receiver = &queueReceiver{q1}
+ b.receiver = &queueReceiver{q2}
+ }
+
+ a.connected = &connectedEndpoint{
+ endpoint: b,
+ writeQueue: q2,
+ }
+ b.connected = &connectedEndpoint{
+ endpoint: a,
+ writeQueue: q1,
+ }
+
+ return a, b
+}
+
+// ID implements ConnectingEndpoint.ID.
+func (e *connectionedEndpoint) ID() uint64 {
+ return e.id
+}
+
+// Type implements ConnectingEndpoint.Type and Endpoint.Type.
+func (e *connectionedEndpoint) Type() SockType {
+ return e.stype
+}
+
+// WaiterQueue implements ConnectingEndpoint.WaiterQueue.
+func (e *connectionedEndpoint) WaiterQueue() *waiter.Queue {
+ return e.Queue
+}
+
+// isBound returns true iff the connectionedEndpoint is bound (but not
+// listening).
+func (e *connectionedEndpoint) isBound() bool {
+ return e.path != "" && e.acceptedChan == nil
+}
+
+// Listening implements ConnectingEndpoint.Listening.
+func (e *connectionedEndpoint) Listening() bool {
+ return e.acceptedChan != nil
+}
+
+// Close puts the connectionedEndpoint in a closed state and frees all
+// resources associated with it.
+//
+// The socket will be a fresh state after a call to close and may be reused.
+// That is, close may be used to "unbind" or "disconnect" the socket in error
+// paths.
+func (e *connectionedEndpoint) Close() {
+ e.Lock()
+ var c ConnectedEndpoint
+ var r Receiver
+ switch {
+ case e.Connected():
+ e.connected.CloseSend()
+ e.receiver.CloseRecv()
+ c = e.connected
+ r = e.receiver
+ e.connected = nil
+ e.receiver = nil
+ case e.isBound():
+ e.path = ""
+ case e.Listening():
+ close(e.acceptedChan)
+ for n := range e.acceptedChan {
+ n.Close()
+ }
+ e.acceptedChan = nil
+ e.path = ""
+ }
+ e.Unlock()
+ if c != nil {
+ c.CloseNotify()
+ c.Release()
+ }
+ if r != nil {
+ r.CloseNotify()
+ r.Release()
+ }
+}
+
+// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
+func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *tcpip.Error {
+ if ce.Type() != e.stype {
+ return tcpip.ErrConnectionRefused
+ }
+
+ // Check if ce is e to avoid a deadlock.
+ if ce, ok := ce.(*connectionedEndpoint); ok && ce == e {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Do a dance to safely acquire locks on both endpoints.
+ if e.id < ce.ID() {
+ e.Lock()
+ ce.Lock()
+ } else {
+ ce.Lock()
+ e.Lock()
+ }
+
+ // Check connecting state.
+ if ce.Connected() {
+ e.Unlock()
+ ce.Unlock()
+ return tcpip.ErrAlreadyConnected
+ }
+ if ce.Listening() {
+ e.Unlock()
+ ce.Unlock()
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Check bound state.
+ if !e.Listening() {
+ e.Unlock()
+ ce.Unlock()
+ return tcpip.ErrConnectionRefused
+ }
+
+ // Create a newly bound connectionedEndpoint.
+ ne := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{
+ path: e.path,
+ Queue: &waiter.Queue{},
+ },
+ id: e.idGenerator.UniqueID(),
+ idGenerator: e.idGenerator,
+ stype: e.stype,
+ }
+ readQueue := queue.New(ce.WaiterQueue(), ne.Queue, initialLimit)
+ writeQueue := queue.New(ne.Queue, ce.WaiterQueue(), initialLimit)
+ ne.connected = &connectedEndpoint{
+ endpoint: ce,
+ writeQueue: readQueue,
+ }
+ if e.stype == SockStream {
+ ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
+ } else {
+ ne.receiver = &queueReceiver{readQueue: writeQueue}
+ }
+
+ select {
+ case e.acceptedChan <- ne:
+ // Commit state.
+ connected := &connectedEndpoint{
+ endpoint: ne,
+ writeQueue: writeQueue,
+ }
+ if e.stype == SockStream {
+ returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected)
+ } else {
+ returnConnect(&queueReceiver{readQueue: readQueue}, connected)
+ }
+
+ // Notify can deadlock if we are holding these locks.
+ e.Unlock()
+ ce.Unlock()
+
+ // Notify on both ends.
+ e.Notify(waiter.EventIn)
+ ce.WaiterQueue().Notify(waiter.EventOut)
+
+ return nil
+ default:
+ // Busy; return ECONNREFUSED per spec.
+ ne.Close()
+ e.Unlock()
+ ce.Unlock()
+ return tcpip.ErrConnectionRefused
+ }
+}
+
+// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
+func (e *connectionedEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *tcpip.Error) {
+ return nil, tcpip.ErrConnectionRefused
+}
+
+// Connect attempts to directly connect to another Endpoint.
+// Implements Endpoint.Connect.
+func (e *connectionedEndpoint) Connect(server BoundEndpoint) *tcpip.Error {
+ returnConnect := func(r Receiver, ce ConnectedEndpoint) {
+ e.receiver = r
+ e.connected = ce
+ }
+
+ return server.BidirectionalConnect(e, returnConnect)
+}
+
+// Listen starts listening on the connection.
+func (e *connectionedEndpoint) Listen(backlog int) *tcpip.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.Listening() {
+ // Adjust the size of the channel iff we can fix existing
+ // pending connections into the new one.
+ if len(e.acceptedChan) > backlog {
+ return tcpip.ErrInvalidEndpointState
+ }
+ origChan := e.acceptedChan
+ e.acceptedChan = make(chan *connectionedEndpoint, backlog)
+ close(origChan)
+ for ep := range origChan {
+ e.acceptedChan <- ep
+ }
+ return nil
+ }
+ if !e.isBound() {
+ return tcpip.ErrInvalidEndpointState
+ }
+
+ // Normal case.
+ e.acceptedChan = make(chan *connectionedEndpoint, backlog)
+ return nil
+}
+
+// Accept accepts a new connection.
+func (e *connectionedEndpoint) Accept() (Endpoint, *tcpip.Error) {
+ e.Lock()
+ defer e.Unlock()
+
+ if !e.Listening() {
+ return nil, tcpip.ErrInvalidEndpointState
+ }
+
+ select {
+ case ne := <-e.acceptedChan:
+ return ne, nil
+
+ default:
+ // Nothing left.
+ return nil, tcpip.ErrWouldBlock
+ }
+}
+
+// Bind binds the connection.
+//
+// For Unix connectionedEndpoints, this _only sets the address associated with
+// the socket_. Work associated with sockets in the filesystem or finding those
+// sockets must be done by a higher level.
+//
+// Bind will fail only if the socket is connected, bound or the passed address
+// is invalid (the empty string).
+func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.isBound() || e.Listening() {
+ return tcpip.ErrAlreadyBound
+ }
+ if addr.Addr == "" {
+ // The empty string is not permitted.
+ return tcpip.ErrBadLocalAddress
+ }
+ if commit != nil {
+ if err := commit(); err != nil {
+ return err
+ }
+ }
+
+ // Save the bound address.
+ e.path = string(addr.Addr)
+ return nil
+}
+
+// SendMsg writes data and a control message to the endpoint's peer.
+// This method does not block if the data cannot be written.
+func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *tcpip.Error) {
+ // Stream sockets do not support specifying the endpoint. Seqpacket
+ // sockets ignore the passed endpoint.
+ if e.stype == SockStream && to != nil {
+ return 0, tcpip.ErrNotSupported
+ }
+ return e.baseEndpoint.SendMsg(data, c, to)
+}
+
+// Readiness returns the current readiness of the connectionedEndpoint. For
+// example, if waiter.EventIn is set, the connectionedEndpoint is immediately
+// readable.
+func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ e.Lock()
+ defer e.Unlock()
+
+ ready := waiter.EventMask(0)
+ switch {
+ case e.Connected():
+ if mask&waiter.EventIn != 0 && e.receiver.Readable() {
+ ready |= waiter.EventIn
+ }
+ if mask&waiter.EventOut != 0 && e.connected.Writable() {
+ ready |= waiter.EventOut
+ }
+ case e.Listening():
+ if mask&waiter.EventIn != 0 && len(e.acceptedChan) > 0 {
+ ready |= waiter.EventIn
+ }
+ }
+
+ return ready
+}
diff --git a/pkg/tcpip/transport/unix/connectioned_state.go b/pkg/tcpip/transport/unix/connectioned_state.go
new file mode 100644
index 000000000..5d835c8b2
--- /dev/null
+++ b/pkg/tcpip/transport/unix/connectioned_state.go
@@ -0,0 +1,43 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unix
+
+// saveAcceptedChan is invoked by stateify.
+func (e *connectionedEndpoint) saveAcceptedChan() []*connectionedEndpoint {
+ // If acceptedChan is nil (i.e. we are not listening) then we will save nil.
+ // Otherwise we create a (possibly empty) slice of the values in acceptedChan and
+ // save that.
+ var acceptedSlice []*connectionedEndpoint
+ if e.acceptedChan != nil {
+ // Swap out acceptedChan with a new empty channel of the same capacity.
+ saveChan := e.acceptedChan
+ e.acceptedChan = make(chan *connectionedEndpoint, cap(saveChan))
+
+ // Create a new slice with the same len and capacity as the channel.
+ acceptedSlice = make([]*connectionedEndpoint, len(saveChan), cap(saveChan))
+ // Drain acceptedChan into saveSlice, and fill up the new acceptChan at the
+ // same time.
+ for i := range acceptedSlice {
+ ep := <-saveChan
+ acceptedSlice[i] = ep
+ e.acceptedChan <- ep
+ }
+ close(saveChan)
+ }
+ return acceptedSlice
+}
+
+// loadAcceptedChan is invoked by stateify.
+func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEndpoint) {
+ // If acceptedSlice is nil, then acceptedChan should also be nil.
+ if acceptedSlice != nil {
+ // Otherwise, create a new channel with the same capacity as acceptedSlice.
+ e.acceptedChan = make(chan *connectionedEndpoint, cap(acceptedSlice))
+ // Seed the channel with values from acceptedSlice.
+ for _, ep := range acceptedSlice {
+ e.acceptedChan <- ep
+ }
+ }
+}
diff --git a/pkg/tcpip/transport/unix/connectionless.go b/pkg/tcpip/transport/unix/connectionless.go
new file mode 100644
index 000000000..34d34f99a
--- /dev/null
+++ b/pkg/tcpip/transport/unix/connectionless.go
@@ -0,0 +1,176 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package unix
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/queue"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// connectionlessEndpoint is a unix endpoint for unix sockets that support operating in
+// a conectionless fashon.
+//
+// Specifically, this means datagram unix sockets not created with
+// socketpair(2).
+type connectionlessEndpoint struct {
+ baseEndpoint
+}
+
+// NewConnectionless creates a new unbound dgram endpoint.
+func NewConnectionless() Endpoint {
+ ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}}
+ ep.receiver = &queueReceiver{readQueue: queue.New(&waiter.Queue{}, ep.Queue, initialLimit)}
+ return ep
+}
+
+// isBound returns true iff the endpoint is bound.
+func (e *connectionlessEndpoint) isBound() bool {
+ return e.path != ""
+}
+
+// Close puts the endpoint in a closed state and frees all resources associated
+// with it.
+//
+// The socket will be a fresh state after a call to close and may be reused.
+// That is, close may be used to "unbind" or "disconnect" the socket in error
+// paths.
+func (e *connectionlessEndpoint) Close() {
+ e.Lock()
+ var r Receiver
+ if e.Connected() {
+ e.receiver.CloseRecv()
+ r = e.receiver
+ e.receiver = nil
+
+ e.connected.Release()
+ e.connected = nil
+ }
+ if e.isBound() {
+ e.path = ""
+ }
+ e.Unlock()
+ if r != nil {
+ r.CloseNotify()
+ r.Release()
+ }
+}
+
+// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
+func (e *connectionlessEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *tcpip.Error {
+ return tcpip.ErrConnectionRefused
+}
+
+// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
+func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *tcpip.Error) {
+ return &connectedEndpoint{
+ endpoint: e,
+ writeQueue: e.receiver.(*queueReceiver).readQueue,
+ }, nil
+}
+
+// SendMsg writes data and a control message to the specified endpoint.
+// This method does not block if the data cannot be written.
+func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *tcpip.Error) {
+ if to == nil {
+ return e.baseEndpoint.SendMsg(data, c, nil)
+ }
+
+ connected, err := to.UnidirectionalConnect()
+ if err != nil {
+ return 0, tcpip.ErrInvalidEndpointState
+ }
+ defer connected.Release()
+
+ e.Lock()
+ n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ e.Unlock()
+ if err != nil {
+ return 0, err
+ }
+ if notify {
+ connected.SendNotify()
+ }
+
+ return n, nil
+}
+
+// Type implements Endpoint.Type.
+func (e *connectionlessEndpoint) Type() SockType {
+ return SockDgram
+}
+
+// Connect attempts to connect directly to server.
+func (e *connectionlessEndpoint) Connect(server BoundEndpoint) *tcpip.Error {
+ connected, err := server.UnidirectionalConnect()
+ if err != nil {
+ return err
+ }
+
+ e.Lock()
+ e.connected = connected
+ e.Unlock()
+
+ return nil
+}
+
+// Listen starts listening on the connection.
+func (e *connectionlessEndpoint) Listen(int) *tcpip.Error {
+ return tcpip.ErrNotSupported
+}
+
+// Accept accepts a new connection.
+func (e *connectionlessEndpoint) Accept() (Endpoint, *tcpip.Error) {
+ return nil, tcpip.ErrNotSupported
+}
+
+// Bind binds the connection.
+//
+// For Unix endpoints, this _only sets the address associated with the socket_.
+// Work associated with sockets in the filesystem or finding those sockets must
+// be done by a higher level.
+//
+// Bind will fail only if the socket is connected, bound or the passed address
+// is invalid (the empty string).
+func (e *connectionlessEndpoint) Bind(addr tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.isBound() {
+ return tcpip.ErrAlreadyBound
+ }
+ if addr.Addr == "" {
+ // The empty string is not permitted.
+ return tcpip.ErrBadLocalAddress
+ }
+ if commit != nil {
+ if err := commit(); err != nil {
+ return err
+ }
+ }
+
+ // Save the bound address.
+ e.path = string(addr.Addr)
+ return nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *connectionlessEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ e.Lock()
+ defer e.Unlock()
+
+ ready := waiter.EventMask(0)
+ if mask&waiter.EventIn != 0 && e.receiver.Readable() {
+ ready |= waiter.EventIn
+ }
+
+ if e.Connected() {
+ if mask&waiter.EventOut != 0 && e.connected.Writable() {
+ ready |= waiter.EventOut
+ }
+ }
+
+ return ready
+}
diff --git a/pkg/tcpip/transport/unix/unix.go b/pkg/tcpip/transport/unix/unix.go
new file mode 100644
index 000000000..5fe37eb71
--- /dev/null
+++ b/pkg/tcpip/transport/unix/unix.go
@@ -0,0 +1,902 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package unix contains the implementation of Unix endpoints.
+package unix
+
+import (
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/ilist"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/queue"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// initialLimit is the starting limit for the socket buffers.
+const initialLimit = 16 * 1024
+
+// A SockType is a type (as opposed to family) of sockets. These are enumerated
+// in the syscall package as syscall.SOCK_* constants.
+type SockType int
+
+const (
+ // SockStream corresponds to syscall.SOCK_STREAM.
+ SockStream SockType = 1
+ // SockDgram corresponds to syscall.SOCK_DGRAM.
+ SockDgram SockType = 2
+ // SockRaw corresponds to syscall.SOCK_RAW.
+ SockRaw SockType = 3
+ // SockSeqpacket corresponds to syscall.SOCK_SEQPACKET.
+ SockSeqpacket SockType = 5
+)
+
+// A RightsControlMessage is a control message containing FDs.
+type RightsControlMessage interface {
+ // Clone returns a copy of the RightsControlMessage.
+ Clone() RightsControlMessage
+
+ // Release releases any resources owned by the RightsControlMessage.
+ Release()
+}
+
+// A CredentialsControlMessage is a control message containing Unix credentials.
+type CredentialsControlMessage interface {
+ // Equals returns true iff the two messages are equal.
+ Equals(CredentialsControlMessage) bool
+}
+
+// A ControlMessages represents a collection of socket control messages.
+type ControlMessages struct {
+ // Rights is a control message containing FDs.
+ Rights RightsControlMessage
+
+ // Credentials is a control message containing Unix credentials.
+ Credentials CredentialsControlMessage
+}
+
+// Empty returns true iff the ControlMessages does not contain either
+// credentials or rights.
+func (c *ControlMessages) Empty() bool {
+ return c.Rights == nil && c.Credentials == nil
+}
+
+// Clone clones both the credentials and the rights.
+func (c *ControlMessages) Clone() ControlMessages {
+ cm := ControlMessages{}
+ if c.Rights != nil {
+ cm.Rights = c.Rights.Clone()
+ }
+ cm.Credentials = c.Credentials
+ return cm
+}
+
+// Release releases both the credentials and the rights.
+func (c *ControlMessages) Release() {
+ if c.Rights != nil {
+ c.Rights.Release()
+ }
+ *c = ControlMessages{}
+}
+
+// Endpoint is the interface implemented by Unix transport protocol
+// implementations that expose functionality like sendmsg, recvmsg, connect,
+// etc. to Unix socket implementations.
+type Endpoint interface {
+ Credentialer
+ waiter.Waitable
+
+ // Close puts the endpoint in a closed state and frees all resources
+ // associated with it.
+ Close()
+
+ // RecvMsg reads data and a control message from the endpoint. This method
+ // does not block if there is no data pending.
+ //
+ // creds indicates if credential control messages are requested by the
+ // caller. This is useful for determining if control messages can be
+ // coalesced. creds is a hint and can be safely ignored by the
+ // implementation if no coalescing is possible. It is fine to return
+ // credential control messages when none were requested or to not return
+ // credential control messages when they were requested.
+ //
+ // numRights is the number of SCM_RIGHTS FDs requested by the caller. This
+ // is useful if one must allocate a buffer to receive a SCM_RIGHTS message
+ // or determine if control messages can be coalesced. numRights is a hint
+ // and can be safely ignored by the implementation if the number of
+ // available SCM_RIGHTS FDs is known and no coalescing is possible. It is
+ // fine for the returned number of SCM_RIGHTS FDs to be either higher or
+ // lower than the requested number.
+ //
+ // If peek is true, no data should be consumed from the Endpoint. Any and
+ // all data returned from a peek should be available in the next call to
+ // RecvMsg.
+ //
+ // recvLen is the number of bytes copied into data.
+ //
+ // msgLen is the length of the read message consumed for datagram Endpoints.
+ // msgLen is always the same as recvLen for stream Endpoints.
+ RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, err *tcpip.Error)
+
+ // SendMsg writes data and a control message to the endpoint's peer.
+ // This method does not block if the data cannot be written.
+ //
+ // SendMsg does not take ownership of any of its arguments on error.
+ SendMsg([][]byte, ControlMessages, BoundEndpoint) (uintptr, *tcpip.Error)
+
+ // Connect connects this endpoint directly to another.
+ //
+ // This should be called on the client endpoint, and the (bound)
+ // endpoint passed in as a parameter.
+ //
+ // The error codes are the same as Connect.
+ Connect(server BoundEndpoint) *tcpip.Error
+
+ // Shutdown closes the read and/or write end of the endpoint connection
+ // to its peer.
+ Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error
+
+ // Listen puts the endpoint in "listen" mode, which allows it to accept
+ // new connections.
+ Listen(backlog int) *tcpip.Error
+
+ // Accept returns a new endpoint if a peer has established a connection
+ // to an endpoint previously set to listen mode. This method does not
+ // block if no new connections are available.
+ //
+ // The returned Queue is the wait queue for the newly created endpoint.
+ Accept() (Endpoint, *tcpip.Error)
+
+ // Bind binds the endpoint to a specific local address and port.
+ // Specifying a NIC is optional.
+ //
+ // An optional commit function will be executed atomically with respect
+ // to binding the endpoint. If this returns an error, the bind will not
+ // occur and the error will be propagated back to the caller.
+ Bind(address tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error
+
+ // Type return the socket type, typically either SockStream, SockDgram
+ // or SockSeqpacket.
+ Type() SockType
+
+ // GetLocalAddress returns the address to which the endpoint is bound.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // GetRemoteAddress returns the address to which the endpoint is
+ // connected.
+ GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // SetSockOpt sets a socket option. opt should be one of the tcpip.*Option
+ // types.
+ SetSockOpt(opt interface{}) *tcpip.Error
+
+ // GetSockOpt gets a socket option. opt should be a pointer to one of the
+ // tcpip.*Option types.
+ GetSockOpt(opt interface{}) *tcpip.Error
+}
+
+// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket
+// option.
+type Credentialer interface {
+ // Passcred returns whether or not the SO_PASSCRED socket option is
+ // enabled on this end.
+ Passcred() bool
+
+ // ConnectedPasscred returns whether or not the SO_PASSCRED socket option
+ // is enabled on the connected end.
+ ConnectedPasscred() bool
+}
+
+// A BoundEndpoint is a unix endpoint that can be connected to.
+type BoundEndpoint interface {
+ // BidirectionalConnect establishes a bi-directional connection between two
+ // unix endpoints in an all-or-nothing manner. If an error occurs during
+ // connecting, the state of neither endpoint should be modified.
+ //
+ // In order for an endpoint to establish such a bidirectional connection
+ // with a BoundEndpoint, the endpoint calls the BidirectionalConnect method
+ // on the BoundEndpoint and sends a representation of itself (the
+ // ConnectingEndpoint) and a callback (returnConnect) to receive the
+ // connection information (Receiver and ConnectedEndpoint) upon a
+ // successful connect. The callback should only be called on a successful
+ // connect.
+ //
+ // For a connection attempt to be successful, the ConnectingEndpoint must
+ // be unconnected and not listening and the BoundEndpoint whose
+ // BidirectionalConnect method is being called must be listening.
+ //
+ // This method will return tcpip.ErrConnectionRefused on endpoints with a
+ // type that isn't SockStream or SockSeqpacket.
+ BidirectionalConnect(ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *tcpip.Error
+
+ // UnidirectionalConnect establishes a write-only connection to a unix endpoint.
+ //
+ // This method will return tcpip.ErrConnectionRefused on a non-SockDgram
+ // endpoint.
+ UnidirectionalConnect() (ConnectedEndpoint, *tcpip.Error)
+
+ // Release releases any resources held by the BoundEndpoint. It must be
+ // called before dropping all references to a BoundEndpoint returned by a
+ // function.
+ Release()
+}
+
+// message represents a message passed over a Unix domain socket.
+type message struct {
+ ilist.Entry
+
+ // Data is the Message payload.
+ Data buffer.View
+
+ // Control is auxiliary control message data that goes along with the
+ // data.
+ Control ControlMessages
+
+ // Address is the bound address of the endpoint that sent the message.
+ //
+ // If the endpoint that sent the message is not bound, the Address is
+ // the empty string.
+ Address tcpip.FullAddress
+}
+
+// Length returns number of bytes stored in the Message.
+func (m *message) Length() int64 {
+ return int64(len(m.Data))
+}
+
+// Release releases any resources held by the Message.
+func (m *message) Release() {
+ m.Control.Release()
+}
+
+func (m *message) Peek() queue.Entry {
+ return &message{Data: m.Data, Control: m.Control.Clone(), Address: m.Address}
+}
+
+// A Receiver can be used to receive Messages.
+type Receiver interface {
+ // Recv receives a single message. This method does not block.
+ //
+ // See Endpoint.RecvMsg for documentation on shared arguments.
+ //
+ // notify indicates if RecvNotify should be called.
+ Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (recvLen, msgLen uintptr, cm ControlMessages, source tcpip.FullAddress, notify bool, err *tcpip.Error)
+
+ // RecvNotify notifies the Receiver of a successful Recv. This must not be
+ // called while holding any endpoint locks.
+ RecvNotify()
+
+ // CloseRecv prevents the receiving of additional Messages.
+ //
+ // After CloseRecv is called, CloseNotify must also be called.
+ CloseRecv()
+
+ // CloseNotify notifies the Receiver of recv being closed. This must not be
+ // called while holding any endpoint locks.
+ CloseNotify()
+
+ // Readable returns if messages should be attempted to be received. This
+ // includes when read has been shutdown.
+ Readable() bool
+
+ // RecvQueuedSize returns the total amount of data currently receivable.
+ // RecvQueuedSize should return -1 if the operation isn't supported.
+ RecvQueuedSize() int64
+
+ // RecvMaxQueueSize returns maximum value for RecvQueuedSize.
+ // RecvMaxQueueSize should return -1 if the operation isn't supported.
+ RecvMaxQueueSize() int64
+
+ // Release releases any resources owned by the Receiver. It should be
+ // called before droping all references to a Receiver.
+ Release()
+}
+
+// queueReceiver implements Receiver for datagram sockets.
+type queueReceiver struct {
+ readQueue *queue.Queue
+}
+
+// Recv implements Receiver.Recv.
+func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) {
+ var m queue.Entry
+ var notify bool
+ var err *tcpip.Error
+ if peek {
+ m, err = q.readQueue.Peek()
+ } else {
+ m, notify, err = q.readQueue.Dequeue()
+ }
+ if err != nil {
+ return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err
+ }
+ msg := m.(*message)
+ src := []byte(msg.Data)
+ var copied uintptr
+ for i := 0; i < len(data) && len(src) > 0; i++ {
+ n := copy(data[i], src)
+ copied += uintptr(n)
+ src = src[n:]
+ }
+ return copied, uintptr(len(msg.Data)), msg.Control, msg.Address, notify, nil
+}
+
+// RecvNotify implements Receiver.RecvNotify.
+func (q *queueReceiver) RecvNotify() {
+ q.readQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseNotify implements Receiver.CloseNotify.
+func (q *queueReceiver) CloseNotify() {
+ q.readQueue.ReaderQueue.Notify(waiter.EventIn)
+ q.readQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseRecv implements Receiver.CloseRecv.
+func (q *queueReceiver) CloseRecv() {
+ q.readQueue.Close()
+}
+
+// Readable implements Receiver.Readable.
+func (q *queueReceiver) Readable() bool {
+ return q.readQueue.IsReadable()
+}
+
+// RecvQueuedSize implements Receiver.RecvQueuedSize.
+func (q *queueReceiver) RecvQueuedSize() int64 {
+ return q.readQueue.QueuedSize()
+}
+
+// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize.
+func (q *queueReceiver) RecvMaxQueueSize() int64 {
+ return q.readQueue.MaxQueueSize()
+}
+
+// Release implements Receiver.Release.
+func (*queueReceiver) Release() {}
+
+// streamQueueReceiver implements Receiver for stream sockets.
+type streamQueueReceiver struct {
+ queueReceiver
+
+ mu sync.Mutex `state:"nosave"`
+ buffer []byte
+ control ControlMessages
+ addr tcpip.FullAddress
+}
+
+func vecCopy(data [][]byte, buf []byte) (uintptr, [][]byte, []byte) {
+ var copied uintptr
+ for len(data) > 0 && len(buf) > 0 {
+ n := copy(data[0], buf)
+ copied += uintptr(n)
+ buf = buf[n:]
+ data[0] = data[0][n:]
+ if len(data[0]) == 0 {
+ data = data[1:]
+ }
+ }
+ return copied, data, buf
+}
+
+// Readable implements Receiver.Readable.
+func (q *streamQueueReceiver) Readable() bool {
+ // We're readable if we have data in our buffer or if the queue receiver is
+ // readable.
+ return len(q.buffer) > 0 || q.readQueue.IsReadable()
+}
+
+// RecvQueuedSize implements Receiver.RecvQueuedSize.
+func (q *streamQueueReceiver) RecvQueuedSize() int64 {
+ return int64(len(q.buffer)) + q.readQueue.QueuedSize()
+}
+
+// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize.
+func (q *streamQueueReceiver) RecvMaxQueueSize() int64 {
+ // The RecvMaxQueueSize() is the readQueue's MaxQueueSize() plus the largest
+ // message we can buffer which is also the largest message we can receive.
+ return 2 * q.readQueue.MaxQueueSize()
+}
+
+// Recv implements Receiver.Recv.
+func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights uintptr, peek bool) (uintptr, uintptr, ControlMessages, tcpip.FullAddress, bool, *tcpip.Error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ var notify bool
+
+ // If we have no data in the endpoint, we need to get some.
+ if len(q.buffer) == 0 {
+ // Load the next message into a buffer, even if we are peeking. Peeking
+ // won't consume the message, so it will be still available to be read
+ // the next time Recv() is called.
+ m, n, err := q.readQueue.Dequeue()
+ if err != nil {
+ return 0, 0, ControlMessages{}, tcpip.FullAddress{}, false, err
+ }
+ notify = n
+ msg := m.(*message)
+ q.buffer = []byte(msg.Data)
+ q.control = msg.Control
+ q.addr = msg.Address
+ }
+
+ var copied uintptr
+ if peek {
+ // Don't consume control message if we are peeking.
+ c := q.control.Clone()
+
+ // Don't consume data since we are peeking.
+ copied, data, _ = vecCopy(data, q.buffer)
+
+ return copied, copied, c, q.addr, notify, nil
+ }
+
+ // Consume data and control message since we are not peeking.
+ copied, data, q.buffer = vecCopy(data, q.buffer)
+
+ // Save the original state of q.control.
+ c := q.control
+
+ // Remove rights from q.control and leave behind just the creds.
+ q.control.Rights = nil
+ if !wantCreds {
+ c.Credentials = nil
+ }
+
+ if c.Rights != nil && numRights == 0 {
+ c.Rights.Release()
+ c.Rights = nil
+ }
+
+ haveRights := c.Rights != nil
+
+ // If we have more capacity for data and haven't received any usable
+ // rights.
+ //
+ // Linux never coalesces rights control messages.
+ for !haveRights && len(data) > 0 {
+ // Get a message from the readQueue.
+ m, n, err := q.readQueue.Dequeue()
+ if err != nil {
+ // We already got some data, so ignore this error. This will
+ // manifest as a short read to the user, which is what Linux
+ // does.
+ break
+ }
+ notify = notify || n
+ msg := m.(*message)
+ q.buffer = []byte(msg.Data)
+ q.control = msg.Control
+ q.addr = msg.Address
+
+ if wantCreds {
+ if (q.control.Credentials == nil) != (c.Credentials == nil) {
+ // One message has credentials, the other does not.
+ break
+ }
+
+ if q.control.Credentials != nil && c.Credentials != nil && !q.control.Credentials.Equals(c.Credentials) {
+ // Both messages have credentials, but they don't match.
+ break
+ }
+ }
+
+ if numRights != 0 && c.Rights != nil && q.control.Rights != nil {
+ // Both messages have rights.
+ break
+ }
+
+ var cpd uintptr
+ cpd, data, q.buffer = vecCopy(data, q.buffer)
+ copied += cpd
+
+ if cpd == 0 {
+ // data was actually full.
+ break
+ }
+
+ if q.control.Rights != nil {
+ // Consume rights.
+ if numRights == 0 {
+ q.control.Rights.Release()
+ } else {
+ c.Rights = q.control.Rights
+ haveRights = true
+ }
+ q.control.Rights = nil
+ }
+ }
+ return copied, copied, c, q.addr, notify, nil
+}
+
+// A ConnectedEndpoint is an Endpoint that can be used to send Messages.
+type ConnectedEndpoint interface {
+ // Passcred implements Endpoint.Passcred.
+ Passcred() bool
+
+ // GetLocalAddress implements Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Send sends a single message. This method does not block.
+ //
+ // notify indicates if SendNotify should be called.
+ Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n uintptr, notify bool, err *tcpip.Error)
+
+ // SendNotify notifies the ConnectedEndpoint of a successful Send. This
+ // must not be called while holding any endpoint locks.
+ SendNotify()
+
+ // CloseSend prevents the sending of additional Messages.
+ //
+ // After CloseSend is call, CloseNotify must also be called.
+ CloseSend()
+
+ // CloseNotify notifies the ConnectedEndpoint of send being closed. This
+ // must not be called while holding any endpoint locks.
+ CloseNotify()
+
+ // Writable returns if messages should be attempted to be sent. This
+ // includes when write has been shutdown.
+ Writable() bool
+
+ // EventUpdate lets the ConnectedEndpoint know that event registrations
+ // have changed.
+ EventUpdate()
+
+ // SendQueuedSize returns the total amount of data currently queued for
+ // sending. SendQueuedSize should return -1 if the operation isn't
+ // supported.
+ SendQueuedSize() int64
+
+ // SendMaxQueueSize returns maximum value for SendQueuedSize.
+ // SendMaxQueueSize should return -1 if the operation isn't supported.
+ SendMaxQueueSize() int64
+
+ // Release releases any resources owned by the ConnectedEndpoint. It should
+ // be called before droping all references to a ConnectedEndpoint.
+ Release()
+}
+
+type connectedEndpoint struct {
+ // endpoint represents the subset of the Endpoint functionality needed by
+ // the connectedEndpoint. It is implemented by both connectionedEndpoint
+ // and connectionlessEndpoint and allows the use of types which don't
+ // fully implement Endpoint.
+ endpoint interface {
+ // Passcred implements Endpoint.Passcred.
+ Passcred() bool
+
+ // GetLocalAddress implements Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Type implements Endpoint.Type.
+ Type() SockType
+ }
+
+ writeQueue *queue.Queue
+}
+
+// Passcred implements ConnectedEndpoint.Passcred.
+func (e *connectedEndpoint) Passcred() bool {
+ return e.endpoint.Passcred()
+}
+
+// GetLocalAddress implements ConnectedEndpoint.GetLocalAddress.
+func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return e.endpoint.GetLocalAddress()
+}
+
+// Send implements ConnectedEndpoint.Send.
+func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (uintptr, bool, *tcpip.Error) {
+ var l int
+ for _, d := range data {
+ l += len(d)
+ }
+ // Discard empty stream packets. Since stream sockets don't preserve
+ // message boundaries, sending zero bytes is a no-op. In Linux, the
+ // receiver actually uses a zero-length receive as an indication that the
+ // stream was closed.
+ if l == 0 && e.endpoint.Type() == SockStream {
+ controlMessages.Release()
+ return 0, false, nil
+ }
+ v := make([]byte, 0, l)
+ for _, d := range data {
+ v = append(v, d...)
+ }
+ notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from})
+ return uintptr(l), notify, err
+}
+
+// SendNotify implements ConnectedEndpoint.SendNotify.
+func (e *connectedEndpoint) SendNotify() {
+ e.writeQueue.ReaderQueue.Notify(waiter.EventIn)
+}
+
+// CloseNotify implements ConnectedEndpoint.CloseNotify.
+func (e *connectedEndpoint) CloseNotify() {
+ e.writeQueue.ReaderQueue.Notify(waiter.EventIn)
+ e.writeQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseSend implements ConnectedEndpoint.CloseSend.
+func (e *connectedEndpoint) CloseSend() {
+ e.writeQueue.Close()
+}
+
+// Writable implements ConnectedEndpoint.Writable.
+func (e *connectedEndpoint) Writable() bool {
+ return e.writeQueue.IsWritable()
+}
+
+// EventUpdate implements ConnectedEndpoint.EventUpdate.
+func (*connectedEndpoint) EventUpdate() {}
+
+// SendQueuedSize implements ConnectedEndpoint.SendQueuedSize.
+func (e *connectedEndpoint) SendQueuedSize() int64 {
+ return e.writeQueue.QueuedSize()
+}
+
+// SendMaxQueueSize implements ConnectedEndpoint.SendMaxQueueSize.
+func (e *connectedEndpoint) SendMaxQueueSize() int64 {
+ return e.writeQueue.MaxQueueSize()
+}
+
+// Release implements ConnectedEndpoint.Release.
+func (*connectedEndpoint) Release() {}
+
+// baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless
+// unix domain socket Endpoint implementations.
+//
+// Not to be used on its own.
+type baseEndpoint struct {
+ *waiter.Queue
+
+ // passcred specifies whether SCM_CREDENTIALS socket control messages are
+ // enabled on this endpoint. Must be accessed atomically.
+ passcred int32
+
+ // Mutex protects the below fields.
+ sync.Mutex `state:"nosave"`
+
+ // receiver allows Messages to be received.
+ receiver Receiver
+
+ // connected allows messages to be sent and state information about the
+ // connected endpoint to be read.
+ connected ConnectedEndpoint
+
+ // path is not empty if the endpoint has been bound,
+ // or may be used if the endpoint is connected.
+ path string
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) {
+ e.Lock()
+ e.Queue.EventRegister(we, mask)
+ if e.connected != nil {
+ e.connected.EventUpdate()
+ }
+ e.Unlock()
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (e *baseEndpoint) EventUnregister(we *waiter.Entry) {
+ e.Lock()
+ e.Queue.EventUnregister(we)
+ if e.connected != nil {
+ e.connected.EventUpdate()
+ }
+ e.Unlock()
+}
+
+// Passcred implements Credentialer.Passcred.
+func (e *baseEndpoint) Passcred() bool {
+ return atomic.LoadInt32(&e.passcred) != 0
+}
+
+// ConnectedPasscred implements Credentialer.ConnectedPasscred.
+func (e *baseEndpoint) ConnectedPasscred() bool {
+ e.Lock()
+ defer e.Unlock()
+ return e.connected != nil && e.connected.Passcred()
+}
+
+func (e *baseEndpoint) setPasscred(pc bool) {
+ if pc {
+ atomic.StoreInt32(&e.passcred, 1)
+ } else {
+ atomic.StoreInt32(&e.passcred, 0)
+ }
+}
+
+// Connected implements ConnectingEndpoint.Connected.
+func (e *baseEndpoint) Connected() bool {
+ return e.receiver != nil && e.connected != nil
+}
+
+// RecvMsg reads data and a control message from the endpoint.
+func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, *tcpip.Error) {
+ e.Lock()
+
+ if e.receiver == nil {
+ e.Unlock()
+ return 0, 0, ControlMessages{}, tcpip.ErrNotConnected
+ }
+
+ recvLen, msgLen, cms, a, notify, err := e.receiver.Recv(data, creds, numRights, peek)
+ e.Unlock()
+ if err != nil {
+ return 0, 0, ControlMessages{}, err
+ }
+
+ if notify {
+ e.receiver.RecvNotify()
+ }
+
+ if addr != nil {
+ *addr = a
+ }
+ return recvLen, msgLen, cms, nil
+}
+
+// SendMsg writes data and a control message to the endpoint's peer.
+// This method does not block if the data cannot be written.
+func (e *baseEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *tcpip.Error) {
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return 0, tcpip.ErrNotConnected
+ }
+ if to != nil {
+ e.Unlock()
+ return 0, tcpip.ErrAlreadyConnected
+ }
+
+ n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ e.Unlock()
+ if err != nil {
+ return 0, err
+ }
+
+ if notify {
+ e.connected.SendNotify()
+ }
+
+ return n, nil
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ switch v := opt.(type) {
+ case tcpip.PasscredOption:
+ e.setPasscred(v != 0)
+ return nil
+ }
+ return nil
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+ case *tcpip.SendQueueSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
+ case *tcpip.ReceiveQueueSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.ReceiveQueueSizeOption(e.receiver.RecvQueuedSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
+ case *tcpip.PasscredOption:
+ if e.Passcred() {
+ *o = tcpip.PasscredOption(1)
+ } else {
+ *o = tcpip.PasscredOption(0)
+ }
+ return nil
+ case *tcpip.SendBufferSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
+ case *tcpip.ReceiveBufferSizeOption:
+ e.Lock()
+ if e.receiver == nil {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+ qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize())
+ e.Unlock()
+ if qs < 0 {
+ return tcpip.ErrQueueSizeNotSupported
+ }
+ *o = qs
+ return nil
+ }
+ return tcpip.ErrUnknownProtocolOption
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection to its
+// peer.
+func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error {
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return tcpip.ErrNotConnected
+ }
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.receiver.CloseRecv()
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ e.connected.CloseSend()
+ }
+
+ e.Unlock()
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.receiver.CloseNotify()
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ e.connected.CloseNotify()
+ }
+
+ return nil
+}
+
+// GetLocalAddress returns the bound path.
+func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.Lock()
+ defer e.Unlock()
+ return tcpip.FullAddress{Addr: tcpip.Address(e.path)}, nil
+}
+
+// GetRemoteAddress returns the local address of the connected endpoint (if
+// available).
+func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.Lock()
+ c := e.connected
+ e.Unlock()
+ if c != nil {
+ return c.GetLocalAddress()
+ }
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Release implements BoundEndpoint.Release.
+func (*baseEndpoint) Release() {}