diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/stack/stack.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 48 |
2 files changed, 54 insertions, 1 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 7885673fe..9a22554e5 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -1533,7 +1533,7 @@ type LinkResolutionResult struct { Success bool } -// GetLinkAddress finds the link address corresponding to a neighbor's address. +// GetLinkAddress finds the link address corresponding to a network address. // // Returns ErrNotSupported if the stack is not configured with a link address // resolver for the specified network protocol. @@ -1562,6 +1562,11 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, return tcpip.ErrNotSupported } + if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok { + onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true}) + return nil + } + _, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve) return err } diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index c44b3faf7..511bd7e7b 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -4391,3 +4391,51 @@ func TestGetLinkAddressErrors(t *testing.T) { t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, tcpip.ErrNotSupported) } } + +func TestStaticGetLinkAddress(t *testing.T) { + const ( + nicID = 1 + ) + + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol}, + }) + if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + + tests := []struct { + name string + proto tcpip.NetworkProtocolNumber + addr tcpip.Address + expectedLinkAddr tcpip.LinkAddress + }{ + { + name: "IPv4", + proto: ipv4.ProtocolNumber, + addr: header.IPv4Broadcast, + expectedLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "IPv6", + proto: ipv6.ProtocolNumber, + addr: header.IPv6AllNodesMulticastAddress, + expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ch := make(chan stack.LinkResolutionResult, 1) + if err := s.GetLinkAddress(nicID, test.addr, "", test.proto, func(r stack.LinkResolutionResult) { + ch <- r + }); err != nil { + t.Fatalf("s.GetLinkAddress(%d, %s, '', %d, _): %s", nicID, test.addr, test.proto, err) + } + + if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: true}, <-ch); diff != "" { + t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff) + } + }) + } +} |