summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/stack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/stack')
-rw-r--r--pkg/tcpip/stack/iptables.go24
-rw-r--r--pkg/tcpip/stack/iptables_types.go64
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go29
-rw-r--r--pkg/tcpip/stack/linkaddrcache_test.go20
-rw-r--r--pkg/tcpip/stack/ndp_test.go155
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go4
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go87
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go13
-rw-r--r--pkg/tcpip/stack/nic.go38
-rw-r--r--pkg/tcpip/stack/pending_packets.go256
-rw-r--r--pkg/tcpip/stack/route.go93
-rw-r--r--pkg/tcpip/stack/stack.go57
-rw-r--r--pkg/tcpip/stack/stack_options.go16
-rw-r--r--pkg/tcpip/stack/stack_test.go212
-rw-r--r--pkg/tcpip/stack/transport_demuxer_test.go2
-rw-r--r--pkg/tcpip/stack/transport_test.go27
16 files changed, 611 insertions, 486 deletions
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 09c7811fa..04af933a6 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -267,11 +267,11 @@ const (
// dropped.
//
// TODO(gvisor.dev/issue/170): PacketBuffer should hold the GSO and route, from
-// which address and nicName can be gathered. Currently, address is only
-// needed for prerouting and nicName is only needed for output.
+// which address can be gathered. Currently, address is only needed for
+// prerouting.
//
// Precondition: pkt.NetworkHeader is set.
-func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) bool {
+func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) bool {
if pkt.NetworkProtocolNumber != header.IPv4ProtocolNumber && pkt.NetworkProtocolNumber != header.IPv6ProtocolNumber {
return true
}
@@ -302,7 +302,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer
table = it.v4Tables[tableID]
}
ruleIdx := table.BuiltinChains[hook]
- switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict {
// If the table returns Accept, move on to the next table.
case chainAccept:
continue
@@ -385,10 +385,10 @@ func (it *IPTables) startReaper(interval time.Duration) {
//
// NOTE: unlike the Check API the returned map contains packets that should be
// dropped.
-func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, nicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
+func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *Route, inNicName, outNicName string) (drop map[*PacketBuffer]struct{}, natPkts map[*PacketBuffer]struct{}) {
for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
if !pkt.NatDone {
- if ok := it.Check(hook, pkt, gso, r, "", nicName); !ok {
+ if ok := it.Check(hook, pkt, gso, r, "", inNicName, outNicName); !ok {
if drop == nil {
drop = make(map[*PacketBuffer]struct{})
}
@@ -408,11 +408,11 @@ func (it *IPTables) CheckPackets(hook Hook, pkts PacketBufferList, gso *GSO, r *
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) chainVerdict {
+func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) chainVerdict {
// Start from ruleIdx and walk the list of rules until a rule gives us
// a verdict.
for ruleIdx < len(table.Rules) {
- switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, nicName); verdict {
+ switch verdict, jumpTo := it.checkRule(hook, pkt, table, ruleIdx, gso, r, preroutingAddr, inNicName, outNicName); verdict {
case RuleAccept:
return chainAccept
@@ -429,7 +429,7 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
ruleIdx++
continue
}
- switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, nicName); verdict {
+ switch verdict := it.checkChain(hook, pkt, table, jumpTo, gso, r, preroutingAddr, inNicName, outNicName); verdict {
case chainAccept:
return chainAccept
case chainDrop:
@@ -455,11 +455,11 @@ func (it *IPTables) checkChain(hook Hook, pkt *PacketBuffer, table Table, ruleId
// Preconditions:
// * pkt is a IPv4 packet of at least length header.IPv4MinimumSize.
// * pkt.NetworkHeader is not nil.
-func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, nicName string) (RuleVerdict, int) {
+func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx int, gso *GSO, r *Route, preroutingAddr tcpip.Address, inNicName, outNicName string) (RuleVerdict, int) {
rule := table.Rules[ruleIdx]
// Check whether the packet matches the IP header filter.
- if !rule.Filter.match(pkt, hook, nicName) {
+ if !rule.Filter.match(pkt, hook, inNicName, outNicName) {
// Continue on to the next rule.
return RuleJump, ruleIdx + 1
}
@@ -467,7 +467,7 @@ func (it *IPTables) checkRule(hook Hook, pkt *PacketBuffer, table Table, ruleIdx
// Go through each rule matcher. If they all match, run
// the rule target.
for _, matcher := range rule.Matchers {
- matches, hotdrop := matcher.Match(hook, pkt, "")
+ matches, hotdrop := matcher.Match(hook, pkt, inNicName, outNicName)
if hotdrop {
return RuleDrop, 0
}
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 56a3e7861..fd9d61e39 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -210,8 +210,19 @@ type IPHeaderFilter struct {
// filter will match packets that fail the source comparison.
SrcInvert bool
- // OutputInterface matches the name of the outgoing interface for the
- // packet.
+ // InputInterface matches the name of the incoming interface for the packet.
+ InputInterface string
+
+ // InputInterfaceMask masks the characters of the interface name when
+ // comparing with InputInterface.
+ InputInterfaceMask string
+
+ // InputInterfaceInvert inverts the meaning of incoming interface check,
+ // i.e. when true the filter will match packets that fail the incoming
+ // interface comparison.
+ InputInterfaceInvert bool
+
+ // OutputInterface matches the name of the outgoing interface for the packet.
OutputInterface string
// OutputInterfaceMask masks the characters of the interface name when
@@ -228,7 +239,7 @@ type IPHeaderFilter struct {
//
// Preconditions: pkt.NetworkHeader is set and is at least of the minimal IPv4
// or IPv6 header length.
-func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) bool {
+func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, inNicName, outNicName string) bool {
// Extract header fields.
var (
// TODO(gvisor.dev/issue/170): Support other filter fields.
@@ -264,26 +275,35 @@ func (fl IPHeaderFilter) match(pkt *PacketBuffer, hook Hook, nicName string) boo
return false
}
- // Check the output interface.
- // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
- // hooks after supported.
- if hook == Output {
- n := len(fl.OutputInterface)
- if n == 0 {
- return true
- }
-
- // If the interface name ends with '+', any interface which
- // begins with the name should be matched.
- ifName := fl.OutputInterface
- matches := nicName == ifName
- if strings.HasSuffix(ifName, "+") {
- matches = strings.HasPrefix(nicName, ifName[:n-1])
- }
- return fl.OutputInterfaceInvert != matches
+ switch hook {
+ case Prerouting, Input:
+ return matchIfName(inNicName, fl.InputInterface, fl.InputInterfaceInvert)
+ case Output:
+ return matchIfName(outNicName, fl.OutputInterface, fl.OutputInterfaceInvert)
+ case Forward, Postrouting:
+ // TODO(gvisor.dev/issue/170): Add the check for FORWARD and POSTROUTING
+ // hooks after supported.
+ return true
+ default:
+ panic(fmt.Sprintf("unknown hook: %d", hook))
}
+}
- return true
+func matchIfName(nicName string, ifName string, invert bool) bool {
+ n := len(ifName)
+ if n == 0 {
+ // If the interface name is omitted in the filter, any interface will match.
+ return true
+ }
+ // If the interface name ends with '+', any interface which begins with the
+ // name should be matched.
+ var matches bool
+ if strings.HasSuffix(ifName, "+") {
+ matches = strings.HasPrefix(nicName, ifName[:n-1])
+ } else {
+ matches = nicName == ifName
+ }
+ return matches != invert
}
// NetworkProtocol returns the protocol (IPv4 or IPv6) on to which the header
@@ -320,7 +340,7 @@ type Matcher interface {
// used for suspicious packets.
//
// Precondition: packet.NetworkHeader is set.
- Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool)
+ Match(hook Hook, packet *PacketBuffer, inputInterfaceName, outputInterfaceName string) (matches bool, hotdrop bool)
}
// A Target is the interface for taking an action for a packet.
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index 3c4fa341e..ba6d56a7d 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -32,6 +32,8 @@ var _ LinkAddressCache = (*linkAddrCache)(nil)
//
// This struct is safe for concurrent use.
type linkAddrCache struct {
+ nic *NIC
+
// ageLimit is how long a cache entry is valid for.
ageLimit time.Duration
@@ -79,6 +81,8 @@ type linkAddrEntry struct {
// linkAddrEntryEntry access is synchronized by the linkAddrCache lock.
linkAddrEntryEntry
+ cache *linkAddrCache
+
// TODO(gvisor.dev/issue/5150): move these fields under mu.
// mu protects the fields below.
mu sync.RWMutex
@@ -93,17 +97,26 @@ type linkAddrEntry struct {
done chan struct{}
// onResolve is called with the result of address resolution.
- onResolve []func(tcpip.LinkAddress, bool)
+ onResolve []func(LinkResolutionResult)
}
func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) {
+ res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0}
for _, callback := range e.onResolve {
- callback(linkAddr, len(linkAddr) != 0)
+ callback(res)
}
e.onResolve = nil
if ch := e.done; ch != nil {
close(ch)
e.done = nil
+ // Dequeue the pending packets in a new goroutine to not hold up the current
+ // goroutine as writing packets may be a costly operation.
+ //
+ // At the time of writing, when writing packets, a neighbor's link address
+ // is resolved (which ends up obtaining the entry's lock) while holding the
+ // link resolution queue's lock. Dequeuing packets in a new goroutine avoids
+ // a lock ordering violation.
+ go e.cache.nic.linkResQueue.dequeue(ch, linkAddr, len(linkAddr) != 0)
}
}
@@ -174,8 +187,9 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry {
}
*entry = linkAddrEntry{
- addr: k,
- s: incomplete,
+ cache: c,
+ addr: k,
+ s: incomplete,
}
c.cache.table[k] = entry
c.cache.lru.PushFront(entry)
@@ -183,7 +197,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry {
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
c.cache.Lock()
defer c.cache.Unlock()
entry := c.getOrCreateEntryLocked(k)
@@ -195,7 +209,7 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA
if !time.Now().After(entry.expiration) {
// Not expired.
if onResolve != nil {
- onResolve(entry.linkAddr, true)
+ onResolve(LinkResolutionResult{LinkAddress: entry.linkAddr, Success: true})
}
return entry.linkAddr, nil, nil
}
@@ -264,8 +278,9 @@ func (c *linkAddrCache) checkLinkRequest(now time.Time, k tcpip.Address, attempt
return true
}
-func newLinkAddrCache(ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
+func newLinkAddrCache(nic *NIC, ageLimit, resolutionTimeout time.Duration, resolutionAttempts int) *linkAddrCache {
c := &linkAddrCache{
+ nic: nic,
ageLimit: ageLimit,
resolutionTimeout: resolutionTimeout,
resolutionAttempts: resolutionAttempts,
diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go
index 8c35067c6..88fbbf3fe 100644
--- a/pkg/tcpip/stack/linkaddrcache_test.go
+++ b/pkg/tcpip/stack/linkaddrcache_test.go
@@ -93,8 +93,14 @@ func getBlocking(c *linkAddrCache, addr tcpip.Address, linkRes LinkAddressResolv
}
}
+func newEmptyNIC() *NIC {
+ n := &NIC{}
+ n.linkResQueue.init(n)
+ return n
+}
+
func TestCacheOverflow(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
for i := len(testAddrs) - 1; i >= 0; i-- {
e := testAddrs[i]
c.AddLinkAddress(e.addr, e.linkAddr)
@@ -129,7 +135,7 @@ func TestCacheOverflow(t *testing.T) {
}
func TestCacheConcurrent(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
linkRes := &testLinkAddressResolver{cache: c}
var wg sync.WaitGroup
@@ -165,7 +171,7 @@ func TestCacheConcurrent(t *testing.T) {
}
func TestCacheAgeLimit(t *testing.T) {
- c := newLinkAddrCache(1*time.Millisecond, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1*time.Millisecond, 1*time.Second, 3)
linkRes := &testLinkAddressResolver{cache: c}
e := testAddrs[0]
@@ -177,7 +183,7 @@ func TestCacheAgeLimit(t *testing.T) {
}
func TestCacheReplace(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 1*time.Second, 3)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 1*time.Second, 3)
e := testAddrs[0]
l2 := e.linkAddr + "2"
c.AddLinkAddress(e.addr, e.linkAddr)
@@ -206,7 +212,7 @@ func TestCacheResolution(t *testing.T) {
//
// Using a large resolution timeout decreases the probability of experiencing
// this race condition and does not affect how long this test takes to run.
- c := newLinkAddrCache(1<<63-1, math.MaxInt64, 1)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, math.MaxInt64, 1)
linkRes := &testLinkAddressResolver{cache: c}
for i, ta := range testAddrs {
got, err := getBlocking(c, ta.addr, linkRes)
@@ -232,7 +238,7 @@ func TestCacheResolution(t *testing.T) {
}
func TestCacheResolutionFailed(t *testing.T) {
- c := newLinkAddrCache(1<<63-1, 10*time.Millisecond, 5)
+ c := newLinkAddrCache(newEmptyNIC(), 1<<63-1, 10*time.Millisecond, 5)
linkRes := &testLinkAddressResolver{cache: c}
var requestCount uint32
@@ -265,7 +271,7 @@ func TestCacheResolutionFailed(t *testing.T) {
func TestCacheResolutionTimeout(t *testing.T) {
resolverDelay := 500 * time.Millisecond
expiration := resolverDelay / 10
- c := newLinkAddrCache(expiration, 1*time.Millisecond, 3)
+ c := newLinkAddrCache(newEmptyNIC(), expiration, 1*time.Millisecond, 3)
linkRes := &testLinkAddressResolver{cache: c, delay: resolverDelay}
e := testAddrs[0]
diff --git a/pkg/tcpip/stack/ndp_test.go b/pkg/tcpip/stack/ndp_test.go
index 270f5fb1a..d7bbb25ea 100644
--- a/pkg/tcpip/stack/ndp_test.go
+++ b/pkg/tcpip/stack/ndp_test.go
@@ -45,6 +45,8 @@ const (
linkAddr3 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x08")
linkAddr4 = tcpip.LinkAddress("\x02\x02\x03\x04\x05\x09")
+ defaultPrefixLen = 128
+
// Extra time to use when waiting for an async event to occur.
defaultAsyncPositiveEventTimeout = 10 * time.Second
@@ -330,8 +332,12 @@ func TestDADDisabled(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ }
+ if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
}
// Should get the address immediately since we should not have performed
@@ -344,12 +350,8 @@ func TestDADDisabled(t *testing.T) {
default:
t.Fatal("expected DAD event")
}
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, %d) err = %s", nicID, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Fatal(err)
}
// We should not have sent any NDP NS messages.
@@ -440,24 +442,24 @@ func TestDADResolve(t *testing.T) {
NIC: nicID,
}})
- if err := s.AddAddress(nicID, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr1, err)
+ addrWithPrefix := tcpip.AddressWithPrefix{
+ Address: addr1,
+ PrefixLen: defaultPrefixLen,
+ }
+ if err := s.AddAddressWithPrefix(nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addrWithPrefix, err)
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
- if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Make sure the address does not resolve before the resolution time has
// passed.
time.Sleep(test.expectedRetransmitTimer*time.Duration(test.dupAddrDetectTransmits) - defaultAsyncNegativeEventTimeout)
- if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Error(err)
}
// Should not get a route even if we specify the local address as the
// tentative address.
@@ -493,10 +495,8 @@ func TestDADResolve(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- if addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, %s), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- } else if addr.Address != addr1 {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, addr, addr1)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addrWithPrefix); err != nil {
+ t.Error(err)
}
// Should get a route using the address now that it is resolved.
{
@@ -662,12 +662,8 @@ func TestDADFail(t *testing.T) {
// Address should not be considered bound to the NIC yet
// (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Receive a packet to simulate an address conflict.
@@ -691,12 +687,8 @@ func TestDADFail(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Attempting to add the address again should not fail if the address's
@@ -777,12 +769,8 @@ func TestDADStop(t *testing.T) {
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
test.stopFn(t, s)
@@ -800,12 +788,8 @@ func TestDADStop(t *testing.T) {
}
if !test.skipFinalAddrCheck {
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
}
@@ -901,26 +885,25 @@ func TestSetNDPConfigurations(t *testing.T) {
}
// Add addresses for each NIC.
- if err := s.AddAddress(nicID1, header.IPv6ProtocolNumber, addr1); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addr1, err)
+ addrWithPrefix1 := tcpip.AddressWithPrefix{Address: addr1, PrefixLen: defaultPrefixLen}
+ if err := s.AddAddressWithPrefix(nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID1, header.IPv6ProtocolNumber, addrWithPrefix1, err)
}
- if err := s.AddAddress(nicID2, header.IPv6ProtocolNumber, addr2); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addr2, err)
+ addrWithPrefix2 := tcpip.AddressWithPrefix{Address: addr2, PrefixLen: defaultPrefixLen}
+ if err := s.AddAddressWithPrefix(nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID2, header.IPv6ProtocolNumber, addrWithPrefix2, err)
}
expectDADEvent(nicID2, addr2)
- if err := s.AddAddress(nicID3, header.IPv6ProtocolNumber, addr3); err != nil {
- t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addr3, err)
+ addrWithPrefix3 := tcpip.AddressWithPrefix{Address: addr3, PrefixLen: defaultPrefixLen}
+ if err := s.AddAddressWithPrefix(nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil {
+ t.Fatalf("AddAddressWithPrefix(%d, %d, %s) = %s", nicID3, header.IPv6ProtocolNumber, addrWithPrefix3, err)
}
expectDADEvent(nicID3, addr3)
// Address should not be considered bound to NIC(1) yet
// (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Should get the address on NIC(2) and NIC(3)
@@ -928,31 +911,19 @@ func TestSetNDPConfigurations(t *testing.T) {
// it as the stack was configured to not do DAD by
// default and we only updated the NDP configurations on
// NIC(1).
- addr, err = s.GetMainNICAddress(nicID2, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID2, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr2 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID2, header.IPv6ProtocolNumber, addr, addr2)
+ if err := checkGetMainNICAddress(s, nicID2, header.IPv6ProtocolNumber, addrWithPrefix2); err != nil {
+ t.Fatal(err)
}
- addr, err = s.GetMainNICAddress(nicID3, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID3, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr3 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID3, header.IPv6ProtocolNumber, addr, addr3)
+ if err := checkGetMainNICAddress(s, nicID3, header.IPv6ProtocolNumber, addrWithPrefix3); err != nil {
+ t.Fatal(err)
}
// Sleep until right (500ms before) before resolution to
// make sure the address didn't resolve on NIC(1) yet.
const delta = 500 * time.Millisecond
time.Sleep(time.Duration(test.dupAddrDetectTransmits)*test.expectedRetransmitTimer - delta)
- addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID1, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Wait for DAD to resolve.
@@ -970,12 +941,8 @@ func TestSetNDPConfigurations(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID1, header.IPv6ProtocolNumber, err)
- }
- if addr.Address != addr1 {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID1, header.IPv6ProtocolNumber, addr, addr1)
+ if err := checkGetMainNICAddress(s, nicID1, header.IPv6ProtocolNumber, addrWithPrefix1); err != nil {
+ t.Fatal(err)
}
})
}
@@ -2946,10 +2913,8 @@ func TestAutoGenAddrDeprecateFromPI(t *testing.T) {
expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatal(err)
}
if got := addrForNewConnection(t, s); got != addr.Address {
@@ -3094,10 +3059,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatal(err)
}
if got := addrForNewConnection(t, s); got != addr.Address {
@@ -3244,10 +3207,8 @@ func TestAutoGenAddrJobDeprecation(t *testing.T) {
t.Fatalf("should not have %s in the list of addresses", addr2)
}
// Should not have any primary endpoints.
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
wq := waiter.Queue{}
we, ch := waiter.NewChannelEntry(nil)
@@ -3621,10 +3582,8 @@ func TestAutoGenAddrAfterRemoval(t *testing.T) {
expectPrimaryAddr := func(addr tcpip.AddressWithPrefix) {
t.Helper()
- if got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber); err != nil {
- t.Fatalf("s.GetMainNICAddress(%d, %d): %s", nicID, header.IPv6ProtocolNumber, err)
- } else if got != addr {
- t.Errorf("got s.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr); err != nil {
+ t.Fatal(err)
}
if got := addrForNewConnection(t, s); got != addr.Address {
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index acee72572..204196d00 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -126,7 +126,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
// packet prompting NUD/link address resolution.
//
// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry.
-func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
entry := n.getOrCreateEntry(remoteAddr, linkRes)
entry.mu.Lock()
defer entry.mu.Unlock()
@@ -142,7 +142,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
// a node continues sending packets to that neighbor using the cached
// link-layer address."
if onResolve != nil {
- onResolve(entry.neigh.LinkAddr, true)
+ onResolve(LinkResolutionResult{LinkAddress: entry.neigh.LinkAddr, Success: true})
}
return entry.neigh, nil, nil
case Unknown, Incomplete, Failed:
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index db27cbc73..dbdb51bb4 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -1188,12 +1188,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1247,12 +1244,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
t.Fatalf("store.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1423,12 +1417,9 @@ func TestNeighborCacheReplace(t *testing.T) {
t.Fatal("store.entry(0) not found")
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1539,12 +1530,9 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
// First, sanity check that resolution is working
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1576,15 +1564,9 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
entry.Addr += "2"
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if ok {
- t.Error("expected unsuccessful address resolution")
- }
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
- }
- if t.Failed() {
- t.FailNow()
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1627,15 +1609,9 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
t.Fatal("store.entry(0) not found")
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if ok {
- t.Error("expected unsuccessful address resolution")
- }
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
- }
- if t.Failed() {
- t.FailNow()
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1674,15 +1650,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
// Perform address resolution with a faulty link, which will fail.
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if ok {
- t.Error("expected unsuccessful address resolution")
- }
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
- }
- if t.Failed() {
- t.FailNow()
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1713,9 +1683,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
// Retry address resolution with a working link.
linkRes.dropReplies = false
{
- incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1772,12 +1742,9 @@ func BenchmarkCacheClear(b *testing.B) {
b.Fatalf("store.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- b.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 75afb3001..53ac9bb6e 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -96,7 +96,7 @@ type neighborEntry struct {
done chan struct{}
// onResolve is called with the result of address resolution.
- onResolve []func(tcpip.LinkAddress, bool)
+ onResolve []func(LinkResolutionResult)
isRouter bool
job *tcpip.Job
@@ -143,13 +143,22 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
+ res := LinkResolutionResult{LinkAddress: e.neigh.LinkAddr, Success: succeeded}
for _, callback := range e.onResolve {
- callback(e.neigh.LinkAddr, succeeded)
+ callback(res)
}
e.onResolve = nil
if ch := e.done; ch != nil {
close(ch)
e.done = nil
+ // Dequeue the pending packets in a new goroutine to not hold up the current
+ // goroutine as writing packets may be a costly operation.
+ //
+ // At the time of writing, when writing packets, a neighbor's link address
+ // is resolved (which ends up obtaining the entry's lock) while holding the
+ // link resolution queue's lock. Dequeuing packets in a new goroutine avoids
+ // a lock ordering violation.
+ go e.nic.linkResQueue.dequeue(ch, e.neigh.LinkAddr, succeeded)
}
}
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index f2bca93d3..1bbfe6213 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -139,9 +139,9 @@ func newNIC(stack *Stack, id tcpip.NICID, name string, ep LinkEndpoint, ctx NICC
context: ctx,
stats: makeNICStats(),
networkEndpoints: make(map[tcpip.NetworkProtocolNumber]NetworkEndpoint),
- linkAddrCache: newLinkAddrCache(ageLimit, resolutionTimeout, resolutionAttempts),
}
- nic.linkResQueue.init()
+ nic.linkResQueue.init(nic)
+ nic.linkAddrCache = newLinkAddrCache(nic, ageLimit, resolutionTimeout, resolutionAttempts)
nic.mu.packetEPs = make(map[tcpip.NetworkProtocolNumber]*packetEndpointList)
// Check for Neighbor Unreachability Detection support.
@@ -303,6 +303,10 @@ func (n *NIC) IsLoopback() bool {
// WritePacket implements NetworkLinkEndpoint.
func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error {
+ _, err := n.enqueuePacketBuffer(r, gso, protocol, pkt)
+ return err
+}
+func (n *NIC) enqueuePacketBuffer(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) {
// As per relevant RFCs, we should queue packets while we wait for link
// resolution to complete.
//
@@ -320,16 +324,7 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb
// be limited to some small value. When a queue overflows, the new arrival
// SHOULD replace the oldest entry. Once address resolution completes, the
// node transmits any queued packets.
- if ch, err := r.Resolve(nil); err != nil {
- if err == tcpip.ErrWouldBlock {
- r.Acquire()
- n.linkResQueue.enqueue(ch, r, protocol, pkt)
- return nil
- }
- return err
- }
-
- return n.writePacket(r.Fields(), gso, protocol, pkt)
+ return n.linkResQueue.enqueue(r, gso, protocol, pkt)
}
// WritePacketToRemote implements NetworkInterface.
@@ -344,6 +339,9 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN
// WritePacket takes ownership of pkt, calculate numBytes first.
numBytes := pkt.Size()
+ pkt.EgressRoute = r
+ pkt.GSOOptions = gso
+ pkt.NetworkProtocolNumber = protocol
if err := n.LinkEndpoint.WritePacket(r, gso, protocol, pkt); err != nil {
return err
}
@@ -355,9 +353,17 @@ func (n *NIC) writePacket(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolN
// WritePackets implements NetworkLinkEndpoint.
func (n *NIC) WritePackets(r *Route, gso *GSO, pkts PacketBufferList, protocol tcpip.NetworkProtocolNumber) (int, *tcpip.Error) {
- // TODO(gvisor.dev/issue/4458): Queue packets whie link address resolution
- // is being peformed like WritePacket.
- writtenPackets, err := n.LinkEndpoint.WritePackets(r.Fields(), gso, pkts, protocol)
+ return n.enqueuePacketBuffer(r, gso, protocol, &pkts)
+}
+
+func (n *NIC) writePackets(r RouteInfo, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkts PacketBufferList) (int, *tcpip.Error) {
+ for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() {
+ pkt.EgressRoute = r
+ pkt.GSOOptions = gso
+ pkt.NetworkProtocolNumber = protocol
+ }
+
+ writtenPackets, err := n.LinkEndpoint.WritePackets(r, gso, pkts, protocol)
n.stats.Tx.Packets.IncrementBy(uint64(writtenPackets))
writtenBytes := 0
for i, pb := 0, pkts.Front(); i < writtenPackets && pb != nil; i, pb = i+1, pb.Next() {
@@ -555,7 +561,7 @@ func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadLocalAddress
}
-func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
if n.neigh != nil {
entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve)
return entry.LinkAddr, ch, err
diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go
index 81d8ff6e8..c4769b17e 100644
--- a/pkg/tcpip/stack/pending_packets.go
+++ b/pkg/tcpip/stack/pending_packets.go
@@ -28,119 +28,219 @@ const (
maxPendingPacketsPerResolution = 256
)
+// pendingPacketBuffer is a pending packet buffer.
+//
+// TODO(gvisor.dev/issue/5331): Drop this when we drop WritePacket and only use
+// WritePackets so we can use a PacketBufferList everywhere.
+type pendingPacketBuffer interface {
+ len() int
+}
+
+func (*PacketBuffer) len() int {
+ return 1
+}
+
+func (p *PacketBufferList) len() int {
+ return p.Len()
+}
+
type pendingPacket struct {
- route *Route
- proto tcpip.NetworkProtocolNumber
- pkt *PacketBuffer
+ routeInfo RouteInfo
+ gso *GSO
+ proto tcpip.NetworkProtocolNumber
+ pkt pendingPacketBuffer
}
// packetsPendingLinkResolution is a queue of packets pending link resolution.
//
// Once link resolution completes successfully, the packets will be written.
type packetsPendingLinkResolution struct {
- sync.Mutex
+ nic *NIC
- // The packets to send once the resolver completes.
- packets map[<-chan struct{}][]pendingPacket
+ mu struct {
+ sync.Mutex
- // FIFO of channels used to cancel the oldest goroutine waiting for
- // link-address resolution.
- cancelChans []chan struct{}
-}
+ // The packets to send once the resolver completes.
+ //
+ // The link resolution channel is used as the key for this map.
+ packets map[<-chan struct{}][]pendingPacket
-func (f *packetsPendingLinkResolution) init() {
- f.Lock()
- defer f.Unlock()
- f.packets = make(map[<-chan struct{}][]pendingPacket)
+ // FIFO of channels used to cancel the oldest goroutine waiting for
+ // link-address resolution.
+ //
+ // cancelChans holds the same channels that are used as keys to packets.
+ cancelChans []<-chan struct{}
+ }
}
-func incrementOutgoingPacketErrors(r *Route, proto tcpip.NetworkProtocolNumber) {
- r.Stats().IP.OutgoingPacketErrors.Increment()
+func (f *packetsPendingLinkResolution) incrementOutgoingPacketErrors(proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) {
+ n := uint64(pkt.len())
+ f.nic.stack.stats.IP.OutgoingPacketErrors.IncrementBy(n)
- // ok may be false if the endpoint's stats do not collect IP-related data.
- if ipEndpointStats, ok := r.outgoingNIC.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok {
- ipEndpointStats.IPStats().OutgoingPacketErrors.Increment()
+ if ipEndpointStats, ok := f.nic.getNetworkEndpoint(proto).Stats().(IPNetworkEndpointStats); ok {
+ ipEndpointStats.IPStats().OutgoingPacketErrors.IncrementBy(n)
}
}
-func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, proto tcpip.NetworkProtocolNumber, pkt *PacketBuffer) {
- f.Lock()
- defer f.Unlock()
+func (f *packetsPendingLinkResolution) init(nic *NIC) {
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ f.nic = nic
+ f.mu.packets = make(map[<-chan struct{}][]pendingPacket)
+}
- packets, ok := f.packets[ch]
- if len(packets) == maxPendingPacketsPerResolution {
- p := packets[0]
- packets[0] = pendingPacket{}
- packets = packets[1:]
+// dequeue any pending packets associated with ch.
+//
+// If success is true, packets will be written and sent to the given remote link
+// address.
+func (f *packetsPendingLinkResolution) dequeue(ch <-chan struct{}, linkAddr tcpip.LinkAddress, success bool) {
+ f.mu.Lock()
+ packets, ok := f.mu.packets[ch]
+ delete(f.mu.packets, ch)
- incrementOutgoingPacketErrors(r, proto)
+ if ok {
+ for i, cancelChan := range f.mu.cancelChans {
+ if cancelChan == ch {
+ f.mu.cancelChans = append(f.mu.cancelChans[:i], f.mu.cancelChans[i+1:]...)
+ break
+ }
+ }
+ }
+
+ f.mu.Unlock()
+
+ if ok {
+ f.dequeuePackets(packets, linkAddr, success)
+ }
+}
- p.route.Release()
+func (f *packetsPendingLinkResolution) writePacketBuffer(r RouteInfo, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) {
+ switch pkt := pkt.(type) {
+ case *PacketBuffer:
+ if err := f.nic.writePacket(r, gso, proto, pkt); err != nil {
+ return 0, err
+ }
+ return 1, nil
+ case *PacketBufferList:
+ return f.nic.writePackets(r, gso, proto, *pkt)
+ default:
+ panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", pkt))
}
+}
- if l := len(packets); l >= maxPendingPacketsPerResolution {
- panic(fmt.Sprintf("max pending packets for resolution reached; got %d packets, max = %d", l, maxPendingPacketsPerResolution))
+// enqueue a packet to be sent once link resolution completes.
+//
+// If the maximum number of pending resolutions is reached, the packets
+// associated with the oldest link resolution will be dequeued as if they failed
+// link resolution.
+func (f *packetsPendingLinkResolution) enqueue(r *Route, gso *GSO, proto tcpip.NetworkProtocolNumber, pkt pendingPacketBuffer) (int, *tcpip.Error) {
+ f.mu.Lock()
+ // Make sure we attempt resolution while holding f's lock so that we avoid
+ // a race where link resolution completes before we enqueue the packets.
+ //
+ // A @ T1: Call ResolvedFields (get link resolution channel)
+ // B @ T2: Complete link resolution, dequeue pending packets
+ // C @ T1: Enqueue packet that already completed link resolution (which will
+ // never dequeue)
+ //
+ // To make sure B does not interleave with A and C, we make sure A and C are
+ // done while holding the lock.
+ routeInfo, ch, err := r.resolvedFields(nil)
+ switch err {
+ case nil:
+ // The route resolved immediately, so we don't need to wait for link
+ // resolution to send the packet.
+ f.mu.Unlock()
+ return f.writePacketBuffer(routeInfo, gso, proto, pkt)
+ case tcpip.ErrWouldBlock:
+ // We need to wait for link resolution to complete.
+ default:
+ f.mu.Unlock()
+ return 0, err
}
- f.packets[ch] = append(packets, pendingPacket{
- route: r,
- proto: proto,
- pkt: pkt,
+ defer f.mu.Unlock()
+
+ packets, ok := f.mu.packets[ch]
+ packets = append(packets, pendingPacket{
+ routeInfo: routeInfo,
+ gso: gso,
+ proto: proto,
+ pkt: pkt,
})
- if ok {
- return
- }
+ if len(packets) > maxPendingPacketsPerResolution {
+ f.incrementOutgoingPacketErrors(packets[0].proto, packets[0].pkt)
+ packets[0] = pendingPacket{}
+ packets = packets[1:]
- // Wait for the link-address resolution to complete.
- cancel := f.newCancelChannelLocked()
- go func() {
- cancelled := false
- select {
- case <-ch:
- case <-cancel:
- cancelled = true
+ if numPackets := len(packets); numPackets != maxPendingPacketsPerResolution {
+ panic(fmt.Sprintf("holding more queued packets than expected; got = %d, want <= %d", numPackets, maxPendingPacketsPerResolution))
}
+ }
- f.Lock()
- packets, ok := f.packets[ch]
- delete(f.packets, ch)
- f.Unlock()
+ f.mu.packets[ch] = packets
- if !ok {
- panic(fmt.Sprintf("link-resolution goroutine woke up but no entry exists in the queue of packets"))
- }
+ if ok {
+ return pkt.len(), nil
+ }
- for _, p := range packets {
- if cancelled || p.route.IsResolutionRequired() {
- incrementOutgoingPacketErrors(r, proto)
+ cancelledPackets := f.newCancelChannelLocked(ch)
- if linkResolvableEP, ok := p.route.outgoingNIC.getNetworkEndpoint(p.route.NetProto).(LinkResolvableNetworkEndpoint); ok {
- linkResolvableEP.HandleLinkResolutionFailure(pkt)
- }
- } else {
- p.route.outgoingNIC.writePacket(p.route.Fields(), nil /* gso */, p.proto, p.pkt)
- }
- p.route.Release()
- }
- }()
+ if len(cancelledPackets) != 0 {
+ // Dequeue the pending packets in a new goroutine to not hold up the current
+ // goroutine as handing link resolution failures may be a costly operation.
+ go f.dequeuePackets(cancelledPackets, "" /* linkAddr */, false /* success */)
+ }
+
+ return pkt.len(), nil
}
-// newCancelChannel creates a channel that can cancel a pending forwarding
-// activity. The oldest channel is closed if the number of open channels would
-// exceed maxPendingResolutions.
-func (f *packetsPendingLinkResolution) newCancelChannelLocked() chan struct{} {
- if len(f.cancelChans) == maxPendingResolutions {
- ch := f.cancelChans[0]
- f.cancelChans[0] = nil
- f.cancelChans = f.cancelChans[1:]
- close(ch)
+// newCancelChannelLocked appends the link resolution channel to a FIFO. If the
+// maximum number of pending resolutions is reached, the oldest channel will be
+// removed and its associated pending packets will be returned.
+func (f *packetsPendingLinkResolution) newCancelChannelLocked(newCH <-chan struct{}) []pendingPacket {
+ f.mu.cancelChans = append(f.mu.cancelChans, newCH)
+ if len(f.mu.cancelChans) <= maxPendingResolutions {
+ return nil
}
- if l := len(f.cancelChans); l >= maxPendingResolutions {
+
+ ch := f.mu.cancelChans[0]
+ f.mu.cancelChans[0] = nil
+ f.mu.cancelChans = f.mu.cancelChans[1:]
+ if l := len(f.mu.cancelChans); l > maxPendingResolutions {
panic(fmt.Sprintf("max pending resolutions reached; got %d active resolutions, max = %d", l, maxPendingResolutions))
}
- ch := make(chan struct{})
- f.cancelChans = append(f.cancelChans, ch)
- return ch
+ packets, ok := f.mu.packets[ch]
+ if !ok {
+ panic("must have a packet queue for an uncancelled channel")
+ }
+ delete(f.mu.packets, ch)
+
+ return packets
+}
+
+func (f *packetsPendingLinkResolution) dequeuePackets(packets []pendingPacket, linkAddr tcpip.LinkAddress, success bool) {
+ for _, p := range packets {
+ if success {
+ p.routeInfo.RemoteLinkAddress = linkAddr
+ _, _ = f.writePacketBuffer(p.routeInfo, p.gso, p.proto, p.pkt)
+ } else {
+ f.incrementOutgoingPacketErrors(p.proto, p.pkt)
+
+ if linkResolvableEP, ok := f.nic.getNetworkEndpoint(p.proto).(LinkResolvableNetworkEndpoint); ok {
+ switch pkt := p.pkt.(type) {
+ case *PacketBuffer:
+ linkResolvableEP.HandleLinkResolutionFailure(pkt)
+ case *PacketBufferList:
+ for pb := pkt.Front(); pb != nil; pb = pb.Next() {
+ linkResolvableEP.HandleLinkResolutionFailure(pb)
+ }
+ default:
+ panic(fmt.Sprintf("unrecognized pending packet buffer type = %T", p.pkt))
+ }
+ }
+ }
+ }
}
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 1ff7b3a37..d9a8554e2 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -86,12 +86,21 @@ type RouteInfo struct {
RemoteLinkAddress tcpip.LinkAddress
}
-// Fields returns a RouteInfo with all of r's exported fields. This allows
-// callers to store the route's fields without retaining a reference to it.
+// Fields returns a RouteInfo with all of the known values for the route's
+// fields.
+//
+// If any fields are unknown (e.g. remote link address when it is waiting for
+// link address resolution), they will be unset.
func (r *Route) Fields() RouteInfo {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
+ return r.fieldsLocked()
+}
+
+func (r *Route) fieldsLocked() RouteInfo {
return RouteInfo{
routeInfo: r.routeInfo,
- RemoteLinkAddress: r.RemoteLinkAddress(),
+ RemoteLinkAddress: r.mu.remoteLinkAddress,
}
}
@@ -306,29 +315,45 @@ func (r *Route) ResolveWith(addr tcpip.LinkAddress) {
r.mu.remoteLinkAddress = addr
}
-// Resolve attempts to resolve the link address if necessary.
+// ResolvedFieldsResult is the result of a route resolution attempt.
+type ResolvedFieldsResult struct {
+ RouteInfo RouteInfo
+ Success bool
+}
+
+// ResolvedFields attempts to resolve the remote link address if it is not
+// known.
//
-// Returns tcpip.ErrWouldBlock if address resolution requires blocking (e.g.
-// waiting for ARP reply). If address resolution is required, a notification
-// channel is also returned for the caller to block on. The channel is closed
-// once address resolution is complete (successful or not). If a callback is
-// provided, it will be called when address resolution is complete, regardless
+// If a callback is provided, it will be called before ResolvedFields returns
+// when address resolution is not required. If address resolution is required,
+// the callback will be called once address resolution is complete, regardless
// of success or failure.
-func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) {
- r.mu.Lock()
+//
+// Note, the route will not cache the remote link address when address
+// resolution completes.
+func (r *Route) ResolvedFields(afterResolve func(ResolvedFieldsResult)) *tcpip.Error {
+ _, _, err := r.resolvedFields(afterResolve)
+ return err
+}
- if !r.isResolutionRequiredRLocked() {
- // Nothing to do if there is no cache (which does the resolution on cache miss) or
- // link address is already known.
- r.mu.Unlock()
- return nil, nil
+// resolvedFields is like ResolvedFields but also returns a notification channel
+// when address resolution is required. This channel will become readable once
+// address resolution is complete.
+//
+// The route's fields will also be returned, regardless of whether address
+// resolution is required or not.
+func (r *Route) resolvedFields(afterResolve func(ResolvedFieldsResult)) (RouteInfo, <-chan struct{}, *tcpip.Error) {
+ r.mu.RLock()
+ fields := r.fieldsLocked()
+ resolutionRequired := r.isResolutionRequiredRLocked()
+ r.mu.RUnlock()
+ if !resolutionRequired {
+ if afterResolve != nil {
+ afterResolve(ResolvedFieldsResult{RouteInfo: fields, Success: true})
+ }
+ return fields, nil, nil
}
- // Increment the route's reference count because finishResolution retains a
- // reference to the route and releases it when called.
- r.acquireLocked()
- r.mu.Unlock()
-
nextAddr := r.NextHop
if nextAddr == "" {
nextAddr = r.RemoteAddress
@@ -341,18 +366,20 @@ func (r *Route) Resolve(afterResolve func()) (<-chan struct{}, *tcpip.Error) {
linkAddressResolutionRequestLocalAddr = r.LocalAddress
}
- finishResolution := func(linkAddress tcpip.LinkAddress, ok bool) {
- if ok {
- r.ResolveWith(linkAddress)
- }
+ afterResolveFields := fields
+ linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, func(r LinkResolutionResult) {
if afterResolve != nil {
- afterResolve()
+ if r.Success {
+ afterResolveFields.RemoteLinkAddress = r.LinkAddress
+ }
+
+ afterResolve(ResolvedFieldsResult{RouteInfo: afterResolveFields, Success: r.Success})
}
- r.Release()
+ })
+ if err == nil {
+ fields.RemoteLinkAddress = linkAddr
}
-
- _, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, finishResolution)
- return ch, err
+ return fields, ch, err
}
// local returns true if the route is a local route.
@@ -371,11 +398,7 @@ func (r *Route) IsResolutionRequired() bool {
}
func (r *Route) isResolutionRequiredRLocked() bool {
- if !r.isValidForOutgoingRLocked() || r.mu.remoteLinkAddress != "" || r.local() {
- return false
- }
-
- return r.linkRes != nil
+ return len(r.mu.remoteLinkAddress) == 0 && r.linkRes != nil && r.isValidForOutgoingRLocked() && !r.local()
}
func (r *Route) isValidForOutgoing() bool {
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 4685fa4cf..e9c5db4c3 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -444,7 +444,7 @@ type Stack struct {
// sendBufferSize holds the min/default/max send buffer sizes for
// endpoints other than TCP.
- sendBufferSize SendBufferSizeOption
+ sendBufferSize tcpip.SendBufferSizeOption
// receiveBufferSize holds the min/default/max receive buffer sizes for
// endpoints other than TCP.
@@ -646,7 +646,7 @@ func New(opts Options) *Stack {
uniqueIDGenerator: opts.UniqueID,
nudDisp: opts.NUDDisp,
randomGenerator: mathrand.New(randSrc),
- sendBufferSize: SendBufferSizeOption{
+ sendBufferSize: tcpip.SendBufferSizeOption{
Min: MinBufferSize,
Default: DefaultBufferSize,
Max: DefaultMaxBufferSize,
@@ -1196,19 +1196,19 @@ func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress {
// GetMainNICAddress returns the first non-deprecated primary address and prefix
// for the given NIC and protocol. If no non-deprecated primary address exists,
-// a deprecated primary address and prefix will be returned. Returns an error if
+// a deprecated primary address and prefix will be returned. Returns false if
// the NIC doesn't exist and an empty value if the NIC doesn't have a primary
// address for the given protocol.
-func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) {
+func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, bool) {
s.mu.RLock()
defer s.mu.RUnlock()
nic, ok := s.nics[id]
if !ok {
- return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID
+ return tcpip.AddressWithPrefix{}, false
}
- return nic.primaryAddress(protocol), nil
+ return nic.primaryAddress(protocol), true
}
func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) AssignableAddressEndpoint {
@@ -1527,9 +1527,13 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAd
return nil
}
-// GetLinkAddress finds the link address corresponding to a neighbor's address.
-//
-// Returns a link address for the remote address, if readily available.
+// LinkResolutionResult is the result of a link address resolution attempt.
+type LinkResolutionResult struct {
+ LinkAddress tcpip.LinkAddress
+ Success bool
+}
+
+// GetLinkAddress finds the link address corresponding to a network address.
//
// Returns ErrNotSupported if the stack is not configured with a link address
// resolver for the specified network protocol.
@@ -1538,30 +1542,33 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAd
// with a notification channel for the caller to block on. Triggers address
// resolution asynchronously.
//
-// If onResolve is provided, it will be called either immediately, if
-// resolution is not required, or when address resolution is complete, with
-// the resolved link address and whether resolution succeeded. After any
-// callbacks have been called, the returned notification channel is closed.
+// onResolve will be called either immediately, if resolution is not required,
+// or when address resolution is complete, with the resolved link address and
+// whether resolution succeeded.
//
// If specified, the local address must be an address local to the interface
// the neighbor cache belongs to. The local address is the source address of
// a packet prompting NUD/link address resolution.
-//
-// TODO(gvisor.dev/issue/5151): Don't return the link address.
-func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) *tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
if !ok {
- return "", nil, tcpip.ErrUnknownNICID
+ return tcpip.ErrUnknownNICID
}
linkRes, ok := s.linkAddrResolvers[protocol]
if !ok {
- return "", nil, tcpip.ErrNotSupported
+ return tcpip.ErrNotSupported
+ }
+
+ if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok {
+ onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true})
+ return nil
}
- return nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve)
+ _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve)
+ return err
}
// Neighbors returns all IP to MAC address associations.
@@ -1622,25 +1629,25 @@ func (s *Stack) ClearNeighbors(nicID tcpip.NICID) *tcpip.Error {
// transport dispatcher. Received packets that match the provided id will be
// delivered to the given endpoint; specifying a nic is optional, but
// nic-specific IDs have precedence over global ones.
-func (s *Stack) RegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (s *Stack) RegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
return s.demux.registerEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
// CheckRegisterTransportEndpoint checks if an endpoint can be registered with
// the stack transport dispatcher.
-func (s *Stack) CheckRegisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
+func (s *Stack) CheckRegisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, flags ports.Flags, bindToDevice tcpip.NICID) *tcpip.Error {
return s.demux.checkEndpoint(netProtos, protocol, id, flags, bindToDevice)
}
// UnregisterTransportEndpoint removes the endpoint with the given id from the
// stack transport dispatcher.
-func (s *Stack) UnregisterTransportEndpoint(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+func (s *Stack) UnregisterTransportEndpoint(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
s.demux.unregisterEndpoint(netProtos, protocol, id, ep, flags, bindToDevice)
}
// StartTransportEndpointCleanup removes the endpoint with the given id from
// the stack transport dispatcher. It also transitions it to the cleanup stage.
-func (s *Stack) StartTransportEndpointCleanup(nicID tcpip.NICID, netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
+func (s *Stack) StartTransportEndpointCleanup(netProtos []tcpip.NetworkProtocolNumber, protocol tcpip.TransportProtocolNumber, id TransportEndpointID, ep TransportEndpoint, flags ports.Flags, bindToDevice tcpip.NICID) {
s.cleanupEndpointsMu.Lock()
s.cleanupEndpoints[ep] = struct{}{}
s.cleanupEndpointsMu.Unlock()
@@ -1665,13 +1672,13 @@ func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, tran
// RegisterRawTransportEndpoint registers the given endpoint with the stack
// transport dispatcher. Received packets that match the provided transport
// protocol will be delivered to the given endpoint.
-func (s *Stack) RegisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
+func (s *Stack) RegisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) *tcpip.Error {
return s.demux.registerRawEndpoint(netProto, transProto, ep)
}
// UnregisterRawTransportEndpoint removes the endpoint for the transport
// protocol from the stack transport dispatcher.
-func (s *Stack) UnregisterRawTransportEndpoint(nicID tcpip.NICID, netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
+func (s *Stack) UnregisterRawTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, ep RawTransportEndpoint) {
s.demux.unregisterRawEndpoint(netProto, transProto, ep)
}
diff --git a/pkg/tcpip/stack/stack_options.go b/pkg/tcpip/stack/stack_options.go
index 0b093e6c5..92e70f94e 100644
--- a/pkg/tcpip/stack/stack_options.go
+++ b/pkg/tcpip/stack/stack_options.go
@@ -14,7 +14,9 @@
package stack
-import "gvisor.dev/gvisor/pkg/tcpip"
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
const (
// MinBufferSize is the smallest size of a receive or send buffer.
@@ -29,14 +31,6 @@ const (
DefaultMaxBufferSize = 4 << 20 // 4 MiB
)
-// SendBufferSizeOption is used by stack.(Stack*).Option/SetOption to
-// get/set the default, min and max send buffer sizes.
-type SendBufferSizeOption struct {
- Min int
- Default int
- Max int
-}
-
// ReceiveBufferSizeOption is used by stack.(Stack*).Option/SetOption to
// get/set the default, min and max receive buffer sizes.
type ReceiveBufferSizeOption struct {
@@ -48,7 +42,7 @@ type ReceiveBufferSizeOption struct {
// SetOption allows setting stack wide options.
func (s *Stack) SetOption(option interface{}) *tcpip.Error {
switch v := option.(type) {
- case SendBufferSizeOption:
+ case tcpip.SendBufferSizeOption:
// Make sure we don't allow lowering the buffer below minimum
// required for stack to work.
if v.Min < MinBufferSize {
@@ -88,7 +82,7 @@ func (s *Stack) SetOption(option interface{}) *tcpip.Error {
// Option allows retrieving stack wide options.
func (s *Stack) Option(option interface{}) *tcpip.Error {
switch v := option.(type) {
- case *SendBufferSizeOption:
+ case *tcpip.SendBufferSizeOption:
s.mu.RLock()
*v = s.sendBufferSize
s.mu.RUnlock()
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index b9ef455e5..0f02f1d53 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -60,6 +60,15 @@ const (
protocolNumberOffset = 2
)
+func checkGetMainNICAddress(s *stack.Stack, nicID tcpip.NICID, proto tcpip.NetworkProtocolNumber, want tcpip.AddressWithPrefix) error {
+ if addr, ok := s.GetMainNICAddress(nicID, proto); !ok {
+ return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (_, false), want = (_, true)", nicID, proto)
+ } else if addr != want {
+ return fmt.Errorf("got stack.GetMainNICAddress(%d, %d) = (%s, true), want = (%s, true)", nicID, proto, addr, want)
+ }
+ return nil
+}
+
// fakeNetworkEndpoint is a network-layer protocol endpoint. It counts sent and
// received packets; the counts of all endpoints are aggregated in the protocol
// descriptor.
@@ -1873,20 +1882,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) {
// Check that GetMainNICAddress returns an address if at least
// one primary address was added. In that case make sure the
// address/prefixLen matches what we added.
- gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
+ gotAddr, ok := s.GetMainNICAddress(1, fakeNetNumber)
+ if !ok {
+ t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
}
if len(primaryAddrAdded) == 0 {
// No primary addresses present.
if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
- t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, wantAddr)
}
} else {
// At least one primary address was added, verify the returned
// address is in the list of primary addresses we added.
if _, ok := primaryAddrAdded[gotAddr]; !ok {
- t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, gotAddr, primaryAddrAdded)
}
}
})
@@ -1927,12 +1936,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
}
// Check that we get the right initial address and prefix length.
- gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr {
- t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
+ if err := checkGetMainNICAddress(s, 1, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil {
@@ -1940,12 +1945,8 @@ func TestGetMainNICAddressAddRemove(t *testing.T) {
}
// Check that we get no address after removal.
- gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("GetMainNICAddress failed:", err)
- }
- if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr {
- t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr)
+ if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
})
}
@@ -2486,12 +2487,12 @@ func TestNICAutoGenLinkLocalAddr(t *testing.T) {
}
}
- gotMainAddr, err := s.GetMainNICAddress(1, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(_, _) err = %s", err)
+ // Check that we get no address after removal.
+ if err := checkGetMainNICAddress(s, 1, fakeNetNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
- if gotMainAddr != expectedMainAddr {
- t.Fatalf("got stack.GetMainNICAddress(_, _) = %s, want = %s", gotMainAddr, expectedMainAddr)
+ if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, expectedMainAddr); err != nil {
+ t.Fatal(err)
}
})
}
@@ -2537,12 +2538,8 @@ func TestNoLinkLocalAutoGenForLoopbackNIC(t *testing.T) {
t.Fatalf("CreateNICWithOptions(%d, _, %+v) = %s", nicID, nicOpts, err)
}
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("stack.GetMainNICAddress(%d, _) err = %s", nicID, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Errorf("got stack.GetMainNICAddress(%d, _) = %s, want = %s", nicID, addr, want)
+ if err := checkGetMainNICAddress(s, 1, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
})
}
@@ -2573,12 +2570,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
// Address should not be considered bound to the
// NIC yet (DAD ongoing).
- addr, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
linkLocalAddr := header.LinkLocalAddr(linkAddr1)
@@ -2596,12 +2589,8 @@ func TestNICAutoGenAddrDoesDAD(t *testing.T) {
t.Errorf("dad event mismatch (-want +got):\n%s", diff)
}
}
- addr, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); addr != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, addr, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{Address: linkLocalAddr, PrefixLen: header.IPv6LinkLocalPrefix.PrefixLen}); err != nil {
+ t.Fatal(err)
}
}
@@ -2633,17 +2622,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
if err := s.AddAddressWithOptions(1, fakeNetNumber, "\x01", pi); err != nil {
t.Fatal("AddAddressWithOptions failed:", err)
}
- addr, err := s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatal("s.GetMainNICAddress failed:", err)
+ addr, ok := s.GetMainNICAddress(1, fakeNetNumber)
+ if !ok {
+ t.Fatalf("GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
}
if pi == stack.NeverPrimaryEndpoint {
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want)
}
} else if addr.Address != "\x01" {
- t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address)
}
{
@@ -2722,18 +2711,17 @@ func TestNewPEBOnPromotionToPermanent(t *testing.T) {
if err := s.RemoveAddress(1, "\x03"); err != nil {
t.Fatalf("RemoveAddress failed: %v", err)
}
- addr, err = s.GetMainNICAddress(1, fakeNetNumber)
- if err != nil {
- t.Fatalf("s.GetMainNICAddress failed: %v", err)
+ addr, ok = s.GetMainNICAddress(1, fakeNetNumber)
+ if !ok {
+ t.Fatalf("got GetMainNICAddress(1, %d) = (_, false), want = (_, true)", fakeNetNumber)
}
if ps == stack.NeverPrimaryEndpoint {
if want := (tcpip.AddressWithPrefix{}); addr != want {
- t.Fatalf("got GetMainNICAddress = %s, want = %s", addr, want)
-
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (%s, true)", fakeNetNumber, addr, want)
}
} else {
if addr.Address != "\x01" {
- t.Fatalf("got GetMainNICAddress = %s, want = 1", addr.Address)
+ t.Fatalf("got GetMainNICAddress(1, %d) = (%s, true), want = (1, true)", fakeNetNumber, addr.Address)
}
}
})
@@ -3259,12 +3247,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
}
// Address should be tentative so it should not be a main address.
- got, err := s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Enabling the NIC should start DAD for the address.
@@ -3276,12 +3260,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
}
// Address should not be considered bound to the NIC yet (DAD ongoing).
- got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if want := (tcpip.AddressWithPrefix{}); got != want {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, want)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, tcpip.AddressWithPrefix{}); err != nil {
+ t.Fatal(err)
}
// Wait for DAD to resolve.
@@ -3296,12 +3276,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
}
- got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if got != addr.AddressWithPrefix {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = %s, want = %s", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
// Enabling the NIC again should be a no-op.
@@ -3311,12 +3287,8 @@ func TestDoDADWhenNICEnabled(t *testing.T) {
if addrs := s.AllAddresses()[nicID]; !containsV6Addr(addrs, addr.AddressWithPrefix) {
t.Fatalf("got s.AllAddresses()[%d] = %+v, want = %+v", nicID, addrs, addr)
}
- got, err = s.GetMainNICAddress(nicID, header.IPv6ProtocolNumber)
- if err != nil {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (_, %v), want = (_, nil)", nicID, header.IPv6ProtocolNumber, err)
- }
- if got != addr.AddressWithPrefix {
- t.Fatalf("got stack.GetMainNICAddress(%d, %d) = (%s, nil), want = (%s, nil)", nicID, header.IPv6ProtocolNumber, got, addr.AddressWithPrefix)
+ if err := checkGetMainNICAddress(s, nicID, header.IPv6ProtocolNumber, addr.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
}
@@ -3364,21 +3336,21 @@ func TestStackSendBufferSizeOption(t *testing.T) {
const sMin = stack.MinBufferSize
testCases := []struct {
name string
- ss stack.SendBufferSizeOption
+ ss tcpip.SendBufferSizeOption
err *tcpip.Error
}{
// Invalid configurations.
- {"min_below_zero", stack.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"min_zero", stack.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"default_below_min", stack.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
- {"default_above_max", stack.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
- {"max_below_min", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"min_below_zero", tcpip.SendBufferSizeOption{Min: -1, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"min_zero", tcpip.SendBufferSizeOption{Min: 0, Default: sMin, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"default_below_min", tcpip.SendBufferSizeOption{Min: 0, Default: sMin - 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
+ {"default_above_max", tcpip.SendBufferSizeOption{Min: 0, Default: sMin + 1, Max: sMin}, tcpip.ErrInvalidOptionValue},
+ {"max_below_min", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin - 1}, tcpip.ErrInvalidOptionValue},
// Valid Configurations
- {"in_ascending_order", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
- {"all_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
- {"min_default_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
- {"default_max_equal", stack.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
+ {"in_ascending_order", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 2}, nil},
+ {"all_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin}, nil},
+ {"min_default_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin, Max: sMin + 1}, nil},
+ {"default_max_equal", tcpip.SendBufferSizeOption{Min: sMin, Default: sMin + 1, Max: sMin + 1}, nil},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
@@ -3387,7 +3359,7 @@ func TestStackSendBufferSizeOption(t *testing.T) {
if err := s.SetOption(tc.ss); err != tc.err {
t.Fatalf("s.SetOption(%+v) = %v, want: %v", tc.ss, err, tc.err)
}
- var ss stack.SendBufferSizeOption
+ var ss tcpip.SendBufferSizeOption
if tc.err == nil {
if err := s.Option(&ss); err != nil {
t.Fatalf("s.Option(%+v) = %v, want: nil", ss, err)
@@ -3790,20 +3762,16 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) {
}
// Check that we get the right initial address and prefix length.
- if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
- t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
- } else if gotAddr != protocolAddress.AddressWithPrefix {
- t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
// Should still get the address when the NIC is diabled.
if err := s.DisableNIC(nicID); err != nil {
t.Fatalf("DisableNIC(%d): %s", nicID, err)
}
- if gotAddr, err := s.GetMainNICAddress(nicID, fakeNetNumber); err != nil {
- t.Fatalf("GetMainNICAddress(%d, %d): %s", nicID, fakeNetNumber, err)
- } else if gotAddr != protocolAddress.AddressWithPrefix {
- t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix)
+ if err := checkGetMainNICAddress(s, nicID, fakeNetNumber, protocolAddress.AddressWithPrefix); err != nil {
+ t.Fatal(err)
}
}
@@ -4384,10 +4352,58 @@ func TestGetLinkAddressErrors(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if addr, _, err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID {
- t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrUnknownNICID)
+ if err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID {
+ t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, tcpip.ErrUnknownNICID)
+ }
+ if err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported {
+ t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, tcpip.ErrNotSupported)
+ }
+}
+
+func TestStaticGetLinkAddress(t *testing.T) {
+ const (
+ nicID = 1
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ addr tcpip.Address
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "IPv4",
+ proto: ipv4.ProtocolNumber,
+ addr: header.IPv4Broadcast,
+ expectedLinkAddr: header.EthernetBroadcastAddress,
+ },
+ {
+ name: "IPv6",
+ proto: ipv6.ProtocolNumber,
+ addr: header.IPv6AllNodesMulticastAddress,
+ expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
+ },
}
- if addr, _, err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported {
- t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrNotSupported)
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ch := make(chan stack.LinkResolutionResult, 1)
+ if err := s.GetLinkAddress(nicID, test.addr, "", test.proto, func(r stack.LinkResolutionResult) {
+ ch <- r
+ }); err != nil {
+ t.Fatalf("s.GetLinkAddress(%d, %s, '', %d, _): %s", nicID, test.addr, test.proto, err)
+ }
+
+ if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: true}, <-ch); diff != "" {
+ t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
+ }
+ })
}
}
diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go
index 57e1f8354..de4b5fbdc 100644
--- a/pkg/tcpip/stack/transport_demuxer_test.go
+++ b/pkg/tcpip/stack/transport_demuxer_test.go
@@ -194,7 +194,7 @@ func TestTransportDemuxerRegister(t *testing.T) {
if !ok {
t.Fatalf("%T does not implement stack.TransportEndpoint", ep)
}
- if got, want := s.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
+ if got, want := s.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{test.proto}, udp.ProtocolNumber, stack.TransportEndpointID{}, tEP, ports.Flags{}, 0), test.want; got != want {
t.Fatalf("s.RegisterTransportEndpoint(...) = %s, want %s", got, want)
}
})
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 9d39533a1..c49427c4c 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -15,6 +15,7 @@
package stack_test
import (
+ "bytes"
"io"
"testing"
@@ -67,9 +68,9 @@ func (f *fakeTransportEndpoint) SocketOptions() *tcpip.SocketOptions {
return &f.ops
}
-func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
- ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
- ep.ops.InitHandler(ep)
+func newFakeTransportEndpoint(proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, s *stack.Stack) tcpip.Endpoint {
+ ep := &fakeTransportEndpoint{TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: s.UniqueID()}
+ ep.ops.InitHandler(ep, s, tcpip.GetStackSendBufferLimits)
return ep
}
@@ -95,10 +96,11 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return 0, tcpip.ErrNoRoute
}
- v, err := p.FullPayload()
- if err != nil {
- return 0, err
+ v := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, v); err != nil {
+ return 0, tcpip.ErrBadBuffer
}
+
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
ReserveHeaderBytes: int(f.route.MaxHeaderLength()) + fakeTransHeaderLen,
Data: buffer.View(v).ToVectorisedView(),
@@ -147,7 +149,7 @@ func (f *fakeTransportEndpoint) Connect(addr tcpip.FullAddress) *tcpip.Error {
// Try to register so that we can start receiving packets.
f.ID.RemoteAddress = addr.Addr
- err = f.proto.stack.RegisterTransportEndpoint(0, []tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
+ err = f.proto.stack.RegisterTransportEndpoint([]tcpip.NetworkProtocolNumber{fakeNetNumber}, fakeTransNumber, f.ID, f, ports.Flags{}, 0 /* bindToDevice */)
if err != nil {
r.Release()
return err
@@ -188,7 +190,6 @@ func (f *fakeTransportEndpoint) Accept(*tcpip.FullAddress) (tcpip.Endpoint, *wai
func (f *fakeTransportEndpoint) Bind(a tcpip.FullAddress) *tcpip.Error {
if err := f.proto.stack.RegisterTransportEndpoint(
- a.NIC,
[]tcpip.NetworkProtocolNumber{fakeNetNumber},
fakeTransNumber,
stack.TransportEndpointID{LocalAddress: a.Addr},
@@ -232,7 +233,7 @@ func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *
peerAddr: route.RemoteAddress,
route: route,
}
- ep.ops.InitHandler(ep)
+ ep.ops.InitHandler(ep, f.proto.stack, tcpip.GetStackSendBufferLimits)
f.acceptQueue = append(f.acceptQueue, ep)
}
@@ -280,7 +281,7 @@ func (*fakeTransportProtocol) Number() tcpip.TransportProtocolNumber {
}
func (f *fakeTransportProtocol) NewEndpoint(netProto tcpip.NetworkProtocolNumber, _ *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
- return newFakeTransportEndpoint(f, netProto, f.stack.UniqueID()), nil
+ return newFakeTransportEndpoint(f, netProto, f.stack), nil
}
func (*fakeTransportProtocol) NewRawEndpoint(tcpip.NetworkProtocolNumber, *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) {
@@ -520,8 +521,10 @@ func TestTransportSend(t *testing.T) {
}
// Create buffer that will hold the payload.
- view := buffer.NewView(30)
- if _, err := ep.Write(tcpip.SlicePayload(view), tcpip.WriteOptions{}); err != nil {
+ b := make([]byte, 30)
+ var r bytes.Reader
+ r.Reset(b)
+ if _, err := ep.Write(&r, tcpip.WriteOptions{}); err != nil {
t.Fatalf("write failed: %v", err)
}