diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 33 | ||||
-rw-r--r-- | pkg/tcpip/stack/icmp_rate_limit.go | 41 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 253 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 79 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 441 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 75 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 15 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 261 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 921 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 227 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer_test.go | 352 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 75 |
12 files changed, 1894 insertions, 879 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 9986b4be3..baf88bfab 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,11 +1,28 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_library") + +go_template_instance( + name = "linkaddrentry_list", + out = "linkaddrentry_list.go", + package = "stack", + prefix = "linkAddrEntry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*linkAddrEntry", + "Linker": "*linkAddrEntry", + }, +) go_library( name = "stack", srcs = [ + "icmp_rate_limit.go", "linkaddrcache.go", + "linkaddrentry_list.go", "nic.go", "registration.go", "route.go", @@ -19,6 +36,7 @@ go_library( ], deps = [ "//pkg/ilist", + "//pkg/rand", "//pkg/sleep", "//pkg/tcpip", "//pkg/tcpip/buffer", @@ -28,6 +46,7 @@ go_library( "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/waiter", + "@org_golang_x_time//rate:go_default_library", ], ) @@ -36,6 +55,7 @@ go_test( size = "small", srcs = [ "stack_test.go", + "transport_demuxer_test.go", "transport_test.go", ], deps = [ @@ -46,6 +66,9 @@ go_test( "//pkg/tcpip/iptables", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/loopback", + "//pkg/tcpip/network/ipv4", + "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/transport/udp", "//pkg/waiter", ], ) @@ -60,3 +83,11 @@ go_test( "//pkg/tcpip", ], ) + +filegroup( + name = "autogen", + srcs = [ + "linkaddrentry_list.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go new file mode 100644 index 000000000..3a20839da --- /dev/null +++ b/pkg/tcpip/stack/icmp_rate_limit.go @@ -0,0 +1,41 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "golang.org/x/time/rate" +) + +const ( + // icmpLimit is the default maximum number of ICMP messages permitted by this + // rate limiter. + icmpLimit = 1000 + + // icmpBurst is the default number of ICMP messages that can be sent in a single + // burst. + icmpBurst = 50 +) + +// ICMPRateLimiter is a global rate limiter that controls the generation of +// ICMP messages generated by the stack. +type ICMPRateLimiter struct { + *rate.Limiter +} + +// NewICMPRateLimiter returns a global rate limiter for controlling the rate +// at which ICMP messages are generated by the stack. +func NewICMPRateLimiter() *ICMPRateLimiter { + return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)} +} diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 77bb0ccb9..267df60d1 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -42,10 +42,11 @@ type linkAddrCache struct { // resolved before failing. resolutionAttempts int - mu sync.Mutex - cache map[tcpip.FullAddress]*linkAddrEntry - next int // array index of next available entry - entries [linkAddrCacheSize]linkAddrEntry + cache struct { + sync.Mutex + table map[tcpip.FullAddress]*linkAddrEntry + lru linkAddrEntryList + } } // entryState controls the state of a single entry in the cache. @@ -60,9 +61,6 @@ const ( // failed means that address resolution timed out and the address // could not be resolved. failed - // expired means that the cache entry has expired and the address must be - // resolved again. - expired ) // String implements Stringer. @@ -74,8 +72,6 @@ func (s entryState) String() string { return "ready" case failed: return "failed" - case expired: - return "expired" default: return fmt.Sprintf("unknown(%d)", s) } @@ -84,64 +80,46 @@ func (s entryState) String() string { // A linkAddrEntry is an entry in the linkAddrCache. // This struct is thread-compatible. type linkAddrEntry struct { + linkAddrEntryEntry + addr tcpip.FullAddress linkAddr tcpip.LinkAddress expiration time.Time s entryState // wakers is a set of waiters for address resolution result. Anytime - // state transitions out of 'incomplete' these waiters are notified. + // state transitions out of incomplete these waiters are notified. wakers map[*sleep.Waker]struct{} + // done is used to allow callers to wait on address resolution. It is nil iff + // s is incomplete and resolution is not yet in progress. done chan struct{} } -func (e *linkAddrEntry) state() entryState { - if e.s != expired && time.Now().After(e.expiration) { - // Force the transition to ensure waiters are notified. - e.changeState(expired) - } - return e.s -} - -func (e *linkAddrEntry) changeState(ns entryState) { - if e.s == ns { - return - } - - // Validate state transition. - switch e.s { - case incomplete: - // All transitions are valid. - case ready, failed: - if ns != expired { - panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) - } - case expired: - // Terminal state. - panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) - default: - panic(fmt.Sprintf("invalid state: %s", e.s)) - } - +// changeState sets the entry's state to ns, notifying any waiters. +// +// The entry's expiration is bumped up to the greater of itself and the passed +// expiration; the zero value indicates immediate expiration, and is set +// unconditionally - this is an implementation detail that allows for entries +// to be reused. +func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { // Notify whoever is waiting on address resolution when transitioning - // out of 'incomplete'. - if e.s == incomplete { + // out of incomplete. + if e.s == incomplete && ns != incomplete { for w := range e.wakers { w.Assert() } e.wakers = nil - if e.done != nil { - close(e.done) + if ch := e.done; ch != nil { + close(ch) } + e.done = nil } - e.s = ns -} -func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) { - if w != nil { - e.wakers[w] = struct{}{} + if expiration.IsZero() || expiration.After(e.expiration) { + e.expiration = expiration } + e.s = ns } func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { @@ -150,53 +128,54 @@ func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { // add adds a k -> v mapping to the cache. func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { - c.mu.Lock() - defer c.mu.Unlock() - - entry, ok := c.cache[k] - if ok { - s := entry.state() - if s != expired && entry.linkAddr == v { - // Disregard repeated calls. - return - } - // Check if entry is waiting for address resolution. - if s == incomplete { - entry.linkAddr = v - } else { - // Otherwise create a new entry to replace it. - entry = c.makeAndAddEntry(k, v) - } - } else { - entry = c.makeAndAddEntry(k, v) - } + // Calculate expiration time before acquiring the lock, since expiration is + // relative to the time when information was learned, rather than when it + // happened to be inserted into the cache. + expiration := time.Now().Add(c.ageLimit) - entry.changeState(ready) + c.cache.Lock() + entry := c.getOrCreateEntryLocked(k) + entry.linkAddr = v + + entry.changeState(ready, expiration) + c.cache.Unlock() } -// makeAndAddEntry is a helper function to create and add a new -// entry to the cache map and evict older entry as needed. -func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry { - // Take over the next entry. - entry := &c.entries[c.next] - if c.cache[entry.addr] == entry { - delete(c.cache, entry.addr) +// getOrCreateEntryLocked retrieves a cache entry associated with k. The +// returned entry is always refreshed in the cache (it is reachable via the +// map, and its place is bumped in LRU). +// +// If a matching entry exists in the cache, it is returned. If no matching +// entry exists and the cache is full, an existing entry is evicted via LRU, +// reset to state incomplete, and returned. If no matching entry exists and the +// cache is not full, a new entry with state incomplete is allocated and +// returned. +func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry { + if entry, ok := c.cache.table[k]; ok { + c.cache.lru.Remove(entry) + c.cache.lru.PushFront(entry) + return entry } + var entry *linkAddrEntry + if len(c.cache.table) == linkAddrCacheSize { + entry = c.cache.lru.Back() - // Mark the soon-to-be-replaced entry as expired, just in case there is - // someone waiting for address resolution on it. - entry.changeState(expired) + delete(c.cache.table, entry.addr) + c.cache.lru.Remove(entry) - *entry = linkAddrEntry{ - addr: k, - linkAddr: v, - expiration: time.Now().Add(c.ageLimit), - wakers: make(map[*sleep.Waker]struct{}), - done: make(chan struct{}), + // Wake waiters and mark the soon-to-be-reused entry as expired. Note + // that the state passed doesn't matter when the zero time is passed. + entry.changeState(failed, time.Time{}) + } else { + entry = new(linkAddrEntry) } - c.cache[k] = entry - c.next = (c.next + 1) % len(c.entries) + *entry = linkAddrEntry{ + addr: k, + s: incomplete, + } + c.cache.table[k] = entry + c.cache.lru.PushFront(entry) return entry } @@ -208,43 +187,55 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo } } - c.mu.Lock() - defer c.mu.Unlock() - if entry, ok := c.cache[k]; ok { - switch s := entry.state(); s { - case expired: - case ready: - return entry.linkAddr, nil, nil - case failed: - return "", nil, tcpip.ErrNoLinkAddress - case incomplete: - // Address resolution is still in progress. - entry.maybeAddWaker(waker) - return "", entry.done, tcpip.ErrWouldBlock - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) + c.cache.Lock() + defer c.cache.Unlock() + entry := c.getOrCreateEntryLocked(k) + switch s := entry.s; s { + case ready, failed: + if !time.Now().After(entry.expiration) { + // Not expired. + switch s { + case ready: + return entry.linkAddr, nil, nil + case failed: + return entry.linkAddr, nil, tcpip.ErrNoLinkAddress + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } } - } - if linkRes == nil { - return "", nil, tcpip.ErrNoLinkAddress - } + entry.changeState(incomplete, time.Time{}) + fallthrough + case incomplete: + if waker != nil { + if entry.wakers == nil { + entry.wakers = make(map[*sleep.Waker]struct{}) + } + entry.wakers[waker] = struct{}{} + } - // Add 'incomplete' entry in the cache to mark that resolution is in progress. - e := c.makeAndAddEntry(k, "") - e.maybeAddWaker(waker) + if entry.done == nil { + // Address resolution needs to be initiated. + if linkRes == nil { + return entry.linkAddr, nil, tcpip.ErrNoLinkAddress + } - go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + entry.done = make(chan struct{}) + go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + } - return "", e.done, tcpip.ErrWouldBlock + return entry.linkAddr, entry.done, tcpip.ErrWouldBlock + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } } // removeWaker removes a waker previously added through get(). func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { - c.mu.Lock() - defer c.mu.Unlock() + c.cache.Lock() + defer c.cache.Unlock() - if entry, ok := c.cache[k]; ok { + if entry, ok := c.cache.table[k]; ok { entry.removeWaker(waker) } } @@ -256,8 +247,8 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) select { - case <-time.After(c.resolutionTimeout): - if stop := c.checkLinkRequest(k, i); stop { + case now := <-time.After(c.resolutionTimeout): + if stop := c.checkLinkRequest(now, k, i); stop { return } case <-done: @@ -269,38 +260,36 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link // checkLinkRequest checks whether previous attempt to resolve address has succeeded // and mark the entry accordingly, e.g. ready, failed, etc. Return true if request // can stop, false if another request should be sent. -func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool { - c.mu.Lock() - defer c.mu.Unlock() - - entry, ok := c.cache[k] +func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { + c.cache.Lock() + defer c.cache.Unlock() + entry, ok := c.cache.table[k] if !ok { // Entry was evicted from the cache. return true } - - switch s := entry.state(); s { - case ready, failed, expired: + switch s := entry.s; s { + case ready, failed: // Entry was made ready by resolver or failed. Either way we're done. - return true case incomplete: - if attempt+1 >= c.resolutionAttempts { - // Max number of retries reached, mark entry as failed. - entry.changeState(failed) - return true + if attempt+1 < c.resolutionAttempts { + // No response yet, need to send another ARP request. + return false } - // No response yet, need to send another ARP request. - return false + // Max number of retries reached, mark entry as failed. + entry.changeState(failed, now.Add(c.ageLimit)) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } + return true } func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { - return &linkAddrCache{ + c := &linkAddrCache{ ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, - cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize), } + c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize) + return c } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 924f4d240..9946b8fe8 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -17,6 +17,7 @@ package stack import ( "fmt" "sync" + "sync/atomic" "testing" "time" @@ -29,25 +30,34 @@ type testaddr struct { linkAddr tcpip.LinkAddress } -var testaddrs []testaddr +var testAddrs = func() []testaddr { + var addrs []testaddr + for i := 0; i < 4*linkAddrCacheSize; i++ { + addr := fmt.Sprintf("Addr%06d", i) + addrs = append(addrs, testaddr{ + addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, + linkAddr: tcpip.LinkAddress("Link" + addr), + }) + } + return addrs +}() type testLinkAddressResolver struct { - cache *linkAddrCache - delay time.Duration + cache *linkAddrCache + delay time.Duration + onLinkAddressRequest func() } func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { - go func() { - if r.delay > 0 { - time.Sleep(r.delay) - } - r.fakeRequest(addr) - }() + time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) + if f := r.onLinkAddressRequest; f != nil { + f() + } return nil } func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { - for _, ta := range testaddrs { + for _, ta := range testAddrs { if ta.addr.Addr == addr { r.cache.add(ta.addr, ta.linkAddr) break @@ -80,20 +90,10 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe } } -func init() { - for i := 0; i < 4*linkAddrCacheSize; i++ { - addr := fmt.Sprintf("Addr%06d", i) - testaddrs = append(testaddrs, testaddr{ - addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, - linkAddr: tcpip.LinkAddress("Link" + addr), - }) - } -} - func TestCacheOverflow(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - for i := len(testaddrs) - 1; i >= 0; i-- { - e := testaddrs[i] + for i := len(testAddrs) - 1; i >= 0; i-- { + e := testAddrs[i] c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { @@ -105,7 +105,7 @@ func TestCacheOverflow(t *testing.T) { } // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { - e := testaddrs[i] + e := testAddrs[i] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) @@ -115,8 +115,8 @@ func TestCacheOverflow(t *testing.T) { } } // The earliest entries should no longer be in the cache. - for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- { - e := testaddrs[i] + for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { + e := testAddrs[i] if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) } @@ -130,7 +130,7 @@ func TestCacheConcurrent(t *testing.T) { for r := 0; r < 16; r++ { wg.Add(1) go func() { - for _, e := range testaddrs { + for _, e := range testAddrs { c.add(e.addr, e.linkAddr) c.get(e.addr, nil, "", nil, nil) // make work for gotsan } @@ -142,7 +142,7 @@ func TestCacheConcurrent(t *testing.T) { // All goroutines add in the same order and add more values than // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. - e := testaddrs[len(testaddrs)-1] + e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -151,7 +151,7 @@ func TestCacheConcurrent(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) } - e = testaddrs[0] + e = testAddrs[0] if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } @@ -159,7 +159,7 @@ func TestCacheConcurrent(t *testing.T) { func TestCacheAgeLimit(t *testing.T) { c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) - e := testaddrs[0] + e := testAddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { @@ -169,7 +169,7 @@ func TestCacheAgeLimit(t *testing.T) { func TestCacheReplace(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - e := testaddrs[0] + e := testAddrs[0] l2 := e.linkAddr + "2" c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) @@ -193,7 +193,7 @@ func TestCacheReplace(t *testing.T) { func TestCacheResolution(t *testing.T) { c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) linkRes := &testLinkAddressResolver{cache: c} - for i, ta := range testaddrs { + for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) @@ -205,7 +205,7 @@ func TestCacheResolution(t *testing.T) { // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { - e := testaddrs[len(testaddrs)-1] + e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -220,8 +220,13 @@ func TestCacheResolutionFailed(t *testing.T) { c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) linkRes := &testLinkAddressResolver{cache: c} + var requestCount uint32 + linkRes.onLinkAddressRequest = func() { + atomic.AddUint32(&requestCount, 1) + } + // First, sanity check that resolution is working... - e := testaddrs[0] + e := testAddrs[0] got, err := getBlocking(c, e.addr, linkRes) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -230,10 +235,16 @@ func TestCacheResolutionFailed(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) } + before := atomic.LoadUint32(&requestCount) + e.addr.Addr += "2" if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } + + if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { + t.Errorf("got link address request count = %d, want = %d", got, want) + } } func TestCacheResolutionTimeout(t *testing.T) { @@ -242,7 +253,7 @@ func TestCacheResolutionTimeout(t *testing.T) { c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} - e := testaddrs[0] + e := testAddrs[0] if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 4ef85bdfb..f6106f762 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -34,15 +34,13 @@ type NIC struct { linkEP LinkEndpoint loopback bool - demux *transportDemuxer - - mu sync.RWMutex - spoofing bool - promiscuous bool - primary map[tcpip.NetworkProtocolNumber]*ilist.List - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - subnets []tcpip.Subnet - mcastJoins map[NetworkEndpointID]int32 + mu sync.RWMutex + spoofing bool + promiscuous bool + primary map[tcpip.NetworkProtocolNumber]*ilist.List + endpoints map[NetworkEndpointID]*referencedNetworkEndpoint + addressRanges []tcpip.Subnet + mcastJoins map[NetworkEndpointID]int32 stats NICStats } @@ -85,7 +83,6 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback name: name, linkEP: ep, loopback: loopback, - demux: newTransportDemuxer(stack), primary: make(map[tcpip.NetworkProtocolNumber]*ilist.List), endpoints: make(map[NetworkEndpointID]*referencedNetworkEndpoint), mcastJoins: make(map[NetworkEndpointID]int32), @@ -102,6 +99,25 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback } } +// enable enables the NIC. enable will attach the link to its LinkEndpoint and +// join the IPv6 All-Nodes Multicast address (ff02::1). +func (n *NIC) enable() *tcpip.Error { + n.attachLinkEndpoint() + + // Join the IPv6 All-Nodes Multicast group if the stack is configured to + // use IPv6. This is required to ensure that this node properly receives + // and responds to the various NDP messages that are destined to the + // all-nodes multicast address. An example is the Neighbor Advertisement + // when we perform Duplicate Address Detection, or Router Advertisement + // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 + // section 4.2 for more information. + if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress) + } + + return nil +} + // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { @@ -129,37 +145,6 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { - n.mu.RLock() - defer n.mu.RUnlock() - - var r *referencedNetworkEndpoint - - // Check for a primary endpoint. - if list, ok := n.primary[protocol]; ok { - for e := list.Front(); e != nil; e = e.Next() { - ref := e.(*referencedNetworkEndpoint) - if ref.holdsInsertRef && ref.tryIncRef() { - r = ref - break - } - } - - } - - if r == nil { - return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress - } - - addressWithPrefix := tcpip.AddressWithPrefix{ - Address: r.ep.ID().LocalAddress, - PrefixLen: r.ep.PrefixLen(), - } - r.decRef() - - return addressWithPrefix, nil -} - // primaryEndpoint returns the primary endpoint of n for the given network // protocol. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { @@ -178,7 +163,7 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN case header.IPv4Broadcast, header.IPv4Any: continue } - if r.tryIncRef() { + if r.isValidForOutgoing() && r.tryIncRef() { return r } } @@ -197,22 +182,44 @@ func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.A // getRefEpOrCreateTemp returns the referenced network endpoint for the given // protocol and address. If none exists a temporary one may be created if -// requested. -func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, allowTemp bool) *referencedNetworkEndpoint { +// we are in promiscuous mode or spoofing. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, spoofingOrPromiscuous bool) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.RUnlock() - return ref + if ref, ok := n.endpoints[id]; ok { + // An endpoint with this id exists, check if it can be used and return it. + switch ref.getKind() { + case permanentExpired: + if !spoofingOrPromiscuous { + n.mu.RUnlock() + return nil + } + fallthrough + case temporary, permanent: + if ref.tryIncRef() { + n.mu.RUnlock() + return ref + } + } } - // The address was not found, create a temporary one if requested by the - // caller or if the address is found in the NIC's subnets. - createTempEP := allowTemp + // A usable reference was not found, create a temporary one if requested by + // the caller or if the address is found in the NIC's subnets. + createTempEP := spoofingOrPromiscuous if !createTempEP { - for _, sn := range n.subnets { + for _, sn := range n.addressRanges { + // Skip the subnet address. + if address == sn.ID() { + continue + } + // For now just skip the broadcast address, until we support it. + // FIXME(b/137608825): Add support for sending/receiving directed + // (subnet) broadcast. + if address == sn.Broadcast() { + continue + } if sn.Contains(address) { createTempEP = true break @@ -230,34 +237,70 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.Unlock() - return ref + if ref, ok := n.endpoints[id]; ok { + // No need to check the type as we are ok with expired endpoints at this + // point. + if ref.tryIncRef() { + n.mu.Unlock() + return ref + } + // tryIncRef failing means the endpoint is scheduled to be removed once the + // lock is released. Remove it here so we can create a new (temporary) one. + // The removal logic waiting for the lock handles this case. + n.removeEndpointLocked(ref) } + // Add a new temporary endpoint. netProto, ok := n.stack.networkProtocols[protocol] if !ok { n.mu.Unlock() return nil } - ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: address, PrefixLen: netProto.DefaultPrefixLen(), }, - }, peb, true) - - if ref != nil { - ref.holdsInsertRef = false - } + }, peb, temporary) n.mu.Unlock() return ref } -func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, replace bool) (*referencedNetworkEndpoint, *tcpip.Error) { +func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) { + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if ref, ok := n.endpoints[id]; ok { + switch ref.getKind() { + case permanent: + // The NIC already have a permanent endpoint with that address. + return nil, tcpip.ErrDuplicateAddress + case permanentExpired, temporary: + // Promote the endpoint to become permanent. + if ref.tryIncRef() { + ref.setKind(permanent) + return ref, nil + } + // tryIncRef failing means the endpoint is scheduled to be removed once + // the lock is released. Remove it here so we can create a new + // (permanent) one. The removal logic waiting for the lock handles this + // case. + n.removeEndpointLocked(ref) + } + } + return n.addAddressLocked(protocolAddress, peb, permanent) +} + +func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { + // TODO(b/141022673): Validate IP address before adding them. + + // Sanity check. + id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} + if _, ok := n.endpoints[id]; ok { + // Endpoint already exists. + return nil, tcpip.ErrDuplicateAddress + } + netProto, ok := n.stack.networkProtocols[protocolAddress.Protocol] if !ok { return nil, tcpip.ErrUnknownProtocol @@ -268,22 +311,12 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar if err != nil { return nil, err } - - id := *ep.ID() - if ref, ok := n.endpoints[id]; ok { - if !replace { - return nil, tcpip.ErrDuplicateAddress - } - - n.removeEndpointLocked(ref) - } - ref := &referencedNetworkEndpoint{ - refs: 1, - ep: ep, - nic: n, - protocol: protocolAddress.Protocol, - holdsInsertRef: true, + refs: 1, + ep: ep, + nic: n, + protocol: protocolAddress.Protocol, + kind: kind, } // Set up cache if link address resolution exists for this protocol. @@ -293,6 +326,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } + // If we are adding an IPv6 unicast address, join the solicited-node + // multicast address. + if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) { + snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) + if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { + return nil, err + } + } + n.endpoints[id] = ref l, ok := n.primary[protocolAddress.Protocol] @@ -316,18 +358,26 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) *tcpip.Error { // Add the endpoint. n.mu.Lock() - _, err := n.addAddressLocked(protocolAddress, peb, false) + _, err := n.addPermanentAddressLocked(protocolAddress, peb) n.mu.Unlock() return err } -// Addresses returns the addresses associated with this NIC. -func (n *NIC) Addresses() []tcpip.ProtocolAddress { +// AllAddresses returns all addresses (primary and non-primary) associated with +// this NIC. +func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { n.mu.RLock() defer n.mu.RUnlock() + addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { + // Don't include expired or tempory endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.getKind() { + case permanentExpired, temporary: + continue + } addrs = append(addrs, tcpip.ProtocolAddress{ Protocol: ref.protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ @@ -339,45 +389,66 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { return addrs } -// AddSubnet adds a new subnet to n, so that it starts accepting packets -// targeted at the given address and network protocol. -func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { +// PrimaryAddresses returns the primary addresses associated with this NIC. +func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { + n.mu.RLock() + defer n.mu.RUnlock() + + var addrs []tcpip.ProtocolAddress + for proto, list := range n.primary { + for e := list.Front(); e != nil; e = e.Next() { + ref := e.(*referencedNetworkEndpoint) + // Don't include expired or tempory endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.getKind() { + case permanentExpired, temporary: + continue + } + + addrs = append(addrs, tcpip.ProtocolAddress{ + Protocol: proto, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: ref.ep.ID().LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + }, + }) + } + } + return addrs +} + +// AddAddressRange adds a range of addresses to n, so that it starts accepting +// packets targeted at the given addresses and network protocol. The range is +// given by a subnet address, and all addresses contained in the subnet are +// used except for the subnet address itself and the subnet's broadcast +// address. +func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { n.mu.Lock() - n.subnets = append(n.subnets, subnet) + n.addressRanges = append(n.addressRanges, subnet) n.mu.Unlock() } -// RemoveSubnet removes the given subnet from n. -func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) { +// RemoveAddressRange removes the given address range from n. +func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Lock() // Use the same underlying array. - tmp := n.subnets[:0] - for _, sub := range n.subnets { + tmp := n.addressRanges[:0] + for _, sub := range n.addressRanges { if sub != subnet { tmp = append(tmp, sub) } } - n.subnets = tmp + n.addressRanges = tmp n.mu.Unlock() } -// ContainsSubnet reports whether this NIC contains the given subnet. -func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool { - for _, s := range n.Subnets() { - if s == subnet { - return true - } - } - return false -} - // Subnets returns the Subnets associated with this NIC. -func (n *NIC) Subnets() []tcpip.Subnet { +func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints)) + sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints)) for nid := range n.endpoints { sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) if err != nil { @@ -387,19 +458,22 @@ func (n *NIC) Subnets() []tcpip.Subnet { } sns = append(sns, sn) } - return append(sns, n.subnets...) + return append(sns, n.addressRanges...) } func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { id := *r.ep.ID() - // Nothing to do if the reference has already been replaced with a - // different one. + // Nothing to do if the reference has already been replaced with a different + // one. This happens in the case where 1) this endpoint's ref count hit zero + // and was waiting (on the lock) to be removed and 2) the same address was + // re-added in the meantime by removing this endpoint from the list and + // adding a new one. if n.endpoints[id] != r { return } - if r.holdsInsertRef { + if r.getKind() == permanent { panic("Reference count dropped to zero before being removed") } @@ -418,15 +492,28 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { n.mu.Unlock() } -func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { - r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || !r.holdsInsertRef { +func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { + r, ok := n.endpoints[NetworkEndpointID{addr}] + if !ok || r.getKind() != permanent { return tcpip.ErrBadLocalAddress } - r.holdsInsertRef = false + r.setKind(permanentExpired) + if !r.decRefLocked() { + // The endpoint still has references to it. + return nil + } - r.decRefLocked() + // At this point the endpoint is deleted. + + // If we are removing an IPv6 unicast address, leave the solicited-node + // multicast address. + if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) { + snmc := header.SolicitedNodeAddr(addr) + if err := n.leaveGroupLocked(snmc); err != nil { + return err + } + } return nil } @@ -435,7 +522,7 @@ func (n *NIC) removeAddressLocked(addr tcpip.Address) *tcpip.Error { func (n *NIC) RemoveAddress(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() - return n.removeAddressLocked(addr) + return n.removePermanentAddressLocked(addr) } // joinGroup adds a new endpoint for the given multicast address, if none @@ -444,6 +531,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address n.mu.Lock() defer n.mu.Unlock() + return n.joinGroupLocked(protocol, addr) +} + +// joinGroupLocked adds a new endpoint for the given multicast address, if none +// exists yet. Otherwise it just increments its count. n MUST be locked before +// joinGroupLocked is called. +func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { @@ -451,13 +545,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address if !ok { return tcpip.ErrUnknownProtocol } - if _, err := n.addAddressLocked(tcpip.ProtocolAddress{ + if _, err := n.addPermanentAddressLocked(tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: addr, PrefixLen: netProto.DefaultPrefixLen(), }, - }, NeverPrimaryEndpoint, false); err != nil { + }, NeverPrimaryEndpoint); err != nil { return err } } @@ -471,6 +565,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() + return n.leaveGroupLocked(addr) +} + +// leaveGroupLocked decrements the count for the given multicast address, and +// when it reaches zero removes the endpoint for this address. n MUST be locked +// before leaveGroupLocked is called. +func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] switch joins { @@ -479,7 +580,7 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress case 1: // This is the last one, clean up. - if err := n.removeAddressLocked(addr); err != nil { + if err := n.removePermanentAddressLocked(addr); err != nil { return err } } @@ -487,6 +588,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return nil } +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { + r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) + r.RemoteLinkAddress = remotelinkAddr + ref.ep.HandlePacket(&r, vv) + ref.decRef() +} + // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the physical interface. @@ -514,6 +622,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr src, dst := netProto.ParseAddresses(vv.First()) + n.stack.AddLinkAddress(n.id, src, remote) + // If the packet is destined to the IPv4 Broadcast address, then make a // route to each IPv4 network endpoint and let each endpoint handle the // packet. @@ -521,11 +631,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr // n.endpoints is mutex protected so acquire lock. n.mu.RLock() for _, ref := range n.endpoints { - if ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() + if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) } } n.mu.RUnlock() @@ -533,10 +640,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr } if ref := n.getRef(protocol, dst); ref != nil { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) return } @@ -559,8 +663,9 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n := r.ref.nic n.mu.RLock() ref, ok := n.endpoints[NetworkEndpointID{dst}] + ok = ok && ref.isValidForOutgoing() && ref.tryIncRef() n.mu.RUnlock() - if ok && ref.tryIncRef() { + if ok { r.RemoteAddress = src // TODO(b/123449044): Update the source NIC as well. ref.ep.HandlePacket(&r, vv) @@ -599,9 +704,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - if !n.demux.deliverRawPacket(r, protocol, netHeader, vv) { - n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) - } + n.stack.demux.deliverRawPacket(r, protocol, netHeader, vv) if len(vv.First()) < transProto.MinimumPacketSize() { n.stack.stats.MalformedRcvdPackets.Increment() @@ -615,9 +718,6 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN } id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.demux.deliverPacket(r, protocol, netHeader, vv, id) { - return - } if n.stack.demux.deliverPacket(r, protocol, netHeader, vv, id) { return } @@ -631,7 +731,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { n.stack.stats.MalformedRcvdPackets.Increment() } } @@ -659,10 +759,7 @@ func (n *NIC) DeliverTransportControlPacket(local, remote tcpip.Address, net tcp } id := TransportEndpointID{srcPort, local, dstPort, remote} - if n.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { - return - } - if n.stack.demux.deliverControlPacket(net, trans, typ, extra, vv, id) { + if n.stack.demux.deliverControlPacket(n, net, trans, typ, extra, vv, id) { return } } @@ -672,9 +769,38 @@ func (n *NIC) ID() tcpip.NICID { return n.id } +// Stack returns the instance of the Stack that owns this NIC. +func (n *NIC) Stack() *Stack { + return n.stack +} + +type networkEndpointKind int32 + +const ( + // A permanent endpoint is created by adding a permanent address (vs. a + // temporary one) to the NIC. Its reference count is biased by 1 to avoid + // removal when no route holds a reference to it. It is removed by explicitly + // removing the permanent address from the NIC. + permanent networkEndpointKind = iota + + // An expired permanent endoint is a permanent endoint that had its address + // removed from the NIC, and it is waiting to be removed once no more routes + // hold a reference to it. This is achieved by decreasing its reference count + // by 1. If its address is re-added before the endpoint is removed, its type + // changes back to permanent and its reference count increases by 1 again. + permanentExpired + + // A temporary endpoint is created for spoofing outgoing packets, or when in + // promiscuous mode and accepting incoming packets that don't match any + // permanent endpoint. Its reference count is not biased by 1 and the + // endpoint is removed immediately when no more route holds a reference to + // it. A temporary endpoint can be promoted to permanent if its address + // is added permanently. + temporary +) + type referencedNetworkEndpoint struct { ilist.Entry - refs int32 ep NetworkEndpoint nic *NIC protocol tcpip.NetworkProtocolNumber @@ -683,11 +809,34 @@ type referencedNetworkEndpoint struct { // protocol. Set to nil otherwise. linkCache LinkAddressCache - // holdsInsertRef is protected by the NIC's mutex. It indicates whether - // the reference count is biased by 1 due to the insertion of the - // endpoint. It is reset to false when RemoveAddress is called on the - // NIC. - holdsInsertRef bool + // refs is counting references held for this endpoint. When refs hits zero it + // triggers the automatic removal of the endpoint from the NIC. + refs int32 + + // networkEndpointKind must only be accessed using {get,set}Kind(). + kind networkEndpointKind +} + +func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { + return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind))) +} + +func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { + atomic.StoreInt32((*int32)(&r.kind), int32(kind)) +} + +// isValidForOutgoing returns true if the endpoint can be used to send out a +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in spoofing mode. +func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { + return r.getKind() != permanentExpired || r.nic.spoofing +} + +// isValidForIncoming returns true if the endpoint can accept an incoming +// packet. It requires the endpoint to not be marked expired (i.e., its address +// has been removed), or the NIC to be in promiscuous mode. +func (r *referencedNetworkEndpoint) isValidForIncoming() bool { + return r.getKind() != permanentExpired || r.nic.promiscuous } // decRef decrements the ref count and cleans up the endpoint once it reaches @@ -699,11 +848,14 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. -func (r *referencedNetworkEndpoint) decRefLocked() { +// locked. Returns true if the endpoint was removed. +func (r *referencedNetworkEndpoint) decRefLocked() bool { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) + return true } + + return false } // incRef increments the ref count. It must only be called when the caller is @@ -728,3 +880,8 @@ func (r *referencedNetworkEndpoint) tryIncRef() bool { } } } + +// stack returns the Stack instance that owns the underlying endpoint. +func (r *referencedNetworkEndpoint) stack() *Stack { + return r.nic.stack +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 2037eef9f..80101d4bb 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -15,8 +15,6 @@ package stack import ( - "sync" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -109,7 +107,7 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -297,6 +295,15 @@ type LinkEndpoint interface { // IsAttached returns whether a NetworkDispatcher is attached to the // endpoint. IsAttached() bool + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // For now, requesting that an endpoint's worker goroutine(s) stop is + // implementation specific. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -359,14 +366,6 @@ type LinkAddressCache interface { RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) } -// TransportProtocolFactory functions are used by the stack to instantiate -// transport protocols. -type TransportProtocolFactory func() TransportProtocol - -// NetworkProtocolFactory provides methods to be used by the stack to -// instantiate network protocols. -type NetworkProtocolFactory func() NetworkProtocol - // UnassociatedEndpointFactory produces endpoints for writing packets not // associated with a particular transport protocol. Such endpoints can be used // to write arbitrary packets that include the IP header. @@ -374,60 +373,6 @@ type UnassociatedEndpointFactory interface { NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) } -var ( - transportProtocols = make(map[string]TransportProtocolFactory) - networkProtocols = make(map[string]NetworkProtocolFactory) - - unassociatedFactory UnassociatedEndpointFactory - - linkEPMu sync.RWMutex - nextLinkEndpointID tcpip.LinkEndpointID = 1 - linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint) -) - -// RegisterTransportProtocolFactory registers a new transport protocol factory -// with the stack so that it becomes available to users of the stack. This -// function is intended to be called by init() functions of the protocols. -func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) { - transportProtocols[name] = p -} - -// RegisterNetworkProtocolFactory registers a new network protocol factory with -// the stack so that it becomes available to users of the stack. This function -// is intended to be called by init() functions of the protocols. -func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) { - networkProtocols[name] = p -} - -// RegisterUnassociatedFactory registers a factory to produce endpoints not -// associated with any particular transport protocol. This function is intended -// to be called by init() functions of the protocols. -func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) { - unassociatedFactory = f -} - -// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an -// ID that can be used to refer to it. -func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID { - linkEPMu.Lock() - defer linkEPMu.Unlock() - - v := nextLinkEndpointID - nextLinkEndpointID++ - - linkEndpoints[v] = linkEP - - return v -} - -// FindLinkEndpoint finds the link endpoint associated with the given ID. -func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint { - linkEPMu.RLock() - defer linkEPMu.RUnlock() - - return linkEndpoints[id] -} - // GSOType is the type of GSO segments. // // +stateify savable diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 391ab4344..5c8b7977a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -148,11 +148,15 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { // IsResolutionRequired returns true if Resolve() must be called to resolve // the link address before the this route can be written to. func (r *Route) IsResolutionRequired() bool { - return r.ref.linkCache != nil && r.RemoteLinkAddress == "" + return r.ref.isValidForOutgoing() && r.ref.linkCache != nil && r.RemoteLinkAddress == "" } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, ttl uint8) *tcpip.Error { + if !r.ref.isValidForOutgoing() { + return tcpip.ErrInvalidEndpointState + } + err := r.ref.ep.WritePacket(r, gso, hdr, payload, protocol, ttl, r.loop) if err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() @@ -166,6 +170,10 @@ func (r *Route) WritePacket(gso *GSO, hdr buffer.Prependable, payload buffer.Vec // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(payload buffer.VectorisedView) *tcpip.Error { + if !r.ref.isValidForOutgoing() { + return tcpip.ErrInvalidEndpointState + } + if err := r.ref.ep.WriteHeaderIncludedPacket(r, payload, r.loop); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err @@ -209,3 +217,8 @@ func (r *Route) MakeLoopedRoute() Route { } return l } + +// Stack returns the instance of the Stack that owns this route. +func (r *Route) Stack() *Stack { + return r.ref.stack() +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index d69162ba1..90c2cf1be 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -17,17 +17,15 @@ // // For consumers, the only function of interest is New(), everything else is // provided by the tcpip/public package. -// -// For protocol implementers, RegisterTransportProtocolFactory() and -// RegisterNetworkProtocolFactory() are used to register protocol factories with -// the stack, which will then be used to instantiate protocol objects when -// consumers interact with the stack. package stack import ( + "encoding/binary" "sync" "time" + "golang.org/x/time/rate" + "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -350,6 +348,9 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + // unassociatedFactory creates unassociated endpoints. If nil, raw + // endpoints are disabled. It is set during Stack creation and is + // immutable. unassociatedFactory UnassociatedEndpointFactory demux *transportDemuxer @@ -358,10 +359,6 @@ type Stack struct { linkAddrCache *linkAddrCache - // raw indicates whether raw sockets may be created. It is set during - // Stack creation and is immutable. - raw bool - mu sync.RWMutex nics map[tcpip.NICID]*NIC forwarding bool @@ -389,10 +386,26 @@ type Stack struct { // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. resumableEndpoints []ResumableEndpoint + + // icmpRateLimiter is a global rate limiter for all ICMP messages generated + // by the stack. + icmpRateLimiter *ICMPRateLimiter + + // portSeed is a one-time random value initialized at stack startup + // and is used to seed the TCP port picking on active connections + // + // TODO(gvisor.dev/issues/940): S/R this field. + portSeed uint32 } // Options contains optional Stack configuration. type Options struct { + // NetworkProtocols lists the network protocols to enable. + NetworkProtocols []NetworkProtocol + + // TransportProtocols lists the transport protocols to enable. + TransportProtocols []TransportProtocol + // Clock is an optional clock source used for timestampping packets. // // If no Clock is specified, the clock source will be time.Now. @@ -406,8 +419,9 @@ type Options struct { // stack (false). HandleLocal bool - // Raw indicates whether raw sockets may be created. - Raw bool + // UnassociatedFactory produces unassociated endpoints raw endpoints. + // Raw endpoints are enabled only if this is non-nil. + UnassociatedFactory UnassociatedEndpointFactory } // New allocates a new networking stack with only the requested networking and @@ -417,7 +431,7 @@ type Options struct { // SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the // stack. Please refer to individual protocol implementations as to what options // are supported. -func New(network []string, transport []string, opts Options) *Stack { +func New(opts Options) *Stack { clock := opts.Clock if clock == nil { clock = &tcpip.StdClock{} @@ -433,16 +447,12 @@ func New(network []string, transport []string, opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, - raw: opts.Raw, + icmpRateLimiter: NewICMPRateLimiter(), + portSeed: generateRandUint32(), } // Add specified network protocols. - for _, name := range network { - netProtoFactory, ok := networkProtocols[name] - if !ok { - continue - } - netProto := netProtoFactory() + for _, netProto := range opts.NetworkProtocols { s.networkProtocols[netProto.Number()] = netProto if r, ok := netProto.(LinkAddressResolver); ok { s.linkAddrResolvers[r.LinkAddressProtocol()] = r @@ -450,18 +460,14 @@ func New(network []string, transport []string, opts Options) *Stack { } // Add specified transport protocols. - for _, name := range transport { - transProtoFactory, ok := transportProtocols[name] - if !ok { - continue - } - transProto := transProtoFactory() + for _, transProto := range opts.TransportProtocols { s.transportProtocols[transProto.Number()] = &transportProtocolState{ proto: transProto, } } - s.unassociatedFactory = unassociatedFactory + // Add the factory for unassociated endpoints, if present. + s.unassociatedFactory = opts.UnassociatedFactory // Create the global transport demuxer. s.demux = newTransportDemuxer(s) @@ -596,7 +602,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if !s.raw { + if s.unassociatedFactory == nil { return nil, tcpip.ErrNotPermitted } @@ -614,12 +620,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error { - ep := FindLinkEndpoint(linkEP) - if ep == nil { - return tcpip.ErrBadLinkEndpoint - } - +func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -632,40 +633,40 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint s.nics[id] = n if enabled { - n.attachLinkEndpoint() + return n.enable() } return nil } // CreateNIC creates a NIC with the provided id and link-layer endpoint. -func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, true, false) +func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, true, false) } // CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, // and a human-readable name. -func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, false) +func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, false) } // CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer // endpoint, and a human-readable name. -func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, true) +func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, true) } // CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, // but leave it disable. Stack.EnableNIC must be called before the link-layer // endpoint starts delivering packets to it. -func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, false, false) +func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, false, false) } // CreateDisabledNamedNIC is a combination of CreateNamedNIC and // CreateDisabledNIC. -func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, false, false) +func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, false, false) } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -679,9 +680,7 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { return tcpip.ErrUnknownNICID } - nic.attachLinkEndpoint() - - return nil + return nic.enable() } // CheckNIC checks if a NIC is usable. @@ -696,14 +695,14 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { } // NICSubnets returns a map of NICIDs to their associated subnets. -func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet { +func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() defer s.mu.RUnlock() nics := map[tcpip.NICID][]tcpip.Subnet{} for id, nic := range s.nics { - nics[id] = append(nics[id], nic.Subnets()...) + nics[id] = append(nics[id], nic.AddressRanges()...) } return nics } @@ -739,7 +738,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { nics[id] = NICInfo{ Name: nic.name, LinkAddress: nic.linkEP.LinkAddress(), - ProtocolAddresses: nic.Addresses(), + ProtocolAddresses: nic.PrimaryAddresses(), Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, @@ -804,71 +803,79 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return nic.AddAddress(protocolAddress, peb) } -// AddSubnet adds a subnet range to the specified NIC. -func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { +// AddAddressRange adds a range of addresses to the specified NIC. The range is +// given by a subnet address, and all addresses contained in the subnet are +// used except for the subnet address itself and the subnet's broadcast +// address. +func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - nic.AddSubnet(protocol, subnet) + nic.AddAddressRange(protocol, subnet) return nil } return tcpip.ErrUnknownNICID } -// RemoveSubnet removes the subnet range from the specified NIC. -func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { +// RemoveAddressRange removes the range of addresses from the specified NIC. +func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - nic.RemoveSubnet(subnet) + nic.RemoveAddressRange(subnet) return nil } return tcpip.ErrUnknownNICID } -// ContainsSubnet reports whether the specified NIC contains the specified -// subnet. -func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) { +// RemoveAddress removes an existing network-layer address from the specified +// NIC. +func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - return nic.ContainsSubnet(subnet), nil + return nic.RemoveAddress(addr) } - return false, tcpip.ErrUnknownNICID + return tcpip.ErrUnknownNICID } -// RemoveAddress removes an existing network-layer address from the specified -// NIC. -func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { +// AllAddresses returns a map of NICIDs to their protocol addresses (primary +// and non-primary). +func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { s.mu.RLock() defer s.mu.RUnlock() - if nic, ok := s.nics[id]; ok { - return nic.RemoveAddress(addr) + nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress) + for id, nic := range s.nics { + nics[id] = nic.AllAddresses() } - - return tcpip.ErrUnknownNICID + return nics } -// GetMainNICAddress returns the first primary address (and the subnet that -// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's -// address if no primary addresses exist. Returns an error if the NIC doesn't -// exist or has no endpoints. +// GetMainNICAddress returns the first primary address and prefix for the given +// NIC and protocol. Returns an error if the NIC doesn't exist and an empty +// value if the NIC doesn't have a primary address for the given protocol. func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() - if nic, ok := s.nics[id]; ok { - return nic.getMainNICAddress(protocol) + nic, ok := s.nics[id] + if !ok { + return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID } - return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID + for _, a := range nic.PrimaryAddresses() { + if a.Protocol == protocol { + return a.AddressWithPrefix, nil + } + } + return tcpip.AddressWithPrefix{}, nil } func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { @@ -1035,73 +1042,27 @@ func (s *Stack) RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep. // transport dispatcher. Received packets that match the provided id will be // delivered to the given endpoint; specifying a nic is optional, but // nic-specific IDs have precedence over global ones. -func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { - if nicID == 0 { - return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort) +func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + return s.demux.registerEndpoint(netProtos, protocol, id, ep, reusePort, bindToDevice) } // UnregisterTransportEndpoint removes the endpoint with the given id from the // stack transport dispatcher. -func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { - if nicID == 0 { - s.demux.unregisterEndpoint(netProtos, protocol, id, ep) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterEndpoint(netProtos, protocol, id, ep) - } +func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { + s.demux.unregisterEndpoint(netProtos, protocol, id, ep, bindToDevice) } // RegisterRawTransportEndpoint registers the given endpoint with the stack // transport dispatcher. Received packets that match the provided transport // protocol will be delivered to the given endpoint. func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error { - if nicID == 0 { - return s.demux.registerRawEndpoint(netProto, transProto, ep) - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic == nil { - return tcpip.ErrUnknownNICID - } - - return nic.demux.registerRawEndpoint(netProto, transProto, ep) + return s.demux.registerRawEndpoint(netProto, transProto, ep) } // UnregisterRawTransportEndpoint removes the endpoint for the transport // protocol from the stack transport dispatcher. func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) { - if nicID == 0 { - s.demux.unregisterRawEndpoint(netProto, transProto, ep) - return - } - - s.mu.RLock() - defer s.mu.RUnlock() - - nic := s.nics[nicID] - if nic != nil { - nic.demux.unregisterRawEndpoint(netProto, transProto, ep) - } + s.demux.unregisterRawEndpoint(netProto, transProto, ep) } // RegisterRestoredEndpoint records e as an endpoint that has been restored on @@ -1215,3 +1176,49 @@ func (s *Stack) IPTables() iptables.IPTables { func (s *Stack) SetIPTables(ipt iptables.IPTables) { s.tables = ipt } + +// ICMPLimit returns the maximum number of ICMP messages that can be sent +// in one second. +func (s *Stack) ICMPLimit() rate.Limit { + return s.icmpRateLimiter.Limit() +} + +// SetICMPLimit sets the maximum number of ICMP messages that be sent +// in one second. +func (s *Stack) SetICMPLimit(newLimit rate.Limit) { + s.icmpRateLimiter.SetLimit(newLimit) +} + +// ICMPBurst returns the maximum number of ICMP messages that can be sent +// in a single burst. +func (s *Stack) ICMPBurst() int { + return s.icmpRateLimiter.Burst() +} + +// SetICMPBurst sets the maximum number of ICMP messages that can be sent +// in a single burst. +func (s *Stack) SetICMPBurst(burst int) { + s.icmpRateLimiter.SetBurst(burst) +} + +// AllowICMPMessage returns true if we the rate limiter allows at least one +// ICMP message to be sent at this instant. +func (s *Stack) AllowICMPMessage() bool { + return s.icmpRateLimiter.Allow() +} + +// PortSeed returns a 32 bit value that can be used as a seed value for port +// picking. +// +// NOTE: The seed is generated once during stack initialization only. +func (s *Stack) PortSeed() uint32 { + return s.portSeed +} + +func generateRandUint32() uint32 { + b := make([]byte, 4) + if _, err := rand.Read(b); err != nil { + panic(err) + } + return binary.LittleEndian.Uint32(b) +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 137c6183e..d2dede8a9 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -60,11 +60,11 @@ type fakeNetworkEndpoint struct { prefixLen int proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher - linkEP stack.LinkEndpoint + ep stack.LinkEndpoint } func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.linkEP.MTU() - uint32(f.MaxHeaderLength()) + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { @@ -108,7 +108,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen + return f.ep.MaxHeaderLength() + fakeNetHeaderLen } func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -116,7 +116,7 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto } func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.linkEP.Capabilities() + return f.ep.Capabilities() } func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error { @@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu return nil } - return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) } func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { @@ -181,18 +181,22 @@ func (f *fakeNetworkProtocol) DefaultPrefixLen() int { return fakeDefaultPrefixLen } +func (f *fakeNetworkProtocol) PacketCount(intfAddr byte) int { + return f.packetCount[int(intfAddr)%len(f.packetCount)] +} + func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Address) { return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ nicid: nicid, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, - linkEP: linkEP, + ep: ep, }, nil } @@ -218,12 +222,18 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +func fakeNetFactory() stack.NetworkProtocol { + return &fakeNetworkProtocol{} +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -241,7 +251,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -251,7 +261,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -261,7 +271,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -270,7 +280,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -280,7 +290,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -289,16 +299,75 @@ func TestNetworkReceive(t *testing.T) { } } -func sendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, payload buffer.View) { +func sendTo(s *stack.Stack, addr tcpip.Address, payload buffer.View) *tcpip.Error { r, err := s.FindRoute(0, "", addr, fakeNetNumber, false /* multicastLoop */) if err != nil { - t.Fatal("FindRoute failed:", err) + return err } defer r.Release() + return send(r, payload) +} +func send(r stack.Route, payload buffer.View) *tcpip.Error { hdr := buffer.NewPrependable(int(r.MaxHeaderLength())) - if err := r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123); err != nil { - t.Error("WritePacket failed:", err) + return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123) +} + +func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { + t.Helper() + ep.Drain() + if err := sendTo(s, addr, payload); err != nil { + t.Error("sendTo failed:", err) + } + if got, want := ep.Drain(), 1; got != want { + t.Errorf("sendTo packet count: got = %d, want %d", got, want) + } +} + +func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { + t.Helper() + ep.Drain() + if err := send(r, payload); err != nil { + t.Error("send failed:", err) + } + if got, want := ep.Drain(), 1; got != want { + t.Errorf("send packet count: got = %d, want %d", got, want) + } +} + +func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { + t.Helper() + if gotErr := send(r, payload); gotErr != wantErr { + t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) + } +} + +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { + t.Helper() + if gotErr := sendTo(s, addr, payload); gotErr != wantErr { + t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) + } +} + +func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { + t.Helper() + // testRecvInternal injects one packet, and we expect to receive it. + want := fakeNet.PacketCount(localAddrByte) + 1 + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) +} + +func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { + t.Helper() + // testRecvInternal injects one packet, and we do NOT expect to receive it. + want := fakeNet.PacketCount(localAddrByte) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) +} + +func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { + t.Helper() + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) + if got := fakeNet.PacketCount(localAddrByte); got != want { + t.Errorf("receive packet count: got = %d, want %d", got, want) } } @@ -306,9 +375,11 @@ func TestNetworkSend(t *testing.T) { // Create a stack with the fake network protocol, one nic, and one // address: 1. The route table sends all packets through the only // existing nic. - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) } @@ -325,20 +396,19 @@ func TestNetworkSend(t *testing.T) { } // Make sure that the link-layer endpoint received the outbound packet. - sendTo(t, s, "\x03", nil) - if c := linkEP.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x03", ep, nil) } func TestNetworkSendMultiRoute(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -350,8 +420,8 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -382,18 +452,10 @@ func TestNetworkSendMultiRoute(t *testing.T) { } // Send a packet to an odd destination. - sendTo(t, s, "\x05", nil) - - if c := linkEP1.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x05", ep1, nil) // Send a packet to an even destination. - sendTo(t, s, "\x06", nil) - - if c := linkEP2.Drain(); c != 1 { - t.Errorf("packetCount = %d, want %d", c, 1) - } + testSendTo(t, s, "\x06", ep2, nil) } func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { @@ -424,10 +486,12 @@ func TestRoutes(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id1, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -439,8 +503,8 @@ func TestRoutes(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -498,58 +562,71 @@ func TestRoutes(t *testing.T) { } func TestAddressRemoval(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + const localAddrByte byte = 0x01 + localAddr := tcpip.Address([]byte{localAddrByte}) + remoteAddr := tcpip.Address("\x02") + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + // Send and receive packets, and verify they are received. + buf[0] = localAddrByte + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) - // Remove the address, then check that packet doesn't get delivered - // anymore. - if err := s.RemoveAddress(1, "\x01"); err != nil { + // Remove the address, then check that send/receive doesn't work anymore. + if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } } -func TestDelayedRemovalDueToRoute(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) +func TestAddressRemovalWithRouteHeld(t *testing.T) { + const localAddrByte byte = 0x01 + localAddr := tcpip.Address([]byte{localAddrByte}) + remoteAddr := tcpip.Address("\x02") - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatal("CreateNIC failed:", err) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) } + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + buf := buffer.NewView(30) - if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } - { subnet, err := tcpip.NewSubnet("\x00", "\x00") if err != nil { @@ -558,58 +635,239 @@ func TestDelayedRemovalDueToRoute(t *testing.T) { s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) } - fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) - - buf := buffer.NewView(30) - - // Write a packet, and check that it gets delivered. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } - - // Get a route, check that packet is still deliverable. - r, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) + r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatal("FindRoute failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 2 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 2) - } + // Send and receive packets, and verify they are received. + buf[0] = localAddrByte + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) - // Remove the address, then check that packet is still deliverable - // because the route is keeping the address alive. - if err := s.RemoveAddress(1, "\x01"); err != nil { + // Remove the address, then check that send/receive doesn't work anymore. + if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) - } + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. - if err := s.RemoveAddress(1, "\x01"); err != tcpip.ErrBadLocalAddress { + if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { t.Fatalf("RemoveAddress returned unexpected error, got = %v, want = %s", err, tcpip.ErrBadLocalAddress) } +} + +func verifyAddress(t *testing.T, s *stack.Stack, nicid tcpip.NICID, addr tcpip.Address) { + t.Helper() + info, ok := s.NICInfo()[nicid] + if !ok { + t.Fatalf("NICInfo() failed to find nicid=%d", nicid) + } + if len(addr) == 0 { + // No address given, verify that there is no address assigned to the NIC. + for _, a := range info.ProtocolAddresses { + if a.Protocol == fakeNetNumber && a.AddressWithPrefix != (tcpip.AddressWithPrefix{}) { + t.Errorf("verify no-address: got = %s, want = %s", a.AddressWithPrefix, (tcpip.AddressWithPrefix{})) + } + } + return + } + // Address given, verify the address is assigned to the NIC and no other + // address is. + found := false + for _, a := range info.ProtocolAddresses { + if a.Protocol == fakeNetNumber { + if a.AddressWithPrefix.Address == addr { + found = true + } else { + t.Errorf("verify address: got = %s, want = %s", a.AddressWithPrefix.Address, addr) + } + } + } + if !found { + t.Errorf("verify address: couldn't find %s on the NIC", addr) + } +} + +func TestEndpointExpiration(t *testing.T) { + const ( + localAddrByte byte = 0x01 + remoteAddr tcpip.Address = "\x03" + noAddr tcpip.Address = "" + nicid tcpip.NICID = 1 + ) + localAddr := tcpip.Address([]byte{localAddrByte}) + + for _, promiscuous := range []bool{true, false} { + for _, spoofing := range []bool{true, false} { + t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - // Release the route, then check that packet is not deliverable anymore. - r.Release() - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 3 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 3) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) + buf := buffer.NewView(30) + buf[0] = localAddrByte + + if promiscuous { + if err := s.SetPromiscuousMode(nicid, true); err != nil { + t.Fatal("SetPromiscuousMode failed:", err) + } + } + + if spoofing { + if err := s.SetSpoofing(nicid, true); err != nil { + t.Fatal("SetSpoofing failed:", err) + } + } + + // 1. No Address yet, send should only work for spoofing, receive for + // promiscuous mode. + //----------------------- + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, ep, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, ep, nil) + } else { + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + } + + // 2. Add Address, everything should work. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + + // 3. Remove the address, send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, ep, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, ep, nil) + } else { + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + } + + // 4. Add Address back, everything should work again. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + + // 5. Take a reference to the endpoint by getting a route. Verify that + // we can still send/receive, including sending using the route. + //----------------------- + r, err := s.FindRoute(0, "", remoteAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) + + // 6. Remove the address. Send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, ep, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + } + if spoofing { + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) + } else { + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + } + + // 7. Add Address back, everything should work again. + //----------------------- + if err := s.AddAddress(nicid, fakeNetNumber, localAddr); err != nil { + t.Fatal("AddAddress failed:", err) + } + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) + + // 8. Remove the route, sendTo/recv should still work. + //----------------------- + r.Release() + verifyAddress(t, s, nicid, localAddr) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + + // 9. Remove the address. Send should only work for spoofing, receive + // for promiscuous mode. + //----------------------- + if err := s.RemoveAddress(nicid, localAddr); err != nil { + t.Fatal("RemoveAddress failed:", err) + } + verifyAddress(t, s, nicid, noAddr) + if promiscuous { + testRecv(t, fakeNet, localAddrByte, ep, buf) + } else { + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + } + if spoofing { + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, ep, nil) + } else { + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) + } + }) + } } } func TestPromiscuousMode(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -627,22 +885,15 @@ func TestPromiscuousMode(t *testing.T) { // Write a packet, and check that it doesn't get delivered as we don't // have a matching endpoint. - fakeNet.packetCount[1] = 0 - buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) - } + const localAddrByte byte = 0x01 + buf[0] = localAddrByte + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testRecv(t, fakeNet, localAddrByte, ep, buf) // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) @@ -655,25 +906,24 @@ func TestPromiscuousMode(t *testing.T) { if err := s.SetPromiscuousMode(1, false); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) - } + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } -func TestAddressSpoofing(t *testing.T) { - srcAddr := tcpip.Address("\x01") - dstAddr := tcpip.Address("\x02") +func TestSpoofingWithAddress(t *testing.T) { + localAddr := tcpip.Address("\x01") + nonExistentLocalAddr := tcpip.Address("\x02") + dstAddr := tcpip.Address("\x03") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } - if err := s.AddAddress(1, fakeNetNumber, dstAddr); err != nil { + if err := s.AddAddress(1, fakeNetNumber, localAddr); err != nil { t.Fatal("AddAddress failed:", err) } @@ -687,7 +937,7 @@ func TestAddressSpoofing(t *testing.T) { // With address spoofing disabled, FindRoute does not permit an address // that was not added to the NIC to be used as the source. - r, err := s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err == nil { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } @@ -697,23 +947,92 @@ func TestAddressSpoofing(t *testing.T) { if err := s.SetSpoofing(1, true); err != nil { t.Fatal("SetSpoofing failed:", err) } - r, err = s.FindRoute(0, srcAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + if r.LocalAddress != nonExistentLocalAddr { + t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + } + // Sending a packet works. + testSendTo(t, s, dstAddr, ep, nil) + testSend(t, r, ep, nil) + + // FindRoute should also work with a local address that exists on the NIC. + r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) if err != nil { t.Fatal("FindRoute failed:", err) } - if r.LocalAddress != srcAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, srcAddr) + if r.LocalAddress != localAddr { + t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) } + // Sending a packet using the route works. + testSend(t, r, ep, nil) +} + +func TestSpoofingNoAddress(t *testing.T) { + nonExistentLocalAddr := tcpip.Address("\x01") + dstAddr := tcpip.Address("\x02") + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + { + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + s.SetRouteTable([]tcpip.Route{{Destination: subnet, Gateway: "\x00", NIC: 1}}) + } + + // With address spoofing disabled, FindRoute does not permit an address + // that was not added to the NIC to be used as the source. + r, err := s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err == nil { + t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) + } + // Sending a packet fails. + testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) + + // With address spoofing enabled, FindRoute permits any address to be used + // as the source. + if err := s.SetSpoofing(1, true); err != nil { + t.Fatal("SetSpoofing failed:", err) + } + r, err = s.FindRoute(0, nonExistentLocalAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatal("FindRoute failed:", err) + } + if r.LocalAddress != nonExistentLocalAddr { + t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) + } + // Sending a packet works. + // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. + // testSendTo(t, s, remoteAddr, ep, nil) } func TestBroadcastNeedsNoRoute(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{}) @@ -781,10 +1100,12 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, } { t.Run(tc.name, func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -835,12 +1156,14 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } } -// Set the subnet, then check that packet is delivered. -func TestSubnetAcceptsMatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) +// Add a range of addresses, then check that a packet is delivered. +func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -856,29 +1179,59 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { buf := buffer.NewView(30) - buf[0] = 1 - fakeNet.packetCount[1] = 0 + const localAddrByte byte = 0x01 + buf[0] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x00"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) + if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 1 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) + testRecv(t, fakeNet, localAddrByte, ep, buf) +} + +func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { + t.Helper() + + // Loop over all addresses and check them. + numOfAddresses := 1 << uint(8-subnet.Prefix()) + if numOfAddresses < 1 || numOfAddresses > 255 { + t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) + } + + addrBytes := []byte(subnet.ID()) + for i := 0; i < numOfAddresses; i++ { + addr := tcpip.Address(addrBytes) + wantNicID := nicID + // The subnet and broadcast addresses are skipped. + if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() { + wantNicID = 0 + } + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID) + } + addrBytes[0]++ + } + + // Trying the next address should always fail since it is outside the range. + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0) } } -// Set the subnet, then check that CheckLocalAddress returns the correct NIC. +// Set a range of addresses, then remove it again, and check at each step that +// CheckLocalAddress returns the correct NIC for each address or zero if not +// existent. func TestCheckLocalAddressForSubnet(t *testing.T) { const nicID tcpip.NICID = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -891,39 +1244,34 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { } subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) - if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) - } - // Loop over all subnet addresses and check them. - numOfAddresses := 1 << uint(8-subnet.Prefix()) - if numOfAddresses < 1 || numOfAddresses > 255 { - t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) - } - addr := []byte(subnet.ID()) - for i := 0; i < numOfAddresses; i++ { - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID) - } - addr[0]++ + testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) + + if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } - // Trying the next address should fail since it is outside the subnet range. - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0) + testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */) + + if err := s.RemoveAddressRange(nicID, subnet); err != nil { + t.Fatal("RemoveAddressRange failed:", err) } + + testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) } -// Set destination outside the subnet, then check it doesn't get delivered. -func TestSubnetRejectsNonmatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) +// Set a range of addresses, then send a packet to a destination outside the +// range and then check it doesn't get delivered. +func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -939,23 +1287,23 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { buf := buffer.NewView(30) - buf[0] = 1 - fakeNet.packetCount[1] = 0 + const localAddrByte byte = 0x01 + buf[0] = localAddrByte subnet, err := tcpip.NewSubnet(tcpip.Address("\x10"), tcpip.AddressMask("\xF0")) if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) - } - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) - if fakeNet.packetCount[1] != 0 { - t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) + if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestNetworkOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{}, + }) // Try an unsupported network protocol. if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { @@ -994,44 +1342,53 @@ func TestNetworkOptions(t *testing.T) { } } -func TestSubnetAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { +func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool { + ranges, ok := s.NICAddressRanges()[id] + if !ok { + return false + } + for _, r := range ranges { + if r == addrRange { + return true + } + } + return false +} + +func TestAddresRangeAddRemove(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } addr := tcpip.Address("\x01\x01\x01\x01") mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) - subnet, err := tcpip.NewSubnet(addr, mask) + addrRange, err := tcpip.NewSubnet(addr, mask) if err != nil { t.Fatal("NewSubnet failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if contained { - t.Fatal("got s.ContainsSubnet(...) = true, want = false") + if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) + if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil { + t.Fatal("AddAddressRange failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if !contained { - t.Fatal("got s.ContainsSubnet(...) = false, want = true") + if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } - if err := s.RemoveSubnet(1, subnet); err != nil { - t.Fatal("RemoveSubnet failed:", err) + if err := s.RemoveAddressRange(1, addrRange); err != nil { + t.Fatal("RemoveAddressRange failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if contained { - t.Fatal("got s.ContainsSubnet(...) = true, want = false") + if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } } @@ -1042,9 +1399,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } // Insert <canBe> primary and <never> never-primary addresses. @@ -1082,20 +1441,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the // address/prefixLen matches what we added. + gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } if len(primaryAddrAdded) == 0 { - // No primary addresses present, expect an error. - if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress) + // No primary addresses present. + if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { + t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr) } } else { - // At least one primary address was added, expect a valid - // address and prefixLen. - gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok { - t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded) + // At least one primary address was added, verify the returned + // address is in the list of primary addresses we added. + if _, ok := primaryAddrAdded[gotAddr]; !ok { + t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded) } } }) @@ -1107,9 +1466,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { } func TestGetMainNICAddressAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1134,19 +1495,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get the right initial address and prefix length. - if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { + gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { t.Fatal("GetMainNICAddress failed:", err) - } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix) + } + if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr { + t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) } if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { t.Fatal("RemoveAddress failed:", err) } - // Check that we get an error after removal. - if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress) + // Check that we get no address after removal. + gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } + if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { + t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) } }) } @@ -1161,8 +1528,10 @@ func (g *addressGenerator) next(addrLen int) tcpip.Address { } func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) { + t.Helper() + if len(gotAddresses) != len(expectedAddresses) { - t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses)) + t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses)) } sort.Slice(gotAddresses, func(i, j int) bool { @@ -1182,9 +1551,11 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1201,15 +1572,17 @@ func TestAddAddress(t *testing.T) { }) } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddress(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1233,15 +1606,17 @@ func TestAddProtocolAddress(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddAddressWithOptions(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1262,15 +1637,17 @@ func TestAddAddressWithOptions(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddressWithOptions(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1297,15 +1674,17 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestNICStats(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { - t.Fatal("CreateNIC failed:", err) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { + t.Fatal("CreateNIC failed: ", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress failed:", err) @@ -1321,7 +1700,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1332,10 +1711,12 @@ func TestNICStats(t *testing.T) { payload := buffer.NewView(10) // Write a packet out via the address for NIC 1 - sendTo(t, s, "\x01", payload) - want := uint64(linkEP1.Drain()) + if err := sendTo(s, "\x01", payload); err != nil { + t.Fatal("sendTo failed: ", err) + } + want := uint64(ep1.Drain()) if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want) + t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) } if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { @@ -1346,19 +1727,21 @@ func TestNICStats(t *testing.T) { func TestNICForwarding(t *testing.T) { // Create a stack with the fake network protocol, two NICs, each with // an address. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) s.SetForwarding(true) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress #1 failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -1377,10 +1760,10 @@ func TestNICForwarding(t *testing.T) { // Send a packet to address 3. buf := buffer.NewView(30) buf[0] = 3 - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) select { - case <-linkEP2.C: + case <-ep2.C: default: t.Fatal("Packet not forwarded") } @@ -1394,9 +1777,3 @@ func TestNICForwarding(t *testing.T) { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } - -func init() { - stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol { - return &fakeNetworkProtocol{} - }) -} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index cf8a6d129..8c768c299 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -35,25 +35,109 @@ type protocolIDs struct { type transportEndpoints struct { // mu protects all fields of the transportEndpoints. mu sync.RWMutex - endpoints map[TransportEndpointID]TransportEndpoint + endpoints map[TransportEndpointID]*endpointsByNic // rawEndpoints contains endpoints for raw sockets, which receive all // traffic of a given protocol regardless of port. rawEndpoints []RawTransportEndpoint } +type endpointsByNic struct { + mu sync.RWMutex + endpoints map[tcpip.NICID]*multiPortEndpoint + // seed is a random secret for a jenkins hash. + seed uint32 +} + +// HandlePacket is called by the stack when new packets arrive to this transport +// endpoint. +func (epsByNic *endpointsByNic) handlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { + epsByNic.mu.RLock() + + mpep, ok := epsByNic.endpoints[r.ref.nic.ID()] + if !ok { + if mpep, ok = epsByNic.endpoints[0]; !ok { + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return + } + } + + // If this is a broadcast or multicast datagram, deliver the datagram to all + // endpoints bound to the right device. + if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { + mpep.handlePacketAll(r, id, vv) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. + return + } + + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandlePacket(r, id, vv) + epsByNic.mu.RUnlock() // Don't use defer for performance reasons. +} + +// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. +func (epsByNic *endpointsByNic) handleControlPacket(n *NIC, id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { + epsByNic.mu.RLock() + defer epsByNic.mu.RUnlock() + + mpep, ok := epsByNic.endpoints[n.ID()] + if !ok { + mpep, ok = epsByNic.endpoints[0] + } + if !ok { + return + } + + // TODO(eyalsoha): Why don't we look at id to see if this packet needs to + // broadcast like we are doing with handlePacket above? + + // multiPortEndpoints are guaranteed to have at least one element. + selectEndpoint(id, mpep, epsByNic.seed).HandleControlPacket(id, typ, extra, vv) +} + +// registerEndpoint returns true if it succeeds. It fails and returns +// false if ep already has an element with the same key. +func (epsByNic *endpointsByNic) registerEndpoint(t TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + + if multiPortEp, ok := epsByNic.endpoints[bindToDevice]; ok { + // There was already a bind. + return multiPortEp.singleRegisterEndpoint(t, reusePort) + } + + // This is a new binding. + multiPortEp := &multiPortEndpoint{} + multiPortEp.endpointsMap = make(map[TransportEndpoint]int) + multiPortEp.reuse = reusePort + epsByNic.endpoints[bindToDevice] = multiPortEp + return multiPortEp.singleRegisterEndpoint(t, reusePort) +} + +// unregisterEndpoint returns true if endpointsByNic has to be unregistered. +func (epsByNic *endpointsByNic) unregisterEndpoint(bindToDevice tcpip.NICID, t TransportEndpoint) bool { + epsByNic.mu.Lock() + defer epsByNic.mu.Unlock() + multiPortEp, ok := epsByNic.endpoints[bindToDevice] + if !ok { + return false + } + if multiPortEp.unregisterEndpoint(t) { + delete(epsByNic.endpoints, bindToDevice) + } + return len(epsByNic.endpoints) == 0 +} + // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint) { +func (eps *transportEndpoints) unregisterEndpoint(id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { eps.mu.Lock() defer eps.mu.Unlock() - e, ok := eps.endpoints[id] + epsByNic, ok := eps.endpoints[id] if !ok { return } - if multiPortEp, ok := e.(*multiPortEndpoint); ok { - if !multiPortEp.unregisterEndpoint(ep) { - return - } + if !epsByNic.unregisterEndpoint(bindToDevice, ep) { + return } delete(eps.endpoints, id) } @@ -75,7 +159,7 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { for netProto := range stack.networkProtocols { for proto := range stack.transportProtocols { d.protocol[protocolIDs{netProto, proto}] = &transportEndpoints{ - endpoints: make(map[TransportEndpointID]TransportEndpoint), + endpoints: make(map[TransportEndpointID]*endpointsByNic), } } } @@ -85,10 +169,10 @@ func newTransportDemuxer(stack *Stack) *transportDemuxer { // registerEndpoint registers the given endpoint with the dispatcher such that // packets that match the endpoint ID are delivered to it. -func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { for i, n := range netProtos { - if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort); err != nil { - d.unregisterEndpoint(netProtos[:i], protocol, id, ep) + if err := d.singleRegisterEndpoint(n, protocol, id, ep, reusePort, bindToDevice); err != nil { + d.unregisterEndpoint(netProtos[:i], protocol, id, ep, bindToDevice) return err } } @@ -97,13 +181,14 @@ func (d *transportDemuxer) registerEndpoint(netProtos []tcpip.NetworkProtocolNum } // multiPortEndpoint is a container for TransportEndpoints which are bound to -// the same pair of address and port. +// the same pair of address and port. endpointsArr always has at least one +// element. type multiPortEndpoint struct { mu sync.RWMutex endpointsArr []TransportEndpoint endpointsMap map[TransportEndpoint]int - // seed is a random secret for a jenkins hash. - seed uint32 + // reuse indicates if more than one endpoint is allowed. + reuse bool } // reciprocalScale scales a value into range [0, n). @@ -117,9 +202,10 @@ func reciprocalScale(val, n uint32) uint32 { // selectEndpoint calculates a hash of destination and source addresses and // ports then uses it to select a socket. In this case, all packets from one // address will be sent to same endpoint. -func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEndpoint { - ep.mu.RLock() - defer ep.mu.RUnlock() +func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32) TransportEndpoint { + if len(mpep.endpointsArr) == 1 { + return mpep.endpointsArr[0] + } payload := []byte{ byte(id.LocalPort), @@ -128,51 +214,50 @@ func (ep *multiPortEndpoint) selectEndpoint(id TransportEndpointID) TransportEnd byte(id.RemotePort >> 8), } - h := jenkins.Sum32(ep.seed) + h := jenkins.Sum32(seed) h.Write(payload) h.Write([]byte(id.LocalAddress)) h.Write([]byte(id.RemoteAddress)) hash := h.Sum32() - idx := reciprocalScale(hash, uint32(len(ep.endpointsArr))) - return ep.endpointsArr[idx] + idx := reciprocalScale(hash, uint32(len(mpep.endpointsArr))) + return mpep.endpointsArr[idx] } -// HandlePacket is called by the stack when new packets arrive to this transport -// endpoint. -func (ep *multiPortEndpoint) HandlePacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { - // If this is a broadcast or multicast datagram, deliver the datagram to all - // endpoints managed by ep. - if id.LocalAddress == header.IPv4Broadcast || header.IsV4MulticastAddress(id.LocalAddress) || header.IsV6MulticastAddress(id.LocalAddress) { - for i, endpoint := range ep.endpointsArr { - // HandlePacket modifies vv, so each endpoint needs its own copy. - if i == len(ep.endpointsArr)-1 { - endpoint.HandlePacket(r, id, vv) - break - } - vvCopy := buffer.NewView(vv.Size()) - copy(vvCopy, vv.ToView()) - endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) +func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, vv buffer.VectorisedView) { + ep.mu.RLock() + for i, endpoint := range ep.endpointsArr { + // HandlePacket modifies vv, so each endpoint needs its own copy except for + // the final one. + if i == len(ep.endpointsArr)-1 { + endpoint.HandlePacket(r, id, vv) + break } - } else { - ep.selectEndpoint(id).HandlePacket(r, id, vv) + vvCopy := buffer.NewView(vv.Size()) + copy(vvCopy, vv.ToView()) + endpoint.HandlePacket(r, id, vvCopy.ToVectorisedView()) } + ep.mu.RUnlock() // Don't use defer for performance reasons. } -// HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. -func (ep *multiPortEndpoint) HandleControlPacket(id TransportEndpointID, typ ControlType, extra uint32, vv buffer.VectorisedView) { - ep.selectEndpoint(id).HandleControlPacket(id, typ, extra, vv) -} - -func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint) { +// singleRegisterEndpoint tries to add an endpoint to the multiPortEndpoint +// list. The list might be empty already. +func (ep *multiPortEndpoint) singleRegisterEndpoint(t TransportEndpoint, reusePort bool) *tcpip.Error { ep.mu.Lock() defer ep.mu.Unlock() - // A new endpoint is added into endpointsArr and its index there is - // saved in endpointsMap. This will allows to remove endpoint from - // the array fast. + if len(ep.endpointsArr) > 0 { + // If it was previously bound, we need to check if we can bind again. + if !ep.reuse || !reusePort { + return tcpip.ErrPortInUse + } + } + + // A new endpoint is added into endpointsArr and its index there is saved in + // endpointsMap. This will allow us to remove endpoint from the array fast. ep.endpointsMap[t] = len(ep.endpointsArr) ep.endpointsArr = append(ep.endpointsArr, t) + return nil } // unregisterEndpoint returns true if multiPortEndpoint has to be unregistered. @@ -197,53 +282,41 @@ func (ep *multiPortEndpoint) unregisterEndpoint(t TransportEndpoint) bool { return true } -func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool) *tcpip.Error { +func (d *transportDemuxer) singleRegisterEndpoint(netProto tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, reusePort bool, bindToDevice tcpip.NICID) *tcpip.Error { if id.RemotePort != 0 { + // TODO(eyalsoha): Why? reusePort = false } eps, ok := d.protocol[protocolIDs{netProto, protocol}] if !ok { - return nil + return tcpip.ErrUnknownProtocol } eps.mu.Lock() defer eps.mu.Unlock() - var multiPortEp *multiPortEndpoint - if _, ok := eps.endpoints[id]; ok { - if !reusePort { - return tcpip.ErrPortInUse - } - multiPortEp, ok = eps.endpoints[id].(*multiPortEndpoint) - if !ok { - return tcpip.ErrPortInUse - } + if epsByNic, ok := eps.endpoints[id]; ok { + // There was already a binding. + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } - if reusePort { - if multiPortEp == nil { - multiPortEp = &multiPortEndpoint{} - multiPortEp.endpointsMap = make(map[TransportEndpoint]int) - multiPortEp.seed = rand.Uint32() - eps.endpoints[id] = multiPortEp - } - - multiPortEp.singleRegisterEndpoint(ep) - - return nil + // This is a new binding. + epsByNic := &endpointsByNic{ + endpoints: make(map[tcpip.NICID]*multiPortEndpoint), + seed: rand.Uint32(), } - eps.endpoints[id] = ep + eps.endpoints[id] = epsByNic - return nil + return epsByNic.registerEndpoint(ep, reusePort, bindToDevice) } // unregisterEndpoint unregisters the endpoint with the given id such that it // won't receive any more packets. -func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint) { +func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, bindToDevice tcpip.NICID) { for _, n := range netProtos { if eps, ok := d.protocol[protocolIDs{n, protocol}]; ok { - eps.unregisterEndpoint(id, ep) + eps.unregisterEndpoint(id, ep, bindToDevice) } } } @@ -273,7 +346,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // If the packet is a broadcast, then find all matching transport endpoints. // Otherwise, try to find a single matching transport endpoint. - destEps := make([]TransportEndpoint, 0, 1) + destEps := make([]*endpointsByNic, 0, 1) eps.mu.RLock() if protocol == header.UDPProtocolNumber && id.LocalAddress == header.IPv4Broadcast { @@ -299,7 +372,7 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // Deliver the packet. for _, ep := range destEps { - ep.HandlePacket(r, id, vv) + ep.handlePacket(r, id, vv) } return true @@ -331,7 +404,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr // deliverControlPacket attempts to deliver the given control packet. Returns // true if it found an endpoint, false otherwise. -func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { +func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtocolNumber, trans tcpip.TransportProtocolNumber, typ ControlType, extra uint32, vv buffer.VectorisedView, id TransportEndpointID) bool { eps, ok := d.protocol[protocolIDs{net, trans}] if !ok { return false @@ -348,12 +421,12 @@ func (d *transportDemuxer) deliverControlPacket(net tcpip.NetworkProtocolNumber, } // Deliver the packet. - ep.HandleControlPacket(id, typ, extra, vv) + ep.handleControlPacket(n, id, typ, extra, vv) return true } -func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) TransportEndpoint { +func (d *transportDemuxer) findEndpointLocked(eps *transportEndpoints, vv buffer.VectorisedView, id TransportEndpointID) *endpointsByNic { // Try to find a match with the id as provided. if ep, ok := eps.endpoints[id]; ok { return ep diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go new file mode 100644 index 000000000..210233dc0 --- /dev/null +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -0,0 +1,352 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack_test + +import ( + "math" + "math/rand" + "testing" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +const ( + stackV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + testV6Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + + stackAddr = "\x0a\x00\x00\x01" + stackPort = 1234 + testPort = 4096 +) + +type testContext struct { + t *testing.T + linkEPs map[string]*channel.Endpoint + s *stack.Stack + + ep tcpip.Endpoint + wq waiter.Queue +} + +func (c *testContext) cleanup() { + if c.ep != nil { + c.ep.Close() + } +} + +func (c *testContext) createV6Endpoint(v6only bool) { + var err *tcpip.Error + c.ep, err = c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &c.wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + var v tcpip.V6OnlyOption + if v6only { + v = 1 + } + if err := c.ep.SetSockOpt(v); err != nil { + c.t.Fatalf("SetSockOpt failed: %v", err) + } +} + +// newDualTestContextMultiNic creates the testing context and also linkEpNames +// named NICs. +func newDualTestContextMultiNic(t *testing.T, mtu uint32, linkEpNames []string) *testContext { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol(), ipv6.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + linkEPs := make(map[string]*channel.Endpoint) + for i, linkEpName := range linkEpNames { + channelEP := channel.New(256, mtu, "") + nicid := tcpip.NICID(i + 1) + if err := s.CreateNamedNIC(nicid, linkEpName, channelEP); err != nil { + t.Fatalf("CreateNIC failed: %v", err) + } + linkEPs[linkEpName] = channelEP + + if err := s.AddAddress(nicid, ipv4.ProtocolNumber, stackAddr); err != nil { + t.Fatalf("AddAddress IPv4 failed: %v", err) + } + + if err := s.AddAddress(nicid, ipv6.ProtocolNumber, stackV6Addr); err != nil { + t.Fatalf("AddAddress IPv6 failed: %v", err) + } + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: 1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + + return &testContext{ + t: t, + s: s, + linkEPs: linkEPs, + } +} + +type headers struct { + srcPort uint16 + dstPort uint16 +} + +func newPayload() []byte { + b := make([]byte, 30+rand.Intn(100)) + for i := range b { + b[i] = byte(rand.Intn(256)) + } + return b +} + +func (c *testContext) sendV6Packet(payload []byte, h *headers, linkEpName string) { + // Allocate a buffer for data and headers. + buf := buffer.NewView(header.UDPMinimumSize + header.IPv6MinimumSize + len(payload)) + copy(buf[len(buf)-len(payload):], payload) + + // Initialize the IP header. + ip := header.IPv6(buf) + ip.Encode(&header.IPv6Fields{ + PayloadLength: uint16(header.UDPMinimumSize + len(payload)), + NextHeader: uint8(udp.ProtocolNumber), + HopLimit: 65, + SrcAddr: testV6Addr, + DstAddr: stackV6Addr, + }) + + // Initialize the UDP header. + u := header.UDP(buf[header.IPv6MinimumSize:]) + u.Encode(&header.UDPFields{ + SrcPort: h.srcPort, + DstPort: h.dstPort, + Length: uint16(header.UDPMinimumSize + len(payload)), + }) + + // Calculate the UDP pseudo-header checksum. + xsum := header.PseudoHeaderChecksum(udp.ProtocolNumber, testV6Addr, stackV6Addr, uint16(len(u))) + + // Calculate the UDP checksum and set it. + xsum = header.Checksum(payload, xsum) + u.SetChecksum(^u.CalculateChecksum(xsum)) + + // Inject packet. + c.linkEPs[linkEpName].Inject(ipv6.ProtocolNumber, buf.ToVectorisedView()) +} + +func TestTransportDemuxerRegister(t *testing.T) { + for _, test := range []struct { + name string + proto tcpip.NetworkProtocolNumber + want *tcpip.Error + }{ + {"failure", ipv6.ProtocolNumber, tcpip.ErrUnknownProtocol}, + {"success", ipv4.ProtocolNumber, nil}, + } { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{ipv4.NewProtocol()}, + TransportProtocols: []stack.TransportProtocol{udp.NewProtocol()}}) + if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, nil, false, 0), test.want; got != want { + t.Fatalf("s.RegisterTransportEndpoint(...) = %v, want %v", got, want) + } + }) + } +} + +// TestReuseBindToDevice injects varied packets on input devices and checks that +// the distribution of packets received matches expectations. +func TestDistribution(t *testing.T) { + type endpointSockopts struct { + reuse int + bindToDevice string + } + for _, test := range []struct { + name string + // endpoints will received the inject packets. + endpoints []endpointSockopts + // wantedDistribution is the wanted ratio of packets received on each + // endpoint for each NIC on which packets are injected. + wantedDistributions map[string][]float64 + }{ + { + "BindPortReuse", + // 5 endpoints that all have reuse set. + []endpointSockopts{ + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed evenly. + "dev0": []float64{0.2, 0.2, 0.2, 0.2, 0.2}, + }, + }, + { + "BindToDevice", + // 3 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{0, "dev0"}, + endpointSockopts{0, "dev1"}, + endpointSockopts{0, "dev2"}, + }, + map[string][]float64{ + // Injected packets on dev0 go only to the endpoint bound to dev0. + "dev0": []float64{1, 0, 0}, + // Injected packets on dev1 go only to the endpoint bound to dev1. + "dev1": []float64{0, 1, 0}, + // Injected packets on dev2 go only to the endpoint bound to dev2. + "dev2": []float64{0, 0, 1}, + }, + }, + { + "ReuseAndBindToDevice", + // 6 endpoints with various bindings. + []endpointSockopts{ + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev0"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, "dev1"}, + endpointSockopts{1, ""}, + }, + map[string][]float64{ + // Injected packets on dev0 get distributed among endpoints bound to + // dev0. + "dev0": []float64{0.5, 0.5, 0, 0, 0, 0}, + // Injected packets on dev1 get distributed among endpoints bound to + // dev1 or unbound. + "dev1": []float64{0, 0, 1. / 3, 1. / 3, 1. / 3, 0}, + // Injected packets on dev999 go only to the unbound. + "dev999": []float64{0, 0, 0, 0, 0, 1}, + }, + }, + } { + t.Run(test.name, func(t *testing.T) { + for device, wantedDistribution := range test.wantedDistributions { + t.Run(device, func(t *testing.T) { + var devices []string + for d := range test.wantedDistributions { + devices = append(devices, d) + } + c := newDualTestContextMultiNic(t, defaultMTU, devices) + defer c.cleanup() + + c.createV6Endpoint(false) + + eps := make(map[tcpip.Endpoint]int) + + pollChannel := make(chan tcpip.Endpoint) + for i, endpoint := range test.endpoints { + // Try to receive the data. + wq := waiter.Queue{} + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + defer close(ch) + + var err *tcpip.Error + ep, err := c.s.NewEndpoint(udp.ProtocolNumber, ipv6.ProtocolNumber, &wq) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + eps[ep] = i + + go func(ep tcpip.Endpoint) { + for range ch { + pollChannel <- ep + } + }(ep) + + defer ep.Close() + reusePortOption := tcpip.ReusePortOption(endpoint.reuse) + if err := ep.SetSockOpt(reusePortOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", reusePortOption, i, err) + } + bindToDeviceOption := tcpip.BindToDeviceOption(endpoint.bindToDevice) + if err := ep.SetSockOpt(bindToDeviceOption); err != nil { + c.t.Fatalf("SetSockOpt(%#v) on endpoint %d failed: %v", bindToDeviceOption, i, err) + } + if err := ep.Bind(tcpip.FullAddress{Addr: stackV6Addr, Port: stackPort}); err != nil { + t.Fatalf("ep.Bind(...) on endpoint %d failed: %v", i, err) + } + } + + npackets := 100000 + nports := 10000 + if got, want := len(test.endpoints), len(wantedDistribution); got != want { + t.Fatalf("got len(test.endpoints) = %d, want %d", got, want) + } + ports := make(map[uint16]tcpip.Endpoint) + stats := make(map[tcpip.Endpoint]int) + for i := 0; i < npackets; i++ { + // Send a packet. + port := uint16(i % nports) + payload := newPayload() + c.sendV6Packet(payload, + &headers{ + srcPort: testPort + port, + dstPort: stackPort}, + device) + + var addr tcpip.FullAddress + ep := <-pollChannel + _, _, err := ep.Read(&addr) + if err != nil { + c.t.Fatalf("Read on endpoint %d failed: %v", eps[ep], err) + } + stats[ep]++ + if i < nports { + ports[uint16(i)] = ep + } else { + // Check that all packets from one client are handled by the same + // socket. + if want, got := ports[port], ep; want != got { + t.Fatalf("Packet sent on port %d expected on endpoint %d but received on endpoint %d", port, eps[want], eps[got]) + } + } + } + + // Check that a packet distribution is as expected. + for ep, i := range eps { + wantedRatio := wantedDistribution[i] + wantedRecv := wantedRatio * float64(npackets) + actualRecv := stats[ep] + actualRatio := float64(stats[ep]) / float64(npackets) + // The deviation is less than 10%. + if math.Abs(actualRatio-wantedRatio) > 0.05 { + t.Errorf("wanted about %.0f%% (%.0f of %d) packets to arrive on endpoint %d, got %.0f%% (%d of %d)", wantedRatio*100, wantedRecv, npackets, i, actualRatio*100, actualRecv, npackets) + } + } + }) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5335897f5..842a16277 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -65,13 +65,13 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { return 0, nil, tcpip.ErrNoRoute } hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } @@ -91,6 +91,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// SetSockOptInt sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption @@ -122,7 +127,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Try to register so that we can start receiving packets. f.id.RemoteAddress = addr.Addr - err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false) + err = f.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.id, f, false /* reuse */, 0 /* bindToDevice */) if err != nil { return err } @@ -163,7 +168,8 @@ func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error { fakeTransNumber, stack.TransportEndpointID{LocalAddress: a.Addr}, f, - false, + false, /* reuse */ + 0, /* bindtoDevice */ ); err != nil { return err } @@ -251,7 +257,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { return true } @@ -277,10 +283,17 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } } +func fakeTransFactory() stack.TransportProtocol { + return &fakeTransportProtocol{} +} + func TestTransportReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + linkEP := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -340,9 +353,12 @@ func TestTransportReceive(t *testing.T) { } func TestTransportControlReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + linkEP := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -408,9 +424,12 @@ func TestTransportControlReceive(t *testing.T) { } func TestTransportSend(t *testing.T) { - id, _ := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + linkEP := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -452,7 +471,10 @@ func TestTransportSend(t *testing.T) { } func TestTransportOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) // Try an unsupported transport protocol. if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { @@ -493,20 +515,23 @@ func TestTransportOptions(t *testing.T) { } func TestTransportForwarding(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) s.SetForwarding(true) // TODO(b/123449044): Change this to a channel NIC. - id1 := loopback.New() - if err := s.CreateNIC(1, id1); err != nil { + ep1 := loopback.New() + if err := s.CreateNIC(1, ep1); err != nil { t.Fatalf("CreateNIC #1 failed: %v", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress #1 failed: %v", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatalf("CreateNIC #2 failed: %v", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -545,7 +570,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - linkEP2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.Inject(fakeNetNumber, req.ToVectorisedView()) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -559,7 +584,7 @@ func TestTransportForwarding(t *testing.T) { var p channel.PacketInfo select { - case p = <-linkEP2.C: + case p = <-ep2.C: default: t.Fatal("Response packet not forwarded") } @@ -571,9 +596,3 @@ func TestTransportForwarding(t *testing.T) { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } - -func init() { - stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol { - return &fakeTransportProtocol{} - }) -} |