diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-01-21 19:53:31 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-01-21 19:55:37 -0800 |
commit | 9f46328e1174be6b8b5442467050ad0b2f0b260f (patch) | |
tree | 3d641b3fe82f01ada7bf3c0f49d2efc81b00bcae /pkg/tcpip | |
parent | 8ecff1890277820972c5f5287539a824b22a1d60 (diff) |
Only use callback for GetLinkAddress
GetLinkAddress's callback will be called immediately with a
stack.LinkResolutionResult which will hold the link address
so no need to also return the link address from the function.
Fixes #5151.
PiperOrigin-RevId: 353157857
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/network/ipv6/ndp_test.go | 99 | ||||
-rw-r--r-- | pkg/tcpip/stack/linkaddrcache.go | 9 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_cache_test.go | 87 | ||||
-rw-r--r-- | pkg/tcpip/stack/neighbor_entry.go | 5 | ||||
-rw-r--r-- | pkg/tcpip/stack/nic.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/route.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 8 | ||||
-rw-r--r-- | pkg/tcpip/tests/integration/link_resolution_test.go | 56 |
10 files changed, 134 insertions, 164 deletions
diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index b1a5a5510..aed3042d1 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -162,6 +162,11 @@ func TestStackNDPEndpointInvalidateDefaultRouter(t *testing.T) { } } +type linkResolutionResult struct { + linkAddr tcpip.LinkAddress + ok bool +} + // TestNeighorSolicitationWithSourceLinkLayerOption tests that receiving a // valid NDP NS message with the Source Link Layer Address option results in a // new entry in the link address cache for the sender of the message. @@ -231,35 +236,28 @@ func TestNeighorSolicitationWithSourceLinkLayerOption(t *testing.T) { Data: hdr.View().ToVectorisedView(), })) - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if linkAddr != test.expectedLinkAddr { - t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) - } - - if test.expectedLinkAddr != "" { - if err != nil { - t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - if c != nil { - t.Errorf("got unexpected channel") - } + ch := make(chan stack.LinkResolutionResult, 1) + err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { + ch <- r + }) - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - } else { - if err != tcpip.ErrWouldBlock { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) - } - if c == nil { - t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) - } + wantInvalid := uint64(0) + wantErr := (*tcpip.Error)(nil) + wantSucccess := true + if len(test.expectedLinkAddr) == 0 { + wantInvalid = 1 + wantErr = tcpip.ErrWouldBlock + wantSucccess = false + } - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } + if err != wantErr { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, wantErr) + } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { + t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) + } + if got := invalid.Value(); got != wantInvalid { + t.Errorf("got invalid = %d, want = %d", got, wantInvalid) } }) } @@ -803,35 +801,28 @@ func TestNeighorAdvertisementWithTargetLinkLayerOption(t *testing.T) { Data: hdr.View().ToVectorisedView(), })) - linkAddr, c, err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, nil) - if linkAddr != test.expectedLinkAddr { - t.Errorf("got link address = %s, want = %s", linkAddr, test.expectedLinkAddr) - } - - if test.expectedLinkAddr != "" { - if err != nil { - t.Errorf("s.GetLinkAddress(%d, %s, %s, %d, nil): %s", nicID, lladdr1, lladdr0, ProtocolNumber, err) - } - if c != nil { - t.Errorf("got unexpected channel") - } + ch := make(chan stack.LinkResolutionResult, 1) + err := s.GetLinkAddress(nicID, lladdr1, lladdr0, ProtocolNumber, func(r stack.LinkResolutionResult) { + ch <- r + }) - // Invalid count should not have increased. - if got := invalid.Value(); got != 0 { - t.Errorf("got invalid = %d, want = 0", got) - } - } else { - if err != tcpip.ErrWouldBlock { - t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, nil) = (_, _, %v), want = (_, _, %s)", nicID, lladdr1, lladdr0, ProtocolNumber, err, tcpip.ErrWouldBlock) - } - if c == nil { - t.Errorf("expected channel from call to s.GetLinkAddress(%d, %s, %s, %d, nil)", nicID, lladdr1, lladdr0, ProtocolNumber) - } + wantInvalid := uint64(0) + wantErr := (*tcpip.Error)(nil) + wantSucccess := true + if len(test.expectedLinkAddr) == 0 { + wantInvalid = 1 + wantErr = tcpip.ErrWouldBlock + wantSucccess = false + } - // Invalid count should have increased. - if got := invalid.Value(); got != 1 { - t.Errorf("got invalid = %d, want = 1", got) - } + if err != wantErr { + t.Errorf("got s.GetLinkAddress(%d, %s, %s, %d, _) = %s, want = %s", nicID, lladdr1, lladdr0, ProtocolNumber, err, wantErr) + } + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: wantSucccess}, <-ch); diff != "" { + t.Errorf("linkResolutionResult mismatch (-want +got):\n%s", diff) + } + if got := invalid.Value(); got != wantInvalid { + t.Errorf("got invalid = %d, want = %d", got, wantInvalid) } }) } diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index f116f8417..ba6d56a7d 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -97,12 +97,13 @@ type linkAddrEntry struct { done chan struct{} // onResolve is called with the result of address resolution. - onResolve []func(tcpip.LinkAddress, bool) + onResolve []func(LinkResolutionResult) } func (e *linkAddrEntry) notifyCompletionLocked(linkAddr tcpip.LinkAddress) { + res := LinkResolutionResult{LinkAddress: linkAddr, Success: len(linkAddr) != 0} for _, callback := range e.onResolve { - callback(linkAddr, len(linkAddr) != 0) + callback(res) } e.onResolve = nil if ch := e.done; ch != nil { @@ -196,7 +197,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.Address) *linkAddrEntry { } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { c.cache.Lock() defer c.cache.Unlock() entry := c.getOrCreateEntryLocked(k) @@ -208,7 +209,7 @@ func (c *linkAddrCache) get(k tcpip.Address, linkRes LinkAddressResolver, localA if !time.Now().After(entry.expiration) { // Not expired. if onResolve != nil { - onResolve(entry.linkAddr, true) + onResolve(LinkResolutionResult{LinkAddress: entry.linkAddr, Success: true}) } return entry.linkAddr, nil, nil } diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index acee72572..204196d00 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -126,7 +126,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkA // packet prompting NUD/link address resolution. // // TODO(gvisor.dev/issue/5151): Don't return the neighbor entry. -func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (NeighborEntry, <-chan struct{}, *tcpip.Error) { +func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (NeighborEntry, <-chan struct{}, *tcpip.Error) { entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() defer entry.mu.Unlock() @@ -142,7 +142,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA // a node continues sending packets to that neighbor using the cached // link-layer address." if onResolve != nil { - onResolve(entry.neigh.LinkAddr, true) + onResolve(LinkResolutionResult{LinkAddress: entry.neigh.LinkAddr, Success: true}) } return entry.neigh, nil, nil case Unknown, Incomplete, Failed: diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index b96a56612..6723aef9b 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -1188,12 +1188,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1247,12 +1244,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { t.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1423,12 +1417,9 @@ func TestNeighborCacheReplace(t *testing.T) { t.Fatal("store.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1539,12 +1530,9 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { // First, sanity check that resolution is working { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - t.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1576,15 +1564,9 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { entry.Addr += "2" { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1627,15 +1609,9 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { t.Fatal("store.entry(0) not found") } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1674,15 +1650,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) { // Perform address resolution with a faulty link, which will fail. { - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if ok { - t.Error("expected unsuccessful address resolution") - } - if len(linkAddr) != 0 { - t.Fatalf("got linkAddr = %s, want = \"\"", linkAddr) - } - if t.Failed() { - t.FailNow() + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{Success: false}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1713,9 +1683,9 @@ func TestNeighborCacheRetryResolution(t *testing.T) { // Retry address resolution with a working link. linkRes.dropReplies = false { - incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if linkAddr != entry.LinkAddr { - t.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + incompleteEntry, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + t.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { @@ -1772,12 +1742,9 @@ func BenchmarkCacheClear(b *testing.B) { b.Fatalf("store.entry(%d) not found", i) } - _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(linkAddr tcpip.LinkAddress, ok bool) { - if !ok { - b.Fatal("expected successful address resolution") - } - if linkAddr != entry.LinkAddr { - b.Fatalf("got linkAddr = %s, want = %s", linkAddr, entry.LinkAddr) + _, ch, err := neigh.entry(entry.Addr, "", linkRes, func(r LinkResolutionResult) { + if diff := cmp.Diff(LinkResolutionResult{LinkAddress: entry.LinkAddr, Success: true}, r); diff != "" { + b.Fatalf("got link resolution result mismatch (-want +got):\n%s", diff) } }) if err != tcpip.ErrWouldBlock { diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index 697132689..53ac9bb6e 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -96,7 +96,7 @@ type neighborEntry struct { done chan struct{} // onResolve is called with the result of address resolution. - onResolve []func(tcpip.LinkAddress, bool) + onResolve []func(LinkResolutionResult) isRouter bool job *tcpip.Job @@ -143,8 +143,9 @@ func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAdd // // Precondition: e.mu MUST be locked. func (e *neighborEntry) notifyCompletionLocked(succeeded bool) { + res := LinkResolutionResult{LinkAddress: e.neigh.LinkAddr, Success: succeeded} for _, callback := range e.onResolve { - callback(e.neigh.LinkAddr, succeeded) + callback(res) } e.onResolve = nil if ch := e.done; ch != nil { diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index 7592cff75..1bbfe6213 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -561,7 +561,7 @@ func (n *NIC) removeAddress(addr tcpip.Address) *tcpip.Error { return tcpip.ErrBadLocalAddress } -func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (n *NIC) getNeighborLinkAddress(addr, localAddr tcpip.Address, linkRes LinkAddressResolver, onResolve func(LinkResolutionResult)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if n.neigh != nil { entry, ch, err := n.neigh.entry(addr, localAddr, linkRes, onResolve) return entry.LinkAddr, ch, err diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index 093b676aa..4523e4746 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -347,7 +347,7 @@ func (r *Route) ResolvedFields(afterResolve func()) (RouteInfo, <-chan struct{}, linkAddressResolutionRequestLocalAddr = r.LocalAddress } - linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, func(tcpip.LinkAddress, bool) { + linkAddr, ch, err := r.outgoingNIC.getNeighborLinkAddress(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, func(LinkResolutionResult) { if afterResolve != nil { afterResolve() } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 4685fa4cf..7885673fe 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1527,10 +1527,14 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAd return nil } +// LinkResolutionResult is the result of a link address resolution attempt. +type LinkResolutionResult struct { + LinkAddress tcpip.LinkAddress + Success bool +} + // GetLinkAddress finds the link address corresponding to a neighbor's address. // -// Returns a link address for the remote address, if readily available. -// // Returns ErrNotSupported if the stack is not configured with a link address // resolver for the specified network protocol. // @@ -1538,30 +1542,28 @@ func (s *Stack) AddLinkAddress(nicID tcpip.NICID, neighbor tcpip.Address, linkAd // with a notification channel for the caller to block on. Triggers address // resolution asynchronously. // -// If onResolve is provided, it will be called either immediately, if -// resolution is not required, or when address resolution is complete, with -// the resolved link address and whether resolution succeeded. After any -// callbacks have been called, the returned notification channel is closed. +// onResolve will be called either immediately, if resolution is not required, +// or when address resolution is complete, with the resolved link address and +// whether resolution succeeded. // // If specified, the local address must be an address local to the interface // the neighbor cache belongs to. The local address is the source address of // a packet prompting NUD/link address resolution. -// -// TODO(gvisor.dev/issue/5151): Don't return the link address. -func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(tcpip.LinkAddress, bool)) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, onResolve func(LinkResolutionResult)) *tcpip.Error { s.mu.RLock() nic, ok := s.nics[nicID] s.mu.RUnlock() if !ok { - return "", nil, tcpip.ErrUnknownNICID + return tcpip.ErrUnknownNICID } linkRes, ok := s.linkAddrResolvers[protocol] if !ok { - return "", nil, tcpip.ErrNotSupported + return tcpip.ErrNotSupported } - return nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) + _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) + return err } // Neighbors returns all IP to MAC address associations. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index b9ef455e5..c44b3faf7 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -4384,10 +4384,10 @@ func TestGetLinkAddressErrors(t *testing.T) { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } - if addr, _, err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrUnknownNICID) + if err := s.GetLinkAddress(unknownNICID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrUnknownNICID { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, tcpip.ErrUnknownNICID) } - if addr, _, err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported { - t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = (%s, _, %s), want = (_, _, %s)", unknownNICID, ipv4.ProtocolNumber, addr, err, tcpip.ErrNotSupported) + if err := s.GetLinkAddress(nicID, "", "", ipv4.ProtocolNumber, nil); err != tcpip.ErrNotSupported { + t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, tcpip.ErrNotSupported) } } diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 7afc4702c..1e13fd6d6 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -406,20 +406,34 @@ func TestGetLinkAddress(t *testing.T) { ) tests := []struct { - name string - netProto tcpip.NetworkProtocolNumber - remoteAddr tcpip.Address - expectedLinkAddr bool + name string + netProto tcpip.NetworkProtocolNumber + remoteAddr tcpip.Address + expectedOk bool }{ { - name: "IPv4", + name: "IPv4 resolvable", netProto: ipv4.ProtocolNumber, remoteAddr: ipv4Addr2.AddressWithPrefix.Address, + expectedOk: true, }, { - name: "IPv6", + name: "IPv6 resolvable", netProto: ipv6.ProtocolNumber, remoteAddr: ipv6Addr2.AddressWithPrefix.Address, + expectedOk: true, + }, + { + name: "IPv4 not resolvable", + netProto: ipv4.ProtocolNumber, + remoteAddr: ipv4Addr3.AddressWithPrefix.Address, + expectedOk: false, + }, + { + name: "IPv6 not resolvable", + netProto: ipv6.ProtocolNumber, + remoteAddr: ipv6Addr3.AddressWithPrefix.Address, + expectedOk: false, }, } @@ -434,24 +448,18 @@ func TestGetLinkAddress(t *testing.T) { host1Stack, _ := setupStack(t, stackOpts, host1NICID, host2NICID) - for i := 0; i < 2; i++ { - addr, ch, err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(tcpip.LinkAddress, bool) {}) - var want *tcpip.Error - if i == 0 { - want = tcpip.ErrWouldBlock - } - if err != want { - t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = (%s, _, %s), want = (_, _, %s)", host1NICID, test.remoteAddr, test.netProto, addr, err, want) - } - - if i == 0 { - <-ch - continue - } - - if addr != linkAddr2 { - t.Fatalf("got addr = %s, want = %s", addr, linkAddr2) - } + ch := make(chan stack.LinkResolutionResult, 1) + if err := host1Stack.GetLinkAddress(host1NICID, test.remoteAddr, "", test.netProto, func(r stack.LinkResolutionResult) { + ch <- r + }); err != tcpip.ErrWouldBlock { + t.Fatalf("got host1Stack.GetLinkAddress(%d, %s, '', %d, _) = %s, want = %s", host1NICID, test.remoteAddr, test.netProto, err, tcpip.ErrWouldBlock) + } + wantRes := stack.LinkResolutionResult{Success: test.expectedOk} + if test.expectedOk { + wantRes.LinkAddress = linkAddr2 + } + if diff := cmp.Diff(wantRes, <-ch); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) } }) } |