summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
authorGoogler <noreply@google.com>2018-04-27 10:37:02 -0700
committerAdin Scannell <ascannell@google.com>2018-04-28 01:44:26 -0400
commitd02b74a5dcfed4bfc8f2f8e545bca4d2afabb296 (patch)
tree54f95eef73aee6bacbfc736fffc631be2605ed53 /pkg/tcpip/stack
parentf70210e742919f40aa2f0934a22f1c9ba6dada62 (diff)
Check in gVisor.
PiperOrigin-RevId: 194583126 Change-Id: Ica1d8821a90f74e7e745962d71801c598c652463
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/BUILD70
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go313
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go256
-rw-r--r--pkg/tcpip/stack/nic.go453
-rw-r--r--pkg/tcpip/stack/registration.go322
-rw-r--r--pkg/tcpip/stack/route.go133
-rw-r--r--pkg/tcpip/stack/stack.go811
-rw-r--r--pkg/tcpip/stack/stack_global_state.go9
-rw-r--r--pkg/tcpip/stack/stack_test.go760
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go166
-rw-r--r--pkg/tcpip/stack/transport_test.go420
11 files changed, 3713 insertions, 0 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD
new file mode 100644
index 000000000..079ade2c8
--- /dev/null
+++ b/pkg/tcpip/stack/BUILD
@@ -0,0 +1,70 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "stack_state",
+ srcs = [
+ "registration.go",
+ "stack.go",
+ ],
+ out = "stack_state.go",
+ package = "stack",
+)
+
+go_library(
+ name = "stack",
+ srcs = [
+ "linkaddrcache.go",
+ "nic.go",
+ "registration.go",
+ "route.go",
+ "stack.go",
+ "stack_global_state.go",
+ "stack_state.go",
+ "transport_demuxer.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/stack",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [
+ "//pkg/ilist",
+ "//pkg/sleep",
+ "//pkg/state",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/ports",
+ "//pkg/tcpip/seqnum",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "stack_x_test",
+ size = "small",
+ srcs = [
+ "stack_test.go",
+ "transport_test.go",
+ ],
+ deps = [
+ ":stack",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/link/channel",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "stack_test",
+ size = "small",
+ srcs = ["linkaddrcache_test.go"],
+ embed = [":stack"],
+ deps = [
+ "//pkg/sleep",
+ "//pkg/tcpip",
+ ],
+)
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
new file mode 100644
index 000000000..789f97882
--- /dev/null
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -0,0 +1,313 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+import (
+ "fmt"
+ "sync"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+const linkAddrCacheSize = 512 // max cache entries
+
+// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses.
+//
+// The entries are stored in a ring buffer, oldest entry replaced first.
+//
+// This struct is safe for concurrent use.
+type linkAddrCache struct {
+ // ageLimit is how long a cache entry is valid for.
+ ageLimit time.Duration
+
+ // resolutionTimeout is the amount of time to wait for a link request to
+ // resolve an address.
+ resolutionTimeout time.Duration
+
+ // resolutionAttempts is the number of times an address is attempted to be
+ // resolved before failing.
+ resolutionAttempts int
+
+ mu sync.Mutex
+ cache map[tcpip.FullAddress]*linkAddrEntry
+ next int // array index of next available entry
+ entries [linkAddrCacheSize]linkAddrEntry
+}
+
+// entryState controls the state of a single entry in the cache.
+type entryState int
+
+const (
+ // incomplete means that there is an outstanding request to resolve the
+ // address. This is the initial state.
+ incomplete entryState = iota
+ // ready means that the address has been resolved and can be used.
+ ready
+ // failed means that address resolution timed out and the address
+ // could not be resolved.
+ failed
+ // expired means that the cache entry has expired and the address must be
+ // resolved again.
+ expired
+)
+
+// String implements Stringer.
+func (s entryState) String() string {
+ switch s {
+ case incomplete:
+ return "incomplete"
+ case ready:
+ return "ready"
+ case failed:
+ return "failed"
+ case expired:
+ return "expired"
+ default:
+ return fmt.Sprintf("invalid entryState: %d", s)
+ }
+}
+
+// A linkAddrEntry is an entry in the linkAddrCache.
+// This struct is thread-compatible.
+type linkAddrEntry struct {
+ addr tcpip.FullAddress
+ linkAddr tcpip.LinkAddress
+ expiration time.Time
+ s entryState
+
+ // wakers is a set of waiters for address resolution result. Anytime
+ // state transitions out of 'incomplete' these waiters are notified.
+ wakers map[*sleep.Waker]struct{}
+
+ cancel chan struct{}
+}
+
+func (e *linkAddrEntry) state() entryState {
+ if e.s != expired && time.Now().After(e.expiration) {
+ // Force the transition to ensure waiters are notified.
+ e.changeState(expired)
+ }
+ return e.s
+}
+
+func (e *linkAddrEntry) changeState(ns entryState) {
+ if e.s == ns {
+ return
+ }
+
+ // Validate state transition.
+ switch e.s {
+ case incomplete:
+ // All transitions are valid.
+ case ready, failed:
+ if ns != expired {
+ panic(fmt.Sprintf("invalid state transition from %v to %v", e.s, ns))
+ }
+ case expired:
+ // Terminal state.
+ panic(fmt.Sprintf("invalid state transition from %v to %v", e.s, ns))
+ default:
+ panic(fmt.Sprintf("invalid state: %v", e.s))
+ }
+
+ // Notify whoever is waiting on address resolution when transitioning
+ // out of 'incomplete'.
+ if e.s == incomplete {
+ for w := range e.wakers {
+ w.Assert()
+ }
+ e.wakers = nil
+ }
+ e.s = ns
+}
+
+func (e *linkAddrEntry) addWaker(w *sleep.Waker) {
+ e.wakers[w] = struct{}{}
+}
+
+func (e *linkAddrEntry) removeWaker(w *sleep.Waker) {
+ delete(e.wakers, w)
+}
+
+// add adds a k -> v mapping to the cache.
+func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ entry := c.cache[k]
+ if entry != nil {
+ s := entry.state()
+ if s != expired && entry.linkAddr == v {
+ // Disregard repeated calls.
+ return
+ }
+ // Check if entry is waiting for address resolution.
+ if s == incomplete {
+ entry.linkAddr = v
+ } else {
+ // Otherwise create a new entry to replace it.
+ entry = c.makeAndAddEntry(k, v)
+ }
+ } else {
+ entry = c.makeAndAddEntry(k, v)
+ }
+
+ entry.changeState(ready)
+}
+
+// makeAndAddEntry is a helper function to create and add a new
+// entry to the cache map and evict older entry as needed.
+func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry {
+ // Take over the next entry.
+ entry := &c.entries[c.next]
+ if c.cache[entry.addr] == entry {
+ delete(c.cache, entry.addr)
+ }
+
+ // Mark the soon-to-be-replaced entry as expired, just in case there is
+ // someone waiting for address resolution on it.
+ entry.changeState(expired)
+ if entry.cancel != nil {
+ entry.cancel <- struct{}{}
+ }
+
+ *entry = linkAddrEntry{
+ addr: k,
+ linkAddr: v,
+ expiration: time.Now().Add(c.ageLimit),
+ wakers: make(map[*sleep.Waker]struct{}),
+ cancel: make(chan struct{}, 1),
+ }
+
+ c.cache[k] = entry
+ c.next++
+ if c.next == len(c.entries) {
+ c.next = 0
+ }
+ return entry
+}
+
+// get reports any known link address for k.
+func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) {
+ if linkRes != nil {
+ if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok {
+ return addr, nil
+ }
+ }
+
+ c.mu.Lock()
+ entry := c.cache[k]
+ if entry == nil || entry.state() == expired {
+ c.mu.Unlock()
+ if linkRes == nil {
+ return "", tcpip.ErrNoLinkAddress
+ }
+ c.startAddressResolution(k, linkRes, localAddr, linkEP, waker)
+ return "", tcpip.ErrWouldBlock
+ }
+ defer c.mu.Unlock()
+
+ switch s := entry.state(); s {
+ case expired:
+ // It's possible that entry expired between state() call above and here
+ // in that case it's safe to consider it ready.
+ fallthrough
+ case ready:
+ return entry.linkAddr, nil
+ case failed:
+ return "", tcpip.ErrNoLinkAddress
+ case incomplete:
+ // Address resolution is still in progress.
+ entry.addWaker(waker)
+ return "", tcpip.ErrWouldBlock
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %d", s))
+ }
+}
+
+// removeWaker removes a waker previously added through get().
+func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if entry := c.cache[k]; entry != nil {
+ entry.removeWaker(waker)
+ }
+}
+
+func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ // Look up again with lock held to ensure entry wasn't added by someone else.
+ if e := c.cache[k]; e != nil && e.state() != expired {
+ return
+ }
+
+ // Add 'incomplete' entry in the cache to mark that resolution is in progress.
+ e := c.makeAndAddEntry(k, "")
+ e.addWaker(waker)
+
+ go func() { // S/R-FIXME
+ for i := 0; ; i++ {
+ // Send link request, then wait for the timeout limit and check
+ // whether the request succeeded.
+ linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP)
+ c.mu.Lock()
+ cancel := e.cancel
+ c.mu.Unlock()
+
+ select {
+ case <-time.After(c.resolutionTimeout):
+ if stop := c.checkLinkRequest(k, i); stop {
+ return
+ }
+ case <-cancel:
+ return
+ }
+ }
+ }()
+}
+
+// checkLinkRequest checks whether previous attempt to resolve address has succeeded
+// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request
+// can stop, false if another request should be sent.
+func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ entry, ok := c.cache[k]
+ if !ok {
+ // Entry was evicted from the cache.
+ return true
+ }
+
+ switch s := entry.state(); s {
+ case ready, failed, expired:
+ // Entry was made ready by resolver or failed. Either way we're done.
+ return true
+ case incomplete:
+ if attempt+1 >= c.resolutionAttempts {
+ // Max number of retries reached, mark entry as failed.
+ entry.changeState(failed)
+ return true
+ }
+ // No response yet, need to send another ARP request.
+ return false
+ default:
+ panic(fmt.Sprintf("invalid cache entry state: %d", s))
+ }
+}
+
+func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
+ return &linkAddrCache{
+ ageLimit: ageLimit,
+ resolutionTimeout: resolutionTimeout,
+ resolutionAttempts: resolutionAttempts,
+ cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize),
+ }
+}
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
new file mode 100644
index 000000000..e9897b2bd
--- /dev/null
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -0,0 +1,256 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+import (
+ "fmt"
+ "sync"
+ "testing"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+)
+
+type testaddr struct {
+ addr tcpip.FullAddress
+ linkAddr tcpip.LinkAddress
+}
+
+var testaddrs []testaddr
+
+type testLinkAddressResolver struct {
+ cache *linkAddrCache
+ delay time.Duration
+}
+
+func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error {
+ go func() {
+ if r.delay > 0 {
+ time.Sleep(r.delay)
+ }
+ r.fakeRequest(addr)
+ }()
+ return nil
+}
+
+func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) {
+ for _, ta := range testaddrs {
+ if ta.addr.Addr == addr {
+ r.cache.add(ta.addr, ta.linkAddr)
+ break
+ }
+ }
+}
+
+func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) {
+ if addr == "broadcast" {
+ return "mac_broadcast", true
+ }
+ return "", false
+}
+
+func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber {
+ return 1
+}
+
+func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) {
+ w := sleep.Waker{}
+ s := sleep.Sleeper{}
+ s.AddWaker(&w, 123)
+ defer s.Done()
+
+ for {
+ if got, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock {
+ return got, err
+ }
+ s.Fetch(true)
+ }
+}
+
+func init() {
+ for i := 0; i < 4*linkAddrCacheSize; i++ {
+ addr := fmt.Sprintf("Addr%06d", i)
+ testaddrs = append(testaddrs, testaddr{
+ addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)},
+ linkAddr: tcpip.LinkAddress("Link" + addr),
+ })
+ }
+}
+
+func TestCacheOverflow(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ for i := len(testaddrs) - 1; i >= 0; i-- {
+ e := testaddrs[i]
+ c.add(e.addr, e.linkAddr)
+ got, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
+ }
+ }
+ // Expect to find at least half of the most recent entries.
+ for i := 0; i < linkAddrCacheSize/2; i++ {
+ e := testaddrs[i]
+ got, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr)
+ }
+ }
+ // The earliest entries should no longer be in the cache.
+ for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- {
+ e := testaddrs[i]
+ if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err)
+ }
+ }
+}
+
+func TestCacheConcurrent(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+
+ var wg sync.WaitGroup
+ for r := 0; r < 16; r++ {
+ wg.Add(1)
+ go func() {
+ for _, e := range testaddrs {
+ c.add(e.addr, e.linkAddr)
+ c.get(e.addr, nil, "", nil, nil) // make work for gotsan
+ }
+ wg.Done()
+ }()
+ }
+ wg.Wait()
+
+ // All goroutines add in the same order and add more values than
+ // can fit in the cache, so our eviction strategy requires that
+ // the last entry be present and the first be missing.
+ e := testaddrs[len(testaddrs)-1]
+ got, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+
+ e = testaddrs[0]
+ if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+}
+
+func TestCacheAgeLimit(t *testing.T) {
+ c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
+ e := testaddrs[0]
+ c.add(e.addr, e.linkAddr)
+ time.Sleep(50 * time.Millisecond)
+ if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+}
+
+func TestCacheReplace(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ e := testaddrs[0]
+ l2 := e.linkAddr + "2"
+ c.add(e.addr, e.linkAddr)
+ got, err := c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+
+ c.add(e.addr, l2)
+ got, err = c.get(e.addr, nil, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != l2 {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2)
+ }
+}
+
+func TestCacheResolution(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1)
+ linkRes := &testLinkAddressResolver{cache: c}
+ for i, ta := range testaddrs {
+ got, err := getBlocking(c, ta.addr, linkRes)
+ if err != nil {
+ t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err)
+ }
+ if got != ta.linkAddr {
+ t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr)
+ }
+ }
+
+ // Check that after resolved, address stays in the cache and never returns WouldBlock.
+ for i := 0; i < 10; i++ {
+ e := testaddrs[len(testaddrs)-1]
+ got, err := c.get(e.addr, linkRes, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+ }
+}
+
+func TestCacheResolutionFailed(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
+ linkRes := &testLinkAddressResolver{cache: c}
+
+ // First, sanity check that resolution is working...
+ e := testaddrs[0]
+ got, err := getBlocking(c, e.addr, linkRes)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err)
+ }
+ if got != e.linkAddr {
+ t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr)
+ }
+
+ e.addr.Addr += "2"
+ if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+}
+
+func TestCacheResolutionTimeout(t *testing.T) {
+ resolverDelay := 50 * time.Millisecond
+ expiration := resolverDelay / 2
+ c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
+ linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
+
+ e := testaddrs[0]
+ if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress {
+ t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err)
+ }
+}
+
+// TestStaticResolution checks that static link addresses are resolved immediately and don't
+// send resolution requests.
+func TestStaticResolution(t *testing.T) {
+ c := newLinkAddrCache(1<<63-1, time.Millisecond, 1)
+ linkRes := &testLinkAddressResolver{cache: c, delay: time.Minute}
+
+ addr := tcpip.Address("broadcast")
+ want := tcpip.LinkAddress("mac_broadcast")
+ got, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil)
+ if err != nil {
+ t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err)
+ }
+ if got != want {
+ t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want))
+ }
+}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
new file mode 100644
index 000000000..8ff4310d5
--- /dev/null
+++ b/pkg/tcpip/stack/nic.go
@@ -0,0 +1,453 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+import (
+ "strings"
+ "sync"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/ilist"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+// NIC represents a "network interface card" to which the networking stack is
+// attached.
+type NIC struct {
+ stack *Stack
+ id tcpip.NICID
+ name string
+ linkEP LinkEndpoint
+
+ demux *transportDemuxer
+
+ mu sync.RWMutex
+ spoofing bool
+ promiscuous bool
+ primary map[tcpip.NetworkProtocolNumber]*ilist.List
+ endpoints map[NetworkEndpointID]*referencedNetworkEndpoint
+ subnets []tcpip.Subnet
+}
+
+func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC {
+ return &NIC{
+ stack: stack,
+ id: id,
+ name: name,
+ linkEP: ep,
+ demux: newTransportDemuxer(stack),
+ primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List),
+ endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint),
+ }
+}
+
+// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it
+// to start delivering packets.
+func (n *NIC) attachLinkEndpoint() {
+ n.linkEP.Attach(n)
+}
+
+// setPromiscuousMode enables or disables promiscuous mode.
+func (n *NIC) setPromiscuousMode(enable bool) {
+ n.mu.Lock()
+ n.promiscuous = enable
+ n.mu.Unlock()
+}
+
+// setSpoofing enables or disables address spoofing.
+func (n *NIC) setSpoofing(enable bool) {
+ n.mu.Lock()
+ n.spoofing = enable
+ n.mu.Unlock()
+}
+
+// primaryEndpoint returns the primary endpoint of n for the given network
+// protocol.
+func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+
+ list := n.primary[protocol]
+ if list == nil {
+ return nil
+ }
+
+ for e := list.Front(); e != nil; e = e.Next() {
+ r := e.(*referencedNetworkEndpoint)
+ if r.tryIncRef() {
+ return r
+ }
+ }
+
+ return nil
+}
+
+// findEndpoint finds the endpoint, if any, with the given address.
+func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address) *referencedNetworkEndpoint {
+ id := NetworkEndpointID{address}
+
+ n.mu.RLock()
+ ref := n.endpoints[id]
+ if ref != nil && !ref.tryIncRef() {
+ ref = nil
+ }
+ spoofing := n.spoofing
+ n.mu.RUnlock()
+
+ if ref != nil || !spoofing {
+ return ref
+ }
+
+ // Try again with the lock in exclusive mode. If we still can't get the
+ // endpoint, create a new "temporary" endpoint. It will only exist while
+ // there's a route through it.
+ n.mu.Lock()
+ ref = n.endpoints[id]
+ if ref == nil || !ref.tryIncRef() {
+ ref, _ = n.addAddressLocked(protocol, address, true)
+ if ref != nil {
+ ref.holdsInsertRef = false
+ }
+ }
+ n.mu.Unlock()
+ return ref
+}
+
+func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) {
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ // Create the new network endpoint.
+ ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP)
+ if err != nil {
+ return nil, err
+ }
+
+ id := *ep.ID()
+ if ref, ok := n.endpoints[id]; ok {
+ if !replace {
+ return nil, tcpip.ErrDuplicateAddress
+ }
+
+ n.removeEndpointLocked(ref)
+ }
+
+ ref := &referencedNetworkEndpoint{
+ refs: 1,
+ ep: ep,
+ nic: n,
+ protocol: protocol,
+ holdsInsertRef: true,
+ }
+
+ // Set up cache if link address resolution exists for this protocol.
+ if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 {
+ if linkRes := n.stack.linkAddrResolvers[protocol]; linkRes != nil {
+ ref.linkCache = n.stack
+ }
+ }
+
+ n.endpoints[id] = ref
+
+ l, ok := n.primary[protocol]
+ if !ok {
+ l = &ilist.List{}
+ n.primary[protocol] = l
+ }
+
+ l.PushBack(ref)
+
+ return ref, nil
+}
+
+// AddAddress adds a new address to n, so that it starts accepting packets
+// targeted at the given address (and network protocol).
+func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ // Add the endpoint.
+ n.mu.Lock()
+ _, err := n.addAddressLocked(protocol, addr, false)
+ n.mu.Unlock()
+
+ return err
+}
+
+// Addresses returns the addresses associated with this NIC.
+func (n *NIC) Addresses() []tcpip.ProtocolAddress {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints))
+ for nid, ep := range n.endpoints {
+ addrs = append(addrs, tcpip.ProtocolAddress{
+ Protocol: ep.protocol,
+ Address: nid.LocalAddress,
+ })
+ }
+ return addrs
+}
+
+// AddSubnet adds a new subnet to n, so that it starts accepting packets
+// targeted at the given address and network protocol.
+func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) {
+ n.mu.Lock()
+ n.subnets = append(n.subnets, subnet)
+ n.mu.Unlock()
+}
+
+// Subnets returns the Subnets associated with this NIC.
+func (n *NIC) Subnets() []tcpip.Subnet {
+ n.mu.RLock()
+ defer n.mu.RUnlock()
+ sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints))
+ for nid := range n.endpoints {
+ sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress))))
+ if err != nil {
+ // This should never happen as the mask has been carefully crafted to
+ // match the address.
+ panic("Invalid endpoint subnet: " + err.Error())
+ }
+ sns = append(sns, sn)
+ }
+ return append(sns, n.subnets...)
+}
+
+func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) {
+ id := *r.ep.ID()
+
+ // Nothing to do if the reference has already been replaced with a
+ // different one.
+ if n.endpoints[id] != r {
+ return
+ }
+
+ if r.holdsInsertRef {
+ panic("Reference count dropped to zero before being removed")
+ }
+
+ delete(n.endpoints, id)
+ n.primary[r.protocol].Remove(r)
+ r.ep.Close()
+}
+
+func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) {
+ n.mu.Lock()
+ n.removeEndpointLocked(r)
+ n.mu.Unlock()
+}
+
+// RemoveAddress removes an address from n.
+func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error {
+ n.mu.Lock()
+ r := n.endpoints[NetworkEndpointID{addr}]
+ if r == nil || !r.holdsInsertRef {
+ n.mu.Unlock()
+ return tcpip.ErrBadLocalAddress
+ }
+
+ r.holdsInsertRef = false
+ n.mu.Unlock()
+
+ r.decRef()
+
+ return nil
+}
+
+// DeliverNetworkPacket finds the appropriate network protocol endpoint and
+// hands the packet over for further processing. This function is called when
+// the NIC receives a packet from the physical interface.
+// Note that the ownership of the slice backing vv is retained by the caller.
+// This rule applies only to the slice itself, not to the items of the slice;
+// the ownership of the items is not retained by the caller.
+func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) {
+ netProto, ok := n.stack.networkProtocols[protocol]
+ if !ok {
+ atomic.AddUint64(&n.stack.stats.UnknownProtocolRcvdPackets, 1)
+ return
+ }
+
+ if len(vv.First()) < netProto.MinimumPacketSize() {
+ atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ return
+ }
+
+ src, dst := netProto.ParseAddresses(vv.First())
+ id := NetworkEndpointID{dst}
+
+ n.mu.RLock()
+ ref := n.endpoints[id]
+ if ref != nil && !ref.tryIncRef() {
+ ref = nil
+ }
+ promiscuous := n.promiscuous
+ subnets := n.subnets
+ n.mu.RUnlock()
+
+ if ref == nil {
+ // Check if the packet is for a subnet this NIC cares about.
+ if !promiscuous {
+ for _, sn := range subnets {
+ if sn.Contains(dst) {
+ promiscuous = true
+ break
+ }
+ }
+ }
+ if promiscuous {
+ // Try again with the lock in exclusive mode. If we still can't
+ // get the endpoint, create a new "temporary" one. It will only
+ // exist while there's a route through it.
+ n.mu.Lock()
+ ref = n.endpoints[id]
+ if ref == nil || !ref.tryIncRef() {
+ ref, _ = n.addAddressLocked(protocol, dst, true)
+ if ref != nil {
+ ref.holdsInsertRef = false
+ }
+ }
+ n.mu.Unlock()
+ }
+ }
+
+ if ref == nil {
+ atomic.AddUint64(&n.stack.stats.UnknownNetworkEndpointRcvdPackets, 1)
+ return
+ }
+
+ r := makeRoute(protocol, dst, src, ref)
+ r.LocalLinkAddress = linkEP.LinkAddress()
+ r.RemoteLinkAddress = remoteLinkAddr
+ ref.ep.HandlePacket(&r, vv)
+ ref.decRef()
+}
+
+// DeliverTransportPacket delivers the packets to the appropriate transport
+// protocol endpoint.
+func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView) {
+ state, ok := n.stack.transportProtocols[protocol]
+ if !ok {
+ atomic.AddUint64(&n.stack.stats.UnknownProtocolRcvdPackets, 1)
+ return
+ }
+
+ transProto := state.proto
+ if len(vv.First()) < transProto.MinimumPacketSize() {
+ atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ return
+ }
+
+ srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ if err != nil {
+ atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ return
+ }
+
+ id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress}
+ if n.demux.deliverPacket(r, protocol, vv, id) {
+ return
+ }
+ if n.stack.demux.deliverPacket(r, protocol, vv, id) {
+ return
+ }
+
+ // Try to deliver to per-stack default handler.
+ if state.defaultHandler != nil {
+ if state.defaultHandler(r, id, vv) {
+ return
+ }
+ }
+
+ // We could not find an appropriate destination for this packet, so
+ // deliver it to the global handler.
+ if !transProto.HandleUnknownDestinationPacket(r, id, vv) {
+ atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1)
+ }
+}
+
+// DeliverTransportControlPacket delivers control packets to the appropriate
+// transport protocol endpoint.
+func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView) {
+ state, ok := n.stack.transportProtocols[trans]
+ if !ok {
+ return
+ }
+
+ transProto := state.proto
+
+ // ICMPv4 only guarantees that 8 bytes of the transport protocol will
+ // be present in the payload. We know that the ports are within the
+ // first 8 bytes for all known transport protocols.
+ if len(vv.First()) < 8 {
+ return
+ }
+
+ srcPort, dstPort, err := transProto.ParsePorts(vv.First())
+ if err != nil {
+ return
+ }
+
+ id := TransportEndpointID{srcPort, local, dstPort, remote}
+ if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ return
+ }
+ if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) {
+ return
+ }
+}
+
+// ID returns the identifier of n.
+func (n *NIC) ID() tcpip.NICID {
+ return n.id
+}
+
+type referencedNetworkEndpoint struct {
+ ilist.Entry
+ refs int32
+ ep NetworkEndpoint
+ nic *NIC
+ protocol tcpip.NetworkProtocolNumber
+
+ // linkCache is set if link address resolution is enabled for this
+ // protocol. Set to nil otherwise.
+ linkCache LinkAddressCache
+
+ // holdsInsertRef is protected by the NIC's mutex. It indicates whether
+ // the reference count is biased by 1 due to the insertion of the
+ // endpoint. It is reset to false when RemoveAddress is called on the
+ // NIC.
+ holdsInsertRef bool
+}
+
+// decRef decrements the ref count and cleans up the endpoint once it reaches
+// zero.
+func (r *referencedNetworkEndpoint) decRef() {
+ if atomic.AddInt32(&r.refs, -1) == 0 {
+ r.nic.removeEndpoint(r)
+ }
+}
+
+// incRef increments the ref count. It must only be called when the caller is
+// known to be holding a reference to the endpoint, otherwise tryIncRef should
+// be used.
+func (r *referencedNetworkEndpoint) incRef() {
+ atomic.AddInt32(&r.refs, 1)
+}
+
+// tryIncRef attempts to increment the ref count from n to n+1, but only if n is
+// not zero. That is, it will increment the count if the endpoint is still
+// alive, and do nothing if it has already been clean up.
+func (r *referencedNetworkEndpoint) tryIncRef() bool {
+ for {
+ v := atomic.LoadInt32(&r.refs)
+ if v == 0 {
+ return false
+ }
+
+ if atomic.CompareAndSwapInt32(&r.refs, v, v+1) {
+ return true
+ }
+ }
+}
diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go
new file mode 100644
index 000000000..e7e6381ac
--- /dev/null
+++ b/pkg/tcpip/stack/registration.go
@@ -0,0 +1,322 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// NetworkEndpointID is the identifier of a network layer protocol endpoint.
+// Currently the local address is sufficient because all supported protocols
+// (i.e., IPv4 and IPv6) have different sizes for their addresses.
+type NetworkEndpointID struct {
+ LocalAddress tcpip.Address
+}
+
+// TransportEndpointID is the identifier of a transport layer protocol endpoint.
+type TransportEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// ControlType is the type of network control message.
+type ControlType int
+
+// The following are the allowed values for ControlType values.
+const (
+ ControlPacketTooBig ControlType = iota
+ ControlPortUnreachable
+ ControlUnknown
+)
+
+// TransportEndpoint is the interface that needs to be implemented by transport
+// protocol (e.g., tcp, udp) endpoints that can handle packets.
+type TransportEndpoint interface {
+ // HandlePacket is called by the stack when new packets arrive to
+ // this transport endpoint.
+ HandlePacket(r *Route, id TransportEndpointID, vv *buffer.VectorisedView)
+
+ // HandleControlPacket is called by the stack when new control (e.g.,
+ // ICMP) packets arrive to this transport endpoint.
+ HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv *buffer.VectorisedView)
+}
+
+// TransportProtocol is the interface that needs to be implemented by transport
+// protocols (e.g., tcp, udp) that want to be part of the networking stack.
+type TransportProtocol interface {
+ // Number returns the transport protocol number.
+ Number() tcpip.TransportProtocolNumber
+
+ // NewEndpoint creates a new endpoint of the transport protocol.
+ NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error)
+
+ // MinimumPacketSize returns the minimum valid packet size of this
+ // transport protocol. The stack automatically drops any packets smaller
+ // than this targeted at this protocol.
+ MinimumPacketSize() int
+
+ // ParsePorts returns the source and destination ports stored in a
+ // packet of this protocol.
+ ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error)
+
+ // HandleUnknownDestinationPacket handles packets targeted at this
+ // protocol but that don't match any existing endpoint. For example,
+ // it is targeted at a port that have no listeners.
+ //
+ // The return value indicates whether the packet was well-formed (for
+ // stats purposes only).
+ HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv *buffer.VectorisedView) bool
+
+ // SetOption allows enabling/disabling protocol specific features.
+ // SetOption returns an error if the option is not supported or the
+ // provided option value is invalid.
+ SetOption(option interface{}) *tcpip.Error
+
+ // Option allows retrieving protocol specific option values.
+ // Option returns an error if the option is not supported or the
+ // provided option value is invalid.
+ Option(option interface{}) *tcpip.Error
+}
+
+// TransportDispatcher contains the methods used by the network stack to deliver
+// packets to the appropriate transport endpoint after it has been handled by
+// the network layer.
+type TransportDispatcher interface {
+ // DeliverTransportPacket delivers packets to the appropriate
+ // transport protocol endpoint.
+ DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView)
+
+ // DeliverTransportControlPacket delivers control packets to the
+ // appropriate transport protocol endpoint.
+ DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView)
+}
+
+// NetworkEndpoint is the interface that needs to be implemented by endpoints
+// of network layer protocols (e.g., ipv4, ipv6).
+type NetworkEndpoint interface {
+ // MTU is the maximum transmission unit for this endpoint. This is
+ // generally calculated as the MTU of the underlying data link endpoint
+ // minus the network endpoint max header length.
+ MTU() uint32
+
+ // Capabilities returns the set of capabilities supported by the
+ // underlying link-layer endpoint.
+ Capabilities() LinkEndpointCapabilities
+
+ // MaxHeaderLength returns the maximum size the network (and lower
+ // level layers combined) headers can have. Higher levels use this
+ // information to reserve space in the front of the packets they're
+ // building.
+ MaxHeaderLength() uint16
+
+ // WritePacket writes a packet to the given destination address and
+ // protocol.
+ WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error
+
+ // ID returns the network protocol endpoint ID.
+ ID() *NetworkEndpointID
+
+ // NICID returns the id of the NIC this endpoint belongs to.
+ NICID() tcpip.NICID
+
+ // HandlePacket is called by the link layer when new packets arrive to
+ // this network endpoint.
+ HandlePacket(r *Route, vv *buffer.VectorisedView)
+
+ // Close is called when the endpoint is reomved from a stack.
+ Close()
+}
+
+// NetworkProtocol is the interface that needs to be implemented by network
+// protocols (e.g., ipv4, ipv6) that want to be part of the networking stack.
+type NetworkProtocol interface {
+ // Number returns the network protocol number.
+ Number() tcpip.NetworkProtocolNumber
+
+ // MinimumPacketSize returns the minimum valid packet size of this
+ // network protocol. The stack automatically drops any packets smaller
+ // than this targeted at this protocol.
+ MinimumPacketSize() int
+
+ // ParsePorts returns the source and destination addresses stored in a
+ // packet of this protocol.
+ ParseAddresses(v buffer.View) (src, dst tcpip.Address)
+
+ // NewEndpoint creates a new endpoint of this protocol.
+ NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error)
+
+ // SetOption allows enabling/disabling protocol specific features.
+ // SetOption returns an error if the option is not supported or the
+ // provided option value is invalid.
+ SetOption(option interface{}) *tcpip.Error
+
+ // Option allows retrieving protocol specific option values.
+ // Option returns an error if the option is not supported or the
+ // provided option value is invalid.
+ Option(option interface{}) *tcpip.Error
+}
+
+// NetworkDispatcher contains the methods used by the network stack to deliver
+// packets to the appropriate network endpoint after it has been handled by
+// the data link layer.
+type NetworkDispatcher interface {
+ // DeliverNetworkPacket finds the appropriate network protocol
+ // endpoint and hands the packet over for further processing.
+ DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView)
+}
+
+// LinkEndpointCapabilities is the type associated with the capabilities
+// supported by a link-layer endpoint. It is a set of bitfields.
+type LinkEndpointCapabilities uint
+
+// The following are the supported link endpoint capabilities.
+const (
+ CapabilityChecksumOffload LinkEndpointCapabilities = 1 << iota
+ CapabilityResolutionRequired
+)
+
+// LinkEndpoint is the interface implemented by data link layer protocols (e.g.,
+// ethernet, loopback, raw) and used by network layer protocols to send packets
+// out through the implementer's data link endpoint.
+type LinkEndpoint interface {
+ // MTU is the maximum transmission unit for this endpoint. This is
+ // usually dictated by the backing physical network; when such a
+ // physical network doesn't exist, the limit is generally 64k, which
+ // includes the maximum size of an IP packet.
+ MTU() uint32
+
+ // Capabilities returns the set of capabilities supported by the
+ // endpoint.
+ Capabilities() LinkEndpointCapabilities
+
+ // MaxHeaderLength returns the maximum size the data link (and
+ // lower level layers combined) headers can have. Higher levels use this
+ // information to reserve space in the front of the packets they're
+ // building.
+ MaxHeaderLength() uint16
+
+ // LinkAddress returns the link address (typically a MAC) of the
+ // link endpoint.
+ LinkAddress() tcpip.LinkAddress
+
+ // WritePacket writes a packet with the given protocol through the given
+ // route.
+ WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error
+
+ // Attach attaches the data link layer endpoint to the network-layer
+ // dispatcher of the stack.
+ Attach(dispatcher NetworkDispatcher)
+}
+
+// A LinkAddressResolver is an extension to a NetworkProtocol that
+// can resolve link addresses.
+type LinkAddressResolver interface {
+ // LinkAddressRequest sends a request for the LinkAddress of addr.
+ // The request is sent on linkEP with localAddr as the source.
+ //
+ // A valid response will cause the discovery protocol's network
+ // endpoint to call AddLinkAddress.
+ LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error
+
+ // ResolveStaticAddress attempts to resolve address without sending
+ // requests. It either resolves the name immediately or returns the
+ // empty LinkAddress.
+ //
+ // It can be used to resolve broadcast addresses for example.
+ ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool)
+
+ // LinkAddressProtocol returns the network protocol of the
+ // addresses this this resolver can resolve.
+ LinkAddressProtocol() tcpip.NetworkProtocolNumber
+}
+
+// A LinkAddressCache caches link addresses.
+type LinkAddressCache interface {
+ // CheckLocalAddress determines if the given local address exists, and if it
+ // does not exist.
+ CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID
+
+ // AddLinkAddress adds a link address to the cache.
+ AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress)
+
+ // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC).
+ // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver
+ // registered with the network protocol, the cache attempts to resolve the address
+ // and returns ErrWouldBlock. Waker is notified when address resolution is
+ // complete (success or not).
+ GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error)
+
+ // RemoveWaker removes a waker that has been added in GetLinkAddress().
+ RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker)
+}
+
+// TransportProtocolFactory functions are used by the stack to instantiate
+// transport protocols.
+type TransportProtocolFactory func() TransportProtocol
+
+// NetworkProtocolFactory provides methods to be used by the stack to
+// instantiate network protocols.
+type NetworkProtocolFactory func() NetworkProtocol
+
+var (
+ transportProtocols = make(map[string]TransportProtocolFactory)
+ networkProtocols = make(map[string]NetworkProtocolFactory)
+
+ linkEPMu sync.RWMutex
+ nextLinkEndpointID tcpip.LinkEndpointID = 1
+ linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint)
+)
+
+// RegisterTransportProtocolFactory registers a new transport protocol factory
+// with the stack so that it becomes available to users of the stack. This
+// function is intended to be called by init() functions of the protocols.
+func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) {
+ transportProtocols[name] = p
+}
+
+// RegisterNetworkProtocolFactory registers a new network protocol factory with
+// the stack so that it becomes available to users of the stack. This function
+// is intended to be called by init() functions of the protocols.
+func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) {
+ networkProtocols[name] = p
+}
+
+// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an
+// ID that can be used to refer to it.
+func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID {
+ linkEPMu.Lock()
+ defer linkEPMu.Unlock()
+
+ v := nextLinkEndpointID
+ nextLinkEndpointID++
+
+ linkEndpoints[v] = linkEP
+
+ return v
+}
+
+// FindLinkEndpoint finds the link endpoint associated with the given ID.
+func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint {
+ linkEPMu.RLock()
+ defer linkEPMu.RUnlock()
+
+ return linkEndpoints[id]
+}
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
new file mode 100644
index 000000000..12f5efba5
--- /dev/null
+++ b/pkg/tcpip/stack/route.go
@@ -0,0 +1,133 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+)
+
+// Route represents a route through the networking stack to a given destination.
+type Route struct {
+ // RemoteAddress is the final destination of the route.
+ RemoteAddress tcpip.Address
+
+ // RemoteLinkAddress is the link-layer (MAC) address of the
+ // final destination of the route.
+ RemoteLinkAddress tcpip.LinkAddress
+
+ // LocalAddress is the local address where the route starts.
+ LocalAddress tcpip.Address
+
+ // LocalLinkAddress is the link-layer (MAC) address of the
+ // where the route starts.
+ LocalLinkAddress tcpip.LinkAddress
+
+ // NextHop is the next node in the path to the destination.
+ NextHop tcpip.Address
+
+ // NetProto is the network-layer protocol.
+ NetProto tcpip.NetworkProtocolNumber
+
+ // ref a reference to the network endpoint through which the route
+ // starts.
+ ref *referencedNetworkEndpoint
+}
+
+// makeRoute initializes a new route. It takes ownership of the provided
+// reference to a network endpoint.
+func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, ref *referencedNetworkEndpoint) Route {
+ return Route{
+ NetProto: netProto,
+ LocalAddress: localAddr,
+ RemoteAddress: remoteAddr,
+ ref: ref,
+ }
+}
+
+// NICID returns the id of the NIC from which this route originates.
+func (r *Route) NICID() tcpip.NICID {
+ return r.ref.ep.NICID()
+}
+
+// MaxHeaderLength forwards the call to the network endpoint's implementation.
+func (r *Route) MaxHeaderLength() uint16 {
+ return r.ref.ep.MaxHeaderLength()
+}
+
+// PseudoHeaderChecksum forwards the call to the network endpoint's
+// implementation.
+func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber) uint16 {
+ return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress)
+}
+
+// Capabilities returns the link-layer capabilities of the route.
+func (r *Route) Capabilities() LinkEndpointCapabilities {
+ return r.ref.ep.Capabilities()
+}
+
+// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in
+// case address resolution requires blocking, e.g. wait for ARP reply. Waker is
+// notified when address resolution is complete (success or not).
+func (r *Route) Resolve(waker *sleep.Waker) *tcpip.Error {
+ if !r.IsResolutionRequired() {
+ // Nothing to do if there is no cache (which does the resolution on cache miss) or
+ // link address is already known.
+ return nil
+ }
+
+ nextAddr := r.NextHop
+ if nextAddr == "" {
+ nextAddr = r.RemoteAddress
+ }
+ linkAddr, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker)
+ if err != nil {
+ return err
+ }
+ r.RemoteLinkAddress = linkAddr
+ return nil
+}
+
+// RemoveWaker removes a waker that has been added in Resolve().
+func (r *Route) RemoveWaker(waker *sleep.Waker) {
+ nextAddr := r.NextHop
+ if nextAddr == "" {
+ nextAddr = r.RemoteAddress
+ }
+ r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker)
+}
+
+// IsResolutionRequired returns true if Resolve() must be called to resolve
+// the link address before the this route can be written to.
+func (r *Route) IsResolutionRequired() bool {
+ return r.ref.linkCache != nil && r.RemoteLinkAddress == ""
+}
+
+// WritePacket writes the packet through the given route.
+func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+ return r.ref.ep.WritePacket(r, hdr, payload, protocol)
+}
+
+// MTU returns the MTU of the underlying network endpoint.
+func (r *Route) MTU() uint32 {
+ return r.ref.ep.MTU()
+}
+
+// Release frees all resources associated with the route.
+func (r *Route) Release() {
+ if r.ref != nil {
+ r.ref.decRef()
+ r.ref = nil
+ }
+}
+
+// Clone Clone a route such that the original one can be released and the new
+// one will remain valid.
+func (r *Route) Clone() Route {
+ r.ref.incRef()
+ return *r
+}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
new file mode 100644
index 000000000..558ecdb72
--- /dev/null
+++ b/pkg/tcpip/stack/stack.go
@@ -0,0 +1,811 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package stack provides the glue between networking protocols and the
+// consumers of the networking stack.
+//
+// For consumers, the only function of interest is New(), everything else is
+// provided by the tcpip/public package.
+//
+// For protocol implementers, RegisterTransportProtocolFactory() and
+// RegisterNetworkProtocolFactory() are used to register protocol factories with
+// the stack, which will then be used to instantiate protocol objects when
+// consumers interact with the stack.
+package stack
+
+import (
+ "sync"
+ "sync/atomic"
+ "time"
+
+ "gvisor.googlesource.com/gvisor/pkg/sleep"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/header"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/ports"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ // ageLimit is set to the same cache stale time used in Linux.
+ ageLimit = 1 * time.Minute
+ // resolutionTimeout is set to the same ARP timeout used in Linux.
+ resolutionTimeout = 1 * time.Second
+ // resolutionAttempts is set to the same ARP retries used in Linux.
+ resolutionAttempts = 3
+)
+
+type transportProtocolState struct {
+ proto TransportProtocol
+ defaultHandler func(*Route, TransportEndpointID, *buffer.VectorisedView) bool
+}
+
+// TCPProbeFunc is the expected function type for a TCP probe function to be
+// passed to stack.AddTCPProbe.
+type TCPProbeFunc func(s TCPEndpointState)
+
+// TCPEndpointID is the unique 4 tuple that identifies a given endpoint.
+type TCPEndpointID struct {
+ // LocalPort is the local port associated with the endpoint.
+ LocalPort uint16
+
+ // LocalAddress is the local [network layer] address associated with
+ // the endpoint.
+ LocalAddress tcpip.Address
+
+ // RemotePort is the remote port associated with the endpoint.
+ RemotePort uint16
+
+ // RemoteAddress it the remote [network layer] address associated with
+ // the endpoint.
+ RemoteAddress tcpip.Address
+}
+
+// TCPFastRecoveryState holds a copy of the internal fast recovery state of a
+// TCP endpoint.
+type TCPFastRecoveryState struct {
+ // Active if true indicates the endpoint is in fast recovery.
+ Active bool
+
+ // First is the first unacknowledged sequence number being recovered.
+ First seqnum.Value
+
+ // Last is the 'recover' sequence number that indicates the point at
+ // which we should exit recovery barring any timeouts etc.
+ Last seqnum.Value
+
+ // MaxCwnd is the maximum value we are permitted to grow the congestion
+ // window during recovery. This is set at the time we enter recovery.
+ MaxCwnd int
+}
+
+// TCPReceiverState holds a copy of the internal state of the receiver for
+// a given TCP endpoint.
+type TCPReceiverState struct {
+ // RcvNxt is the TCP variable RCV.NXT.
+ RcvNxt seqnum.Value
+
+ // RcvAcc is the TCP variable RCV.ACC.
+ RcvAcc seqnum.Value
+
+ // RcvWndScale is the window scaling to use for inbound segments.
+ RcvWndScale uint8
+
+ // PendingBufUsed is the number of bytes pending in the receive
+ // queue.
+ PendingBufUsed seqnum.Size
+
+ // PendingBufSize is the size of the socket receive buffer.
+ PendingBufSize seqnum.Size
+}
+
+// TCPSenderState holds a copy of the internal state of the sender for
+// a given TCP Endpoint.
+type TCPSenderState struct {
+ // LastSendTime is the time at which we sent the last segment.
+ LastSendTime time.Time
+
+ // DupAckCount is the number of Duplicate ACK's received.
+ DupAckCount int
+
+ // SndCwnd is the size of the sending congestion window in packets.
+ SndCwnd int
+
+ // Ssthresh is the slow start threshold in packets.
+ Ssthresh int
+
+ // SndCAAckCount is the number of packets consumed in congestion
+ // avoidance mode.
+ SndCAAckCount int
+
+ // Outstanding is the number of packets in flight.
+ Outstanding int
+
+ // SndWnd is the send window size in bytes.
+ SndWnd seqnum.Size
+
+ // SndUna is the next unacknowledged sequence number.
+ SndUna seqnum.Value
+
+ // SndNxt is the sequence number of the next segment to be sent.
+ SndNxt seqnum.Value
+
+ // RTTMeasureSeqNum is the sequence number being used for the latest RTT
+ // measurement.
+ RTTMeasureSeqNum seqnum.Value
+
+ // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent.
+ RTTMeasureTime time.Time
+
+ // Closed indicates that the caller has closed the endpoint for sending.
+ Closed bool
+
+ // SRTT is the smoothed round-trip time as defined in section 2 of
+ // RFC 6298.
+ SRTT time.Duration
+
+ // RTO is the retransmit timeout as defined in section of 2 of RFC 6298.
+ RTO time.Duration
+
+ // RTTVar is the round-trip time variation as defined in section 2 of
+ // RFC 6298.
+ RTTVar time.Duration
+
+ // SRTTInited if true indicates take a valid RTT measurement has been
+ // completed.
+ SRTTInited bool
+
+ // MaxPayloadSize is the maximum size of the payload of a given segment.
+ // It is initialized on demand.
+ MaxPayloadSize int
+
+ // SndWndScale is the number of bits to shift left when reading the send
+ // window size from a segment.
+ SndWndScale uint8
+
+ // MaxSentAck is the highest acknowledgemnt number sent till now.
+ MaxSentAck seqnum.Value
+
+ // FastRecovery holds the fast recovery state for the endpoint.
+ FastRecovery TCPFastRecoveryState
+}
+
+// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint.
+type TCPSACKInfo struct {
+ // Blocks is the list of SACK block currently received by the
+ // TCP endpoint.
+ Blocks []header.SACKBlock
+}
+
+// TCPEndpointState is a copy of the internal state of a TCP endpoint.
+type TCPEndpointState struct {
+ // ID is a copy of the TransportEndpointID for the endpoint.
+ ID TCPEndpointID
+
+ // SegTime denotes the absolute time when this segment was received.
+ SegTime time.Time
+
+ // RcvBufSize is the size of the receive socket buffer for the endpoint.
+ RcvBufSize int
+
+ // RcvBufUsed is the amount of bytes actually held in the receive socket
+ // buffer for the endpoint.
+ RcvBufUsed int
+
+ // RcvClosed if true, indicates the endpoint has been closed for reading.
+ RcvClosed bool
+
+ // SendTSOk is used to indicate when the TS Option has been negotiated.
+ // When sendTSOk is true every non-RST segment should carry a TS as per
+ // RFC7323#section-1.1.
+ SendTSOk bool
+
+ // RecentTS is the timestamp that should be sent in the TSEcr field of
+ // the timestamp for future segments sent by the endpoint. This field is
+ // updated if required when a new segment is received by this endpoint.
+ RecentTS uint32
+
+ // TSOffset is a randomized offset added to the value of the TSVal field
+ // in the timestamp option.
+ TSOffset uint32
+
+ // SACKPermitted is set to true if the peer sends the TCPSACKPermitted
+ // option in the SYN/SYN-ACK.
+ SACKPermitted bool
+
+ // SACK holds TCP SACK related information for this endpoint.
+ SACK TCPSACKInfo
+
+ // SndBufSize is the size of the socket send buffer.
+ SndBufSize int
+
+ // SndBufUsed is the number of bytes held in the socket send buffer.
+ SndBufUsed int
+
+ // SndClosed indicates that the endpoint has been closed for sends.
+ SndClosed bool
+
+ // SndBufInQueue is the number of bytes in the send queue.
+ SndBufInQueue seqnum.Size
+
+ // PacketTooBigCount is used to notify the main protocol routine how
+ // many times a "packet too big" control packet is received.
+ PacketTooBigCount int
+
+ // SndMTU is the smallest MTU seen in the control packets received.
+ SndMTU int
+
+ // Receiver holds variables related to the TCP receiver for the endpoint.
+ Receiver TCPReceiverState
+
+ // Sender holds state related to the TCP Sender for the endpoint.
+ Sender TCPSenderState
+}
+
+// Stack is a networking stack, with all supported protocols, NICs, and route
+// table.
+type Stack struct {
+ transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState
+ networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol
+ linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver
+
+ demux *transportDemuxer
+
+ stats tcpip.Stats
+
+ linkAddrCache *linkAddrCache
+
+ mu sync.RWMutex
+ nics map[tcpip.NICID]*NIC
+
+ // route is the route table passed in by the user via SetRouteTable(),
+ // it is used by FindRoute() to build a route for a specific
+ // destination.
+ routeTable []tcpip.Route
+
+ *ports.PortManager
+
+ // If not nil, then any new endpoints will have this probe function
+ // invoked everytime they receive a TCP segment.
+ tcpProbeFunc TCPProbeFunc
+}
+
+// New allocates a new networking stack with only the requested networking and
+// transport protocols configured with default options.
+//
+// Protocol options can be changed by calling the
+// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the
+// stack. Please refer to individual protocol implementations as to what options
+// are supported.
+func New(network []string, transport []string) *Stack {
+ s := &Stack{
+ transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState),
+ networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol),
+ linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver),
+ nics: make(map[tcpip.NICID]*NIC),
+ linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
+ PortManager: ports.NewPortManager(),
+ }
+
+ // Add specified network protocols.
+ for _, name := range network {
+ netProtoFactory, ok := networkProtocols[name]
+ if !ok {
+ continue
+ }
+ netProto := netProtoFactory()
+ s.networkProtocols[netProto.Number()] = netProto
+ if r, ok := netProto.(LinkAddressResolver); ok {
+ s.linkAddrResolvers[r.LinkAddressProtocol()] = r
+ }
+ }
+
+ // Add specified transport protocols.
+ for _, name := range transport {
+ transProtoFactory, ok := transportProtocols[name]
+ if !ok {
+ continue
+ }
+ transProto := transProtoFactory()
+ s.transportProtocols[transProto.Number()] = &transportProtocolState{
+ proto: transProto,
+ }
+ }
+
+ // Create the global transport demuxer.
+ s.demux = newTransportDemuxer(s)
+
+ return s
+}
+
+// SetNetworkProtocolOption allows configuring individual protocol level
+// options. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation or the provided value
+// is incorrect.
+func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+ netProto, ok := s.networkProtocols[network]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return netProto.SetOption(option)
+}
+
+// NetworkProtocolOption allows retrieving individual protocol level option
+// values. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation.
+// e.g.
+// var v ipv4.MyOption
+// err := s.NetworkProtocolOption(tcpip.IPv4ProtocolNumber, &v)
+// if err != nil {
+// ...
+// }
+func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error {
+ netProto, ok := s.networkProtocols[network]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return netProto.Option(option)
+}
+
+// SetTransportProtocolOption allows configuring individual protocol level
+// options. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation or the provided value
+// is incorrect.
+func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+ transProtoState, ok := s.transportProtocols[transport]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return transProtoState.proto.SetOption(option)
+}
+
+// TransportProtocolOption allows retrieving individual protocol level option
+// values. This method returns an error if the protocol is not supported or
+// option is not supported by the protocol implementation.
+// var v tcp.SACKEnabled
+// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil {
+// ...
+// }
+func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error {
+ transProtoState, ok := s.transportProtocols[transport]
+ if !ok {
+ return tcpip.ErrUnknownProtocol
+ }
+ return transProtoState.proto.Option(option)
+}
+
+// SetTransportProtocolHandler sets the per-stack default handler for the given
+// protocol.
+//
+// It must be called only during initialization of the stack. Changing it as the
+// stack is operating is not supported.
+func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *buffer.VectorisedView) bool) {
+ state := s.transportProtocols[p]
+ if state != nil {
+ state.defaultHandler = h
+ }
+}
+
+// Stats returns a snapshot of the current stats.
+//
+// NOTE: The underlying stats are updated using atomic instructions as a result
+// the snapshot returned does not represent the value of all the stats at any
+// single given point of time.
+// TODO: Make stats available in sentry for debugging/diag.
+func (s *Stack) Stats() tcpip.Stats {
+ return tcpip.Stats{
+ UnknownProtocolRcvdPackets: atomic.LoadUint64(&s.stats.UnknownProtocolRcvdPackets),
+ UnknownNetworkEndpointRcvdPackets: atomic.LoadUint64(&s.stats.UnknownNetworkEndpointRcvdPackets),
+ MalformedRcvdPackets: atomic.LoadUint64(&s.stats.MalformedRcvdPackets),
+ DroppedPackets: atomic.LoadUint64(&s.stats.DroppedPackets),
+ }
+}
+
+// MutableStats returns a mutable copy of the current stats.
+//
+// This is not generally exported via the public interface, but is available
+// internally.
+func (s *Stack) MutableStats() *tcpip.Stats {
+ return &s.stats
+}
+
+// SetRouteTable assigns the route table to be used by this stack. It
+// specifies which NIC to use for given destination address ranges.
+func (s *Stack) SetRouteTable(table []tcpip.Route) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ s.routeTable = table
+}
+
+// NewEndpoint creates a new transport layer endpoint of the given protocol.
+func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ t, ok := s.transportProtocols[transport]
+ if !ok {
+ return nil, tcpip.ErrUnknownProtocol
+ }
+
+ return t.proto.NewEndpoint(s, network, waiterQueue)
+}
+
+// createNIC creates a NIC with the provided id and link-layer endpoint, and
+// optionally enable it.
+func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error {
+ ep := FindLinkEndpoint(linkEP)
+ if ep == nil {
+ return tcpip.ErrBadLinkEndpoint
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Make sure id is unique.
+ if _, ok := s.nics[id]; ok {
+ return tcpip.ErrDuplicateNICID
+ }
+
+ n := newNIC(s, id, name, ep)
+
+ s.nics[id] = n
+ if enabled {
+ n.attachLinkEndpoint()
+ }
+
+ return nil
+}
+
+// CreateNIC creates a NIC with the provided id and link-layer endpoint.
+func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, "", linkEP, true)
+}
+
+// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint,
+// and a human-readable name.
+func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, name, linkEP, true)
+}
+
+// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint,
+// but leave it disable. Stack.EnableNIC must be called before the link-layer
+// endpoint starts delivering packets to it.
+func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error {
+ return s.createNIC(id, "", linkEP, false)
+}
+
+// EnableNIC enables the given NIC so that the link-layer endpoint can start
+// delivering packets to it.
+func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.attachLinkEndpoint()
+
+ return nil
+}
+
+// NICSubnets returns a map of NICIDs to their associated subnets.
+func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := map[tcpip.NICID][]tcpip.Subnet{}
+
+ for id, nic := range s.nics {
+ nics[id] = append(nics[id], nic.Subnets()...)
+ }
+ return nics
+}
+
+// NICInfo captures the name and addresses assigned to a NIC.
+type NICInfo struct {
+ Name string
+ LinkAddress tcpip.LinkAddress
+ ProtocolAddresses []tcpip.ProtocolAddress
+}
+
+// NICInfo returns a map of NICIDs to their associated information.
+func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nics := make(map[tcpip.NICID]NICInfo)
+ for id, nic := range s.nics {
+ nics[id] = NICInfo{
+ Name: nic.name,
+ LinkAddress: nic.linkEP.LinkAddress(),
+ ProtocolAddresses: nic.Addresses(),
+ }
+ }
+ return nics
+}
+
+// AddAddress adds a new network-layer address to the specified NIC.
+func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.AddAddress(protocol, addr)
+}
+
+// AddSubnet adds a subnet range to the specified NIC.
+func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.AddSubnet(protocol, subnet)
+ return nil
+}
+
+// RemoveAddress removes an existing network-layer address from the specified
+// NIC.
+func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[id]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.RemoveAddress(addr)
+}
+
+// FindRoute creates a route to the given destination address, leaving through
+// the given nic and local address (if provided).
+func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ for i := range s.routeTable {
+ if (id != 0 && id != s.routeTable[i].NIC) || (len(remoteAddr) != 0 && !s.routeTable[i].Match(remoteAddr)) {
+ continue
+ }
+
+ nic := s.nics[s.routeTable[i].NIC]
+ if nic == nil {
+ continue
+ }
+
+ var ref *referencedNetworkEndpoint
+ if len(localAddr) != 0 {
+ ref = nic.findEndpoint(netProto, localAddr)
+ } else {
+ ref = nic.primaryEndpoint(netProto)
+ }
+ if ref == nil {
+ continue
+ }
+
+ if len(remoteAddr) == 0 {
+ // If no remote address was provided, then the route
+ // provided will refer to the link local address.
+ remoteAddr = ref.ep.ID().LocalAddress
+ }
+
+ r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, ref)
+ r.NextHop = s.routeTable[i].Gateway
+ return r, nil
+ }
+
+ return Route{}, tcpip.ErrNoRoute
+}
+
+// CheckNetworkProtocol checks if a given network protocol is enabled in the
+// stack.
+func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool {
+ _, ok := s.networkProtocols[protocol]
+ return ok
+}
+
+// CheckLocalAddress determines if the given local address exists, and if it
+// does, returns the id of the NIC it's bound to. Returns 0 if the address
+// does not exist.
+func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ // If a NIC is specified, we try to find the address there only.
+ if nicid != 0 {
+ nic := s.nics[nicid]
+ if nic == nil {
+ return 0
+ }
+
+ ref := nic.findEndpoint(protocol, addr)
+ if ref == nil {
+ return 0
+ }
+
+ ref.decRef()
+
+ return nic.id
+ }
+
+ // Go through all the NICs.
+ for _, nic := range s.nics {
+ ref := nic.findEndpoint(protocol, addr)
+ if ref != nil {
+ ref.decRef()
+ return nic.id
+ }
+ }
+
+ return 0
+}
+
+// SetPromiscuousMode enables or disables promiscuous mode in the given NIC.
+func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setPromiscuousMode(enable)
+
+ return nil
+}
+
+// SetSpoofing enables or disables address spoofing in the given NIC, allowing
+// endpoints to bind to any address in the NIC.
+func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ nic.setSpoofing(enable)
+
+ return nil
+}
+
+// AddLinkAddress adds a link address to the stack link cache.
+func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) {
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ s.linkAddrCache.add(fullAddr, linkAddr)
+ // TODO: provide a way for a
+ // transport endpoint to receive a signal that AddLinkAddress
+ // for a particular address has been called.
+}
+
+// GetLinkAddress implements LinkAddressCache.GetLinkAddress.
+func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) {
+ s.mu.RLock()
+ nic := s.nics[nicid]
+ if nic == nil {
+ s.mu.RUnlock()
+ return "", tcpip.ErrUnknownNICID
+ }
+ s.mu.RUnlock()
+
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ linkRes := s.linkAddrResolvers[protocol]
+ return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker)
+}
+
+// RemoveWaker implements LinkAddressCache.RemoveWaker.
+func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) {
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ if nic := s.nics[nicid]; nic == nil {
+ fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr}
+ s.linkAddrCache.removeWaker(fullAddr, waker)
+ }
+}
+
+// RegisterTransportEndpoint registers the given endpoint with the stack
+// transport dispatcher. Received packets that match the provided id will be
+// delivered to the given endpoint; specifying a nic is optional, but
+// nic-specific IDs have precedence over global ones.
+func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+ if nicID == 0 {
+ return s.demux.registerEndpoint(netProtos, protocol, id, ep)
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic == nil {
+ return tcpip.ErrUnknownNICID
+ }
+
+ return nic.demux.registerEndpoint(netProtos, protocol, id, ep)
+}
+
+// UnregisterTransportEndpoint removes the endpoint with the given id from the
+// stack transport dispatcher.
+func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
+ if nicID == 0 {
+ s.demux.unregisterEndpoint(netProtos, protocol, id)
+ return
+ }
+
+ s.mu.RLock()
+ defer s.mu.RUnlock()
+
+ nic := s.nics[nicID]
+ if nic != nil {
+ nic.demux.unregisterEndpoint(netProtos, protocol, id)
+ }
+}
+
+// NetworkProtocolInstance returns the protocol instance in the stack for the
+// specified network protocol. This method is public for protocol implementers
+// and tests to use.
+func (s *Stack) NetworkProtocolInstance(num tcpip.NetworkProtocolNumber) NetworkProtocol {
+ if p, ok := s.networkProtocols[num]; ok {
+ return p
+ }
+ return nil
+}
+
+// TransportProtocolInstance returns the protocol instance in the stack for the
+// specified transport protocol. This method is public for protocol implementers
+// and tests to use.
+func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) TransportProtocol {
+ if pState, ok := s.transportProtocols[num]; ok {
+ return pState.proto
+ }
+ return nil
+}
+
+// AddTCPProbe installs a probe function that will be invoked on every segment
+// received by a given TCP endpoint. The probe function is passed a copy of the
+// TCP endpoint state.
+//
+// NOTE: TCPProbe is added only to endpoints created after this call. Endpoints
+// created prior to this call will not call the probe function.
+//
+// Further, installing two different probes back to back can result in some
+// endpoints calling the first one and some the second one. There is no
+// guarantee provided on which probe will be invoked. Ideally this should only
+// be called once per stack.
+func (s *Stack) AddTCPProbe(probe TCPProbeFunc) {
+ s.mu.Lock()
+ s.tcpProbeFunc = probe
+ s.mu.Unlock()
+}
+
+// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil
+// otherwise.
+func (s *Stack) GetTCPProbe() TCPProbeFunc {
+ s.mu.Lock()
+ p := s.tcpProbeFunc
+ s.mu.Unlock()
+ return p
+}
+
+// RemoveTCPProbe removes an installed TCP probe.
+//
+// NOTE: This only ensures that endpoints created after this call do not
+// have a probe attached. Endpoints already created will continue to invoke
+// TCP probe.
+func (s *Stack) RemoveTCPProbe() {
+ s.mu.Lock()
+ s.tcpProbeFunc = nil
+ s.mu.Unlock()
+}
diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go
new file mode 100644
index 000000000..030ae98d1
--- /dev/null
+++ b/pkg/tcpip/stack/stack_global_state.go
@@ -0,0 +1,9 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+// StackFromEnv is the global stack created in restore run.
+// FIXME: remove this variable once tcpip S/R is fully supported.
+var StackFromEnv *Stack
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
new file mode 100644
index 000000000..b416065d7
--- /dev/null
+++ b/pkg/tcpip/stack/stack_test.go
@@ -0,0 +1,760 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package stack_test contains tests for the stack. It is in its own package so
+// that the tests can also validate that all definitions needed to implement
+// transport and network protocols are properly exported by the stack package.
+package stack_test
+
+import (
+ "math"
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+const (
+ fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32
+ fakeNetHeaderLen = 12
+
+ // fakeControlProtocol is used for control packets that represent
+ // destination port unreachable.
+ fakeControlProtocol tcpip.TransportProtocolNumber = 2
+
+ // defaultMTU is the MTU, in bytes, used throughout the tests, except
+ // where another value is explicitly used. It is chosen to match the MTU
+ // of loopback interfaces on linux systems.
+ defaultMTU = 65536
+)
+
+// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
+// received packets; the counts of all endpoints are aggregated in the protocol
+// descriptor.
+//
+// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only
+// use the first three: destination address, source address, and transport
+// protocol. They're all one byte fields to simplify parsing.
+type fakeNetworkEndpoint struct {
+ nicid tcpip.NICID
+ id stack.NetworkEndpointID
+ proto *fakeNetworkProtocol
+ dispatcher stack.TransportDispatcher
+ linkEP stack.LinkEndpoint
+}
+
+func (f *fakeNetworkEndpoint) MTU() uint32 {
+ return f.linkEP.MTU() - uint32(f.MaxHeaderLength())
+}
+
+func (f *fakeNetworkEndpoint) NICID() tcpip.NICID {
+ return f.nicid
+}
+
+func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID {
+ return &f.id
+}
+
+func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) {
+ // Increment the received packet count in the protocol descriptor.
+ f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++
+
+ // Consume the network header.
+ b := vv.First()
+ vv.TrimFront(fakeNetHeaderLen)
+
+ // Handle control packets.
+ if b[2] == uint8(fakeControlProtocol) {
+ nb := vv.First()
+ if len(nb) < fakeNetHeaderLen {
+ return
+ }
+
+ vv.TrimFront(fakeNetHeaderLen)
+ f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv)
+ return
+ }
+
+ // Dispatch the packet to the transport protocol.
+ f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), vv)
+}
+
+func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 {
+ return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen
+}
+
+func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 {
+ return 0
+}
+
+func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities {
+ return f.linkEP.Capabilities()
+}
+
+func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error {
+ // Increment the sent packet count in the protocol descriptor.
+ f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++
+
+ // Add the protocol's header to the packet and send it to the link
+ // endpoint.
+ b := hdr.Prepend(fakeNetHeaderLen)
+ b[0] = r.RemoteAddress[0]
+ b[1] = f.id.LocalAddress[0]
+ b[2] = byte(protocol)
+ return f.linkEP.WritePacket(r, hdr, payload, fakeNetNumber)
+}
+
+func (*fakeNetworkEndpoint) Close() {}
+
+type fakeNetGoodOption bool
+
+type fakeNetBadOption bool
+
+type fakeNetInvalidValueOption int
+
+type fakeNetOptions struct {
+ good bool
+}
+
+// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the
+// number of packets sent and received via endpoints of this protocol. The index
+// where packets are added is given by the packet's destination address MOD 10.
+type fakeNetworkProtocol struct {
+ packetCount [10]int
+ sendPacketCount [10]int
+ opts fakeNetOptions
+}
+
+func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber {
+ return fakeNetNumber
+}
+
+func (f *fakeNetworkProtocol) MinimumPacketSize() int {
+ return fakeNetHeaderLen
+}
+
+func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) {
+ return tcpip.Address(v[1:2]), tcpip.Address(v[0:1])
+}
+
+func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) {
+ return &fakeNetworkEndpoint{
+ nicid: nicid,
+ id: stack.NetworkEndpointID{addr},
+ proto: f,
+ dispatcher: dispatcher,
+ linkEP: linkEP,
+ }, nil
+}
+
+func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case fakeNetGoodOption:
+ f.opts.good = bool(v)
+ return nil
+ case fakeNetInvalidValueOption:
+ return tcpip.ErrInvalidOptionValue
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *fakeNetGoodOption:
+ *v = fakeNetGoodOption(f.opts.good)
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func TestNetworkReceive(t *testing.T) {
+ // Create a stack with the fake network protocol, one nic, and two
+ // addresses attached to it: 1 & 2.
+ id, linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New([]string{"fakeNet"}, nil)
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+ var views [1]buffer.View
+ // Allocate the buffer containing the packet that will be injected into
+ // the stack.
+ buf := buffer.NewView(30)
+
+ // Make sure packet with wrong address is not delivered.
+ buf[0] = 3
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 0 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
+ }
+ if fakeNet.packetCount[2] != 0 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
+ }
+
+ // Make sure packet is delivered to first endpoint.
+ buf[0] = 1
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+ if fakeNet.packetCount[2] != 0 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0)
+ }
+
+ // Make sure packet is delivered to second endpoint.
+ buf[0] = 2
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+ if fakeNet.packetCount[2] != 1 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
+ }
+
+ // Make sure packet is not delivered if protocol number is wrong.
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber-1, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+ if fakeNet.packetCount[2] != 1 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
+ }
+
+ // Make sure packet that is too small is dropped.
+ buf.CapLength(2)
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+ if fakeNet.packetCount[2] != 1 {
+ t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1)
+ }
+}
+
+func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address) {
+ r, err := s.FindRoute(0, "", addr, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("FindRoute failed: %v", err)
+ }
+ defer r.Release()
+
+ hdr := buffer.NewPrependable(int(r.MaxHeaderLength()))
+ err = r.WritePacket(&hdr, nil, fakeTransNumber)
+ if err != nil {
+ t.Errorf("WritePacket failed: %v", err)
+ return
+ }
+}
+
+func TestNetworkSend(t *testing.T) {
+ // Create a stack with the fake network protocol, one nic, and one
+ // address: 1. The route table sends all packets through the only
+ // existing nic.
+ id, linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New([]string{"fakeNet"}, nil)
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("NewNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ // Make sure that the link-layer endpoint received the outbound packet.
+ sendTo(t, s, "\x03")
+ if c := linkEP.Drain(); c != 1 {
+ t.Errorf("packetCount = %d, want %d", c, 1)
+ }
+}
+
+func TestNetworkSendMultiRoute(t *testing.T) {
+ // Create a stack with the fake network protocol, two nics, and two
+ // addresses per nic, the first nic has odd address, the second one has
+ // even addresses.
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id1, linkEP1 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id1); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ id2, linkEP2 := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, id2); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ s.SetRouteTable([]tcpip.Route{
+ {"\x01", "\x01", "\x00", 1},
+ {"\x00", "\x01", "\x00", 2},
+ })
+
+ // Send a packet to an odd destination.
+ sendTo(t, s, "\x05")
+
+ if c := linkEP1.Drain(); c != 1 {
+ t.Errorf("packetCount = %d, want %d", c, 1)
+ }
+
+ // Send a packet to an even destination.
+ sendTo(t, s, "\x06")
+
+ if c := linkEP2.Drain(); c != 1 {
+ t.Errorf("packetCount = %d, want %d", c, 1)
+ }
+}
+
+func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) {
+ r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("FindRoute failed: %v", err)
+ }
+
+ defer r.Release()
+
+ if r.LocalAddress != expectedSrcAddr {
+ t.Fatalf("Bad source address: expected %v, got %v", expectedSrcAddr, r.LocalAddress)
+ }
+
+ if r.RemoteAddress != dstAddr {
+ t.Fatalf("Bad destination address: expected %v, got %v", dstAddr, r.RemoteAddress)
+ }
+}
+
+func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) {
+ _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber)
+ if err != tcpip.ErrNoRoute {
+ t.Fatalf("FindRoute returned unexpected error, expected tcpip.ErrNoRoute, got %v", err)
+ }
+}
+
+func TestRoutes(t *testing.T) {
+ // Create a stack with the fake network protocol, two nics, and two
+ // addresses per nic, the first nic has odd address, the second one has
+ // even addresses.
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id1, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id1); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ id2, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(2, id2); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ // Set a route table that sends all packets with odd destination
+ // addresses through the first NIC, and all even destination address
+ // through the second one.
+ s.SetRouteTable([]tcpip.Route{
+ {"\x01", "\x01", "\x00", 1},
+ {"\x00", "\x01", "\x00", 2},
+ })
+
+ // Test routes to odd address.
+ testRoute(t, s, 0, "", "\x05", "\x01")
+ testRoute(t, s, 0, "\x01", "\x05", "\x01")
+ testRoute(t, s, 1, "\x01", "\x05", "\x01")
+ testRoute(t, s, 0, "\x03", "\x05", "\x03")
+ testRoute(t, s, 1, "\x03", "\x05", "\x03")
+
+ // Test routes to even address.
+ testRoute(t, s, 0, "", "\x06", "\x02")
+ testRoute(t, s, 0, "\x02", "\x06", "\x02")
+ testRoute(t, s, 2, "\x02", "\x06", "\x02")
+ testRoute(t, s, 0, "\x04", "\x06", "\x04")
+ testRoute(t, s, 2, "\x04", "\x06", "\x04")
+
+ // Try to send to odd numbered address from even numbered ones, then
+ // vice-versa.
+ testNoRoute(t, s, 0, "\x02", "\x05")
+ testNoRoute(t, s, 2, "\x02", "\x05")
+ testNoRoute(t, s, 0, "\x04", "\x05")
+ testNoRoute(t, s, 2, "\x04", "\x05")
+
+ testNoRoute(t, s, 0, "\x01", "\x06")
+ testNoRoute(t, s, 1, "\x01", "\x06")
+ testNoRoute(t, s, 0, "\x03", "\x06")
+ testNoRoute(t, s, 1, "\x03", "\x06")
+}
+
+func TestAddressRemoval(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ var views [1]buffer.View
+ buf := buffer.NewView(30)
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ // Write a packet, and check that it gets delivered.
+ fakeNet.packetCount[1] = 0
+ buf[0] = 1
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+
+ // Remove the address, then check that packet doesn't get delivered
+ // anymore.
+ if err := s.RemoveAddress(1, "\x01"); err != nil {
+ t.Fatalf("RemoveAddress failed: %v", err)
+ }
+
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+
+ // Check that removing the same address fails.
+ if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ t.Fatalf("RemoveAddress failed: %v", err)
+ }
+}
+
+func TestDelayedRemovalDueToRoute(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {"\x00", "\x00", "\x00", 1},
+ })
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ var views [1]buffer.View
+ buf := buffer.NewView(30)
+
+ // Write a packet, and check that it gets delivered.
+ fakeNet.packetCount[1] = 0
+ buf[0] = 1
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+
+ // Get a route, check that packet is still deliverable.
+ r, err := s.FindRoute(0, "", "\x02", fakeNetNumber)
+ if err != nil {
+ t.Fatalf("FindRoute failed: %v", err)
+ }
+
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 2 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2)
+ }
+
+ // Remove the address, then check that packet is still deliverable
+ // because the route is keeping the address alive.
+ if err := s.RemoveAddress(1, "\x01"); err != nil {
+ t.Fatalf("RemoveAddress failed: %v", err)
+ }
+
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 3 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
+ }
+
+ // Check that removing the same address fails.
+ if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress {
+ t.Fatalf("RemoveAddress failed: %v", err)
+ }
+
+ // Release the route, then check that packet is not deliverable anymore.
+ r.Release()
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 3 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3)
+ }
+}
+
+func TestPromiscuousMode(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {"\x00", "\x00", "\x00", 1},
+ })
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ var views [1]buffer.View
+ buf := buffer.NewView(30)
+
+ // Write a packet, and check that it doesn't get delivered as we don't
+ // have a matching endpoint.
+ fakeNet.packetCount[1] = 0
+ buf[0] = 1
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 0 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
+ }
+
+ // Set promiscuous mode, then check that packet is delivered.
+ if err := s.SetPromiscuousMode(1, true); err != nil {
+ t.Fatalf("SetPromiscuousMode failed: %v", err)
+ }
+
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+
+ // Check that we can't get a route as there is no local address.
+ _, err := s.FindRoute(0, "", "\x02", fakeNetNumber)
+ if err != tcpip.ErrNoRoute {
+ t.Fatalf("FindRoute returned unexpected status: expected %v, got %v", tcpip.ErrNoRoute, err)
+ }
+
+ // Set promiscuous mode to false, then check that packet can't be
+ // delivered anymore.
+ if err := s.SetPromiscuousMode(1, false); err != nil {
+ t.Fatalf("SetPromiscuousMode failed: %v", err)
+ }
+
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+}
+
+func TestAddressSpoofing(t *testing.T) {
+ srcAddr := tcpip.Address("\x01")
+ dstAddr := tcpip.Address("\x02")
+
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id, _ := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {"\x00", "\x00", "\x00", 1},
+ })
+
+ // With address spoofing disabled, FindRoute does not permit an address
+ // that was not added to the NIC to be used as the source.
+ r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber)
+ if err == nil {
+ t.Errorf("FindRoute succeeded with route %+v when it should have failed", r)
+ }
+
+ // With address spoofing enabled, FindRoute permits any address to be used
+ // as the source.
+ if err := s.SetSpoofing(1, true); err != nil {
+ t.Fatalf("SetSpoofing failed: %v", err)
+ }
+ r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber)
+ if err != nil {
+ t.Fatalf("FindRoute failed: %v", err)
+ }
+ if r.LocalAddress != srcAddr {
+ t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr)
+ }
+ if r.RemoteAddress != dstAddr {
+ t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr)
+ }
+}
+
+// Set the subnet, then check that packet is delivered.
+func TestSubnetAcceptsMatchingPacket(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {"\x00", "\x00", "\x00", 1},
+ })
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ var views [1]buffer.View
+ buf := buffer.NewView(30)
+ buf[0] = 1
+ fakeNet.packetCount[1] = 0
+ subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0"))
+ if err != nil {
+ t.Fatalf("NewSubnet failed: %v", err)
+ }
+ if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
+ t.Fatalf("AddSubnet failed: %v", err)
+ }
+
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 1 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1)
+ }
+}
+
+// Set destination outside the subnet, then check it doesn't get delivered.
+func TestSubnetRejectsNonmatchingPacket(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, nil)
+
+ id, linkEP := channel.New(10, defaultMTU, "")
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{
+ {"\x00", "\x00", "\x00", 1},
+ })
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ var views [1]buffer.View
+ buf := buffer.NewView(30)
+ buf[0] = 1
+ fakeNet.packetCount[1] = 0
+ subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0"))
+ if err != nil {
+ t.Fatalf("NewSubnet failed: %v", err)
+ }
+ if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil {
+ t.Fatalf("AddSubnet failed: %v", err)
+ }
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeNet.packetCount[1] != 0 {
+ t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0)
+ }
+}
+
+func TestNetworkOptions(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, []string{})
+
+ // Try an unsupported network protocol.
+ if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol {
+ t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
+ }
+
+ testCases := []struct {
+ option interface{}
+ wantErr *tcpip.Error
+ verifier func(t *testing.T, p stack.NetworkProtocol)
+ }{
+ {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) {
+ t.Helper()
+ fakeNet := p.(*fakeNetworkProtocol)
+ if fakeNet.opts.good != true {
+ t.Fatalf("fakeNet.opts.good = false, want = true")
+ }
+ var v fakeNetGoodOption
+ if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil {
+ t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err)
+ }
+ if v != true {
+ t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v)
+ }
+ }},
+ {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
+ {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
+ }
+ for _, tc := range testCases {
+ if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr {
+ t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr)
+ }
+ if tc.verifier != nil {
+ tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber))
+ }
+ }
+}
+
+func init() {
+ stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol {
+ return &fakeNetworkProtocol{}
+ })
+}
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
new file mode 100644
index 000000000..3c0d7aa31
--- /dev/null
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -0,0 +1,166 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+)
+
+type protocolIDs struct {
+ network tcpip.NetworkProtocolNumber
+ transport tcpip.TransportProtocolNumber
+}
+
+// transportEndpoints manages all endpoints of a given protocol. It has its own
+// mutex so as to reduce interference between protocols.
+type transportEndpoints struct {
+ mu sync.RWMutex
+ endpoints map[TransportEndpointID]TransportEndpoint
+}
+
+// transportDemuxer demultiplexes packets targeted at a transport endpoint
+// (i.e., after they've been parsed by the network layer). It does two levels
+// of demultiplexing: first based on the network and transport protocols, then
+// based on endpoints IDs.
+type transportDemuxer struct {
+ protocol map[protocolIDs]*transportEndpoints
+}
+
+func newTransportDemuxer(stack *Stack) *transportDemuxer {
+ d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)}
+
+ // Add each network and transport pair to the demuxer.
+ for netProto := range stack.networkProtocols {
+ for proto := range stack.transportProtocols {
+ d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)}
+ }
+ }
+
+ return d
+}
+
+// registerEndpoint registers the given endpoint with the dispatcher such that
+// packets that match the endpoint ID are delivered to it.
+func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+ for i, n := range netProtos {
+ if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil {
+ d.unregisterEndpoint(netProtos[:i], protocol, id)
+ return err
+ }
+ }
+
+ return nil
+}
+
+func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error {
+ eps, ok := d.protocol[protocolIDs{netProto, protocol}]
+ if !ok {
+ return nil
+ }
+
+ eps.mu.Lock()
+ defer eps.mu.Unlock()
+
+ if _, ok := eps.endpoints[id]; ok {
+ return tcpip.ErrPortInUse
+ }
+
+ eps.endpoints[id] = ep
+
+ return nil
+}
+
+// unregisterEndpoint unregisters the endpoint with the given id such that it
+// won't receive any more packets.
+func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) {
+ for _, n := range netProtos {
+ if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok {
+ eps.mu.Lock()
+ delete(eps.endpoints, id)
+ eps.mu.Unlock()
+ }
+ }
+}
+
+// deliverPacket attempts to deliver the given packet. Returns true if it found
+// an endpoint, false otherwise.
+func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}]
+ if !ok {
+ return false
+ }
+
+ eps.mu.RLock()
+ ep := d.findEndpointLocked(eps, vv, id)
+ eps.mu.RUnlock()
+
+ // Fail if we didn't find one.
+ if ep == nil {
+ return false
+ }
+
+ // Deliver the packet.
+ ep.HandlePacket(r, id, vv)
+
+ return true
+}
+
+// deliverControlPacket attempts to deliver the given control packet. Returns
+// true if it found an endpoint, false otherwise.
+func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView, id TransportEndpointID) bool {
+ eps, ok := d.protocol[protocolIDs{net, trans}]
+ if !ok {
+ return false
+ }
+
+ // Try to find the endpoint.
+ eps.mu.RLock()
+ ep := d.findEndpointLocked(eps, vv, id)
+ eps.mu.RUnlock()
+
+ // Fail if we didn't find one.
+ if ep == nil {
+ return false
+ }
+
+ // Deliver the packet.
+ ep.HandleControlPacket(id, typ, extra, vv)
+
+ return true
+}
+
+func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv *buffer.VectorisedView, id TransportEndpointID) TransportEndpoint {
+ // Try to find a match with the id as provided.
+ if ep := eps.endpoints[id]; ep != nil {
+ return ep
+ }
+
+ // Try to find a match with the id minus the local address.
+ nid := id
+
+ nid.LocalAddress = ""
+ if ep := eps.endpoints[nid]; ep != nil {
+ return ep
+ }
+
+ // Try to find a match with the id minus the remote part.
+ nid.LocalAddress = id.LocalAddress
+ nid.RemoteAddress = ""
+ nid.RemotePort = 0
+ if ep := eps.endpoints[nid]; ep != nil {
+ return ep
+ }
+
+ // Try to find a match with only the local port.
+ nid.LocalAddress = ""
+ if ep := eps.endpoints[nid]; ep != nil {
+ return ep
+ }
+
+ return nil
+}
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
new file mode 100644
index 000000000..7e072e96e
--- /dev/null
+++ b/pkg/tcpip/stack/transport_test.go
@@ -0,0 +1,420 @@
+// Copyright 2016 The Netstack Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package stack_test
+
+import (
+ "testing"
+
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const (
+ fakeTransNumber tcpip.TransportProtocolNumber = 1
+ fakeTransHeaderLen = 3
+)
+
+// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts
+// received packets; the counts of all endpoints are aggregated in the protocol
+// descriptor.
+//
+// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't
+// use it.
+type fakeTransportEndpoint struct {
+ id stack.TransportEndpointID
+ stack *stack.Stack
+ netProto tcpip.NetworkProtocolNumber
+ proto *fakeTransportProtocol
+ peerAddr tcpip.Address
+ route stack.Route
+}
+
+func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint {
+ return &fakeTransportEndpoint{stack: stack, netProto: netProto, proto: proto}
+}
+
+func (f *fakeTransportEndpoint) Close() {
+ f.route.Release()
+}
+
+func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return mask
+}
+
+func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, *tcpip.Error) {
+ return buffer.View{}, nil
+}
+
+func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) {
+ if len(f.route.RemoteAddress) == 0 {
+ return 0, tcpip.ErrNoRoute
+ }
+
+ hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength()))
+ v, err := p.Get(p.Size())
+ if err != nil {
+ return 0, err
+ }
+ if err := f.route.WritePacket(&hdr, v, fakeTransNumber); err != nil {
+ return 0, err
+ }
+
+ return uintptr(len(v)), nil
+}
+
+func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, *tcpip.Error) {
+ return 0, nil
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error {
+ return tcpip.ErrInvalidEndpointState
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+ }
+ return tcpip.ErrInvalidEndpointState
+}
+
+func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
+ f.peerAddr = addr.Addr
+
+ // Find the route.
+ r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber)
+ if err != nil {
+ return tcpip.ErrNoRoute
+ }
+ defer r.Release()
+
+ // Try to register so that we can start receiving packets.
+ f.id.RemoteAddress = addr.Addr
+ err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f)
+ if err != nil {
+ return err
+ }
+
+ f.route = r.Clone()
+
+ return nil
+}
+
+func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
+ return nil
+}
+
+func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error {
+ return nil
+}
+
+func (*fakeTransportEndpoint) Reset() {
+}
+
+func (*fakeTransportEndpoint) Listen(int) *tcpip.Error {
+ return nil
+}
+
+func (*fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) {
+ return nil, nil, nil
+}
+
+func (*fakeTransportEndpoint) Bind(_ tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error {
+ return commit()
+}
+
+func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, nil
+}
+
+func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return tcpip.FullAddress{}, nil
+}
+
+func (f *fakeTransportEndpoint) HandlePacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) {
+ // Increment the number of received packets.
+ f.proto.packetCount++
+}
+
+func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *buffer.VectorisedView) {
+ // Increment the number of received control packets.
+ f.proto.controlCount++
+}
+
+type fakeTransportGoodOption bool
+
+type fakeTransportBadOption bool
+
+type fakeTransportInvalidValueOption int
+
+type fakeTransportProtocolOptions struct {
+ good bool
+}
+
+// fakeTransportProtocol is a transport-layer protocol descriptor. It
+// aggregates the number of packets received via endpoints of this protocol.
+type fakeTransportProtocol struct {
+ packetCount int
+ controlCount int
+ opts fakeTransportProtocolOptions
+}
+
+func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
+ return fakeTransNumber
+}
+
+func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
+ return newFakeTransportEndpoint(stack, f, netProto), nil
+}
+
+func (*fakeTransportProtocol) MinimumPacketSize() int {
+ return fakeTransHeaderLen
+}
+
+func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) {
+ return 0, 0, nil
+}
+
+func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool {
+ return true
+}
+
+func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case fakeTransportGoodOption:
+ f.opts.good = bool(v)
+ return nil
+ case fakeTransportInvalidValueOption:
+ return tcpip.ErrInvalidOptionValue
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error {
+ switch v := option.(type) {
+ case *fakeTransportGoodOption:
+ *v = fakeTransportGoodOption(f.opts.good)
+ return nil
+ default:
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func TestTransportReceive(t *testing.T) {
+ id, linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"})
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ // Create endpoint and connect to remote address.
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
+ t.Fatalf("Connect failed: %v", err)
+ }
+
+ fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
+
+ var views [1]buffer.View
+ // Create buffer that will hold the packet.
+ buf := buffer.NewView(30)
+
+ // Make sure packet with wrong protocol is not delivered.
+ buf[0] = 1
+ buf[2] = 0
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeTrans.packetCount != 0 {
+ t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
+ }
+
+ // Make sure packet from the wrong source is not delivered.
+ buf[0] = 1
+ buf[1] = 3
+ buf[2] = byte(fakeTransNumber)
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeTrans.packetCount != 0 {
+ t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0)
+ }
+
+ // Make sure packet is delivered.
+ buf[0] = 1
+ buf[1] = 2
+ buf[2] = byte(fakeTransNumber)
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeTrans.packetCount != 1 {
+ t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1)
+ }
+}
+
+func TestTransportControlReceive(t *testing.T) {
+ id, linkEP := channel.New(10, defaultMTU, "")
+ s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"})
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ // Create endpoint and connect to remote address.
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
+ t.Fatalf("Connect failed: %v", err)
+ }
+
+ fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol)
+
+ var views [1]buffer.View
+ // Create buffer that will hold the control packet.
+ buf := buffer.NewView(2*fakeNetHeaderLen + 30)
+
+ // Outer packet contains the control protocol number.
+ buf[0] = 1
+ buf[1] = 0xfe
+ buf[2] = uint8(fakeControlProtocol)
+
+ // Make sure packet with wrong protocol is not delivered.
+ buf[fakeNetHeaderLen+0] = 0
+ buf[fakeNetHeaderLen+1] = 1
+ buf[fakeNetHeaderLen+2] = 0
+ vv := buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeTrans.controlCount != 0 {
+ t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
+ }
+
+ // Make sure packet from the wrong source is not delivered.
+ buf[fakeNetHeaderLen+0] = 3
+ buf[fakeNetHeaderLen+1] = 1
+ buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeTrans.controlCount != 0 {
+ t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0)
+ }
+
+ // Make sure packet is delivered.
+ buf[fakeNetHeaderLen+0] = 2
+ buf[fakeNetHeaderLen+1] = 1
+ buf[fakeNetHeaderLen+2] = byte(fakeTransNumber)
+ vv = buf.ToVectorisedView(views)
+ linkEP.Inject(fakeNetNumber, &vv)
+ if fakeTrans.controlCount != 1 {
+ t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1)
+ }
+}
+
+func TestTransportSend(t *testing.T) {
+ id, _ := channel.New(10, defaultMTU, "")
+ s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"})
+ if err := s.CreateNIC(1, id); err != nil {
+ t.Fatalf("CreateNIC failed: %v", err)
+ }
+
+ if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil {
+ t.Fatalf("AddAddress failed: %v", err)
+ }
+
+ s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}})
+
+ // Create endpoint and bind it.
+ wq := waiter.Queue{}
+ ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq)
+ if err != nil {
+ t.Fatalf("NewEndpoint failed: %v", err)
+ }
+
+ if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil {
+ t.Fatalf("Connect failed: %v", err)
+ }
+
+ // Create buffer that will hold the payload.
+ view := buffer.NewView(30)
+ _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{})
+ if err != nil {
+ t.Fatalf("write failed: %v", err)
+ }
+
+ fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol)
+
+ if fakeNet.sendPacketCount[2] != 1 {
+ t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1)
+ }
+}
+
+func TestTransportOptions(t *testing.T) {
+ s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"})
+
+ // Try an unsupported transport protocol.
+ if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol {
+ t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err)
+ }
+
+ testCases := []struct {
+ option interface{}
+ wantErr *tcpip.Error
+ verifier func(t *testing.T, p stack.TransportProtocol)
+ }{
+ {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) {
+ t.Helper()
+ fakeTrans := p.(*fakeTransportProtocol)
+ if fakeTrans.opts.good != true {
+ t.Fatalf("fakeTrans.opts.good = false, want = true")
+ }
+ var v fakeTransportGoodOption
+ if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil {
+ t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err)
+ }
+ if v != true {
+ t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v)
+ }
+
+ }},
+ {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil},
+ {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil},
+ }
+ for _, tc := range testCases {
+ if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr {
+ t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr)
+ }
+ if tc.verifier != nil {
+ tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber))
+ }
+ }
+}
+
+func init() {
+ stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol {
+ return &fakeTransportProtocol{}
+ })
+}