diff options
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r-- | pkg/tcpip/stack/BUILD | 28 | ||||
-rw-r--r-- | pkg/tcpip/stack/icmp_rate_limit.go | 41 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 253 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache_test.go | 79 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 257 | ||||
-rw-r--r-- | pkg/tcpip/stack/registration.go | 75 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 178 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 574 | ||||
-rw-r--r-- | pkg/tcpip/stack/transport_test.go | 70 |
10 files changed, 898 insertions, 662 deletions
diff --git a/pkg/tcpip/stack/BUILD b/pkg/tcpip/stack/BUILD index 9986b4be3..28c49e8ff 100644 --- a/pkg/tcpip/stack/BUILD +++ b/pkg/tcpip/stack/BUILD @@ -1,11 +1,28 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_stateify:defs.bzl", "go_library") + +go_template_instance( + name = "linkaddrentry_list", + out = "linkaddrentry_list.go", + package = "stack", + prefix = "linkAddrEntry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*linkAddrEntry", + "Linker": "*linkAddrEntry", + }, +) go_library( name = "stack", srcs = [ + "icmp_rate_limit.go", "linkaddrcache.go", + "linkaddrentry_list.go", "nic.go", "registration.go", "route.go", @@ -28,6 +45,7 @@ go_library( "//pkg/tcpip/ports", "//pkg/tcpip/seqnum", "//pkg/waiter", + "@org_golang_x_time//rate:go_default_library", ], ) @@ -60,3 +78,11 @@ go_test( "//pkg/tcpip", ], ) + +filegroup( + name = "autogen", + srcs = [ + "linkaddrentry_list.go", + ], + visibility = ["//:sandbox"], +) diff --git a/pkg/tcpip/stack/icmp_rate_limit.go b/pkg/tcpip/stack/icmp_rate_limit.go new file mode 100644 index 000000000..3a20839da --- /dev/null +++ b/pkg/tcpip/stack/icmp_rate_limit.go @@ -0,0 +1,41 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package stack + +import ( + "golang.org/x/time/rate" +) + +const ( + // icmpLimit is the default maximum number of ICMP messages permitted by this + // rate limiter. + icmpLimit = 1000 + + // icmpBurst is the default number of ICMP messages that can be sent in a single + // burst. + icmpBurst = 50 +) + +// ICMPRateLimiter is a global rate limiter that controls the generation of +// ICMP messages generated by the stack. +type ICMPRateLimiter struct { + *rate.Limiter +} + +// NewICMPRateLimiter returns a global rate limiter for controlling the rate +// at which ICMP messages are generated by the stack. +func NewICMPRateLimiter() *ICMPRateLimiter { + return &ICMPRateLimiter{Limiter: rate.NewLimiter(icmpLimit, icmpBurst)} +} diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 77bb0ccb9..267df60d1 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -42,10 +42,11 @@ type linkAddrCache struct { // resolved before failing. resolutionAttempts int - mu sync.Mutex - cache map[tcpip.FullAddress]*linkAddrEntry - next int // array index of next available entry - entries [linkAddrCacheSize]linkAddrEntry + cache struct { + sync.Mutex + table map[tcpip.FullAddress]*linkAddrEntry + lru linkAddrEntryList + } } // entryState controls the state of a single entry in the cache. @@ -60,9 +61,6 @@ const ( // failed means that address resolution timed out and the address // could not be resolved. failed - // expired means that the cache entry has expired and the address must be - // resolved again. - expired ) // String implements Stringer. @@ -74,8 +72,6 @@ func (s entryState) String() string { return "ready" case failed: return "failed" - case expired: - return "expired" default: return fmt.Sprintf("unknown(%d)", s) } @@ -84,64 +80,46 @@ func (s entryState) String() string { // A linkAddrEntry is an entry in the linkAddrCache. // This struct is thread-compatible. type linkAddrEntry struct { + linkAddrEntryEntry + addr tcpip.FullAddress linkAddr tcpip.LinkAddress expiration time.Time s entryState // wakers is a set of waiters for address resolution result. Anytime - // state transitions out of 'incomplete' these waiters are notified. + // state transitions out of incomplete these waiters are notified. wakers map[*sleep.Waker]struct{} + // done is used to allow callers to wait on address resolution. It is nil iff + // s is incomplete and resolution is not yet in progress. done chan struct{} } -func (e *linkAddrEntry) state() entryState { - if e.s != expired && time.Now().After(e.expiration) { - // Force the transition to ensure waiters are notified. - e.changeState(expired) - } - return e.s -} - -func (e *linkAddrEntry) changeState(ns entryState) { - if e.s == ns { - return - } - - // Validate state transition. - switch e.s { - case incomplete: - // All transitions are valid. - case ready, failed: - if ns != expired { - panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) - } - case expired: - // Terminal state. - panic(fmt.Sprintf("invalid state transition from %s to %s", e.s, ns)) - default: - panic(fmt.Sprintf("invalid state: %s", e.s)) - } - +// changeState sets the entry's state to ns, notifying any waiters. +// +// The entry's expiration is bumped up to the greater of itself and the passed +// expiration; the zero value indicates immediate expiration, and is set +// unconditionally - this is an implementation detail that allows for entries +// to be reused. +func (e *linkAddrEntry) changeState(ns entryState, expiration time.Time) { // Notify whoever is waiting on address resolution when transitioning - // out of 'incomplete'. - if e.s == incomplete { + // out of incomplete. + if e.s == incomplete && ns != incomplete { for w := range e.wakers { w.Assert() } e.wakers = nil - if e.done != nil { - close(e.done) + if ch := e.done; ch != nil { + close(ch) } + e.done = nil } - e.s = ns -} -func (e *linkAddrEntry) maybeAddWaker(w *sleep.Waker) { - if w != nil { - e.wakers[w] = struct{}{} + if expiration.IsZero() || expiration.After(e.expiration) { + e.expiration = expiration } + e.s = ns } func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { @@ -150,53 +128,54 @@ func (e *linkAddrEntry) removeWaker(w *sleep.Waker) { // add adds a k -> v mapping to the cache. func (c *linkAddrCache) add(k tcpip.FullAddress, v tcpip.LinkAddress) { - c.mu.Lock() - defer c.mu.Unlock() - - entry, ok := c.cache[k] - if ok { - s := entry.state() - if s != expired && entry.linkAddr == v { - // Disregard repeated calls. - return - } - // Check if entry is waiting for address resolution. - if s == incomplete { - entry.linkAddr = v - } else { - // Otherwise create a new entry to replace it. - entry = c.makeAndAddEntry(k, v) - } - } else { - entry = c.makeAndAddEntry(k, v) - } + // Calculate expiration time before acquiring the lock, since expiration is + // relative to the time when information was learned, rather than when it + // happened to be inserted into the cache. + expiration := time.Now().Add(c.ageLimit) - entry.changeState(ready) + c.cache.Lock() + entry := c.getOrCreateEntryLocked(k) + entry.linkAddr = v + + entry.changeState(ready, expiration) + c.cache.Unlock() } -// makeAndAddEntry is a helper function to create and add a new -// entry to the cache map and evict older entry as needed. -func (c *linkAddrCache) makeAndAddEntry(k tcpip.FullAddress, v tcpip.LinkAddress) *linkAddrEntry { - // Take over the next entry. - entry := &c.entries[c.next] - if c.cache[entry.addr] == entry { - delete(c.cache, entry.addr) +// getOrCreateEntryLocked retrieves a cache entry associated with k. The +// returned entry is always refreshed in the cache (it is reachable via the +// map, and its place is bumped in LRU). +// +// If a matching entry exists in the cache, it is returned. If no matching +// entry exists and the cache is full, an existing entry is evicted via LRU, +// reset to state incomplete, and returned. If no matching entry exists and the +// cache is not full, a new entry with state incomplete is allocated and +// returned. +func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEntry { + if entry, ok := c.cache.table[k]; ok { + c.cache.lru.Remove(entry) + c.cache.lru.PushFront(entry) + return entry } + var entry *linkAddrEntry + if len(c.cache.table) == linkAddrCacheSize { + entry = c.cache.lru.Back() - // Mark the soon-to-be-replaced entry as expired, just in case there is - // someone waiting for address resolution on it. - entry.changeState(expired) + delete(c.cache.table, entry.addr) + c.cache.lru.Remove(entry) - *entry = linkAddrEntry{ - addr: k, - linkAddr: v, - expiration: time.Now().Add(c.ageLimit), - wakers: make(map[*sleep.Waker]struct{}), - done: make(chan struct{}), + // Wake waiters and mark the soon-to-be-reused entry as expired. Note + // that the state passed doesn't matter when the zero time is passed. + entry.changeState(failed, time.Time{}) + } else { + entry = new(linkAddrEntry) } - c.cache[k] = entry - c.next = (c.next + 1) % len(c.entries) + *entry = linkAddrEntry{ + addr: k, + s: incomplete, + } + c.cache.table[k] = entry + c.cache.lru.PushFront(entry) return entry } @@ -208,43 +187,55 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo } } - c.mu.Lock() - defer c.mu.Unlock() - if entry, ok := c.cache[k]; ok { - switch s := entry.state(); s { - case expired: - case ready: - return entry.linkAddr, nil, nil - case failed: - return "", nil, tcpip.ErrNoLinkAddress - case incomplete: - // Address resolution is still in progress. - entry.maybeAddWaker(waker) - return "", entry.done, tcpip.ErrWouldBlock - default: - panic(fmt.Sprintf("invalid cache entry state: %s", s)) + c.cache.Lock() + defer c.cache.Unlock() + entry := c.getOrCreateEntryLocked(k) + switch s := entry.s; s { + case ready, failed: + if !time.Now().After(entry.expiration) { + // Not expired. + switch s { + case ready: + return entry.linkAddr, nil, nil + case failed: + return entry.linkAddr, nil, tcpip.ErrNoLinkAddress + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } } - } - if linkRes == nil { - return "", nil, tcpip.ErrNoLinkAddress - } + entry.changeState(incomplete, time.Time{}) + fallthrough + case incomplete: + if waker != nil { + if entry.wakers == nil { + entry.wakers = make(map[*sleep.Waker]struct{}) + } + entry.wakers[waker] = struct{}{} + } - // Add 'incomplete' entry in the cache to mark that resolution is in progress. - e := c.makeAndAddEntry(k, "") - e.maybeAddWaker(waker) + if entry.done == nil { + // Address resolution needs to be initiated. + if linkRes == nil { + return entry.linkAddr, nil, tcpip.ErrNoLinkAddress + } - go c.startAddressResolution(k, linkRes, localAddr, linkEP, e.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + entry.done = make(chan struct{}) + go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + } - return "", e.done, tcpip.ErrWouldBlock + return entry.linkAddr, entry.done, tcpip.ErrWouldBlock + default: + panic(fmt.Sprintf("invalid cache entry state: %s", s)) + } } // removeWaker removes a waker previously added through get(). func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { - c.mu.Lock() - defer c.mu.Unlock() + c.cache.Lock() + defer c.cache.Unlock() - if entry, ok := c.cache[k]; ok { + if entry, ok := c.cache.table[k]; ok { entry.removeWaker(waker) } } @@ -256,8 +247,8 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link linkRes.LinkAddressRequest(k.Addr, localAddr, linkEP) select { - case <-time.After(c.resolutionTimeout): - if stop := c.checkLinkRequest(k, i); stop { + case now := <-time.After(c.resolutionTimeout): + if stop := c.checkLinkRequest(now, k, i); stop { return } case <-done: @@ -269,38 +260,36 @@ func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes Link // checkLinkRequest checks whether previous attempt to resolve address has succeeded // and mark the entry accordingly, e.g. ready, failed, etc. Return true if request // can stop, false if another request should be sent. -func (c *linkAddrCache) checkLinkRequest(k tcpip.FullAddress, attempt int) bool { - c.mu.Lock() - defer c.mu.Unlock() - - entry, ok := c.cache[k] +func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.FullAddress, attempt int) bool { + c.cache.Lock() + defer c.cache.Unlock() + entry, ok := c.cache.table[k] if !ok { // Entry was evicted from the cache. return true } - - switch s := entry.state(); s { - case ready, failed, expired: + switch s := entry.s; s { + case ready, failed: // Entry was made ready by resolver or failed. Either way we're done. - return true case incomplete: - if attempt+1 >= c.resolutionAttempts { - // Max number of retries reached, mark entry as failed. - entry.changeState(failed) - return true + if attempt+1 < c.resolutionAttempts { + // No response yet, need to send another ARP request. + return false } - // No response yet, need to send another ARP request. - return false + // Max number of retries reached, mark entry as failed. + entry.changeState(failed, now.Add(c.ageLimit)) default: panic(fmt.Sprintf("invalid cache entry state: %s", s)) } + return true } func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache { - return &linkAddrCache{ + c := &linkAddrCache{ ageLimit: ageLimit, resolutionTimeout: resolutionTimeout, resolutionAttempts: resolutionAttempts, - cache: make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize), } + c.cache.table = make(map[tcpip.FullAddress]*linkAddrEntry, linkAddrCacheSize) + return c } diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 924f4d240..9946b8fe8 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -17,6 +17,7 @@ package stack import ( "fmt" "sync" + "sync/atomic" "testing" "time" @@ -29,25 +30,34 @@ type testaddr struct { linkAddr tcpip.LinkAddress } -var testaddrs []testaddr +var testAddrs = func() []testaddr { + var addrs []testaddr + for i := 0; i < 4*linkAddrCacheSize; i++ { + addr := fmt.Sprintf("Addr%06d", i) + addrs = append(addrs, testaddr{ + addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, + linkAddr: tcpip.LinkAddress("Link" + addr), + }) + } + return addrs +}() type testLinkAddressResolver struct { - cache *linkAddrCache - delay time.Duration + cache *linkAddrCache + delay time.Duration + onLinkAddressRequest func() } func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ LinkEndpoint) *tcpip.Error { - go func() { - if r.delay > 0 { - time.Sleep(r.delay) - } - r.fakeRequest(addr) - }() + time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) + if f := r.onLinkAddressRequest; f != nil { + f() + } return nil } func (r *testLinkAddressResolver) fakeRequest(addr tcpip.Address) { - for _, ta := range testaddrs { + for _, ta := range testAddrs { if ta.addr.Addr == addr { r.cache.add(ta.addr, ta.linkAddr) break @@ -80,20 +90,10 @@ func getBlocking(c *linkAddrCache, addr tcpip.FullAddress, linkRes LinkAddressRe } } -func init() { - for i := 0; i < 4*linkAddrCacheSize; i++ { - addr := fmt.Sprintf("Addr%06d", i) - testaddrs = append(testaddrs, testaddr{ - addr: tcpip.FullAddress{NIC: 1, Addr: tcpip.Address(addr)}, - linkAddr: tcpip.LinkAddress("Link" + addr), - }) - } -} - func TestCacheOverflow(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - for i := len(testaddrs) - 1; i >= 0; i-- { - e := testaddrs[i] + for i := len(testAddrs) - 1; i >= 0; i-- { + e := testAddrs[i] c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { @@ -105,7 +105,7 @@ func TestCacheOverflow(t *testing.T) { } // Expect to find at least half of the most recent entries. for i := 0; i < linkAddrCacheSize/2; i++ { - e := testaddrs[i] + e := testAddrs[i] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(e.addr.Addr), got, err) @@ -115,8 +115,8 @@ func TestCacheOverflow(t *testing.T) { } } // The earliest entries should no longer be in the cache. - for i := len(testaddrs) - 1; i >= len(testaddrs)-linkAddrCacheSize; i-- { - e := testaddrs[i] + for i := len(testAddrs) - 1; i >= len(testAddrs)-linkAddrCacheSize; i-- { + e := testAddrs[i] if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("check %d, c.get(%q), got error: %v, want: error ErrNoLinkAddress", i, string(e.addr.Addr), err) } @@ -130,7 +130,7 @@ func TestCacheConcurrent(t *testing.T) { for r := 0; r < 16; r++ { wg.Add(1) go func() { - for _, e := range testaddrs { + for _, e := range testAddrs { c.add(e.addr, e.linkAddr) c.get(e.addr, nil, "", nil, nil) // make work for gotsan } @@ -142,7 +142,7 @@ func TestCacheConcurrent(t *testing.T) { // All goroutines add in the same order and add more values than // can fit in the cache, so our eviction strategy requires that // the last entry be present and the first be missing. - e := testaddrs[len(testaddrs)-1] + e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, nil, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -151,7 +151,7 @@ func TestCacheConcurrent(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) } - e = testaddrs[0] + e = testAddrs[0] if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } @@ -159,7 +159,7 @@ func TestCacheConcurrent(t *testing.T) { func TestCacheAgeLimit(t *testing.T) { c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3) - e := testaddrs[0] + e := testAddrs[0] c.add(e.addr, e.linkAddr) time.Sleep(50 * time.Millisecond) if _, _, err := c.get(e.addr, nil, "", nil, nil); err != tcpip.ErrNoLinkAddress { @@ -169,7 +169,7 @@ func TestCacheAgeLimit(t *testing.T) { func TestCacheReplace(t *testing.T) { c := newLinkAddrCache(1<<63-1, 1*time.Second, 3) - e := testaddrs[0] + e := testAddrs[0] l2 := e.linkAddr + "2" c.add(e.addr, e.linkAddr) got, _, err := c.get(e.addr, nil, "", nil, nil) @@ -193,7 +193,7 @@ func TestCacheReplace(t *testing.T) { func TestCacheResolution(t *testing.T) { c := newLinkAddrCache(1<<63-1, 250*time.Millisecond, 1) linkRes := &testLinkAddressResolver{cache: c} - for i, ta := range testaddrs { + for i, ta := range testAddrs { got, err := getBlocking(c, ta.addr, linkRes) if err != nil { t.Errorf("check %d, c.get(%q)=%q, got error: %v", i, string(ta.addr.Addr), got, err) @@ -205,7 +205,7 @@ func TestCacheResolution(t *testing.T) { // Check that after resolved, address stays in the cache and never returns WouldBlock. for i := 0; i < 10; i++ { - e := testaddrs[len(testaddrs)-1] + e := testAddrs[len(testAddrs)-1] got, _, err := c.get(e.addr, linkRes, "", nil, nil) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -220,8 +220,13 @@ func TestCacheResolutionFailed(t *testing.T) { c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5) linkRes := &testLinkAddressResolver{cache: c} + var requestCount uint32 + linkRes.onLinkAddressRequest = func() { + atomic.AddUint32(&requestCount, 1) + } + // First, sanity check that resolution is working... - e := testaddrs[0] + e := testAddrs[0] got, err := getBlocking(c, e.addr, linkRes) if err != nil { t.Errorf("c.get(%q)=%q, got error: %v", string(e.addr.Addr), got, err) @@ -230,10 +235,16 @@ func TestCacheResolutionFailed(t *testing.T) { t.Errorf("c.get(%q)=%q, want %q", string(e.addr.Addr), got, e.linkAddr) } + before := atomic.LoadUint32(&requestCount) + e.addr.Addr += "2" if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } + + if got, want := int(atomic.LoadUint32(&requestCount)-before), c.resolutionAttempts; got != want { + t.Errorf("got link address request count = %d, want = %d", got, want) + } } func TestCacheResolutionTimeout(t *testing.T) { @@ -242,7 +253,7 @@ func TestCacheResolutionTimeout(t *testing.T) { c := newLinkAddrCache(expiration, 1*time.Millisecond, 3) linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay} - e := testaddrs[0] + e := testAddrs[0] if _, err := getBlocking(c, e.addr, linkRes); err != tcpip.ErrNoLinkAddress { t.Errorf("c.get(%q), got error: %v, want: error ErrNoLinkAddress", string(e.addr.Addr), err) } diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 04b63d783..0e8a23f00 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -36,13 +36,13 @@ type NIC struct { demux *transportDemuxer - mu sync.RWMutex - spoofing bool - promiscuous bool - primary map[tcpip.NetworkProtocolNumber]*ilist.List - endpoints map[NetworkEndpointID]*referencedNetworkEndpoint - subnets []tcpip.Subnet - mcastJoins map[NetworkEndpointID]int32 + mu sync.RWMutex + spoofing bool + promiscuous bool + primary map[tcpip.NetworkProtocolNumber]*ilist.List + endpoints map[NetworkEndpointID]*referencedNetworkEndpoint + addressRanges []tcpip.Subnet + mcastJoins map[NetworkEndpointID]int32 stats NICStats } @@ -102,6 +102,25 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, loopback } } +// enable enables the NIC. enable will attach the link to its LinkEndpoint and +// join the IPv6 All-Nodes Multicast address (ff02::1). +func (n *NIC) enable() *tcpip.Error { + n.attachLinkEndpoint() + + // Join the IPv6 All-Nodes Multicast group if the stack is configured to + // use IPv6. This is required to ensure that this node properly receives + // and responds to the various NDP messages that are destined to the + // all-nodes multicast address. An example is the Neighbor Advertisement + // when we perform Duplicate Address Detection, or Router Advertisement + // when we do Router Discovery. See RFC 4862, section 5.4.2 and RFC 4861 + // section 4.2 for more information. + if _, ok := n.stack.networkProtocols[header.IPv6ProtocolNumber]; ok { + return n.joinGroup(header.IPv6ProtocolNumber, header.IPv6AllNodesMulticastAddress) + } + + return nil +} + // attachLinkEndpoint attaches the NIC to the endpoint, which will enable it // to start delivering packets. func (n *NIC) attachLinkEndpoint() { @@ -129,37 +148,6 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { - n.mu.RLock() - defer n.mu.RUnlock() - - var r *referencedNetworkEndpoint - - // Check for a primary endpoint. - if list, ok := n.primary[protocol]; ok { - for e := list.Front(); e != nil; e = e.Next() { - ref := e.(*referencedNetworkEndpoint) - if ref.kind == permanent && ref.tryIncRef() { - r = ref - break - } - } - - } - - if r == nil { - return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress - } - - addressWithPrefix := tcpip.AddressWithPrefix{ - Address: r.ep.ID().LocalAddress, - PrefixLen: r.ep.PrefixLen(), - } - r.decRef() - - return addressWithPrefix, nil -} - // primaryEndpoint returns the primary endpoint of n for the given network // protocol. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { @@ -205,7 +193,7 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t if ref, ok := n.endpoints[id]; ok { // An endpoint with this id exists, check if it can be used and return it. - switch ref.kind { + switch ref.getKind() { case permanentExpired: if !spoofingOrPromiscuous { n.mu.RUnlock() @@ -224,7 +212,17 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t // the caller or if the address is found in the NIC's subnets. createTempEP := spoofingOrPromiscuous if !createTempEP { - for _, sn := range n.subnets { + for _, sn := range n.addressRanges { + // Skip the subnet address. + if address == sn.ID() { + continue + } + // For now just skip the broadcast address, until we support it. + // FIXME(b/137608825): Add support for sending/receiving directed + // (subnet) broadcast. + if address == sn.Broadcast() { + continue + } if sn.Contains(address) { createTempEP = true break @@ -276,14 +274,14 @@ func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address t func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior) (*referencedNetworkEndpoint, *tcpip.Error) { id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} if ref, ok := n.endpoints[id]; ok { - switch ref.kind { + switch ref.getKind() { case permanent: // The NIC already have a permanent endpoint with that address. return nil, tcpip.ErrDuplicateAddress case permanentExpired, temporary: // Promote the endpoint to become permanent. if ref.tryIncRef() { - ref.kind = permanent + ref.setKind(permanent) return ref, nil } // tryIncRef failing means the endpoint is scheduled to be removed once @@ -297,6 +295,8 @@ func (n *NIC) addPermanentAddressLocked(protocolAddress tcpip.ProtocolAddress, p } func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpointBehavior, kind networkEndpointKind) (*referencedNetworkEndpoint, *tcpip.Error) { + // TODO(b/141022673): Validate IP address before adding them. + // Sanity check. id := NetworkEndpointID{protocolAddress.AddressWithPrefix.Address} if _, ok := n.endpoints[id]; ok { @@ -329,6 +329,15 @@ func (n *NIC) addAddressLocked(protocolAddress tcpip.ProtocolAddress, peb Primar } } + // If we are adding an IPv6 unicast address, join the solicited-node + // multicast address. + if protocolAddress.Protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(protocolAddress.AddressWithPrefix.Address) { + snmc := header.SolicitedNodeAddr(protocolAddress.AddressWithPrefix.Address) + if err := n.joinGroupLocked(protocolAddress.Protocol, snmc); err != nil { + return nil, err + } + } + n.endpoints[id] = ref l, ok := n.primary[protocolAddress.Protocol] @@ -358,15 +367,17 @@ func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return err } -// Addresses returns the addresses associated with this NIC. -func (n *NIC) Addresses() []tcpip.ProtocolAddress { +// AllAddresses returns all addresses (primary and non-primary) associated with +// this NIC. +func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { n.mu.RLock() defer n.mu.RUnlock() + addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { // Don't include expired or tempory endpoints to avoid confusion and // prevent the caller from using those. - switch ref.kind { + switch ref.getKind() { case permanentExpired, temporary: continue } @@ -381,45 +392,66 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { return addrs } -// AddSubnet adds a new subnet to n, so that it starts accepting packets -// targeted at the given address and network protocol. -func (n *NIC) AddSubnet(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { +// PrimaryAddresses returns the primary addresses associated with this NIC. +func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { + n.mu.RLock() + defer n.mu.RUnlock() + + var addrs []tcpip.ProtocolAddress + for proto, list := range n.primary { + for e := list.Front(); e != nil; e = e.Next() { + ref := e.(*referencedNetworkEndpoint) + // Don't include expired or tempory endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.getKind() { + case permanentExpired, temporary: + continue + } + + addrs = append(addrs, tcpip.ProtocolAddress{ + Protocol: proto, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: ref.ep.ID().LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + }, + }) + } + } + return addrs +} + +// AddAddressRange adds a range of addresses to n, so that it starts accepting +// packets targeted at the given addresses and network protocol. The range is +// given by a subnet address, and all addresses contained in the subnet are +// used except for the subnet address itself and the subnet's broadcast +// address. +func (n *NIC) AddAddressRange(protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) { n.mu.Lock() - n.subnets = append(n.subnets, subnet) + n.addressRanges = append(n.addressRanges, subnet) n.mu.Unlock() } -// RemoveSubnet removes the given subnet from n. -func (n *NIC) RemoveSubnet(subnet tcpip.Subnet) { +// RemoveAddressRange removes the given address range from n. +func (n *NIC) RemoveAddressRange(subnet tcpip.Subnet) { n.mu.Lock() // Use the same underlying array. - tmp := n.subnets[:0] - for _, sub := range n.subnets { + tmp := n.addressRanges[:0] + for _, sub := range n.addressRanges { if sub != subnet { tmp = append(tmp, sub) } } - n.subnets = tmp + n.addressRanges = tmp n.mu.Unlock() } -// ContainsSubnet reports whether this NIC contains the given subnet. -func (n *NIC) ContainsSubnet(subnet tcpip.Subnet) bool { - for _, s := range n.Subnets() { - if s == subnet { - return true - } - } - return false -} - // Subnets returns the Subnets associated with this NIC. -func (n *NIC) Subnets() []tcpip.Subnet { +func (n *NIC) AddressRanges() []tcpip.Subnet { n.mu.RLock() defer n.mu.RUnlock() - sns := make([]tcpip.Subnet, 0, len(n.subnets)+len(n.endpoints)) + sns := make([]tcpip.Subnet, 0, len(n.addressRanges)+len(n.endpoints)) for nid := range n.endpoints { sn, err := tcpip.NewSubnet(nid.LocalAddress, tcpip.AddressMask(strings.Repeat("\xff", len(nid.LocalAddress)))) if err != nil { @@ -429,7 +461,7 @@ func (n *NIC) Subnets() []tcpip.Subnet { } sns = append(sns, sn) } - return append(sns, n.subnets...) + return append(sns, n.addressRanges...) } func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { @@ -444,7 +476,7 @@ func (n *NIC) removeEndpointLocked(r *referencedNetworkEndpoint) { return } - if r.kind == permanent { + if r.getKind() == permanent { panic("Reference count dropped to zero before being removed") } @@ -464,13 +496,27 @@ func (n *NIC) removeEndpoint(r *referencedNetworkEndpoint) { } func (n *NIC) removePermanentAddressLocked(addr tcpip.Address) *tcpip.Error { - r := n.endpoints[NetworkEndpointID{addr}] - if r == nil || r.kind != permanent { + r, ok := n.endpoints[NetworkEndpointID{addr}] + if !ok || r.getKind() != permanent { return tcpip.ErrBadLocalAddress } - r.kind = permanentExpired - r.decRefLocked() + r.setKind(permanentExpired) + if !r.decRefLocked() { + // The endpoint still has references to it. + return nil + } + + // At this point the endpoint is deleted. + + // If we are removing an IPv6 unicast address, leave the solicited-node + // multicast address. + if r.protocol == header.IPv6ProtocolNumber && header.IsV6UnicastAddress(addr) { + snmc := header.SolicitedNodeAddr(addr) + if err := n.leaveGroupLocked(snmc); err != nil { + return err + } + } return nil } @@ -488,6 +534,13 @@ func (n *NIC) joinGroup(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address n.mu.Lock() defer n.mu.Unlock() + return n.joinGroupLocked(protocol, addr) +} + +// joinGroupLocked adds a new endpoint for the given multicast address, if none +// exists yet. Otherwise it just increments its count. n MUST be locked before +// joinGroupLocked is called. +func (n *NIC) joinGroupLocked(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] if joins == 0 { @@ -515,6 +568,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { n.mu.Lock() defer n.mu.Unlock() + return n.leaveGroupLocked(addr) +} + +// leaveGroupLocked decrements the count for the given multicast address, and +// when it reaches zero removes the endpoint for this address. n MUST be locked +// before leaveGroupLocked is called. +func (n *NIC) leaveGroupLocked(addr tcpip.Address) *tcpip.Error { id := NetworkEndpointID{addr} joins := n.mcastJoins[id] switch joins { @@ -531,6 +591,13 @@ func (n *NIC) leaveGroup(addr tcpip.Address) *tcpip.Error { return nil } +func handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, localLinkAddr, remotelinkAddr tcpip.LinkAddress, ref *referencedNetworkEndpoint, vv buffer.VectorisedView) { + r := makeRoute(protocol, dst, src, localLinkAddr, ref, false /* handleLocal */, false /* multicastLoop */) + r.RemoteLinkAddress = remotelinkAddr + ref.ep.HandlePacket(&r, vv) + ref.decRef() +} + // DeliverNetworkPacket finds the appropriate network protocol endpoint and // hands the packet over for further processing. This function is called when // the NIC receives a packet from the physical interface. @@ -558,6 +625,8 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr src, dst := netProto.ParseAddresses(vv.First()) + n.stack.AddLinkAddress(n.id, src, remote) + // If the packet is destined to the IPv4 Broadcast address, then make a // route to each IPv4 network endpoint and let each endpoint handle the // packet. @@ -566,10 +635,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n.mu.RLock() for _, ref := range n.endpoints { if ref.isValidForIncoming() && ref.protocol == header.IPv4ProtocolNumber && ref.tryIncRef() { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) } } n.mu.RUnlock() @@ -577,10 +643,7 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr } if ref := n.getRef(protocol, dst); ref != nil { - r := makeRoute(protocol, dst, src, linkEP.LinkAddress(), ref, false /* handleLocal */, false /* multicastLoop */) - r.RemoteLinkAddress = remote - ref.ep.HandlePacket(&r, vv) - ref.decRef() + handlePacket(protocol, dst, src, linkEP.LinkAddress(), remote, ref, vv) return } @@ -676,7 +739,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet, so // deliver it to the global handler. - if !transProto.HandleUnknownDestinationPacket(r, id, vv) { + if !transProto.HandleUnknownDestinationPacket(r, id, netHeader, vv) { n.stack.stats.MalformedRcvdPackets.Increment() } } @@ -717,7 +780,12 @@ func (n *NIC) ID() tcpip.NICID { return n.id } -type networkEndpointKind int +// Stack returns the instance of the Stack that owns this NIC. +func (n *NIC) Stack() *Stack { + return n.stack +} + +type networkEndpointKind int32 const ( // A permanent endpoint is created by adding a permanent address (vs. a @@ -756,21 +824,30 @@ type referencedNetworkEndpoint struct { // triggers the automatic removal of the endpoint from the NIC. refs int32 + // networkEndpointKind must only be accessed using {get,set}Kind(). kind networkEndpointKind } +func (r *referencedNetworkEndpoint) getKind() networkEndpointKind { + return networkEndpointKind(atomic.LoadInt32((*int32)(&r.kind))) +} + +func (r *referencedNetworkEndpoint) setKind(kind networkEndpointKind) { + atomic.StoreInt32((*int32)(&r.kind), int32(kind)) +} + // isValidForOutgoing returns true if the endpoint can be used to send out a // packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed), or the NIC to be in spoofing mode. func (r *referencedNetworkEndpoint) isValidForOutgoing() bool { - return r.kind != permanentExpired || r.nic.spoofing + return r.getKind() != permanentExpired || r.nic.spoofing } // isValidForIncoming returns true if the endpoint can accept an incoming // packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed), or the NIC to be in promiscuous mode. func (r *referencedNetworkEndpoint) isValidForIncoming() bool { - return r.kind != permanentExpired || r.nic.promiscuous + return r.getKind() != permanentExpired || r.nic.promiscuous } // decRef decrements the ref count and cleans up the endpoint once it reaches @@ -782,11 +859,14 @@ func (r *referencedNetworkEndpoint) decRef() { } // decRefLocked is the same as decRef but assumes that the NIC.mu mutex is -// locked. -func (r *referencedNetworkEndpoint) decRefLocked() { +// locked. Returns true if the endpoint was removed. +func (r *referencedNetworkEndpoint) decRefLocked() bool { if atomic.AddInt32(&r.refs, -1) == 0 { r.nic.removeEndpointLocked(r) + return true } + + return false } // incRef increments the ref count. It must only be called when the caller is @@ -811,3 +891,8 @@ func (r *referencedNetworkEndpoint) tryIncRef() bool { } } } + +// stack returns the Stack instance that owns the underlying endpoint. +func (r *referencedNetworkEndpoint) stack() *Stack { + return r.nic.stack +} diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index 2037eef9f..80101d4bb 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -15,8 +15,6 @@ package stack import ( - "sync" - "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -109,7 +107,7 @@ type TransportProtocol interface { // // The return value indicates whether the packet was well-formed (for // stats purposes only). - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, vv buffer.VectorisedView) bool + HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, netHeader buffer.View, vv buffer.VectorisedView) bool // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -297,6 +295,15 @@ type LinkEndpoint interface { // IsAttached returns whether a NetworkDispatcher is attached to the // endpoint. IsAttached() bool + + // Wait waits for any worker goroutines owned by the endpoint to stop. + // + // For now, requesting that an endpoint's worker goroutine(s) stop is + // implementation specific. + // + // Wait will not block if the endpoint hasn't started any goroutines + // yet, even if it might later. + Wait() } // InjectableLinkEndpoint is a LinkEndpoint where inbound packets are @@ -359,14 +366,6 @@ type LinkAddressCache interface { RemoveWaker(nicid tcpip.NICID, addr tcpip.Address, waker *sleep.Waker) } -// TransportProtocolFactory functions are used by the stack to instantiate -// transport protocols. -type TransportProtocolFactory func() TransportProtocol - -// NetworkProtocolFactory provides methods to be used by the stack to -// instantiate network protocols. -type NetworkProtocolFactory func() NetworkProtocol - // UnassociatedEndpointFactory produces endpoints for writing packets not // associated with a particular transport protocol. Such endpoints can be used // to write arbitrary packets that include the IP header. @@ -374,60 +373,6 @@ type UnassociatedEndpointFactory interface { NewUnassociatedRawEndpoint(stack *Stack, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) } -var ( - transportProtocols = make(map[string]TransportProtocolFactory) - networkProtocols = make(map[string]NetworkProtocolFactory) - - unassociatedFactory UnassociatedEndpointFactory - - linkEPMu sync.RWMutex - nextLinkEndpointID tcpip.LinkEndpointID = 1 - linkEndpoints = make(map[tcpip.LinkEndpointID]LinkEndpoint) -) - -// RegisterTransportProtocolFactory registers a new transport protocol factory -// with the stack so that it becomes available to users of the stack. This -// function is intended to be called by init() functions of the protocols. -func RegisterTransportProtocolFactory(name string, p TransportProtocolFactory) { - transportProtocols[name] = p -} - -// RegisterNetworkProtocolFactory registers a new network protocol factory with -// the stack so that it becomes available to users of the stack. This function -// is intended to be called by init() functions of the protocols. -func RegisterNetworkProtocolFactory(name string, p NetworkProtocolFactory) { - networkProtocols[name] = p -} - -// RegisterUnassociatedFactory registers a factory to produce endpoints not -// associated with any particular transport protocol. This function is intended -// to be called by init() functions of the protocols. -func RegisterUnassociatedFactory(f UnassociatedEndpointFactory) { - unassociatedFactory = f -} - -// RegisterLinkEndpoint register a link-layer protocol endpoint and returns an -// ID that can be used to refer to it. -func RegisterLinkEndpoint(linkEP LinkEndpoint) tcpip.LinkEndpointID { - linkEPMu.Lock() - defer linkEPMu.Unlock() - - v := nextLinkEndpointID - nextLinkEndpointID++ - - linkEndpoints[v] = linkEP - - return v -} - -// FindLinkEndpoint finds the link endpoint associated with the given ID. -func FindLinkEndpoint(id tcpip.LinkEndpointID) LinkEndpoint { - linkEPMu.RLock() - defer linkEPMu.RUnlock() - - return linkEndpoints[id] -} - // GSOType is the type of GSO segments. // // +stateify savable diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index e52cdd674..5c8b7977a 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -217,3 +217,8 @@ func (r *Route) MakeLoopedRoute() Route { } return l } + +// Stack returns the instance of the Stack that owns this route. +func (r *Route) Stack() *Stack { + return r.ref.stack() +} diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index d69162ba1..18d1704a5 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -17,17 +17,13 @@ // // For consumers, the only function of interest is New(), everything else is // provided by the tcpip/public package. -// -// For protocol implementers, RegisterTransportProtocolFactory() and -// RegisterNetworkProtocolFactory() are used to register protocol factories with -// the stack, which will then be used to instantiate protocol objects when -// consumers interact with the stack. package stack import ( "sync" "time" + "golang.org/x/time/rate" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -350,6 +346,9 @@ type Stack struct { networkProtocols map[tcpip.NetworkProtocolNumber]NetworkProtocol linkAddrResolvers map[tcpip.NetworkProtocolNumber]LinkAddressResolver + // unassociatedFactory creates unassociated endpoints. If nil, raw + // endpoints are disabled. It is set during Stack creation and is + // immutable. unassociatedFactory UnassociatedEndpointFactory demux *transportDemuxer @@ -358,10 +357,6 @@ type Stack struct { linkAddrCache *linkAddrCache - // raw indicates whether raw sockets may be created. It is set during - // Stack creation and is immutable. - raw bool - mu sync.RWMutex nics map[tcpip.NICID]*NIC forwarding bool @@ -389,10 +384,20 @@ type Stack struct { // resumableEndpoints is a list of endpoints that need to be resumed if the // stack is being restored. resumableEndpoints []ResumableEndpoint + + // icmpRateLimiter is a global rate limiter for all ICMP messages generated + // by the stack. + icmpRateLimiter *ICMPRateLimiter } // Options contains optional Stack configuration. type Options struct { + // NetworkProtocols lists the network protocols to enable. + NetworkProtocols []NetworkProtocol + + // TransportProtocols lists the transport protocols to enable. + TransportProtocols []TransportProtocol + // Clock is an optional clock source used for timestampping packets. // // If no Clock is specified, the clock source will be time.Now. @@ -406,8 +411,9 @@ type Options struct { // stack (false). HandleLocal bool - // Raw indicates whether raw sockets may be created. - Raw bool + // UnassociatedFactory produces unassociated endpoints raw endpoints. + // Raw endpoints are enabled only if this is non-nil. + UnassociatedFactory UnassociatedEndpointFactory } // New allocates a new networking stack with only the requested networking and @@ -417,7 +423,7 @@ type Options struct { // SetNetworkProtocolOption/SetTransportProtocolOption methods provided by the // stack. Please refer to individual protocol implementations as to what options // are supported. -func New(network []string, transport []string, opts Options) *Stack { +func New(opts Options) *Stack { clock := opts.Clock if clock == nil { clock = &tcpip.StdClock{} @@ -433,16 +439,11 @@ func New(network []string, transport []string, opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, - raw: opts.Raw, + icmpRateLimiter: NewICMPRateLimiter(), } // Add specified network protocols. - for _, name := range network { - netProtoFactory, ok := networkProtocols[name] - if !ok { - continue - } - netProto := netProtoFactory() + for _, netProto := range opts.NetworkProtocols { s.networkProtocols[netProto.Number()] = netProto if r, ok := netProto.(LinkAddressResolver); ok { s.linkAddrResolvers[r.LinkAddressProtocol()] = r @@ -450,18 +451,14 @@ func New(network []string, transport []string, opts Options) *Stack { } // Add specified transport protocols. - for _, name := range transport { - transProtoFactory, ok := transportProtocols[name] - if !ok { - continue - } - transProto := transProtoFactory() + for _, transProto := range opts.TransportProtocols { s.transportProtocols[transProto.Number()] = &transportProtocolState{ proto: transProto, } } - s.unassociatedFactory = unassociatedFactory + // Add the factory for unassociated endpoints, if present. + s.unassociatedFactory = opts.UnassociatedFactory // Create the global transport demuxer. s.demux = newTransportDemuxer(s) @@ -596,7 +593,7 @@ func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcp // protocol. Raw endpoints receive all traffic for a given protocol regardless // of address. func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue, associated bool) (tcpip.Endpoint, *tcpip.Error) { - if !s.raw { + if s.unassociatedFactory == nil { return nil, tcpip.ErrNotPermitted } @@ -614,12 +611,7 @@ func (s *Stack) NewRawEndpoint(transport tcpip.TransportProtocolNumber, network // createNIC creates a NIC with the provided id and link-layer endpoint, and // optionally enable it. -func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID, enabled, loopback bool) *tcpip.Error { - ep := FindLinkEndpoint(linkEP) - if ep == nil { - return tcpip.ErrBadLinkEndpoint - } - +func (s *Stack) createNIC(id tcpip.NICID, name string, ep LinkEndpoint, enabled, loopback bool) *tcpip.Error { s.mu.Lock() defer s.mu.Unlock() @@ -632,40 +624,40 @@ func (s *Stack) createNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpoint s.nics[id] = n if enabled { - n.attachLinkEndpoint() + return n.enable() } return nil } // CreateNIC creates a NIC with the provided id and link-layer endpoint. -func (s *Stack) CreateNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, true, false) +func (s *Stack) CreateNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, true, false) } // CreateNamedNIC creates a NIC with the provided id and link-layer endpoint, // and a human-readable name. -func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, false) +func (s *Stack) CreateNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, false) } // CreateNamedLoopbackNIC creates a NIC with the provided id and link-layer // endpoint, and a human-readable name. -func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, true, true) +func (s *Stack) CreateNamedLoopbackNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, true, true) } // CreateDisabledNIC creates a NIC with the provided id and link-layer endpoint, // but leave it disable. Stack.EnableNIC must be called before the link-layer // endpoint starts delivering packets to it. -func (s *Stack) CreateDisabledNIC(id tcpip.NICID, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, "", linkEP, false, false) +func (s *Stack) CreateDisabledNIC(id tcpip.NICID, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, "", ep, false, false) } // CreateDisabledNamedNIC is a combination of CreateNamedNIC and // CreateDisabledNIC. -func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, linkEP tcpip.LinkEndpointID) *tcpip.Error { - return s.createNIC(id, name, linkEP, false, false) +func (s *Stack) CreateDisabledNamedNIC(id tcpip.NICID, name string, ep LinkEndpoint) *tcpip.Error { + return s.createNIC(id, name, ep, false, false) } // EnableNIC enables the given NIC so that the link-layer endpoint can start @@ -679,9 +671,7 @@ func (s *Stack) EnableNIC(id tcpip.NICID) *tcpip.Error { return tcpip.ErrUnknownNICID } - nic.attachLinkEndpoint() - - return nil + return nic.enable() } // CheckNIC checks if a NIC is usable. @@ -696,14 +686,14 @@ func (s *Stack) CheckNIC(id tcpip.NICID) bool { } // NICSubnets returns a map of NICIDs to their associated subnets. -func (s *Stack) NICSubnets() map[tcpip.NICID][]tcpip.Subnet { +func (s *Stack) NICAddressRanges() map[tcpip.NICID][]tcpip.Subnet { s.mu.RLock() defer s.mu.RUnlock() nics := map[tcpip.NICID][]tcpip.Subnet{} for id, nic := range s.nics { - nics[id] = append(nics[id], nic.Subnets()...) + nics[id] = append(nics[id], nic.AddressRanges()...) } return nics } @@ -739,7 +729,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { nics[id] = NICInfo{ Name: nic.name, LinkAddress: nic.linkEP.LinkAddress(), - ProtocolAddresses: nic.Addresses(), + ProtocolAddresses: nic.PrimaryAddresses(), Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, @@ -804,71 +794,79 @@ func (s *Stack) AddProtocolAddressWithOptions(id tcpip.NICID, protocolAddress tc return nic.AddAddress(protocolAddress, peb) } -// AddSubnet adds a subnet range to the specified NIC. -func (s *Stack) AddSubnet(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { +// AddAddressRange adds a range of addresses to the specified NIC. The range is +// given by a subnet address, and all addresses contained in the subnet are +// used except for the subnet address itself and the subnet's broadcast +// address. +func (s *Stack) AddAddressRange(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber, subnet tcpip.Subnet) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - nic.AddSubnet(protocol, subnet) + nic.AddAddressRange(protocol, subnet) return nil } return tcpip.ErrUnknownNICID } -// RemoveSubnet removes the subnet range from the specified NIC. -func (s *Stack) RemoveSubnet(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { +// RemoveAddressRange removes the range of addresses from the specified NIC. +func (s *Stack) RemoveAddressRange(id tcpip.NICID, subnet tcpip.Subnet) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - nic.RemoveSubnet(subnet) + nic.RemoveAddressRange(subnet) return nil } return tcpip.ErrUnknownNICID } -// ContainsSubnet reports whether the specified NIC contains the specified -// subnet. -func (s *Stack) ContainsSubnet(id tcpip.NICID, subnet tcpip.Subnet) (bool, *tcpip.Error) { +// RemoveAddress removes an existing network-layer address from the specified +// NIC. +func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { s.mu.RLock() defer s.mu.RUnlock() if nic, ok := s.nics[id]; ok { - return nic.ContainsSubnet(subnet), nil + return nic.RemoveAddress(addr) } - return false, tcpip.ErrUnknownNICID + return tcpip.ErrUnknownNICID } -// RemoveAddress removes an existing network-layer address from the specified -// NIC. -func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { +// AllAddresses returns a map of NICIDs to their protocol addresses (primary +// and non-primary). +func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { s.mu.RLock() defer s.mu.RUnlock() - if nic, ok := s.nics[id]; ok { - return nic.RemoveAddress(addr) + nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress) + for id, nic := range s.nics { + nics[id] = nic.AllAddresses() } - - return tcpip.ErrUnknownNICID + return nics } -// GetMainNICAddress returns the first primary address (and the subnet that -// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's -// address if no primary addresses exist. Returns an error if the NIC doesn't -// exist or has no endpoints. +// GetMainNICAddress returns the first primary address and prefix for the given +// NIC and protocol. Returns an error if the NIC doesn't exist and an empty +// value if the NIC doesn't have a primary address for the given protocol. func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() - if nic, ok := s.nics[id]; ok { - return nic.getMainNICAddress(protocol) + nic, ok := s.nics[id] + if !ok { + return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID } - return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID + for _, a := range nic.PrimaryAddresses() { + if a.Protocol == protocol { + return a.AddressWithPrefix, nil + } + } + return tcpip.AddressWithPrefix{}, nil } func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { @@ -1215,3 +1213,33 @@ func (s *Stack) IPTables() iptables.IPTables { func (s *Stack) SetIPTables(ipt iptables.IPTables) { s.tables = ipt } + +// ICMPLimit returns the maximum number of ICMP messages that can be sent +// in one second. +func (s *Stack) ICMPLimit() rate.Limit { + return s.icmpRateLimiter.Limit() +} + +// SetICMPLimit sets the maximum number of ICMP messages that be sent +// in one second. +func (s *Stack) SetICMPLimit(newLimit rate.Limit) { + s.icmpRateLimiter.SetLimit(newLimit) +} + +// ICMPBurst returns the maximum number of ICMP messages that can be sent +// in a single burst. +func (s *Stack) ICMPBurst() int { + return s.icmpRateLimiter.Burst() +} + +// SetICMPBurst sets the maximum number of ICMP messages that can be sent +// in a single burst. +func (s *Stack) SetICMPBurst(burst int) { + s.icmpRateLimiter.SetBurst(burst) +} + +// AllowICMPMessage returns true if we the rate limiter allows at least one +// ICMP message to be sent at this instant. +func (s *Stack) AllowICMPMessage() bool { + return s.icmpRateLimiter.Allow() +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 4debd1eec..d2dede8a9 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -60,11 +60,11 @@ type fakeNetworkEndpoint struct { prefixLen int proto *fakeNetworkProtocol dispatcher stack.TransportDispatcher - linkEP stack.LinkEndpoint + ep stack.LinkEndpoint } func (f *fakeNetworkEndpoint) MTU() uint32 { - return f.linkEP.MTU() - uint32(f.MaxHeaderLength()) + return f.ep.MTU() - uint32(f.MaxHeaderLength()) } func (f *fakeNetworkEndpoint) NICID() tcpip.NICID { @@ -108,7 +108,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, vv buffer.VectorisedV } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { - return f.linkEP.MaxHeaderLength() + fakeNetHeaderLen + return f.ep.MaxHeaderLength() + fakeNetHeaderLen } func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, dstAddr tcpip.Address) uint16 { @@ -116,7 +116,7 @@ func (f *fakeNetworkEndpoint) PseudoHeaderChecksum(protocol tcpip.TransportProto } func (f *fakeNetworkEndpoint) Capabilities() stack.LinkEndpointCapabilities { - return f.linkEP.Capabilities() + return f.ep.Capabilities() } func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr buffer.Prependable, payload buffer.VectorisedView, protocol tcpip.TransportProtocolNumber, _ uint8, loop stack.PacketLooping) *tcpip.Error { @@ -141,7 +141,7 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, hdr bu return nil } - return f.linkEP.WritePacket(r, gso, hdr, payload, fakeNetNumber) + return f.ep.WritePacket(r, gso, hdr, payload, fakeNetNumber) } func (*fakeNetworkEndpoint) WriteHeaderIncludedPacket(r *stack.Route, payload buffer.VectorisedView, loop stack.PacketLooping) *tcpip.Error { @@ -189,14 +189,14 @@ func (*fakeNetworkProtocol) ParseAddresses(v buffer.View) (src, dst tcpip.Addres return tcpip.Address(v[1:2]), tcpip.Address(v[0:1]) } -func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, linkEP stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { +func (f *fakeNetworkProtocol) NewEndpoint(nicid tcpip.NICID, addrWithPrefix tcpip.AddressWithPrefix, linkAddrCache stack.LinkAddressCache, dispatcher stack.TransportDispatcher, ep stack.LinkEndpoint) (stack.NetworkEndpoint, *tcpip.Error) { return &fakeNetworkEndpoint{ nicid: nicid, id: stack.NetworkEndpointID{LocalAddress: addrWithPrefix.Address}, prefixLen: addrWithPrefix.PrefixLen, proto: f, dispatcher: dispatcher, - linkEP: linkEP, + ep: ep, }, nil } @@ -222,12 +222,18 @@ func (f *fakeNetworkProtocol) Option(option interface{}) *tcpip.Error { } } +func fakeNetFactory() stack.NetworkProtocol { + return &fakeNetworkProtocol{} +} + func TestNetworkReceive(t *testing.T) { // Create a stack with the fake network protocol, one nic, and two // addresses attached to it: 1 & 2. - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -245,7 +251,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet with wrong address is not delivered. buf[0] = 3 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 0 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 0) } @@ -255,7 +261,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to first endpoint. buf[0] = 1 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -265,7 +271,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet is delivered to second endpoint. buf[0] = 2 - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -274,7 +280,7 @@ func TestNetworkReceive(t *testing.T) { } // Make sure packet is not delivered if protocol number is wrong. - linkEP.Inject(fakeNetNumber-1, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber-1, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -284,7 +290,7 @@ func TestNetworkReceive(t *testing.T) { // Make sure packet that is too small is dropped. buf.CapLength(2) - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if fakeNet.packetCount[1] != 1 { t.Errorf("packetCount[1] = %d, want %d", fakeNet.packetCount[1], 1) } @@ -307,59 +313,59 @@ func send(r stack.Route, payload buffer.View) *tcpip.Error { return r.WritePacket(nil /* gso */, hdr, payload.ToVectorisedView(), fakeTransNumber, 123) } -func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View) { +func testSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View) { t.Helper() - linkEP.Drain() + ep.Drain() if err := sendTo(s, addr, payload); err != nil { t.Error("sendTo failed:", err) } - if got, want := linkEP.Drain(), 1; got != want { + if got, want := ep.Drain(), 1; got != want { t.Errorf("sendTo packet count: got = %d, want %d", got, want) } } -func testSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View) { +func testSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View) { t.Helper() - linkEP.Drain() + ep.Drain() if err := send(r, payload); err != nil { t.Error("send failed:", err) } - if got, want := linkEP.Drain(), 1; got != want { + if got, want := ep.Drain(), 1; got != want { t.Errorf("send packet count: got = %d, want %d", got, want) } } -func testFailingSend(t *testing.T, r stack.Route, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSend(t *testing.T, r stack.Route, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := send(r, payload); gotErr != wantErr { t.Errorf("send failed: got = %s, want = %s ", gotErr, wantErr) } } -func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, linkEP *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { +func testFailingSendTo(t *testing.T, s *stack.Stack, addr tcpip.Address, ep *channel.Endpoint, payload buffer.View, wantErr *tcpip.Error) { t.Helper() if gotErr := sendTo(s, addr, payload); gotErr != wantErr { t.Errorf("sendto failed: got = %s, want = %s ", gotErr, wantErr) } } -func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { +func testRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { t.Helper() // testRecvInternal injects one packet, and we expect to receive it. want := fakeNet.PacketCount(localAddrByte) + 1 - testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) } -func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View) { +func testFailingRecv(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View) { t.Helper() // testRecvInternal injects one packet, and we do NOT expect to receive it. want := fakeNet.PacketCount(localAddrByte) - testRecvInternal(t, fakeNet, localAddrByte, linkEP, buf, want) + testRecvInternal(t, fakeNet, localAddrByte, ep, buf, want) } -func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, linkEP *channel.Endpoint, buf buffer.View, want int) { +func testRecvInternal(t *testing.T, fakeNet *fakeNetworkProtocol, localAddrByte byte, ep *channel.Endpoint, buf buffer.View, want int) { t.Helper() - linkEP.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep.Inject(fakeNetNumber, buf.ToVectorisedView()) if got := fakeNet.PacketCount(localAddrByte); got != want { t.Errorf("receive packet count: got = %d, want %d", got, want) } @@ -369,9 +375,11 @@ func TestNetworkSend(t *testing.T) { // Create a stack with the fake network protocol, one nic, and one // address: 1. The route table sends all packets through the only // existing nic. - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("NewNIC failed:", err) } @@ -388,17 +396,19 @@ func TestNetworkSend(t *testing.T) { } // Make sure that the link-layer endpoint received the outbound packet. - testSendTo(t, s, "\x03", linkEP, nil) + testSendTo(t, s, "\x03", ep, nil) } func TestNetworkSendMultiRoute(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -410,8 +420,8 @@ func TestNetworkSendMultiRoute(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -442,10 +452,10 @@ func TestNetworkSendMultiRoute(t *testing.T) { } // Send a packet to an odd destination. - testSendTo(t, s, "\x05", linkEP1, nil) + testSendTo(t, s, "\x05", ep1, nil) // Send a packet to an even destination. - testSendTo(t, s, "\x06", linkEP2, nil) + testSendTo(t, s, "\x06", ep2, nil) } func testRoute(t *testing.T, s *stack.Stack, nic tcpip.NICID, srcAddr, dstAddr, expectedSrcAddr tcpip.Address) { @@ -476,10 +486,12 @@ func TestRoutes(t *testing.T) { // Create a stack with the fake network protocol, two nics, and two // addresses per nic, the first nic has odd address, the second one has // even addresses. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id1, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -491,8 +503,8 @@ func TestRoutes(t *testing.T) { t.Fatal("AddAddress failed:", err) } - id2, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -554,10 +566,12 @@ func TestAddressRemoval(t *testing.T) { localAddr := tcpip.Address([]byte{localAddrByte}) remoteAddr := tcpip.Address("\x02") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -578,15 +592,15 @@ func TestAddressRemoval(t *testing.T) { // Send and receive packets, and verify they are received. buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // Remove the address, then check that send/receive doesn't work anymore. if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { @@ -599,11 +613,13 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { localAddr := tcpip.Address([]byte{localAddrByte}) remoteAddr := tcpip.Address("\x02") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { - t.Fatal("CreateNIC failed:", err) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { + t.Fatalf("CreateNIC failed: %v", err) } fakeNet := s.NetworkProtocolInstance(fakeNetNumber).(*fakeNetworkProtocol) buf := buffer.NewView(30) @@ -626,17 +642,17 @@ func TestAddressRemovalWithRouteHeld(t *testing.T) { // Send and receive packets, and verify they are received. buf[0] = localAddrByte - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSend(t, r, linkEP, nil) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) // Remove the address, then check that send/receive doesn't work anymore. if err := s.RemoveAddress(1, localAddr); err != nil { t.Fatal("RemoveAddress failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) - testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) // Check that removing the same address fails. if err := s.RemoveAddress(1, localAddr); err != tcpip.ErrBadLocalAddress { @@ -688,10 +704,12 @@ func TestEndpointExpiration(t *testing.T) { for _, promiscuous := range []bool{true, false} { for _, spoofing := range []bool{true, false} { t.Run(fmt.Sprintf("promiscuous=%t spoofing=%t", promiscuous, spoofing), func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -724,15 +742,15 @@ func TestEndpointExpiration(t *testing.T) { //----------------------- verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 2. Add Address, everything should work. @@ -741,8 +759,8 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 3. Remove the address, send should only work for spoofing, receive // for promiscuous mode. @@ -752,15 +770,15 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 4. Add Address back, everything should work again. @@ -769,8 +787,8 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 5. Take a reference to the endpoint by getting a route. Verify that // we can still send/receive, including sending using the route. @@ -779,9 +797,9 @@ func TestEndpointExpiration(t *testing.T) { if err != nil { t.Fatal("FindRoute failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) // 6. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. @@ -791,16 +809,16 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { - testSend(t, r, linkEP, nil) - testSendTo(t, s, remoteAddr, linkEP, nil) + testSend(t, r, ep, nil) + testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSend(t, r, linkEP, nil, tcpip.ErrInvalidEndpointState) - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSend(t, r, ep, nil, tcpip.ErrInvalidEndpointState) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } // 7. Add Address back, everything should work again. @@ -809,16 +827,16 @@ func TestEndpointExpiration(t *testing.T) { t.Fatal("AddAddress failed:", err) } verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) + testSend(t, r, ep, nil) // 8. Remove the route, sendTo/recv should still work. //----------------------- r.Release() verifyAddress(t, s, nicid, localAddr) - testRecv(t, fakeNet, localAddrByte, linkEP, buf) - testSendTo(t, s, remoteAddr, linkEP, nil) + testRecv(t, fakeNet, localAddrByte, ep, buf) + testSendTo(t, s, remoteAddr, ep, nil) // 9. Remove the address. Send should only work for spoofing, receive // for promiscuous mode. @@ -828,15 +846,15 @@ func TestEndpointExpiration(t *testing.T) { } verifyAddress(t, s, nicid, noAddr) if promiscuous { - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } else { - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } if spoofing { // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } else { - testFailingSendTo(t, s, remoteAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, remoteAddr, ep, nil, tcpip.ErrNoRoute) } }) } @@ -844,10 +862,12 @@ func TestEndpointExpiration(t *testing.T) { } func TestPromiscuousMode(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -867,13 +887,13 @@ func TestPromiscuousMode(t *testing.T) { // have a matching endpoint. const localAddrByte byte = 0x01 buf[0] = localAddrByte - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) // Set promiscuous mode, then check that packet is delivered. if err := s.SetPromiscuousMode(1, true); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) // Check that we can't get a route as there is no local address. _, err := s.FindRoute(0, "", "\x02", fakeNetNumber, false /* multicastLoop */) @@ -886,7 +906,7 @@ func TestPromiscuousMode(t *testing.T) { if err := s.SetPromiscuousMode(1, false); err != nil { t.Fatal("SetPromiscuousMode failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestSpoofingWithAddress(t *testing.T) { @@ -894,10 +914,12 @@ func TestSpoofingWithAddress(t *testing.T) { nonExistentLocalAddr := tcpip.Address("\x02") dstAddr := tcpip.Address("\x03") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -930,14 +952,14 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) } // Sending a packet works. - testSendTo(t, s, dstAddr, linkEP, nil) - testSend(t, r, linkEP, nil) + testSendTo(t, s, dstAddr, ep, nil) + testSend(t, r, ep, nil) // FindRoute should also work with a local address that exists on the NIC. r, err = s.FindRoute(0, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) @@ -945,23 +967,25 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != localAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) } // Sending a packet using the route works. - testSend(t, r, linkEP, nil) + testSend(t, r, ep, nil) } func TestSpoofingNoAddress(t *testing.T) { nonExistentLocalAddr := tcpip.Address("\x01") dstAddr := tcpip.Address("\x02") - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -980,7 +1004,7 @@ func TestSpoofingNoAddress(t *testing.T) { t.Errorf("FindRoute succeeded with route %+v when it should have failed", r) } // Sending a packet fails. - testFailingSendTo(t, s, dstAddr, linkEP, nil, tcpip.ErrNoRoute) + testFailingSendTo(t, s, dstAddr, ep, nil, tcpip.ErrNoRoute) // With address spoofing enabled, FindRoute permits any address to be used // as the source. @@ -992,21 +1016,23 @@ func TestSpoofingNoAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) } // Sending a packet works. // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. - // testSendTo(t, s, remoteAddr, linkEP, nil) + // testSendTo(t, s, remoteAddr, ep, nil) } func TestBroadcastNeedsNoRoute(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } s.SetRouteTable([]tcpip.Route{}) @@ -1074,10 +1100,12 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { {"IPv6 Unicast Not Link-Local 7", true, "\xf0\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00"}, } { t.Run(tc.name, func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1128,12 +1156,14 @@ func TestMulticastOrIPv6LinkLocalNeedsNoRoute(t *testing.T) { } } -// Set the subnet, then check that packet is delivered. -func TestSubnetAcceptsMatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) +// Add a range of addresses, then check that a packet is delivered. +func TestAddressRangeAcceptsMatchingPacket(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1155,20 +1185,53 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) + if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } - testRecv(t, fakeNet, localAddrByte, linkEP, buf) + testRecv(t, fakeNet, localAddrByte, ep, buf) } -// Set the subnet, then check that CheckLocalAddress returns the correct NIC. +func testNicForAddressRange(t *testing.T, nicID tcpip.NICID, s *stack.Stack, subnet tcpip.Subnet, rangeExists bool) { + t.Helper() + + // Loop over all addresses and check them. + numOfAddresses := 1 << uint(8-subnet.Prefix()) + if numOfAddresses < 1 || numOfAddresses > 255 { + t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) + } + + addrBytes := []byte(subnet.ID()) + for i := 0; i < numOfAddresses; i++ { + addr := tcpip.Address(addrBytes) + wantNicID := nicID + // The subnet and broadcast addresses are skipped. + if !rangeExists || addr == subnet.ID() || addr == subnet.Broadcast() { + wantNicID = 0 + } + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, addr); gotNicID != wantNicID { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, addr, gotNicID, wantNicID) + } + addrBytes[0]++ + } + + // Trying the next address should always fail since it is outside the range. + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addrBytes)); gotNicID != 0 { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addrBytes), gotNicID, 0) + } +} + +// Set a range of addresses, then remove it again, and check at each step that +// CheckLocalAddress returns the correct NIC for each address or zero if not +// existent. func TestCheckLocalAddressForSubnet(t *testing.T) { const nicID tcpip.NICID = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicID, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1181,39 +1244,34 @@ func TestCheckLocalAddressForSubnet(t *testing.T) { } subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) - if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) - } - // Loop over all subnet addresses and check them. - numOfAddresses := 1 << uint(8-subnet.Prefix()) - if numOfAddresses < 1 || numOfAddresses > 255 { - t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) - } - addr := []byte(subnet.ID()) - for i := 0; i < numOfAddresses; i++ { - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID) - } - addr[0]++ + testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) + + if err := s.AddAddressRange(nicID, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } - // Trying the next address should fail since it is outside the subnet range. - if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 { - t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0) + testNicForAddressRange(t, nicID, s, subnet, true /* rangeExists */) + + if err := s.RemoveAddressRange(nicID, subnet); err != nil { + t.Fatal("RemoveAddressRange failed:", err) } + + testNicForAddressRange(t, nicID, s, subnet, false /* rangeExists */) } -// Set destination outside the subnet, then check it doesn't get delivered. -func TestSubnetRejectsNonmatchingPacket(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) +// Set a range of addresses, then send a packet to a destination outside the +// range and then check it doesn't get delivered. +func TestAddressRangeRejectsNonmatchingPacket(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) - id, linkEP := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1235,14 +1293,17 @@ func TestSubnetRejectsNonmatchingPacket(t *testing.T) { if err != nil { t.Fatal("NewSubnet failed:", err) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) + if err := s.AddAddressRange(1, fakeNetNumber, subnet); err != nil { + t.Fatal("AddAddressRange failed:", err) } - testFailingRecv(t, fakeNet, localAddrByte, linkEP, buf) + testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } func TestNetworkOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{}, + }) // Try an unsupported network protocol. if err := s.SetNetworkProtocolOption(tcpip.NetworkProtocolNumber(99999), fakeNetGoodOption(false)); err != tcpip.ErrUnknownProtocol { @@ -1281,44 +1342,53 @@ func TestNetworkOptions(t *testing.T) { } } -func TestSubnetAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { +func stackContainsAddressRange(s *stack.Stack, id tcpip.NICID, addrRange tcpip.Subnet) bool { + ranges, ok := s.NICAddressRanges()[id] + if !ok { + return false + } + for _, r := range ranges { + if r == addrRange { + return true + } + } + return false +} + +func TestAddresRangeAddRemove(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } addr := tcpip.Address("\x01\x01\x01\x01") mask := tcpip.AddressMask(strings.Repeat("\xff", len(addr))) - subnet, err := tcpip.NewSubnet(addr, mask) + addrRange, err := tcpip.NewSubnet(addr, mask) if err != nil { t.Fatal("NewSubnet failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if contained { - t.Fatal("got s.ContainsSubnet(...) = true, want = false") + if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } - if err := s.AddSubnet(1, fakeNetNumber, subnet); err != nil { - t.Fatal("AddSubnet failed:", err) + if err := s.AddAddressRange(1, fakeNetNumber, addrRange); err != nil { + t.Fatal("AddAddressRange failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if !contained { - t.Fatal("got s.ContainsSubnet(...) = false, want = true") + if got, want := stackContainsAddressRange(s, 1, addrRange), true; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } - if err := s.RemoveSubnet(1, subnet); err != nil { - t.Fatal("RemoveSubnet failed:", err) + if err := s.RemoveAddressRange(1, addrRange); err != nil { + t.Fatal("RemoveAddressRange failed:", err) } - if contained, err := s.ContainsSubnet(1, subnet); err != nil { - t.Fatal("ContainsSubnet failed:", err) - } else if contained { - t.Fatal("got s.ContainsSubnet(...) = true, want = false") + if got, want := stackContainsAddressRange(s, 1, addrRange), false; got != want { + t.Fatalf("got stackContainsAddressRange(...) = %t, want = %t", got, want) } } @@ -1329,9 +1399,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { t.Run(fmt.Sprintf("canBe=%d", canBe), func(t *testing.T) { for never := 0; never < 3; never++ { t.Run(fmt.Sprintf("never=%d", never), func(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } // Insert <canBe> primary and <never> never-primary addresses. @@ -1369,20 +1441,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the // address/prefixLen matches what we added. + gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } if len(primaryAddrAdded) == 0 { - // No primary addresses present, expect an error. - if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress) + // No primary addresses present. + if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { + t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr) } } else { - // At least one primary address was added, expect a valid - // address and prefixLen. - gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok { - t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded) + // At least one primary address was added, verify the returned + // address is in the list of primary addresses we added. + if _, ok := primaryAddrAdded[gotAddr]; !ok { + t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded) } } }) @@ -1394,9 +1466,11 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { } func TestGetMainNICAddressAddRemove(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1421,19 +1495,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get the right initial address and prefix length. - if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { + gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { t.Fatal("GetMainNICAddress failed:", err) - } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix) + } + if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr { + t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) } if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { t.Fatal("RemoveAddress failed:", err) } - // Check that we get an error after removal. - if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress) + // Check that we get no address after removal. + gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } + if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { + t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) } }) } @@ -1448,8 +1528,10 @@ func (g *addressGenerator) next(addrLen int) tcpip.Address { } func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) { + t.Helper() + if len(gotAddresses) != len(expectedAddresses) { - t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses)) + t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses)) } sort.Slice(gotAddresses, func(i, j int) bool { @@ -1469,9 +1551,11 @@ func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.Proto func TestAddAddress(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1488,15 +1572,17 @@ func TestAddAddress(t *testing.T) { }) } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddress(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1520,15 +1606,17 @@ func TestAddProtocolAddress(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddAddressWithOptions(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1549,15 +1637,17 @@ func TestAddAddressWithOptions(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestAddProtocolAddressWithOptions(t *testing.T) { const nicid = 1 - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id, _ := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(nicid, id); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicid, ep); err != nil { t.Fatal("CreateNIC failed:", err) } @@ -1584,14 +1674,16 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } func TestNICStats(t *testing.T) { - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC failed: ", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { @@ -1608,7 +1700,7 @@ func TestNICStats(t *testing.T) { // Send a packet to address 1. buf := buffer.NewView(30) - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) if got, want := s.NICInfo()[1].Stats.Rx.Packets.Value(), uint64(1); got != want { t.Errorf("got Rx.Packets.Value() = %d, want = %d", got, want) } @@ -1622,9 +1714,9 @@ func TestNICStats(t *testing.T) { if err := sendTo(s, "\x01", payload); err != nil { t.Fatal("sendTo failed: ", err) } - want := uint64(linkEP1.Drain()) + want := uint64(ep1.Drain()) if got := s.NICInfo()[1].Stats.Tx.Packets.Value(); got != want { - t.Errorf("got Tx.Packets.Value() = %d, linkEP1.Drain() = %d", got, want) + t.Errorf("got Tx.Packets.Value() = %d, ep1.Drain() = %d", got, want) } if got, want := s.NICInfo()[1].Stats.Tx.Bytes.Value(), uint64(len(payload)); got != want { @@ -1635,19 +1727,21 @@ func TestNICStats(t *testing.T) { func TestNICForwarding(t *testing.T) { // Create a stack with the fake network protocol, two NICs, each with // an address. - s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + }) s.SetForwarding(true) - id1, linkEP1 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(1, id1); err != nil { + ep1 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(1, ep1); err != nil { t.Fatal("CreateNIC #1 failed:", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatal("AddAddress #1 failed:", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatal("CreateNIC #2 failed:", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -1666,10 +1760,10 @@ func TestNICForwarding(t *testing.T) { // Send a packet to address 3. buf := buffer.NewView(30) buf[0] = 3 - linkEP1.Inject(fakeNetNumber, buf.ToVectorisedView()) + ep1.Inject(fakeNetNumber, buf.ToVectorisedView()) select { - case <-linkEP2.C: + case <-ep2.C: default: t.Fatal("Packet not forwarded") } @@ -1683,9 +1777,3 @@ func TestNICForwarding(t *testing.T) { t.Errorf("got Tx.Bytes.Value() = %d, want = %d", got, want) } } - -func init() { - stack.RegisterNetworkProtocolFactory("fakeNet", func() stack.NetworkProtocol { - return &fakeNetworkProtocol{} - }) -} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 5335897f5..56e8a5d9b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -65,13 +65,13 @@ func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.Contr return buffer.View{}, tcpip.ControlMessages{}, nil } -func (f *fakeTransportEndpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { +func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { if len(f.route.RemoteAddress) == 0 { return 0, nil, tcpip.ErrNoRoute } hdr := buffer.NewPrependable(int(f.route.MaxHeaderLength())) - v, err := p.Get(p.Size()) + v, err := p.FullPayload() if err != nil { return 0, nil, err } @@ -91,6 +91,11 @@ func (*fakeTransportEndpoint) SetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrInvalidEndpointState } +// SetSockOptInt sets a socket option. Currently not supported. +func (*fakeTransportEndpoint) SetSockOptInt(tcpip.SockOpt, int) *tcpip.Error { + return tcpip.ErrInvalidEndpointState +} + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrUnknownProtocolOption @@ -251,7 +256,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.VectorisedView) bool { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, buffer.View, buffer.VectorisedView) bool { return true } @@ -277,10 +282,17 @@ func (f *fakeTransportProtocol) Option(option interface{}) *tcpip.Error { } } +func fakeTransFactory() stack.TransportProtocol { + return &fakeTransportProtocol{} +} + func TestTransportReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + linkEP := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -340,9 +352,12 @@ func TestTransportReceive(t *testing.T) { } func TestTransportControlReceive(t *testing.T) { - id, linkEP := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + linkEP := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -408,9 +423,12 @@ func TestTransportControlReceive(t *testing.T) { } func TestTransportSend(t *testing.T) { - id, _ := channel.New(10, defaultMTU, "") - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) - if err := s.CreateNIC(1, id); err != nil { + linkEP := channel.New(10, defaultMTU, "") + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) + if err := s.CreateNIC(1, linkEP); err != nil { t.Fatalf("CreateNIC failed: %v", err) } @@ -452,7 +470,10 @@ func TestTransportSend(t *testing.T) { } func TestTransportOptions(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) // Try an unsupported transport protocol. if err := s.SetTransportProtocolOption(tcpip.TransportProtocolNumber(99999), fakeTransportGoodOption(false)); err != tcpip.ErrUnknownProtocol { @@ -493,20 +514,23 @@ func TestTransportOptions(t *testing.T) { } func TestTransportForwarding(t *testing.T) { - s := stack.New([]string{"fakeNet"}, []string{"fakeTrans"}, stack.Options{}) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocol{fakeNetFactory()}, + TransportProtocols: []stack.TransportProtocol{fakeTransFactory()}, + }) s.SetForwarding(true) // TODO(b/123449044): Change this to a channel NIC. - id1 := loopback.New() - if err := s.CreateNIC(1, id1); err != nil { + ep1 := loopback.New() + if err := s.CreateNIC(1, ep1); err != nil { t.Fatalf("CreateNIC #1 failed: %v", err) } if err := s.AddAddress(1, fakeNetNumber, "\x01"); err != nil { t.Fatalf("AddAddress #1 failed: %v", err) } - id2, linkEP2 := channel.New(10, defaultMTU, "") - if err := s.CreateNIC(2, id2); err != nil { + ep2 := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(2, ep2); err != nil { t.Fatalf("CreateNIC #2 failed: %v", err) } if err := s.AddAddress(2, fakeNetNumber, "\x02"); err != nil { @@ -545,7 +569,7 @@ func TestTransportForwarding(t *testing.T) { req[0] = 1 req[1] = 3 req[2] = byte(fakeTransNumber) - linkEP2.Inject(fakeNetNumber, req.ToVectorisedView()) + ep2.Inject(fakeNetNumber, req.ToVectorisedView()) aep, _, err := ep.Accept() if err != nil || aep == nil { @@ -559,7 +583,7 @@ func TestTransportForwarding(t *testing.T) { var p channel.PacketInfo select { - case p = <-linkEP2.C: + case p = <-ep2.C: default: t.Fatal("Response packet not forwarded") } @@ -571,9 +595,3 @@ func TestTransportForwarding(t *testing.T) { t.Errorf("Response packet has incorrect source addresss: got = %d, want = 3", src) } } - -func init() { - stack.RegisterTransportProtocolFactory("fakeTrans", func() stack.TransportProtocol { - return &fakeTransportProtocol{} - }) -} |