diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 70 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 313 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 256 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 453 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 322 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 133 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 811 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_global_state.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 760 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 166 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 420 |
11 files changed, 3713 insertions, 0 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD new file mode 100644 index 000000000..079ade2c8 --- /dev/null +++ b/pkg/tcpip/stack/BUILD @@ -0,0 +1,70 @@ +package(licenses = ["notice"]) # BSD + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_stateify") + +go_stateify( + name = "stack_state", + srcs = [ + "registration.go", + "stack.go", + ], + out = "stack_state.go", + package = "stack", +) + +go_library( + name = "stack", + srcs = [ + "linkaddrcache.go", + "nic.go", + "registration.go", + "route.go", + "stack.go", + "stack_global_state.go", + "stack_state.go", + "transport_demuxer.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/tcpip/stack", + visibility = [ + "//visibility:public", + ], + deps = [ + "//pkg/ilist", + "//pkg/sleep", + "//pkg/state", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/header", + "//pkg/tcpip/ports", + "//pkg/tcpip/seqnum", + "//pkg/waiter", + ], +) + +go_test( + name = "stack_x_test", + size = "small", + srcs = [ + "stack_test.go", + "transport_test.go", + ], + deps = [ + ":stack", + "//pkg/tcpip", + "//pkg/tcpip/buffer", + "//pkg/tcpip/link/channel", + "//pkg/waiter", + ], +) + +go_test( + name = "stack_test", + size = "small", + srcs = ["linkaddrcache_test.go"], + embed = [":stack"], + deps = [ + "//pkg/sleep", + "//pkg/tcpip", + ], +) diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go new file mode 100644 index 000000000..789f97882 --- /dev/null +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -0,0 +1,313 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stack + +import ( + "fmt" + "sync" + "time" + + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" +) + +const linkAddrCacheSize = 512 // max cache entries + +// linkAddrCache is a fixed-sized cache mapping IP addresses to link addresses. +// +// The entries are stored in a ring buffer, oldest entry replaced first. +// +// This struct is safe for concurrent use. +type linkAddrCache struct { + // ageLimit is how long a cache entry is valid for. + ageLimit time.Duration + + // resolutionTimeout is the amount of time to wait for a link request to + // resolve an address. + resolutionTimeout time.Duration + + // resolutionAttempts is the number of times an address is attempted to be + // resolved before failing. + resolutionAttempts int + + mu sync.Mutex + cache map[tcpip.FullAddress]*linkAddrEntry + next int // array index of next available entry + entries [linkAddrCacheSize]linkAddrEntry +} + +// entryState controls the state of a single entry in the cache. +type entryState int + +const ( + // incomplete means that there is an outstanding request to resolve the + // address. This is the initial state. + incomplete entryState = iota + // ready means that the address has been resolved and can be used. + ready + // failed means that address resolution timed out and the address + // could not be resolved. + failed + // expired means that the cache entry has expired and the address must be + // resolved again. + expired +) + +// String implements Stringer. +func (s entryState) String() string { + switch s { + case incomplete: + return "incomplete" + case ready: + return "ready" + case failed: + return "failed" + case expired: + return "expired" + default: + return fmt.Sprintf("invalid entryState: %d", s) + } +} + +// A linkAddrEntry is an entry in the linkAddrCache. +// This struct is thread-compatible. +type linkAddrEntry struct { + addr tcpip.FullAddress + linkAddr tcpip.LinkAddress + expiration time.Time + s entryState + + // wakers is a set of waiters for address resolution result. Anytime + // state transitions out of 'incomplete' these waiters are notified. + wakers map[*sleep.Waker]struct{} + + cancel chan struct{} +} + +func (e *linkAddrEntry) state() entryState { + if e.s != expired && time.Now().After(e.expiration) { + // Force the transition to ensure waiters are notified. + e.changeState(expired) + } + return e.s +} + +func (e *linkAddrEntry) changeState(ns entryState) { + if e.s == ns { + return + } + + // Validate state transition. + switch e.s { + case incomplete: + // All transitions are valid. + case ready, failed: + if ns != expired { + panic(fmt.Sprintf("invalid state transition from %v to %v", e.s, ns)) + } + case expired: + // Terminal state. + panic(fmt.Sprintf("invalid state transition from %v to %v", e.s, ns)) + default: + panic(fmt.Sprintf("invalid state: %v", e.s)) + } + + // Notify whoever is waiting on address resolution when transitioning + // out of 'incomplete'. + if e.s == incomplete { + for w := range e.wakers { + w.Assert() + } + e.wakers = nil + } + e.s = ns +} + +func (e *linkAddrEntry) addWaker(w *sleep.Waker) { + e.wakers[w] = struct{}{} +} + +func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { + delete(e.wakers, w) +} + +// add adds a k -> v mapping to the cache. +func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { + c.mu.Lock() + defer c.mu.Unlock() + + entry := c.cache[k] + if entry != nil { + s := entry.state() + if s != expired && entry.linkAddr == v { + // Disregard repeated calls. + return + } + // Check if entry is waiting for address resolution. + if s == incomplete { + entry.linkAddr = v + } else { + // Otherwise create a new entry to replace it. + entry = c.makeAndAddEntry(k, v) + } + } else { + entry = c.makeAndAddEntry(k, v) + } + + entry.changeState(ready) +} + +// makeAndAddEntry is a helper function to create and add a new +// entry to the cache map and evict older entry as needed. +func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry { + // Take over the next entry. + entry := &c.entries[c.next] + if c.cache[entry.addr] == entry { + delete(c.cache, entry.addr) + } + + // Mark the soon-to-be-replaced entry as expired, just in case there is + // someone waiting for address resolution on it. + entry.changeState(expired) + if entry.cancel != nil { + entry.cancel <- struct{}{} + } + + *entry = linkAddrEntry{ + addr: k, + linkAddr: v, + expiration: time.Now().Add(c.ageLimit), + wakers: make(map[*sleep.Waker]struct{}), + cancel: make(chan struct{}, 1), + } + + c.cache[k] = entry + c.next++ + if c.next == len(c.entries) { + c.next = 0 + } + return entry +} + +// get reports any known link address for k. +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) { + if linkRes != nil { + if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { + return addr, nil + } + } + + c.mu.Lock() + entry := c.cache[k] + if entry == nil || entry.state() == expired { + c.mu.Unlock() + if linkRes == nil { + return "", tcpip.ErrNoLinkAddress + } + c.startAddressResolution(k, linkRes, localAddr, linkEP, waker) + return "", tcpip.ErrWouldBlock + } + defer c.mu.Unlock() + + switch s := entry.state(); s { + case expired: + // It's possible that entry expired between state() call above and here + // in that case it's safe to consider it ready. + fallthrough + case ready: + return entry.linkAddr, nil + case failed: + return "", tcpip.ErrNoLinkAddress + case incomplete: + // Address resolution is still in progress. + entry.addWaker(waker) + return "", tcpip.ErrWouldBlock + default: + panic(fmt.Sprintf("invalid cache entry state: %d", s)) + } +} + +// removeWaker removes a waker previously added through get(). +func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { + c.mu.Lock() + defer c.mu.Unlock() + + if entry := c.cache[k]; entry != nil { + entry.removeWaker(waker) + } +} + +func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) { + c.mu.Lock() + defer c.mu.Unlock() + + // Look up again with lock held to ensure entry wasn't added by someone else. + if e := c.cache[k]; e != nil && e.state() != expired { + return + } + + // Add 'incomplete' entry in the cache to mark that resolution is in progress. + e := c.makeAndAddEntry(k, "") + e.addWaker(waker) + + go func() { // S/R-FIXME + for i := 0; ; i++ { + // Send link request, then wait for the timeout limit and check + // whether the request succeeded. + linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) + c.mu.Lock() + cancel := e.cancel + c.mu.Unlock() + + select { + case <-time.After(c.resolutionTimeout): + if stop := c.checkLinkRequest(k, i); stop { + return + } + case <-cancel: + return + } + } + }() +} + +// checkLinkRequest checks whether previous attempt to resolve address has succeeded +// and mark the entry accordingly, e.g. ready, failed, etc. Return true if request +// can stop, false if another request should be sent. +func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool { + c.mu.Lock() + defer c.mu.Unlock() + + entry, ok := c.cache[k] + if !ok { + // Entry was evicted from the cache. + return true + } + + switch s := entry.state(); s { + case ready, failed, expired: + // Entry was made ready by resolver or failed. Either way we're done. + return true + case incomplete: + if attempt+1 >= c.resolutionAttempts { + // Max number of retries reached, mark entry as failed. + entry.changeState(failed) + return true + } + // No response yet, need to send another ARP request. + return false + default: + panic(fmt.Sprintf("invalid cache entry state: %d", s)) + } +} + +func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { + return &linkAddrCache{ + ageLimit: ageLimit, + resolutionTimeout: resolutionTimeout, + resolutionAttempts: resolutionAttempts, + cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize), + } +} diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go new file mode 100644 index 000000000..e9897b2bd --- /dev/null +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -0,0 +1,256 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stack + +import ( + "fmt" + "sync" + "testing" + "time" + + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" +) + +type testaddr struct { + addr tcpip.FullAddress + linkAddr tcpip.LinkAddress +} + +var testaddrs []testaddr + +type testLinkAddressResolver struct { + cache *linkAddrCache + delay time.Duration +} + +func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { + go func() { + if r.delay > 0 { + time.Sleep(r.delay) + } + r.fakeRequest(addr) + }() + return nil +} + +func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { + for _, ta := range testaddrs { + if ta.addr.Addr == addr { + r.cache.add(ta.addr, ta.linkAddr) + break + } + } +} + +func (*testLinkAddressResolver) ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) { + if addr == "broadcast" { + return "mac_broadcast", true + } + return "", false +} + +func (*testLinkAddressResolver) LinkAddressProtocol() tcpip.NetworkProtocolNumber { + return 1 +} + +func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressResolver) (tcpip.LinkAddress, *tcpip.Error) { + w := sleep.Waker{} + s := sleep.Sleeper{} + s.AddWaker(&w, 123) + defer s.Done() + + for { + if got, err := c.get(addr, linkRes, "", nil, &w); err != tcpip.ErrWouldBlock { + return got, err + } + s.Fetch(true) + } +} + +func init() { + for i := 0; i < 4*linkAddrCacheSize; i++ { + addr := fmt.Sprintf("Addr%06d", i) + testaddrs = append(testaddrs, testaddr{ + addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, + linkAddr: tcpip.LinkAddress("Link" + addr), + }) + } +} + +func TestCacheOverflow(t *testing.T) { + c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + for i := len(testaddrs) - 1; i >= 0; i-- { + e := testaddrs[i] + c.add(e.addr, e.linkAddr) + got, err := c.get(e.addr, nil, "", nil, nil) + if err != nil { + t.Errorf("insert %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) + } + if got != e.linkAddr { + t.Errorf("insert %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) + } + } + // Expect to find at least half of the most recent entries. + for i := 0; i < linkAddrCacheSize/2; i++ { + e := testaddrs[i] + got, err := c.get(e.addr, nil, "", nil, nil) + if err != nil { + t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) + } + if got != e.linkAddr { + t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(e.addr.Addr), got, e.linkAddr) + } + } + // The earliest entries should no longer be in the cache. + for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- { + e := testaddrs[i] + if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { + t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) + } + } +} + +func TestCacheConcurrent(t *testing.T) { + c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + + var wg sync.WaitGroup + for r := 0; r < 16; r++ { + wg.Add(1) + go func() { + for _, e := range testaddrs { + c.add(e.addr, e.linkAddr) + c.get(e.addr, nil, "", nil, nil) // make work for gotsan + } + wg.Done() + }() + } + wg.Wait() + + // All goroutines add in the same order and add more values than + // can fit in the cache, so our eviction strategy requires that + // the last entry be present and the first be missing. + e := testaddrs[len(testaddrs)-1] + got, err := c.get(e.addr, nil, "", nil, nil) + if err != nil { + t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + } + if got != e.linkAddr { + t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + } + + e = testaddrs[0] + if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { + t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + } +} + +func TestCacheAgeLimit(t *testing.T) { + c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) + e := testaddrs[0] + c.add(e.addr, e.linkAddr) + time.Sleep(50 * time.Millisecond) + if _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { + t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + } +} + +func TestCacheReplace(t *testing.T) { + c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) + e := testaddrs[0] + l2 := e.linkAddr + "2" + c.add(e.addr, e.linkAddr) + got, err := c.get(e.addr, nil, "", nil, nil) + if err != nil { + t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + } + if got != e.linkAddr { + t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + } + + c.add(e.addr, l2) + got, err = c.get(e.addr, nil, "", nil, nil) + if err != nil { + t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + } + if got != l2 { + t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, l2) + } +} + +func TestCacheResolution(t *testing.T) { + c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) + linkRes := &testLinkAddressResolver{cache: c} + for i, ta := range testaddrs { + got, err := getBlocking(c, ta.addr, linkRes) + if err != nil { + t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) + } + if got != ta.linkAddr { + t.Errorf("check %d, c.get(%q)=%q, want %q", i, string(ta.addr.Addr), got, ta.linkAddr) + } + } + + // Check that after resolved, address stays in the cache and never returns WouldBlock. + for i := 0; i < 10; i++ { + e := testaddrs[len(testaddrs)-1] + got, err := c.get(e.addr, linkRes, "", nil, nil) + if err != nil { + t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + } + if got != e.linkAddr { + t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + } + } +} + +func TestCacheResolutionFailed(t *testing.T) { + c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) + linkRes := &testLinkAddressResolver{cache: c} + + // First, sanity check that resolution is working... + e := testaddrs[0] + got, err := getBlocking(c, e.addr, linkRes) + if err != nil { + t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) + } + if got != e.linkAddr { + t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) + } + + e.addr.Addr += "2" + if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { + t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + } +} + +func TestCacheResolutionTimeout(t *testing.T) { + resolverDelay := 50 * time.Millisecond + expiration := resolverDelay / 2 + c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) + linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} + + e := testaddrs[0] + if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { + t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) + } +} + +// TestStaticResolution checks that static link addresses are resolved immediately and don't +// send resolution requests. +func TestStaticResolution(t *testing.T) { + c := newLinkAddrCache(1<<63-1, time.Millisecond, 1) + linkRes := &testLinkAddressResolver{cache: c, delay: time.Minute} + + addr := tcpip.Address("broadcast") + want := tcpip.LinkAddress("mac_broadcast") + got, err := c.get(tcpip.FullAddress{Addr: addr}, linkRes, "", nil, nil) + if err != nil { + t.Errorf("c.get(%q)=%q, got error: %v", string(addr), string(got), err) + } + if got != want { + t.Errorf("c.get(%q)=%q, want %q", string(addr), string(got), string(want)) + } +} diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go new file mode 100644 index 000000000..8ff4310d5 --- /dev/null +++ b/pkg/tcpip/stack/nic.go @@ -0,0 +1,453 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stack + +import ( + "strings" + "sync" + "sync/atomic" + + "gvisor.googlesource.com/gvisor/pkg/ilist" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" +) + +// NIC represents a "network interface card" to which the networking stack is +// attached. +type NIC struct { + stack *Stack + id tcpip.NICID + name string + linkEP LinkEndpoint + + demux *transportDemuxer + + mu sync.RWMutex + spoofing bool + promiscuous bool + primary map[tcpip.NetworkProtocolNumber]*ilist.List + endpoints map[NetworkEndpointID]*referencedNetworkEndpoint + subnets []tcpip.Subnet +} + +func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint) *NIC { + return &NIC{ + stack: stack, + id: id, + name: name, + linkEP: ep, + demux: newTransportDemuxer(stack), + primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), + endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), + } +} + +// attachLinkEndpoint attaches the NIC to the endpoint, which will enable it +// to start delivering packets. +func (n *NIC) attachLinkEndpoint() { + n.linkEP.Attach(n) +} + +// setPromiscuousMode enables or disables promiscuous mode. +func (n *NIC) setPromiscuousMode(enable bool) { + n.mu.Lock() + n.promiscuous = enable + n.mu.Unlock() +} + +// setSpoofing enables or disables address spoofing. +func (n *NIC) setSpoofing(enable bool) { + n.mu.Lock() + n.spoofing = enable + n.mu.Unlock() +} + +// primaryEndpoint returns the primary endpoint of n for the given network +// protocol. +func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { + n.mu.RLock() + defer n.mu.RUnlock() + + list := n.primary[protocol] + if list == nil { + return nil + } + + for e := list.Front(); e != nil; e = e.Next() { + r := e.(*referencedNetworkEndpoint) + if r.tryIncRef() { + return r + } + } + + return nil +} + +// findEndpoint finds the endpoint, if any, with the given address. +func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address) *referencedNetworkEndpoint { + id := NetworkEndpointID{address} + + n.mu.RLock() + ref := n.endpoints[id] + if ref != nil && !ref.tryIncRef() { + ref = nil + } + spoofing := n.spoofing + n.mu.RUnlock() + + if ref != nil || !spoofing { + return ref + } + + // Try again with the lock in exclusive mode. If we still can't get the + // endpoint, create a new "temporary" endpoint. It will only exist while + // there's a route through it. + n.mu.Lock() + ref = n.endpoints[id] + if ref == nil || !ref.tryIncRef() { + ref, _ = n.addAddressLocked(protocol, address, true) + if ref != nil { + ref.holdsInsertRef = false + } + } + n.mu.Unlock() + return ref +} + +func (n *NIC) addAddressLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { + netProto, ok := n.stack.networkProtocols[protocol] + if !ok { + return nil, tcpip.ErrUnknownProtocol + } + + // Create the new network endpoint. + ep, err := netProto.NewEndpoint(n.id, addr, n.stack, n, n.linkEP) + if err != nil { + return nil, err + } + + id := *ep.ID() + if ref, ok := n.endpoints[id]; ok { + if !replace { + return nil, tcpip.ErrDuplicateAddress + } + + n.removeEndpointLocked(ref) + } + + ref := &referencedNetworkEndpoint{ + refs: 1, + ep: ep, + nic: n, + protocol: protocol, + holdsInsertRef: true, + } + + // Set up cache if link address resolution exists for this protocol. + if n.linkEP.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes := n.stack.linkAddrResolvers[protocol]; linkRes != nil { + ref.linkCache = n.stack + } + } + + n.endpoints[id] = ref + + l, ok := n.primary[protocol] + if !ok { + l = &ilist.List{} + n.primary[protocol] = l + } + + l.PushBack(ref) + + return ref, nil +} + +// AddAddress adds a new address to n, so that it starts accepting packets +// targeted at the given address (and network protocol). +func (n *NIC) AddAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + // Add the endpoint. + n.mu.Lock() + _, err := n.addAddressLocked(protocol, addr, false) + n.mu.Unlock() + + return err +} + +// Addresses returns the addresses associated with this NIC. +func (n *NIC) Addresses() []tcpip.ProtocolAddress { + n.mu.RLock() + defer n.mu.RUnlock() + addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) + for nid, ep := range n.endpoints { + addrs = append(addrs, tcpip.ProtocolAddress{ + Protocol: ep.protocol, + Address: nid.LocalAddress, + }) + } + return addrs +} + +// AddSubnet adds a new subnet to n, so that it starts accepting packets +// targeted at the given address and network protocol. +func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { + n.mu.Lock() + n.subnets = append(n.subnets, subnet) + n.mu.Unlock() +} + +// Subnets returns the Subnets associated with this NIC. +func (n *NIC) Subnets() []tcpip.Subnet { + n.mu.RLock() + defer n.mu.RUnlock() + sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints)) + for nid := range n.endpoints { + sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) + if err != nil { + // This should never happen as the mask has been carefully crafted to + // match the address. + panic("Invalid endpoint subnet: " + err.Error()) + } + sns = append(sns, sn) + } + return append(sns, n.subnets...) +} + +func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { + id := *r.ep.ID() + + // Nothing to do if the reference has already been replaced with a + // different one. + if n.endpoints[id] != r { + return + } + + if r.holdsInsertRef { + panic("Reference count dropped to zero before being removed") + } + + delete(n.endpoints, id) + n.primary[r.protocol].Remove(r) + r.ep.Close() +} + +func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { + n.mu.Lock() + n.removeEndpointLocked(r) + n.mu.Unlock() +} + +// RemoveAddress removes an address from n. +func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { + n.mu.Lock() + r := n.endpoints[NetworkEndpointID{addr}] + if r == nil || !r.holdsInsertRef { + n.mu.Unlock() + return tcpip.ErrBadLocalAddress + } + + r.holdsInsertRef = false + n.mu.Unlock() + + r.decRef() + + return nil +} + +// DeliverNetworkPacket finds the appropriate network protocol endpoint and +// hands the packet over for further processing. This function is called when +// the NIC receives a packet from the physical interface. +// Note that the ownership of the slice backing vv is retained by the caller. +// This rule applies only to the slice itself, not to the items of the slice; +// the ownership of the items is not retained by the caller. +func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) { + netProto, ok := n.stack.networkProtocols[protocol] + if !ok { + atomic.AddUint64(&n.stack.stats.UnknownProtocolRcvdPackets, 1) + return + } + + if len(vv.First()) < netProto.MinimumPacketSize() { + atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1) + return + } + + src, dst := netProto.ParseAddresses(vv.First()) + id := NetworkEndpointID{dst} + + n.mu.RLock() + ref := n.endpoints[id] + if ref != nil && !ref.tryIncRef() { + ref = nil + } + promiscuous := n.promiscuous + subnets := n.subnets + n.mu.RUnlock() + + if ref == nil { + // Check if the packet is for a subnet this NIC cares about. + if !promiscuous { + for _, sn := range subnets { + if sn.Contains(dst) { + promiscuous = true + break + } + } + } + if promiscuous { + // Try again with the lock in exclusive mode. If we still can't + // get the endpoint, create a new "temporary" one. It will only + // exist while there's a route through it. + n.mu.Lock() + ref = n.endpoints[id] + if ref == nil || !ref.tryIncRef() { + ref, _ = n.addAddressLocked(protocol, dst, true) + if ref != nil { + ref.holdsInsertRef = false + } + } + n.mu.Unlock() + } + } + + if ref == nil { + atomic.AddUint64(&n.stack.stats.UnknownNetworkEndpointRcvdPackets, 1) + return + } + + r := makeRoute(protocol, dst, src, ref) + r.LocalLinkAddress = linkEP.LinkAddress() + r.RemoteLinkAddress = remoteLinkAddr + ref.ep.HandlePacket(&r, vv) + ref.decRef() +} + +// DeliverTransportPacket delivers the packets to the appropriate transport +// protocol endpoint. +func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView) { + state, ok := n.stack.transportProtocols[protocol] + if !ok { + atomic.AddUint64(&n.stack.stats.UnknownProtocolRcvdPackets, 1) + return + } + + transProto := state.proto + if len(vv.First()) < transProto.MinimumPacketSize() { + atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1) + return + } + + srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + if err != nil { + atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1) + return + } + + id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} + if n.demux.deliverPacket(r, protocol, vv, id) { + return + } + if n.stack.demux.deliverPacket(r, protocol, vv, id) { + return + } + + // Try to deliver to per-stack default handler. + if state.defaultHandler != nil { + if state.defaultHandler(r, id, vv) { + return + } + } + + // We could not find an appropriate destination for this packet, so + // deliver it to the global handler. + if !transProto.HandleUnknownDestinationPacket(r, id, vv) { + atomic.AddUint64(&n.stack.stats.MalformedRcvdPackets, 1) + } +} + +// DeliverTransportControlPacket delivers control packets to the appropriate +// transport protocol endpoint. +func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView) { + state, ok := n.stack.transportProtocols[trans] + if !ok { + return + } + + transProto := state.proto + + // ICMPv4 only guarantees that 8 bytes of the transport protocol will + // be present in the payload. We know that the ports are within the + // first 8 bytes for all known transport protocols. + if len(vv.First()) < 8 { + return + } + + srcPort, dstPort, err := transProto.ParsePorts(vv.First()) + if err != nil { + return + } + + id := TransportEndpointID{srcPort, local, dstPort, remote} + if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { + return + } + if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { + return + } +} + +// ID returns the identifier of n. +func (n *NIC) ID() tcpip.NICID { + return n.id +} + +type referencedNetworkEndpoint struct { + ilist.Entry + refs int32 + ep NetworkEndpoint + nic *NIC + protocol tcpip.NetworkProtocolNumber + + // linkCache is set if link address resolution is enabled for this + // protocol. Set to nil otherwise. + linkCache LinkAddressCache + + // holdsInsertRef is protected by the NIC's mutex. It indicates whether + // the reference count is biased by 1 due to the insertion of the + // endpoint. It is reset to false when RemoveAddress is called on the + // NIC. + holdsInsertRef bool +} + +// decRef decrements the ref count and cleans up the endpoint once it reaches +// zero. +func (r *referencedNetworkEndpoint) decRef() { + if atomic.AddInt32(&r.refs, -1) == 0 { + r.nic.removeEndpoint(r) + } +} + +// incRef increments the ref count. It must only be called when the caller is +// known to be holding a reference to the endpoint, otherwise tryIncRef should +// be used. +func (r *referencedNetworkEndpoint) incRef() { + atomic.AddInt32(&r.refs, 1) +} + +// tryIncRef attempts to increment the ref count from n to n+1, but only if n is +// not zero. That is, it will increment the count if the endpoint is still +// alive, and do nothing if it has already been clean up. +func (r *referencedNetworkEndpoint) tryIncRef() bool { + for { + v := atomic.LoadInt32(&r.refs) + if v == 0 { + return false + } + + if atomic.CompareAndSwapInt32(&r.refs, v, v+1) { + return true + } + } +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go new file mode 100644 index 000000000..e7e6381ac --- /dev/null +++ b/pkg/tcpip/stack/registration.go @@ -0,0 +1,322 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stack + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +// NetworkEndpointID is the identifier of a network layer protocol endpoint. +// Currently the local address is sufficient because all supported protocols +// (i.e., IPv4 and IPv6) have different sizes for their addresses. +type NetworkEndpointID struct { + LocalAddress tcpip.Address +} + +// TransportEndpointID is the identifier of a transport layer protocol endpoint. +type TransportEndpointID struct { + // LocalPort is the local port associated with the endpoint. + LocalPort uint16 + + // LocalAddress is the local [network layer] address associated with + // the endpoint. + LocalAddress tcpip.Address + + // RemotePort is the remote port associated with the endpoint. + RemotePort uint16 + + // RemoteAddress it the remote [network layer] address associated with + // the endpoint. + RemoteAddress tcpip.Address +} + +// ControlType is the type of network control message. +type ControlType int + +// The following are the allowed values for ControlType values. +const ( + ControlPacketTooBig ControlType = iota + ControlPortUnreachable + ControlUnknown +) + +// TransportEndpoint is the interface that needs to be implemented by transport +// protocol (e.g., tcp, udp) endpoints that can handle packets. +type TransportEndpoint interface { + // HandlePacket is called by the stack when new packets arrive to + // this transport endpoint. + HandlePacket(r *Route, id TransportEndpointID, vv *buffer.VectorisedView) + + // HandleControlPacket is called by the stack when new control (e.g., + // ICMP) packets arrive to this transport endpoint. + HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv *buffer.VectorisedView) +} + +// TransportProtocol is the interface that needs to be implemented by transport +// protocols (e.g., tcp, udp) that want to be part of the networking stack. +type TransportProtocol interface { + // Number returns the transport protocol number. + Number() tcpip.TransportProtocolNumber + + // NewEndpoint creates a new endpoint of the transport protocol. + NewEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, waitQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) + + // MinimumPacketSize returns the minimum valid packet size of this + // transport protocol. The stack automatically drops any packets smaller + // than this targeted at this protocol. + MinimumPacketSize() int + + // ParsePorts returns the source and destination ports stored in a + // packet of this protocol. + ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) + + // HandleUnknownDestinationPacket handles packets targeted at this + // protocol but that don't match any existing endpoint. For example, + // it is targeted at a port that have no listeners. + // + // The return value indicates whether the packet was well-formed (for + // stats purposes only). + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv *buffer.VectorisedView) bool + + // SetOption allows enabling/disabling protocol specific features. + // SetOption returns an error if the option is not supported or the + // provided option value is invalid. + SetOption(option interface{}) *tcpip.Error + + // Option allows retrieving protocol specific option values. + // Option returns an error if the option is not supported or the + // provided option value is invalid. + Option(option interface{}) *tcpip.Error +} + +// TransportDispatcher contains the methods used by the network stack to deliver +// packets to the appropriate transport endpoint after it has been handled by +// the network layer. +type TransportDispatcher interface { + // DeliverTransportPacket delivers packets to the appropriate + // transport protocol endpoint. + DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView) + + // DeliverTransportControlPacket delivers control packets to the + // appropriate transport protocol endpoint. + DeliverTransportControlPacket(local, remote tcpip.Address, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView) +} + +// NetworkEndpoint is the interface that needs to be implemented by endpoints +// of network layer protocols (e.g., ipv4, ipv6). +type NetworkEndpoint interface { + // MTU is the maximum transmission unit for this endpoint. This is + // generally calculated as the MTU of the underlying data link endpoint + // minus the network endpoint max header length. + MTU() uint32 + + // Capabilities returns the set of capabilities supported by the + // underlying link-layer endpoint. + Capabilities() LinkEndpointCapabilities + + // MaxHeaderLength returns the maximum size the network (and lower + // level layers combined) headers can have. Higher levels use this + // information to reserve space in the front of the packets they're + // building. + MaxHeaderLength() uint16 + + // WritePacket writes a packet to the given destination address and + // protocol. + WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error + + // ID returns the network protocol endpoint ID. + ID() *NetworkEndpointID + + // NICID returns the id of the NIC this endpoint belongs to. + NICID() tcpip.NICID + + // HandlePacket is called by the link layer when new packets arrive to + // this network endpoint. + HandlePacket(r *Route, vv *buffer.VectorisedView) + + // Close is called when the endpoint is reomved from a stack. + Close() +} + +// NetworkProtocol is the interface that needs to be implemented by network +// protocols (e.g., ipv4, ipv6) that want to be part of the networking stack. +type NetworkProtocol interface { + // Number returns the network protocol number. + Number() tcpip.NetworkProtocolNumber + + // MinimumPacketSize returns the minimum valid packet size of this + // network protocol. The stack automatically drops any packets smaller + // than this targeted at this protocol. + MinimumPacketSize() int + + // ParsePorts returns the source and destination addresses stored in a + // packet of this protocol. + ParseAddresses(v buffer.View) (src, dst tcpip.Address) + + // NewEndpoint creates a new endpoint of this protocol. + NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache LinkAddressCache, dispatcher TransportDispatcher, sender LinkEndpoint) (NetworkEndpoint, *tcpip.Error) + + // SetOption allows enabling/disabling protocol specific features. + // SetOption returns an error if the option is not supported or the + // provided option value is invalid. + SetOption(option interface{}) *tcpip.Error + + // Option allows retrieving protocol specific option values. + // Option returns an error if the option is not supported or the + // provided option value is invalid. + Option(option interface{}) *tcpip.Error +} + +// NetworkDispatcher contains the methods used by the network stack to deliver +// packets to the appropriate network endpoint after it has been handled by +// the data link layer. +type NetworkDispatcher interface { + // DeliverNetworkPacket finds the appropriate network protocol + // endpoint and hands the packet over for further processing. + DeliverNetworkPacket(linkEP LinkEndpoint, remoteLinkAddr tcpip.LinkAddress, protocol tcpip.NetworkProtocolNumber, vv *buffer.VectorisedView) +} + +// LinkEndpointCapabilities is the type associated with the capabilities +// supported by a link-layer endpoint. It is a set of bitfields. +type LinkEndpointCapabilities uint + +// The following are the supported link endpoint capabilities. +const ( + CapabilityChecksumOffload LinkEndpointCapabilities = 1 << iota + CapabilityResolutionRequired +) + +// LinkEndpoint is the interface implemented by data link layer protocols (e.g., +// ethernet, loopback, raw) and used by network layer protocols to send packets +// out through the implementer's data link endpoint. +type LinkEndpoint interface { + // MTU is the maximum transmission unit for this endpoint. This is + // usually dictated by the backing physical network; when such a + // physical network doesn't exist, the limit is generally 64k, which + // includes the maximum size of an IP packet. + MTU() uint32 + + // Capabilities returns the set of capabilities supported by the + // endpoint. + Capabilities() LinkEndpointCapabilities + + // MaxHeaderLength returns the maximum size the data link (and + // lower level layers combined) headers can have. Higher levels use this + // information to reserve space in the front of the packets they're + // building. + MaxHeaderLength() uint16 + + // LinkAddress returns the link address (typically a MAC) of the + // link endpoint. + LinkAddress() tcpip.LinkAddress + + // WritePacket writes a packet with the given protocol through the given + // route. + WritePacket(r *Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.NetworkProtocolNumber) *tcpip.Error + + // Attach attaches the data link layer endpoint to the network-layer + // dispatcher of the stack. + Attach(dispatcher NetworkDispatcher) +} + +// A LinkAddressResolver is an extension to a NetworkProtocol that +// can resolve link addresses. +type LinkAddressResolver interface { + // LinkAddressRequest sends a request for the LinkAddress of addr. + // The request is sent on linkEP with localAddr as the source. + // + // A valid response will cause the discovery protocol's network + // endpoint to call AddLinkAddress. + LinkAddressRequest(addr, localAddr tcpip.Address, linkEP LinkEndpoint) *tcpip.Error + + // ResolveStaticAddress attempts to resolve address without sending + // requests. It either resolves the name immediately or returns the + // empty LinkAddress. + // + // It can be used to resolve broadcast addresses for example. + ResolveStaticAddress(addr tcpip.Address) (tcpip.LinkAddress, bool) + + // LinkAddressProtocol returns the network protocol of the + // addresses this this resolver can resolve. + LinkAddressProtocol() tcpip.NetworkProtocolNumber +} + +// A LinkAddressCache caches link addresses. +type LinkAddressCache interface { + // CheckLocalAddress determines if the given local address exists, and if it + // does not exist. + CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID + + // AddLinkAddress adds a link address to the cache. + AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) + + // GetLinkAddress looks up the cache to translate address to link address (e.g. IP -> MAC). + // If the LinkEndpoint requests address resolution and there is a LinkAddressResolver + // registered with the network protocol, the cache attempts to resolve the address + // and returns ErrWouldBlock. Waker is notified when address resolution is + // complete (success or not). + GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, w *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) + + // RemoveWaker removes a waker that has been added in GetLinkAddress(). + RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) +} + +// TransportProtocolFactory functions are used by the stack to instantiate +// transport protocols. +type TransportProtocolFactory func() TransportProtocol + +// NetworkProtocolFactory provides methods to be used by the stack to +// instantiate network protocols. +type NetworkProtocolFactory func() NetworkProtocol + +var ( + transportProtocols = make(map[string]TransportProtocolFactory) + networkProtocols = make(map[string]NetworkProtocolFactory) + + linkEPMu sync.RWMutex + nextLinkEndpointID tcpip.LinkEndpointID = 1 + linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint) +) + +// RegisterTransportProtocolFactory registers a new transport protocol factory +// with the stack so that it becomes available to users of the stack. This +// function is intended to be called by init() functions of the protocols. +func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) { + transportProtocols[name] = p +} + +// RegisterNetworkProtocolFactory registers a new network protocol factory with +// the stack so that it becomes available to users of the stack. This function +// is intended to be called by init() functions of the protocols. +func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) { + networkProtocols[name] = p +} + +// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an +// ID that can be used to refer to it. +func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID { + linkEPMu.Lock() + defer linkEPMu.Unlock() + + v := nextLinkEndpointID + nextLinkEndpointID++ + + linkEndpoints[v] = linkEP + + return v +} + +// FindLinkEndpoint finds the link endpoint associated with the given ID. +func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint { + linkEPMu.RLock() + defer linkEPMu.RUnlock() + + return linkEndpoints[id] +} diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go new file mode 100644 index 000000000..12f5efba5 --- /dev/null +++ b/pkg/tcpip/stack/route.go @@ -0,0 +1,133 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stack + +import ( + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" +) + +// Route represents a route through the networking stack to a given destination. +type Route struct { + // RemoteAddress is the final destination of the route. + RemoteAddress tcpip.Address + + // RemoteLinkAddress is the link-layer (MAC) address of the + // final destination of the route. + RemoteLinkAddress tcpip.LinkAddress + + // LocalAddress is the local address where the route starts. + LocalAddress tcpip.Address + + // LocalLinkAddress is the link-layer (MAC) address of the + // where the route starts. + LocalLinkAddress tcpip.LinkAddress + + // NextHop is the next node in the path to the destination. + NextHop tcpip.Address + + // NetProto is the network-layer protocol. + NetProto tcpip.NetworkProtocolNumber + + // ref a reference to the network endpoint through which the route + // starts. + ref *referencedNetworkEndpoint +} + +// makeRoute initializes a new route. It takes ownership of the provided +// reference to a network endpoint. +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, ref *referencedNetworkEndpoint) Route { + return Route{ + NetProto: netProto, + LocalAddress: localAddr, + RemoteAddress: remoteAddr, + ref: ref, + } +} + +// NICID returns the id of the NIC from which this route originates. +func (r *Route) NICID() tcpip.NICID { + return r.ref.ep.NICID() +} + +// MaxHeaderLength forwards the call to the network endpoint's implementation. +func (r *Route) MaxHeaderLength() uint16 { + return r.ref.ep.MaxHeaderLength() +} + +// PseudoHeaderChecksum forwards the call to the network endpoint's +// implementation. +func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber) uint16 { + return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress) +} + +// Capabilities returns the link-layer capabilities of the route. +func (r *Route) Capabilities() LinkEndpointCapabilities { + return r.ref.ep.Capabilities() +} + +// Resolve attempts to resolve the link address if necessary. Returns ErrWouldBlock in +// case address resolution requires blocking, e.g. wait for ARP reply. Waker is +// notified when address resolution is complete (success or not). +func (r *Route) Resolve(waker *sleep.Waker) *tcpip.Error { + if !r.IsResolutionRequired() { + // Nothing to do if there is no cache (which does the resolution on cache miss) or + // link address is already known. + return nil + } + + nextAddr := r.NextHop + if nextAddr == "" { + nextAddr = r.RemoteAddress + } + linkAddr, err := r.ref.linkCache.GetLinkAddress(r.ref.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + if err != nil { + return err + } + r.RemoteLinkAddress = linkAddr + return nil +} + +// RemoveWaker removes a waker that has been added in Resolve(). +func (r *Route) RemoveWaker(waker *sleep.Waker) { + nextAddr := r.NextHop + if nextAddr == "" { + nextAddr = r.RemoteAddress + } + r.ref.linkCache.RemoveWaker(r.ref.nic.ID(), nextAddr, waker) +} + +// IsResolutionRequired returns true if Resolve() must be called to resolve +// the link address before the this route can be written to. +func (r *Route) IsResolutionRequired() bool { + return r.ref.linkCache != nil && r.RemoteLinkAddress == "" +} + +// WritePacket writes the packet through the given route. +func (r *Route) WritePacket(hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error { + return r.ref.ep.WritePacket(r, hdr, payload, protocol) +} + +// MTU returns the MTU of the underlying network endpoint. +func (r *Route) MTU() uint32 { + return r.ref.ep.MTU() +} + +// Release frees all resources associated with the route. +func (r *Route) Release() { + if r.ref != nil { + r.ref.decRef() + r.ref = nil + } +} + +// Clone Clone a route such that the original one can be released and the new +// one will remain valid. +func (r *Route) Clone() Route { + r.ref.incRef() + return *r +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go new file mode 100644 index 000000000..558ecdb72 --- /dev/null +++ b/pkg/tcpip/stack/stack.go @@ -0,0 +1,811 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package stack provides the glue between networking protocols and the +// consumers of the networking stack. +// +// For consumers, the only function of interest is New(), everything else is +// provided by the tcpip/public package. +// +// For protocol implementers, RegisterTransportProtocolFactory() and +// RegisterNetworkProtocolFactory() are used to register protocol factories with +// the stack, which will then be used to instantiate protocol objects when +// consumers interact with the stack. +package stack + +import ( + "sync" + "sync/atomic" + "time" + + "gvisor.googlesource.com/gvisor/pkg/sleep" + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/header" + "gvisor.googlesource.com/gvisor/pkg/tcpip/ports" + "gvisor.googlesource.com/gvisor/pkg/tcpip/seqnum" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + // ageLimit is set to the same cache stale time used in Linux. + ageLimit = 1 * time.Minute + // resolutionTimeout is set to the same ARP timeout used in Linux. + resolutionTimeout = 1 * time.Second + // resolutionAttempts is set to the same ARP retries used in Linux. + resolutionAttempts = 3 +) + +type transportProtocolState struct { + proto TransportProtocol + defaultHandler func(*Route, TransportEndpointID, *buffer.VectorisedView) bool +} + +// TCPProbeFunc is the expected function type for a TCP probe function to be +// passed to stack.AddTCPProbe. +type TCPProbeFunc func(s TCPEndpointState) + +// TCPEndpointID is the unique 4 tuple that identifies a given endpoint. +type TCPEndpointID struct { + // LocalPort is the local port associated with the endpoint. + LocalPort uint16 + + // LocalAddress is the local [network layer] address associated with + // the endpoint. + LocalAddress tcpip.Address + + // RemotePort is the remote port associated with the endpoint. + RemotePort uint16 + + // RemoteAddress it the remote [network layer] address associated with + // the endpoint. + RemoteAddress tcpip.Address +} + +// TCPFastRecoveryState holds a copy of the internal fast recovery state of a +// TCP endpoint. +type TCPFastRecoveryState struct { + // Active if true indicates the endpoint is in fast recovery. + Active bool + + // First is the first unacknowledged sequence number being recovered. + First seqnum.Value + + // Last is the 'recover' sequence number that indicates the point at + // which we should exit recovery barring any timeouts etc. + Last seqnum.Value + + // MaxCwnd is the maximum value we are permitted to grow the congestion + // window during recovery. This is set at the time we enter recovery. + MaxCwnd int +} + +// TCPReceiverState holds a copy of the internal state of the receiver for +// a given TCP endpoint. +type TCPReceiverState struct { + // RcvNxt is the TCP variable RCV.NXT. + RcvNxt seqnum.Value + + // RcvAcc is the TCP variable RCV.ACC. + RcvAcc seqnum.Value + + // RcvWndScale is the window scaling to use for inbound segments. + RcvWndScale uint8 + + // PendingBufUsed is the number of bytes pending in the receive + // queue. + PendingBufUsed seqnum.Size + + // PendingBufSize is the size of the socket receive buffer. + PendingBufSize seqnum.Size +} + +// TCPSenderState holds a copy of the internal state of the sender for +// a given TCP Endpoint. +type TCPSenderState struct { + // LastSendTime is the time at which we sent the last segment. + LastSendTime time.Time + + // DupAckCount is the number of Duplicate ACK's received. + DupAckCount int + + // SndCwnd is the size of the sending congestion window in packets. + SndCwnd int + + // Ssthresh is the slow start threshold in packets. + Ssthresh int + + // SndCAAckCount is the number of packets consumed in congestion + // avoidance mode. + SndCAAckCount int + + // Outstanding is the number of packets in flight. + Outstanding int + + // SndWnd is the send window size in bytes. + SndWnd seqnum.Size + + // SndUna is the next unacknowledged sequence number. + SndUna seqnum.Value + + // SndNxt is the sequence number of the next segment to be sent. + SndNxt seqnum.Value + + // RTTMeasureSeqNum is the sequence number being used for the latest RTT + // measurement. + RTTMeasureSeqNum seqnum.Value + + // RTTMeasureTime is the time when the RTTMeasureSeqNum was sent. + RTTMeasureTime time.Time + + // Closed indicates that the caller has closed the endpoint for sending. + Closed bool + + // SRTT is the smoothed round-trip time as defined in section 2 of + // RFC 6298. + SRTT time.Duration + + // RTO is the retransmit timeout as defined in section of 2 of RFC 6298. + RTO time.Duration + + // RTTVar is the round-trip time variation as defined in section 2 of + // RFC 6298. + RTTVar time.Duration + + // SRTTInited if true indicates take a valid RTT measurement has been + // completed. + SRTTInited bool + + // MaxPayloadSize is the maximum size of the payload of a given segment. + // It is initialized on demand. + MaxPayloadSize int + + // SndWndScale is the number of bits to shift left when reading the send + // window size from a segment. + SndWndScale uint8 + + // MaxSentAck is the highest acknowledgemnt number sent till now. + MaxSentAck seqnum.Value + + // FastRecovery holds the fast recovery state for the endpoint. + FastRecovery TCPFastRecoveryState +} + +// TCPSACKInfo holds TCP SACK related information for a given TCP endpoint. +type TCPSACKInfo struct { + // Blocks is the list of SACK block currently received by the + // TCP endpoint. + Blocks []header.SACKBlock +} + +// TCPEndpointState is a copy of the internal state of a TCP endpoint. +type TCPEndpointState struct { + // ID is a copy of the TransportEndpointID for the endpoint. + ID TCPEndpointID + + // SegTime denotes the absolute time when this segment was received. + SegTime time.Time + + // RcvBufSize is the size of the receive socket buffer for the endpoint. + RcvBufSize int + + // RcvBufUsed is the amount of bytes actually held in the receive socket + // buffer for the endpoint. + RcvBufUsed int + + // RcvClosed if true, indicates the endpoint has been closed for reading. + RcvClosed bool + + // SendTSOk is used to indicate when the TS Option has been negotiated. + // When sendTSOk is true every non-RST segment should carry a TS as per + // RFC7323#section-1.1. + SendTSOk bool + + // RecentTS is the timestamp that should be sent in the TSEcr field of + // the timestamp for future segments sent by the endpoint. This field is + // updated if required when a new segment is received by this endpoint. + RecentTS uint32 + + // TSOffset is a randomized offset added to the value of the TSVal field + // in the timestamp option. + TSOffset uint32 + + // SACKPermitted is set to true if the peer sends the TCPSACKPermitted + // option in the SYN/SYN-ACK. + SACKPermitted bool + + // SACK holds TCP SACK related information for this endpoint. + SACK TCPSACKInfo + + // SndBufSize is the size of the socket send buffer. + SndBufSize int + + // SndBufUsed is the number of bytes held in the socket send buffer. + SndBufUsed int + + // SndClosed indicates that the endpoint has been closed for sends. + SndClosed bool + + // SndBufInQueue is the number of bytes in the send queue. + SndBufInQueue seqnum.Size + + // PacketTooBigCount is used to notify the main protocol routine how + // many times a "packet too big" control packet is received. + PacketTooBigCount int + + // SndMTU is the smallest MTU seen in the control packets received. + SndMTU int + + // Receiver holds variables related to the TCP receiver for the endpoint. + Receiver TCPReceiverState + + // Sender holds state related to the TCP Sender for the endpoint. + Sender TCPSenderState +} + +// Stack is a networking stack, with all supported protocols, NICs, and route +// table. +type Stack struct { + transportProtocols map[tcpip.TransportProtocolNumber]*transportProtocolState + networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol + linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + + demux *transportDemuxer + + stats tcpip.Stats + + linkAddrCache *linkAddrCache + + mu sync.RWMutex + nics map[tcpip.NICID]*NIC + + // route is the route table passed in by the user via SetRouteTable(), + // it is used by FindRoute() to build a route for a specific + // destination. + routeTable []tcpip.Route + + *ports.PortManager + + // If not nil, then any new endpoints will have this probe function + // invoked everytime they receive a TCP segment. + tcpProbeFunc TCPProbeFunc +} + +// New allocates a new networking stack with only the requested networking and +// transport protocols configured with default options. +// +// Protocol options can be changed by calling the +// SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the +// stack. Please refer to individual protocol implementations as to what options +// are supported. +func New(network []string, transport []string) *Stack { + s := &Stack{ + transportProtocols: make(map[tcpip.TransportProtocolNumber]*transportProtocolState), + networkProtocols: make(map[tcpip.NetworkProtocolNumber]NetworkProtocol), + linkAddrResolvers: make(map[tcpip.NetworkProtocolNumber]LinkAddressResolver), + nics: make(map[tcpip.NICID]*NIC), + linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts), + PortManager: ports.NewPortManager(), + } + + // Add specified network protocols. + for _, name := range network { + netProtoFactory, ok := networkProtocols[name] + if !ok { + continue + } + netProto := netProtoFactory() + s.networkProtocols[netProto.Number()] = netProto + if r, ok := netProto.(LinkAddressResolver); ok { + s.linkAddrResolvers[r.LinkAddressProtocol()] = r + } + } + + // Add specified transport protocols. + for _, name := range transport { + transProtoFactory, ok := transportProtocols[name] + if !ok { + continue + } + transProto := transProtoFactory() + s.transportProtocols[transProto.Number()] = &transportProtocolState{ + proto: transProto, + } + } + + // Create the global transport demuxer. + s.demux = newTransportDemuxer(s) + + return s +} + +// SetNetworkProtocolOption allows configuring individual protocol level +// options. This method returns an error if the protocol is not supported or +// option is not supported by the protocol implementation or the provided value +// is incorrect. +func (s *Stack) SetNetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error { + netProto, ok := s.networkProtocols[network] + if !ok { + return tcpip.ErrUnknownProtocol + } + return netProto.SetOption(option) +} + +// NetworkProtocolOption allows retrieving individual protocol level option +// values. This method returns an error if the protocol is not supported or +// option is not supported by the protocol implementation. +// e.g. +// var v ipv4.MyOption +// err := s.NetworkProtocolOption(tcpip.IPv4ProtocolNumber, &v) +// if err != nil { +// ... +// } +func (s *Stack) NetworkProtocolOption(network tcpip.NetworkProtocolNumber, option interface{}) *tcpip.Error { + netProto, ok := s.networkProtocols[network] + if !ok { + return tcpip.ErrUnknownProtocol + } + return netProto.Option(option) +} + +// SetTransportProtocolOption allows configuring individual protocol level +// options. This method returns an error if the protocol is not supported or +// option is not supported by the protocol implementation or the provided value +// is incorrect. +func (s *Stack) SetTransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { + transProtoState, ok := s.transportProtocols[transport] + if !ok { + return tcpip.ErrUnknownProtocol + } + return transProtoState.proto.SetOption(option) +} + +// TransportProtocolOption allows retrieving individual protocol level option +// values. This method returns an error if the protocol is not supported or +// option is not supported by the protocol implementation. +// var v tcp.SACKEnabled +// if err := s.TransportProtocolOption(tcpip.TCPProtocolNumber, &v); err != nil { +// ... +// } +func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, option interface{}) *tcpip.Error { + transProtoState, ok := s.transportProtocols[transport] + if !ok { + return tcpip.ErrUnknownProtocol + } + return transProtoState.proto.Option(option) +} + +// SetTransportProtocolHandler sets the per-stack default handler for the given +// protocol. +// +// It must be called only during initialization of the stack. Changing it as the +// stack is operating is not supported. +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *buffer.VectorisedView) bool) { + state := s.transportProtocols[p] + if state != nil { + state.defaultHandler = h + } +} + +// Stats returns a snapshot of the current stats. +// +// NOTE: The underlying stats are updated using atomic instructions as a result +// the snapshot returned does not represent the value of all the stats at any +// single given point of time. +// TODO: Make stats available in sentry for debugging/diag. +func (s *Stack) Stats() tcpip.Stats { + return tcpip.Stats{ + UnknownProtocolRcvdPackets: atomic.LoadUint64(&s.stats.UnknownProtocolRcvdPackets), + UnknownNetworkEndpointRcvdPackets: atomic.LoadUint64(&s.stats.UnknownNetworkEndpointRcvdPackets), + MalformedRcvdPackets: atomic.LoadUint64(&s.stats.MalformedRcvdPackets), + DroppedPackets: atomic.LoadUint64(&s.stats.DroppedPackets), + } +} + +// MutableStats returns a mutable copy of the current stats. +// +// This is not generally exported via the public interface, but is available +// internally. +func (s *Stack) MutableStats() *tcpip.Stats { + return &s.stats +} + +// SetRouteTable assigns the route table to be used by this stack. It +// specifies which NIC to use for given destination address ranges. +func (s *Stack) SetRouteTable(table []tcpip.Route) { + s.mu.Lock() + defer s.mu.Unlock() + + s.routeTable = table +} + +// NewEndpoint creates a new transport layer endpoint of the given protocol. +func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + t, ok := s.transportProtocols[transport] + if !ok { + return nil, tcpip.ErrUnknownProtocol + } + + return t.proto.NewEndpoint(s, network, waiterQueue) +} + +// createNIC creates a NIC with the provided id and link-layer endpoint, and +// optionally enable it. +func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled bool) *tcpip.Error { + ep := FindLinkEndpoint(linkEP) + if ep == nil { + return tcpip.ErrBadLinkEndpoint + } + + s.mu.Lock() + defer s.mu.Unlock() + + // Make sure id is unique. + if _, ok := s.nics[id]; ok { + return tcpip.ErrDuplicateNICID + } + + n := newNIC(s, id, name, ep) + + s.nics[id] = n + if enabled { + n.attachLinkEndpoint() + } + + return nil +} + +// CreateNIC creates a NIC with the provided id and link-layer endpoint. +func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { + return s.createNIC(id, "", linkEP, true) +} + +// CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, +// and a human-readable name. +func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { + return s.createNIC(id, name, linkEP, true) +} + +// CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, +// but leave it disable. Stack.EnableNIC must be called before the link-layer +// endpoint starts delivering packets to it. +func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { + return s.createNIC(id, "", linkEP, false) +} + +// EnableNIC enables the given NIC so that the link-layer endpoint can start +// delivering packets to it. +func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[id] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + nic.attachLinkEndpoint() + + return nil +} + +// NICSubnets returns a map of NICIDs to their associated subnets. +func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet { + s.mu.RLock() + defer s.mu.RUnlock() + + nics := map[tcpip.NICID][]tcpip.Subnet{} + + for id, nic := range s.nics { + nics[id] = append(nics[id], nic.Subnets()...) + } + return nics +} + +// NICInfo captures the name and addresses assigned to a NIC. +type NICInfo struct { + Name string + LinkAddress tcpip.LinkAddress + ProtocolAddresses []tcpip.ProtocolAddress +} + +// NICInfo returns a map of NICIDs to their associated information. +func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { + s.mu.RLock() + defer s.mu.RUnlock() + + nics := make(map[tcpip.NICID]NICInfo) + for id, nic := range s.nics { + nics[id] = NICInfo{ + Name: nic.name, + LinkAddress: nic.linkEP.LinkAddress(), + ProtocolAddresses: nic.Addresses(), + } + } + return nics +} + +// AddAddress adds a new network-layer address to the specified NIC. +func (s *Stack) AddAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[id] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + return nic.AddAddress(protocol, addr) +} + +// AddSubnet adds a subnet range to the specified NIC. +func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[id] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + nic.AddSubnet(protocol, subnet) + return nil +} + +// RemoveAddress removes an existing network-layer address from the specified +// NIC. +func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[id] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + return nic.RemoveAddress(addr) +} + +// FindRoute creates a route to the given destination address, leaving through +// the given nic and local address (if provided). +func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (Route, *tcpip.Error) { + s.mu.RLock() + defer s.mu.RUnlock() + + for i := range s.routeTable { + if (id != 0 && id != s.routeTable[i].NIC) || (len(remoteAddr) != 0 && !s.routeTable[i].Match(remoteAddr)) { + continue + } + + nic := s.nics[s.routeTable[i].NIC] + if nic == nil { + continue + } + + var ref *referencedNetworkEndpoint + if len(localAddr) != 0 { + ref = nic.findEndpoint(netProto, localAddr) + } else { + ref = nic.primaryEndpoint(netProto) + } + if ref == nil { + continue + } + + if len(remoteAddr) == 0 { + // If no remote address was provided, then the route + // provided will refer to the link local address. + remoteAddr = ref.ep.ID().LocalAddress + } + + r := makeRoute(netProto, ref.ep.ID().LocalAddress, remoteAddr, ref) + r.NextHop = s.routeTable[i].Gateway + return r, nil + } + + return Route{}, tcpip.ErrNoRoute +} + +// CheckNetworkProtocol checks if a given network protocol is enabled in the +// stack. +func (s *Stack) CheckNetworkProtocol(protocol tcpip.NetworkProtocolNumber) bool { + _, ok := s.networkProtocols[protocol] + return ok +} + +// CheckLocalAddress determines if the given local address exists, and if it +// does, returns the id of the NIC it's bound to. Returns 0 if the address +// does not exist. +func (s *Stack) CheckLocalAddress(nicid tcpip.NICID, protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) tcpip.NICID { + s.mu.RLock() + defer s.mu.RUnlock() + + // If a NIC is specified, we try to find the address there only. + if nicid != 0 { + nic := s.nics[nicid] + if nic == nil { + return 0 + } + + ref := nic.findEndpoint(protocol, addr) + if ref == nil { + return 0 + } + + ref.decRef() + + return nic.id + } + + // Go through all the NICs. + for _, nic := range s.nics { + ref := nic.findEndpoint(protocol, addr) + if ref != nil { + ref.decRef() + return nic.id + } + } + + return 0 +} + +// SetPromiscuousMode enables or disables promiscuous mode in the given NIC. +func (s *Stack) SetPromiscuousMode(nicID tcpip.NICID, enable bool) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + nic.setPromiscuousMode(enable) + + return nil +} + +// SetSpoofing enables or disables address spoofing in the given NIC, allowing +// endpoints to bind to any address in the NIC. +func (s *Stack) SetSpoofing(nicID tcpip.NICID, enable bool) *tcpip.Error { + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + nic.setSpoofing(enable) + + return nil +} + +// AddLinkAddress adds a link address to the stack link cache. +func (s *Stack) AddLinkAddress(nicid tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress) { + fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + s.linkAddrCache.add(fullAddr, linkAddr) + // TODO: provide a way for a + // transport endpoint to receive a signal that AddLinkAddress + // for a particular address has been called. +} + +// GetLinkAddress implements LinkAddressCache.GetLinkAddress. +func (s *Stack) GetLinkAddress(nicid tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, waker *sleep.Waker) (tcpip.LinkAddress, *tcpip.Error) { + s.mu.RLock() + nic := s.nics[nicid] + if nic == nil { + s.mu.RUnlock() + return "", tcpip.ErrUnknownNICID + } + s.mu.RUnlock() + + fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + linkRes := s.linkAddrResolvers[protocol] + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.linkEP, waker) +} + +// RemoveWaker implements LinkAddressCache.RemoveWaker. +func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) { + s.mu.RLock() + defer s.mu.RUnlock() + + if nic := s.nics[nicid]; nic == nil { + fullAddr := tcpip.FullAddress{NIC: nicid, Addr: addr} + s.linkAddrCache.removeWaker(fullAddr, waker) + } +} + +// RegisterTransportEndpoint registers the given endpoint with the stack +// transport dispatcher. Received packets that match the provided id will be +// delivered to the given endpoint; specifying a nic is optional, but +// nic-specific IDs have precedence over global ones. +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { + if nicID == 0 { + return s.demux.registerEndpoint(netProtos, protocol, id, ep) + } + + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic == nil { + return tcpip.ErrUnknownNICID + } + + return nic.demux.registerEndpoint(netProtos, protocol, id, ep) +} + +// UnregisterTransportEndpoint removes the endpoint with the given id from the +// stack transport dispatcher. +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) { + if nicID == 0 { + s.demux.unregisterEndpoint(netProtos, protocol, id) + return + } + + s.mu.RLock() + defer s.mu.RUnlock() + + nic := s.nics[nicID] + if nic != nil { + nic.demux.unregisterEndpoint(netProtos, protocol, id) + } +} + +// NetworkProtocolInstance returns the protocol instance in the stack for the +// specified network protocol. This method is public for protocol implementers +// and tests to use. +func (s *Stack) NetworkProtocolInstance(num tcpip.NetworkProtocolNumber) NetworkProtocol { + if p, ok := s.networkProtocols[num]; ok { + return p + } + return nil +} + +// TransportProtocolInstance returns the protocol instance in the stack for the +// specified transport protocol. This method is public for protocol implementers +// and tests to use. +func (s *Stack) TransportProtocolInstance(num tcpip.TransportProtocolNumber) TransportProtocol { + if pState, ok := s.transportProtocols[num]; ok { + return pState.proto + } + return nil +} + +// AddTCPProbe installs a probe function that will be invoked on every segment +// received by a given TCP endpoint. The probe function is passed a copy of the +// TCP endpoint state. +// +// NOTE: TCPProbe is added only to endpoints created after this call. Endpoints +// created prior to this call will not call the probe function. +// +// Further, installing two different probes back to back can result in some +// endpoints calling the first one and some the second one. There is no +// guarantee provided on which probe will be invoked. Ideally this should only +// be called once per stack. +func (s *Stack) AddTCPProbe(probe TCPProbeFunc) { + s.mu.Lock() + s.tcpProbeFunc = probe + s.mu.Unlock() +} + +// GetTCPProbe returns the TCPProbeFunc if installed with AddTCPProbe, nil +// otherwise. +func (s *Stack) GetTCPProbe() TCPProbeFunc { + s.mu.Lock() + p := s.tcpProbeFunc + s.mu.Unlock() + return p +} + +// RemoveTCPProbe removes an installed TCP probe. +// +// NOTE: This only ensures that endpoints created after this call do not +// have a probe attached. Endpoints already created will continue to invoke +// TCP probe. +func (s *Stack) RemoveTCPProbe() { + s.mu.Lock() + s.tcpProbeFunc = nil + s.mu.Unlock() +} diff --git a/pkg/tcpip/stack/stack_global_state.go b/pkg/tcpip/stack/stack_global_state.go new file mode 100644 index 000000000..030ae98d1 --- /dev/null +++ b/pkg/tcpip/stack/stack_global_state.go @@ -0,0 +1,9 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stack + +// StackFromEnv is the global stack created in restore run. +// FIXME: remove this variable once tcpip S/R is fully supported. +var StackFromEnv *Stack diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go new file mode 100644 index 000000000..b416065d7 --- /dev/null +++ b/pkg/tcpip/stack/stack_test.go @@ -0,0 +1,760 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package stack_test contains tests for the stack. It is in its own package so +// that the tests can also validate that all definitions needed to implement +// transport and network protocols are properly exported by the stack package. +package stack_test + +import ( + "math" + "testing" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" +) + +const ( + fakeNetNumber tcpip.NetworkProtocolNumber = math.MaxUint32 + fakeNetHeaderLen = 12 + + // fakeControlProtocol is used for control packets that represent + // destination port unreachable. + fakeControlProtocol tcpip.TransportProtocolNumber = 2 + + // defaultMTU is the MTU, in bytes, used throughout the tests, except + // where another value is explicitly used. It is chosen to match the MTU + // of loopback interfaces on linux systems. + defaultMTU = 65536 +) + +// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and +// received packets; the counts of all endpoints are aggregated in the protocol +// descriptor. +// +// Headers of this protocol are fakeNetHeaderLen bytes, but we currently only +// use the first three: destination address, source address, and transport +// protocol. They're all one byte fields to simplify parsing. +type fakeNetworkEndpoint struct { + nicid tcpip.NICID + id stack.NetworkEndpointID + proto *fakeNetworkProtocol + dispatcher stack.TransportDispatcher + linkEP stack.LinkEndpoint +} + +func (f *fakeNetworkEndpoint) MTU() uint32 { + return f.linkEP.MTU() - uint32(f.MaxHeaderLength()) +} + +func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { + return f.nicid +} + +func (f *fakeNetworkEndpoint) ID() *stack.NetworkEndpointID { + return &f.id +} + +func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv *buffer.VectorisedView) { + // Increment the received packet count in the protocol descriptor. + f.proto.packetCount[int(f.id.LocalAddress[0])%len(f.proto.packetCount)]++ + + // Consume the network header. + b := vv.First() + vv.TrimFront(fakeNetHeaderLen) + + // Handle control packets. + if b[2] == uint8(fakeControlProtocol) { + nb := vv.First() + if len(nb) < fakeNetHeaderLen { + return + } + + vv.TrimFront(fakeNetHeaderLen) + f.dispatcher.DeliverTransportControlPacket(tcpip.Address(nb[1:2]), tcpip.Address(nb[0:1]), fakeNetNumber, tcpip.TransportProtocolNumber(nb[2]), stack.ControlPortUnreachable, 0, vv) + return + } + + // Dispatch the packet to the transport protocol. + f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(b[2]), vv) +} + +func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { + return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen +} + +func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { + return 0 +} + +func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { + return f.linkEP.Capabilities() +} + +func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, hdr *buffer.Prependable, payload buffer.View, protocol tcpip.TransportProtocolNumber) *tcpip.Error { + // Increment the sent packet count in the protocol descriptor. + f.proto.sendPacketCount[int(r.RemoteAddress[0])%len(f.proto.sendPacketCount)]++ + + // Add the protocol's header to the packet and send it to the link + // endpoint. + b := hdr.Prepend(fakeNetHeaderLen) + b[0] = r.RemoteAddress[0] + b[1] = f.id.LocalAddress[0] + b[2] = byte(protocol) + return f.linkEP.WritePacket(r, hdr, payload, fakeNetNumber) +} + +func (*fakeNetworkEndpoint) Close() {} + +type fakeNetGoodOption bool + +type fakeNetBadOption bool + +type fakeNetInvalidValueOption int + +type fakeNetOptions struct { + good bool +} + +// fakeNetworkProtocol is a network-layer protocol descriptor. It aggregates the +// number of packets sent and received via endpoints of this protocol. The index +// where packets are added is given by the packet's destination address MOD 10. +type fakeNetworkProtocol struct { + packetCount [10]int + sendPacketCount [10]int + opts fakeNetOptions +} + +func (f *fakeNetworkProtocol) Number() tcpip.NetworkProtocolNumber { + return fakeNetNumber +} + +func (f *fakeNetworkProtocol) MinimumPacketSize() int { + return fakeNetHeaderLen +} + +func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { + return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) +} + +func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addr tcpip.Address, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { + return &fakeNetworkEndpoint{ + nicid: nicid, + id: stack.NetworkEndpointID{addr}, + proto: f, + dispatcher: dispatcher, + linkEP: linkEP, + }, nil +} + +func (f *fakeNetworkProtocol) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case fakeNetGoodOption: + f.opts.good = bool(v) + return nil + case fakeNetInvalidValueOption: + return tcpip.ErrInvalidOptionValue + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *fakeNetGoodOption: + *v = fakeNetGoodOption(f.opts.good) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func TestNetworkReceive(t *testing.T) { + // Create a stack with the fake network protocol, one nic, and two + // addresses attached to it: 1 & 2. + id, linkEP := channel.New(10, defaultMTU, "") + s := stack.New([]string{"fakeNet"}, nil) + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + var views [1]buffer.View + // Allocate the buffer containing the packet that will be injected into + // the stack. + buf := buffer.NewView(30) + + // Make sure packet with wrong address is not delivered. + buf[0] = 3 + vv := buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 0 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) + } + if fakeNet.packetCount[2] != 0 { + t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) + } + + // Make sure packet is delivered to first endpoint. + buf[0] = 1 + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + if fakeNet.packetCount[2] != 0 { + t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 0) + } + + // Make sure packet is delivered to second endpoint. + buf[0] = 2 + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + if fakeNet.packetCount[2] != 1 { + t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) + } + + // Make sure packet is not delivered if protocol number is wrong. + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber-1, &vv) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + if fakeNet.packetCount[2] != 1 { + t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) + } + + // Make sure packet that is too small is dropped. + buf.CapLength(2) + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + if fakeNet.packetCount[2] != 1 { + t.Errorf("packetCount[2] = %d, want %d", fakeNet.packetCount[2], 1) + } +} + +func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address) { + r, err := s.FindRoute(0, "", addr, fakeNetNumber) + if err != nil { + t.Fatalf("FindRoute failed: %v", err) + } + defer r.Release() + + hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) + err = r.WritePacket(&hdr, nil, fakeTransNumber) + if err != nil { + t.Errorf("WritePacket failed: %v", err) + return + } +} + +func TestNetworkSend(t *testing.T) { + // Create a stack with the fake network protocol, one nic, and one + // address: 1. The route table sends all packets through the only + // existing nic. + id, linkEP := channel.New(10, defaultMTU, "") + s := stack.New([]string{"fakeNet"}, nil) + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("NewNIC failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + // Make sure that the link-layer endpoint received the outbound packet. + sendTo(t, s, "\x03") + if c := linkEP.Drain(); c != 1 { + t.Errorf("packetCount = %d, want %d", c, 1) + } +} + +func TestNetworkSendMultiRoute(t *testing.T) { + // Create a stack with the fake network protocol, two nics, and two + // addresses per nic, the first nic has odd address, the second one has + // even addresses. + s := stack.New([]string{"fakeNet"}, nil) + + id1, linkEP1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id1); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + id2, linkEP2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, id2); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + s.SetRouteTable([]tcpip.Route{ + {"\x01", "\x01", "\x00", 1}, + {"\x00", "\x01", "\x00", 2}, + }) + + // Send a packet to an odd destination. + sendTo(t, s, "\x05") + + if c := linkEP1.Drain(); c != 1 { + t.Errorf("packetCount = %d, want %d", c, 1) + } + + // Send a packet to an even destination. + sendTo(t, s, "\x06") + + if c := linkEP2.Drain(); c != 1 { + t.Errorf("packetCount = %d, want %d", c, 1) + } +} + +func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { + r, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber) + if err != nil { + t.Fatalf("FindRoute failed: %v", err) + } + + defer r.Release() + + if r.LocalAddress != expectedSrcAddr { + t.Fatalf("Bad source address: expected %v, got %v", expectedSrcAddr, r.LocalAddress) + } + + if r.RemoteAddress != dstAddr { + t.Fatalf("Bad destination address: expected %v, got %v", dstAddr, r.RemoteAddress) + } +} + +func testNoRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr tcpip.Address) { + _, err := s.FindRoute(nic, srcAddr, dstAddr, fakeNetNumber) + if err != tcpip.ErrNoRoute { + t.Fatalf("FindRoute returned unexpected error, expected tcpip.ErrNoRoute, got %v", err) + } +} + +func TestRoutes(t *testing.T) { + // Create a stack with the fake network protocol, two nics, and two + // addresses per nic, the first nic has odd address, the second one has + // even addresses. + s := stack.New([]string{"fakeNet"}, nil) + + id1, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id1); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x03"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + id2, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, id2); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + if err := s.AddAddress(2, fakeNetNumber, "\x04"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + // Set a route table that sends all packets with odd destination + // addresses through the first NIC, and all even destination address + // through the second one. + s.SetRouteTable([]tcpip.Route{ + {"\x01", "\x01", "\x00", 1}, + {"\x00", "\x01", "\x00", 2}, + }) + + // Test routes to odd address. + testRoute(t, s, 0, "", "\x05", "\x01") + testRoute(t, s, 0, "\x01", "\x05", "\x01") + testRoute(t, s, 1, "\x01", "\x05", "\x01") + testRoute(t, s, 0, "\x03", "\x05", "\x03") + testRoute(t, s, 1, "\x03", "\x05", "\x03") + + // Test routes to even address. + testRoute(t, s, 0, "", "\x06", "\x02") + testRoute(t, s, 0, "\x02", "\x06", "\x02") + testRoute(t, s, 2, "\x02", "\x06", "\x02") + testRoute(t, s, 0, "\x04", "\x06", "\x04") + testRoute(t, s, 2, "\x04", "\x06", "\x04") + + // Try to send to odd numbered address from even numbered ones, then + // vice-versa. + testNoRoute(t, s, 0, "\x02", "\x05") + testNoRoute(t, s, 2, "\x02", "\x05") + testNoRoute(t, s, 0, "\x04", "\x05") + testNoRoute(t, s, 2, "\x04", "\x05") + + testNoRoute(t, s, 0, "\x01", "\x06") + testNoRoute(t, s, 1, "\x01", "\x06") + testNoRoute(t, s, 0, "\x03", "\x06") + testNoRoute(t, s, 1, "\x03", "\x06") +} + +func TestAddressRemoval(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + var views [1]buffer.View + buf := buffer.NewView(30) + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + + // Write a packet, and check that it gets delivered. + fakeNet.packetCount[1] = 0 + buf[0] = 1 + vv := buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + + // Remove the address, then check that packet doesn't get delivered + // anymore. + if err := s.RemoveAddress(1, "\x01"); err != nil { + t.Fatalf("RemoveAddress failed: %v", err) + } + + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + + // Check that removing the same address fails. + if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + t.Fatalf("RemoveAddress failed: %v", err) + } +} + +func TestDelayedRemovalDueToRoute(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + {"\x00", "\x00", "\x00", 1}, + }) + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + + var views [1]buffer.View + buf := buffer.NewView(30) + + // Write a packet, and check that it gets delivered. + fakeNet.packetCount[1] = 0 + buf[0] = 1 + vv := buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + + // Get a route, check that packet is still deliverable. + r, err := s.FindRoute(0, "", "\x02", fakeNetNumber) + if err != nil { + t.Fatalf("FindRoute failed: %v", err) + } + + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 2 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2) + } + + // Remove the address, then check that packet is still deliverable + // because the route is keeping the address alive. + if err := s.RemoveAddress(1, "\x01"); err != nil { + t.Fatalf("RemoveAddress failed: %v", err) + } + + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 3 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) + } + + // Check that removing the same address fails. + if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + t.Fatalf("RemoveAddress failed: %v", err) + } + + // Release the route, then check that packet is not deliverable anymore. + r.Release() + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 3 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) + } +} + +func TestPromiscuousMode(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + {"\x00", "\x00", "\x00", 1}, + }) + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + + var views [1]buffer.View + buf := buffer.NewView(30) + + // Write a packet, and check that it doesn't get delivered as we don't + // have a matching endpoint. + fakeNet.packetCount[1] = 0 + buf[0] = 1 + vv := buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 0 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) + } + + // Set promiscuous mode, then check that packet is delivered. + if err := s.SetPromiscuousMode(1, true); err != nil { + t.Fatalf("SetPromiscuousMode failed: %v", err) + } + + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } + + // Check that we can't get a route as there is no local address. + _, err := s.FindRoute(0, "", "\x02", fakeNetNumber) + if err != tcpip.ErrNoRoute { + t.Fatalf("FindRoute returned unexpected status: expected %v, got %v", tcpip.ErrNoRoute, err) + } + + // Set promiscuous mode to false, then check that packet can't be + // delivered anymore. + if err := s.SetPromiscuousMode(1, false); err != nil { + t.Fatalf("SetPromiscuousMode failed: %v", err) + } + + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } +} + +func TestAddressSpoofing(t *testing.T) { + srcAddr := tcpip.Address("\x01") + dstAddr := tcpip.Address("\x02") + + s := stack.New([]string{"fakeNet"}, nil) + + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + {"\x00", "\x00", "\x00", 1}, + }) + + // With address spoofing disabled, FindRoute does not permit an address + // that was not added to the NIC to be used as the source. + r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber) + if err == nil { + t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) + } + + // With address spoofing enabled, FindRoute permits any address to be used + // as the source. + if err := s.SetSpoofing(1, true); err != nil { + t.Fatalf("SetSpoofing failed: %v", err) + } + r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber) + if err != nil { + t.Fatalf("FindRoute failed: %v", err) + } + if r.LocalAddress != srcAddr { + t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + } +} + +// Set the subnet, then check that packet is delivered. +func TestSubnetAcceptsMatchingPacket(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + {"\x00", "\x00", "\x00", 1}, + }) + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + + var views [1]buffer.View + buf := buffer.NewView(30) + buf[0] = 1 + fakeNet.packetCount[1] = 0 + subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) + if err != nil { + t.Fatalf("NewSubnet failed: %v", err) + } + if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { + t.Fatalf("AddSubnet failed: %v", err) + } + + vv := buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 1 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + } +} + +// Set destination outside the subnet, then check it doesn't get delivered. +func TestSubnetRejectsNonmatchingPacket(t *testing.T) { + s := stack.New([]string{"fakeNet"}, nil) + + id, linkEP := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{ + {"\x00", "\x00", "\x00", 1}, + }) + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + + var views [1]buffer.View + buf := buffer.NewView(30) + buf[0] = 1 + fakeNet.packetCount[1] = 0 + subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) + if err != nil { + t.Fatalf("NewSubnet failed: %v", err) + } + if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { + t.Fatalf("AddSubnet failed: %v", err) + } + vv := buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeNet.packetCount[1] != 0 { + t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) + } +} + +func TestNetworkOptions(t *testing.T) { + s := stack.New([]string{"fakeNet"}, []string{}) + + // Try an unsupported network protocol. + if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { + t.Fatalf("SetNetworkProtocolOption(fakeNet2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) + } + + testCases := []struct { + option interface{} + wantErr *tcpip.Error + verifier func(t *testing.T, p stack.NetworkProtocol) + }{ + {fakeNetGoodOption(true), nil, func(t *testing.T, p stack.NetworkProtocol) { + t.Helper() + fakeNet := p.(*fakeNetworkProtocol) + if fakeNet.opts.good != true { + t.Fatalf("fakeNet.opts.good = false, want = true") + } + var v fakeNetGoodOption + if err := s.NetworkProtocolOption(fakeNetNumber, &v); err != nil { + t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) = %v, want = nil, where v is option %T", v, err) + } + if v != true { + t.Fatalf("s.NetworkProtocolOption(fakeNetNumber, &v) returned v = %v, want = true", v) + } + }}, + {fakeNetBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, + {fakeNetInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, + } + for _, tc := range testCases { + if got := s.SetNetworkProtocolOption(fakeNetNumber, tc.option); got != tc.wantErr { + t.Errorf("s.SetNetworkProtocolOption(fakeNet, %v) = %v, want = %v", tc.option, got, tc.wantErr) + } + if tc.verifier != nil { + tc.verifier(t, s.NetworkProtocolInstance(fakeNetNumber)) + } + } +} + +func init() { + stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol { + return &fakeNetworkProtocol{} + }) +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go new file mode 100644 index 000000000..3c0d7aa31 --- /dev/null +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -0,0 +1,166 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stack + +import ( + "sync" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" +) + +type protocolIDs struct { + network tcpip.NetworkProtocolNumber + transport tcpip.TransportProtocolNumber +} + +// transportEndpoints manages all endpoints of a given protocol. It has its own +// mutex so as to reduce interference between protocols. +type transportEndpoints struct { + mu sync.RWMutex + endpoints map[TransportEndpointID]TransportEndpoint +} + +// transportDemuxer demultiplexes packets targeted at a transport endpoint +// (i.e., after they've been parsed by the network layer). It does two levels +// of demultiplexing: first based on the network and transport protocols, then +// based on endpoints IDs. +type transportDemuxer struct { + protocol map[protocolIDs]*transportEndpoints +} + +func newTransportDemuxer(stack *Stack) *transportDemuxer { + d := &transportDemuxer{protocol: make(map[protocolIDs]*transportEndpoints)} + + // Add each network and transport pair to the demuxer. + for netProto := range stack.networkProtocols { + for proto := range stack.transportProtocols { + d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{endpoints: make(map[TransportEndpointID]TransportEndpoint)} + } + } + + return d +} + +// registerEndpoint registers the given endpoint with the dispatcher such that +// packets that match the endpoint ID are delivered to it. +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { + for i, n := range netProtos { + if err := d.singleRegisterEndpoint(n, protocol, id, ep); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id) + return err + } + } + + return nil +} + +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) *tcpip.Error { + eps, ok := d.protocol[protocolIDs{netProto, protocol}] + if !ok { + return nil + } + + eps.mu.Lock() + defer eps.mu.Unlock() + + if _, ok := eps.endpoints[id]; ok { + return tcpip.ErrPortInUse + } + + eps.endpoints[id] = ep + + return nil +} + +// unregisterEndpoint unregisters the endpoint with the given id such that it +// won't receive any more packets. +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID) { + for _, n := range netProtos { + if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { + eps.mu.Lock() + delete(eps.endpoints, id) + eps.mu.Unlock() + } + } +} + +// deliverPacket attempts to deliver the given packet. Returns true if it found +// an endpoint, false otherwise. +func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, vv *buffer.VectorisedView, id TransportEndpointID) bool { + eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] + if !ok { + return false + } + + eps.mu.RLock() + ep := d.findEndpointLocked(eps, vv, id) + eps.mu.RUnlock() + + // Fail if we didn't find one. + if ep == nil { + return false + } + + // Deliver the packet. + ep.HandlePacket(r, id, vv) + + return true +} + +// deliverControlPacket attempts to deliver the given control packet. Returns +// true if it found an endpoint, false otherwise. +func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv *buffer.VectorisedView, id TransportEndpointID) bool { + eps, ok := d.protocol[protocolIDs{net, trans}] + if !ok { + return false + } + + // Try to find the endpoint. + eps.mu.RLock() + ep := d.findEndpointLocked(eps, vv, id) + eps.mu.RUnlock() + + // Fail if we didn't find one. + if ep == nil { + return false + } + + // Deliver the packet. + ep.HandleControlPacket(id, typ, extra, vv) + + return true +} + +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv *buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { + // Try to find a match with the id as provided. + if ep := eps.endpoints[id]; ep != nil { + return ep + } + + // Try to find a match with the id minus the local address. + nid := id + + nid.LocalAddress = "" + if ep := eps.endpoints[nid]; ep != nil { + return ep + } + + // Try to find a match with the id minus the remote part. + nid.LocalAddress = id.LocalAddress + nid.RemoteAddress = "" + nid.RemotePort = 0 + if ep := eps.endpoints[nid]; ep != nil { + return ep + } + + // Try to find a match with only the local port. + nid.LocalAddress = "" + if ep := eps.endpoints[nid]; ep != nil { + return ep + } + + return nil +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go new file mode 100644 index 000000000..7e072e96e --- /dev/null +++ b/pkg/tcpip/stack/transport_test.go @@ -0,0 +1,420 @@ +// Copyright 2016 The Netstack Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package stack_test + +import ( + "testing" + + "gvisor.googlesource.com/gvisor/pkg/tcpip" + "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" + "gvisor.googlesource.com/gvisor/pkg/tcpip/link/channel" + "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/waiter" +) + +const ( + fakeTransNumber tcpip.TransportProtocolNumber = 1 + fakeTransHeaderLen = 3 +) + +// fakeTransportEndpoint is a transport-layer protocol endpoint. It counts +// received packets; the counts of all endpoints are aggregated in the protocol +// descriptor. +// +// Headers of this protocol are fakeTransHeaderLen bytes, but we currently don't +// use it. +type fakeTransportEndpoint struct { + id stack.TransportEndpointID + stack *stack.Stack + netProto tcpip.NetworkProtocolNumber + proto *fakeTransportProtocol + peerAddr tcpip.Address + route stack.Route +} + +func newFakeTransportEndpoint(stack *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber) tcpip.Endpoint { + return &fakeTransportEndpoint{stack: stack, netProto: netProto, proto: proto} +} + +func (f *fakeTransportEndpoint) Close() { + f.route.Release() +} + +func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask { + return mask +} + +func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, *tcpip.Error) { + return buffer.View{}, nil +} + +func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, *tcpip.Error) { + if len(f.route.RemoteAddress) == 0 { + return 0, tcpip.ErrNoRoute + } + + hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) + v, err := p.Get(p.Size()) + if err != nil { + return 0, err + } + if err := f.route.WritePacket(&hdr, v, fakeTransNumber); err != nil { + return 0, err + } + + return uintptr(len(v)), nil +} + +func (f *fakeTransportEndpoint) Peek([][]byte) (uintptr, *tcpip.Error) { + return 0, nil +} + +// SetSockOpt sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch opt.(type) { + case tcpip.ErrorOption: + return nil + } + return tcpip.ErrInvalidEndpointState +} + +func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { + f.peerAddr = addr.Addr + + // Find the route. + r, err := f.stack.FindRoute(addr.NIC, "", addr.Addr, fakeNetNumber) + if err != nil { + return tcpip.ErrNoRoute + } + defer r.Release() + + // Try to register so that we can start receiving packets. + f.id.RemoteAddress = addr.Addr + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f) + if err != nil { + return err + } + + f.route = r.Clone() + + return nil +} + +func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { + return nil +} + +func (*fakeTransportEndpoint) Shutdown(tcpip.ShutdownFlags) *tcpip.Error { + return nil +} + +func (*fakeTransportEndpoint) Reset() { +} + +func (*fakeTransportEndpoint) Listen(int) *tcpip.Error { + return nil +} + +func (*fakeTransportEndpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { + return nil, nil, nil +} + +func (*fakeTransportEndpoint) Bind(_ tcpip.FullAddress, commit func() *tcpip.Error) *tcpip.Error { + return commit() +} + +func (*fakeTransportEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{}, nil +} + +func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { + return tcpip.FullAddress{}, nil +} + +func (f *fakeTransportEndpoint) HandlePacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) { + // Increment the number of received packets. + f.proto.packetCount++ +} + +func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *buffer.VectorisedView) { + // Increment the number of received control packets. + f.proto.controlCount++ +} + +type fakeTransportGoodOption bool + +type fakeTransportBadOption bool + +type fakeTransportInvalidValueOption int + +type fakeTransportProtocolOptions struct { + good bool +} + +// fakeTransportProtocol is a transport-layer protocol descriptor. It +// aggregates the number of packets received via endpoints of this protocol. +type fakeTransportProtocol struct { + packetCount int + controlCount int + opts fakeTransportProtocolOptions +} + +func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber { + return fakeTransNumber +} + +func (f *fakeTransportProtocol) NewEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { + return newFakeTransportEndpoint(stack, f, netProto), nil +} + +func (*fakeTransportProtocol) MinimumPacketSize() int { + return fakeTransHeaderLen +} + +func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcpip.Error) { + return 0, 0, nil +} + +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *buffer.VectorisedView) bool { + return true +} + +func (f *fakeTransportProtocol) SetOption(option interface{}) *tcpip.Error { + switch v := option.(type) { + case fakeTransportGoodOption: + f.opts.good = bool(v) + return nil + case fakeTransportInvalidValueOption: + return tcpip.ErrInvalidOptionValue + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { + switch v := option.(type) { + case *fakeTransportGoodOption: + *v = fakeTransportGoodOption(f.opts.good) + return nil + default: + return tcpip.ErrUnknownProtocolOption + } +} + +func TestTransportReceive(t *testing.T) { + id, linkEP := channel.New(10, defaultMTU, "") + s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}) + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + // Create endpoint and connect to remote address. + wq := waiter.Queue{} + ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { + t.Fatalf("Connect failed: %v", err) + } + + fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) + + var views [1]buffer.View + // Create buffer that will hold the packet. + buf := buffer.NewView(30) + + // Make sure packet with wrong protocol is not delivered. + buf[0] = 1 + buf[2] = 0 + vv := buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeTrans.packetCount != 0 { + t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) + } + + // Make sure packet from the wrong source is not delivered. + buf[0] = 1 + buf[1] = 3 + buf[2] = byte(fakeTransNumber) + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeTrans.packetCount != 0 { + t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 0) + } + + // Make sure packet is delivered. + buf[0] = 1 + buf[1] = 2 + buf[2] = byte(fakeTransNumber) + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeTrans.packetCount != 1 { + t.Errorf("packetCount = %d, want %d", fakeTrans.packetCount, 1) + } +} + +func TestTransportControlReceive(t *testing.T) { + id, linkEP := channel.New(10, defaultMTU, "") + s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}) + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + // Create endpoint and connect to remote address. + wq := waiter.Queue{} + ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { + t.Fatalf("Connect failed: %v", err) + } + + fakeTrans := s.TransportProtocolInstance(fakeTransNumber).(*fakeTransportProtocol) + + var views [1]buffer.View + // Create buffer that will hold the control packet. + buf := buffer.NewView(2*fakeNetHeaderLen + 30) + + // Outer packet contains the control protocol number. + buf[0] = 1 + buf[1] = 0xfe + buf[2] = uint8(fakeControlProtocol) + + // Make sure packet with wrong protocol is not delivered. + buf[fakeNetHeaderLen+0] = 0 + buf[fakeNetHeaderLen+1] = 1 + buf[fakeNetHeaderLen+2] = 0 + vv := buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeTrans.controlCount != 0 { + t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) + } + + // Make sure packet from the wrong source is not delivered. + buf[fakeNetHeaderLen+0] = 3 + buf[fakeNetHeaderLen+1] = 1 + buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeTrans.controlCount != 0 { + t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 0) + } + + // Make sure packet is delivered. + buf[fakeNetHeaderLen+0] = 2 + buf[fakeNetHeaderLen+1] = 1 + buf[fakeNetHeaderLen+2] = byte(fakeTransNumber) + vv = buf.ToVectorisedView(views) + linkEP.Inject(fakeNetNumber, &vv) + if fakeTrans.controlCount != 1 { + t.Errorf("controlCount = %d, want %d", fakeTrans.controlCount, 1) + } +} + +func TestTransportSend(t *testing.T) { + id, _ := channel.New(10, defaultMTU, "") + s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}) + if err := s.CreateNIC(1, id); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + + if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + t.Fatalf("AddAddress failed: %v", err) + } + + s.SetRouteTable([]tcpip.Route{{"\x00", "\x00", "\x00", 1}}) + + // Create endpoint and bind it. + wq := waiter.Queue{} + ep, err := s.NewEndpoint(fakeTransNumber, fakeNetNumber, &wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Connect(tcpip.FullAddress{0, "\x02", 0}); err != nil { + t.Fatalf("Connect failed: %v", err) + } + + // Create buffer that will hold the payload. + view := buffer.NewView(30) + _, err = ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}) + if err != nil { + t.Fatalf("write failed: %v", err) + } + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + + if fakeNet.sendPacketCount[2] != 1 { + t.Errorf("sendPacketCount = %d, want %d", fakeNet.sendPacketCount[2], 1) + } +} + +func TestTransportOptions(t *testing.T) { + s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}) + + // Try an unsupported transport protocol. + if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { + t.Fatalf("SetTransportProtocolOption(fakeTrans2, blah, false) = %v, want = tcpip.ErrUnknownProtocol", err) + } + + testCases := []struct { + option interface{} + wantErr *tcpip.Error + verifier func(t *testing.T, p stack.TransportProtocol) + }{ + {fakeTransportGoodOption(true), nil, func(t *testing.T, p stack.TransportProtocol) { + t.Helper() + fakeTrans := p.(*fakeTransportProtocol) + if fakeTrans.opts.good != true { + t.Fatalf("fakeTrans.opts.good = false, want = true") + } + var v fakeTransportGoodOption + if err := s.TransportProtocolOption(fakeTransNumber, &v); err != nil { + t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) = %v, want = nil, where v is option %T", v, err) + } + if v != true { + t.Fatalf("s.TransportProtocolOption(fakeTransNumber, &v) returned v = %v, want = true", v) + } + + }}, + {fakeTransportBadOption(true), tcpip.ErrUnknownProtocolOption, nil}, + {fakeTransportInvalidValueOption(1), tcpip.ErrInvalidOptionValue, nil}, + } + for _, tc := range testCases { + if got := s.SetTransportProtocolOption(fakeTransNumber, tc.option); got != tc.wantErr { + t.Errorf("s.SetTransportProtocolOption(fakeTrans, %v) = %v, want = %v", tc.option, got, tc.wantErr) + } + if tc.verifier != nil { + tc.verifier(t, s.TransportProtocolInstance(fakeTransNumber)) + } + } +} + +func init() { + stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol { + return &fakeTransportProtocol{} + }) +} |