summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-01-21 19:53:31 -0800
committergVisor bot <gvisor-bot@google.com>2021-01-21 19:55:37 -0800
commit9f46328e1174be6b8b5442467050ad0b2f0b260f (patch)
tree3d641b3fe82f01ada7bf3c0f49d2efc81b00bcae /pkg/tcpip
parent8ecff1890277820972c5f5287539a824b22a1d60 (diff)
Only use callback for GetLinkAddress
GetLinkAddress's callback will be called immediately with a stack.LinkResolutionResult which will hold the link address so no need to also return the link address from the function. Fixes #5151. PiperOrigin-RevId: 353157857
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/network/ipv6/ndp_test.go99
-rw-r--r--pkg/tcpip/stack/linkaddrcache.go9
-rw-r--r--pkg/tcpip/stack/neighbor_cache.go4
-rw-r--r--pkg/tcpip/stack/neighbor_cache_test.go87
-rw-r--r--pkg/tcpip/stack/neighbor_entry.go5
-rw-r--r--pkg/tcpip/stack/nic.go2
-rw-r--r--pkg/tcpip/stack/route.go2
-rw-r--r--pkg/tcpip/stack/stack.go26
-rw-r--r--pkg/tcpip/stack/stack_test.go8
-rw-r--r--pkg/tcpip/tests/integration/link_resolution_test.go56
10 files changed, 134 insertions, 164 deletions
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go
index b1a5a5510..aed3042d1 100644
--- a/pkg/tcpip/network/ipv6/ndp_test.go
+++ b/pkg/tcpip/network/ipv6/ndp_test.go
@@ -162,6 +162,11 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) {
}
}
+type linkResolutionResult struct {
+ linkAddr tcpip.LinkAddress
+ ok bool
+}
+
// TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a
// valid NDP NS message with the Source Link Layer Address option results in a
// new entry in the link address cache for the sender of the message.
@@ -231,35 +236,28 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) {
Data: hdr.View().ToVectorisedView(),
}))
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if linkAddr != test.expectedLinkAddr {
- t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
- }
-
- if test.expectedLinkAddr != "" {
- if err != nil {
- t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
- }
- if c != nil {
- t.Errorf("got unexpected channel")
- }
+ ch := make(chan stack.LinkResolutionResult, 1)
+ err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) {
+ ch <- r
+ })
- // Invalid count should not have increased.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
- } else {
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
- }
- if c == nil {
- t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
- }
+ wantInvalid := uint64(0)
+ wantErr := (*tcpip.Error)(nil)
+ wantSucccess := true
+ if len(test.expectedLinkAddr) == 0 {
+ wantInvalid = 1
+ wantErr = tcpip.ErrWouldBlock
+ wantSucccess = false
+ }
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Errorf("got invalid = %d, want = 1", got)
- }
+ if err != wantErr {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, wantErr)
+ }
+ if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" {
+ t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff)
+ }
+ if got := invalid.Value(); got != wantInvalid {
+ t.Errorf("got invalid = %d, want = %d", got, wantInvalid)
}
})
}
@@ -803,35 +801,28 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) {
Data: hdr.View().ToVectorisedView(),
}))
- linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil)
- if linkAddr != test.expectedLinkAddr {
- t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr)
- }
-
- if test.expectedLinkAddr != "" {
- if err != nil {
- t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err)
- }
- if c != nil {
- t.Errorf("got unexpected channel")
- }
+ ch := make(chan stack.LinkResolutionResult, 1)
+ err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) {
+ ch <- r
+ })
- // Invalid count should not have increased.
- if got := invalid.Value(); got != 0 {
- t.Errorf("got invalid = %d, want = 0", got)
- }
- } else {
- if err != tcpip.ErrWouldBlock {
- t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock)
- }
- if c == nil {
- t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber)
- }
+ wantInvalid := uint64(0)
+ wantErr := (*tcpip.Error)(nil)
+ wantSucccess := true
+ if len(test.expectedLinkAddr) == 0 {
+ wantInvalid = 1
+ wantErr = tcpip.ErrWouldBlock
+ wantSucccess = false
+ }
- // Invalid count should have increased.
- if got := invalid.Value(); got != 1 {
- t.Errorf("got invalid = %d, want = 1", got)
- }
+ if err != wantErr {
+ t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, wantErr)
+ }
+ if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" {
+ t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff)
+ }
+ if got := invalid.Value(); got != wantInvalid {
+ t.Errorf("got invalid = %d, want = %d", got, wantInvalid)
}
})
}
diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go
index f116f8417..ba6d56a7d 100644
--- a/pkg/tcpip/stack/linkaddrcache.go
+++ b/pkg/tcpip/stack/linkaddrcache.go
@@ -97,12 +97,13 @@ type linkAddrEntry struct {
done chan struct{}
// onResolve is called with the result of address resolution.
- onResolve []func(tcpip.LinkAddress, bool)
+ onResolve []func(LinkResolutionResult)
}
func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) {
+ res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0}
for _, callback := range e.onResolve {
- callback(linkAddr, len(linkAddr) != 0)
+ callback(res)
}
e.onResolve = nil
if ch := e.done; ch != nil {
@@ -196,7 +197,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry {
}
// get reports any known link address for k.
-func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
c.cache.Lock()
defer c.cache.Unlock()
entry := c.getOrCreateEntryLocked(k)
@@ -208,7 +209,7 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA
if !time.Now().After(entry.expiration) {
// Not expired.
if onResolve != nil {
- onResolve(entry.linkAddr, true)
+ onResolve(LinkResolutionResult{LinkAddress: entry.linkAddr, Success: true})
}
return entry.linkAddr, nil, nil
}
diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go
index acee72572..204196d00 100644
--- a/pkg/tcpip/stack/neighbor_cache.go
+++ b/pkg/tcpip/stack/neighbor_cache.go
@@ -126,7 +126,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA
// packet prompting NUD/link address resolution.
//
// TODO(gvisor.dev/issue/5151): Don't return the neighbor entry.
-func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
+func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, *tcpip.Error) {
entry := n.getOrCreateEntry(remoteAddr, linkRes)
entry.mu.Lock()
defer entry.mu.Unlock()
@@ -142,7 +142,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA
// a node continues sending packets to that neighbor using the cached
// link-layer address."
if onResolve != nil {
- onResolve(entry.neigh.LinkAddr, true)
+ onResolve(LinkResolutionResult{LinkAddress: entry.neigh.LinkAddr, Success: true})
}
return entry.neigh, nil, nil
case Unknown, Incomplete, Failed:
diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go
index b96a56612..6723aef9b 100644
--- a/pkg/tcpip/stack/neighbor_cache_test.go
+++ b/pkg/tcpip/stack/neighbor_cache_test.go
@@ -1188,12 +1188,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
if !ok {
t.Fatalf("store.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1247,12 +1244,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) {
t.Fatalf("store.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1423,12 +1417,9 @@ func TestNeighborCacheReplace(t *testing.T) {
t.Fatal("store.entry(0) not found")
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1539,12 +1530,9 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
// First, sanity check that resolution is working
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- t.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1576,15 +1564,9 @@ func TestNeighborCacheResolutionFailed(t *testing.T) {
entry.Addr += "2"
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if ok {
- t.Error("expected unsuccessful address resolution")
- }
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
- }
- if t.Failed() {
- t.FailNow()
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1627,15 +1609,9 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) {
t.Fatal("store.entry(0) not found")
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if ok {
- t.Error("expected unsuccessful address resolution")
- }
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
- }
- if t.Failed() {
- t.FailNow()
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1674,15 +1650,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
// Perform address resolution with a faulty link, which will fail.
{
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if ok {
- t.Error("expected unsuccessful address resolution")
- }
- if len(linkAddr) != 0 {
- t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr)
- }
- if t.Failed() {
- t.FailNow()
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1713,9 +1683,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) {
// Retry address resolution with a working link.
linkRes.dropReplies = false
{
- incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if linkAddr != entry.LinkAddr {
- t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
@@ -1772,12 +1742,9 @@ func BenchmarkCacheClear(b *testing.B) {
b.Fatalf("store.entry(%d) not found", i)
}
- _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) {
- if !ok {
- b.Fatal("expected successful address resolution")
- }
- if linkAddr != entry.LinkAddr {
- b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr)
+ _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) {
+ if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" {
+ b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff)
}
})
if err != tcpip.ErrWouldBlock {
diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go
index 697132689..53ac9bb6e 100644
--- a/pkg/tcpip/stack/neighbor_entry.go
+++ b/pkg/tcpip/stack/neighbor_entry.go
@@ -96,7 +96,7 @@ type neighborEntry struct {
done chan struct{}
// onResolve is called with the result of address resolution.
- onResolve []func(tcpip.LinkAddress, bool)
+ onResolve []func(LinkResolutionResult)
isRouter bool
job *tcpip.Job
@@ -143,8 +143,9 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd
//
// Precondition: e.mu MUST be locked.
func (e *neighborEntry) notifyCompletionLocked(succeeded bool) {
+ res := LinkResolutionResult{LinkAddress: e.neigh.LinkAddr, Success: succeeded}
for _, callback := range e.onResolve {
- callback(e.neigh.LinkAddr, succeeded)
+ callback(res)
}
e.onResolve = nil
if ch := e.done; ch != nil {
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go
index 7592cff75..1bbfe6213 100644
--- a/pkg/tcpip/stack/nic.go
+++ b/pkg/tcpip/stack/nic.go
@@ -561,7 +561,7 @@ func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error {
return tcpip.ErrBadLocalAddress
}
-func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
if n.neigh != nil {
entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve)
return entry.LinkAddr, ch, err
diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go
index 093b676aa..4523e4746 100644
--- a/pkg/tcpip/stack/route.go
+++ b/pkg/tcpip/stack/route.go
@@ -347,7 +347,7 @@ func (r *Route) ResolvedFields(afterResolve func()) (RouteInfo, <-chan struct{},
linkAddressResolutionRequestLocalAddr = r.LocalAddress
}
- linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, func(tcpip.LinkAddress, bool) {
+ linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, func(LinkResolutionResult) {
if afterResolve != nil {
afterResolve()
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 4685fa4cf..7885673fe 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1527,10 +1527,14 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAd
return nil
}
+// LinkResolutionResult is the result of a link address resolution attempt.
+type LinkResolutionResult struct {
+ LinkAddress tcpip.LinkAddress
+ Success bool
+}
+
// GetLinkAddress finds the link address corresponding to a neighbor's address.
//
-// Returns a link address for the remote address, if readily available.
-//
// Returns ErrNotSupported if the stack is not configured with a link address
// resolver for the specified network protocol.
//
@@ -1538,30 +1542,28 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAd
// with a notification channel for the caller to block on. Triggers address
// resolution asynchronously.
//
-// If onResolve is provided, it will be called either immediately, if
-// resolution is not required, or when address resolution is complete, with
-// the resolved link address and whether resolution succeeded. After any
-// callbacks have been called, the returned notification channel is closed.
+// onResolve will be called either immediately, if resolution is not required,
+// or when address resolution is complete, with the resolved link address and
+// whether resolution succeeded.
//
// If specified, the local address must be an address local to the interface
// the neighbor cache belongs to. The local address is the source address of
// a packet prompting NUD/link address resolution.
-//
-// TODO(gvisor.dev/issue/5151): Don't return the link address.
-func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) {
+func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) *tcpip.Error {
s.mu.RLock()
nic, ok := s.nics[nicID]
s.mu.RUnlock()
if !ok {
- return "", nil, tcpip.ErrUnknownNICID
+ return tcpip.ErrUnknownNICID
}
linkRes, ok := s.linkAddrResolvers[protocol]
if !ok {
- return "", nil, tcpip.ErrNotSupported
+ return tcpip.ErrNotSupported
}
- return nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve)
+ _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve)
+ return err
}
// Neighbors returns all IP to MAC address associations.
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index b9ef455e5..c44b3faf7 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -4384,10 +4384,10 @@ func TestGetLinkAddressErrors(t *testing.T) {
t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
}
- if addr, _, err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID {
- t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrUnknownNICID)
+ if err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID {
+ t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, tcpip.ErrUnknownNICID)
}
- if addr, _, err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported {
- t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrNotSupported)
+ if err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported {
+ t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, tcpip.ErrNotSupported)
}
}
diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go
index 7afc4702c..1e13fd6d6 100644
--- a/pkg/tcpip/tests/integration/link_resolution_test.go
+++ b/pkg/tcpip/tests/integration/link_resolution_test.go
@@ -406,20 +406,34 @@ func TestGetLinkAddress(t *testing.T) {
)
tests := []struct {
- name string
- netProto tcpip.NetworkProtocolNumber
- remoteAddr tcpip.Address
- expectedLinkAddr bool
+ name string
+ netProto tcpip.NetworkProtocolNumber
+ remoteAddr tcpip.Address
+ expectedOk bool
}{
{
- name: "IPv4",
+ name: "IPv4 resolvable",
netProto: ipv4.ProtocolNumber,
remoteAddr: ipv4Addr2.AddressWithPrefix.Address,
+ expectedOk: true,
},
{
- name: "IPv6",
+ name: "IPv6 resolvable",
netProto: ipv6.ProtocolNumber,
remoteAddr: ipv6Addr2.AddressWithPrefix.Address,
+ expectedOk: true,
+ },
+ {
+ name: "IPv4 not resolvable",
+ netProto: ipv4.ProtocolNumber,
+ remoteAddr: ipv4Addr3.AddressWithPrefix.Address,
+ expectedOk: false,
+ },
+ {
+ name: "IPv6 not resolvable",
+ netProto: ipv6.ProtocolNumber,
+ remoteAddr: ipv6Addr3.AddressWithPrefix.Address,
+ expectedOk: false,
},
}
@@ -434,24 +448,18 @@ func TestGetLinkAddress(t *testing.T) {
host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID)
- for i := 0; i < 2; i++ {
- addr, ch, err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(tcpip.LinkAddress, bool) {})
- var want *tcpip.Error
- if i == 0 {
- want = tcpip.ErrWouldBlock
- }
- if err != want {
- t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = (%s, _, %s), want = (_, _, %s)", host1NICID, test.remoteAddr, test.netProto, addr, err, want)
- }
-
- if i == 0 {
- <-ch
- continue
- }
-
- if addr != linkAddr2 {
- t.Fatalf("got addr = %s, want = %s", addr, linkAddr2)
- }
+ ch := make(chan stack.LinkResolutionResult, 1)
+ if err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) {
+ ch <- r
+ }); err != tcpip.ErrWouldBlock {
+ t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, tcpip.ErrWouldBlock)
+ }
+ wantRes := stack.LinkResolutionResult{Success: test.expectedOk}
+ if test.expectedOk {
+ wantRes.LinkAddress = linkAddr2
+ }
+ if diff := cmp.Diff(wantRes, <-ch); diff != "" {
+ t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
}
})
}