summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-01-21 23:19:38 -0800
committergVisor bot <gvisor-bot@google.com>2021-01-21 23:26:40 -0800
commite0f4e46e340f2f5e666332ac3ff14f113239400a (patch)
tree96a23367441ee7adbfad107d938250a1c2175fe5 /pkg
parentcfbf209173e34561c5d80072997159486966edc1 (diff)
Resolve static link addresses in GetLinkAddress
If a network address has a static mapping to a link address, calculate it in GetLinkAddress. Test: stack_test.TestStaticGetLinkAddress PiperOrigin-RevId: 353179616
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/stack.go7
-rw-r--r--pkg/tcpip/stack/stack_test.go48
2 files changed, 54 insertions, 1 deletions
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 7885673fe..9a22554e5 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -1533,7 +1533,7 @@ type LinkResolutionResult struct {
Success bool
}
-// GetLinkAddress finds the link address corresponding to a neighbor's address.
+// GetLinkAddress finds the link address corresponding to a network address.
//
// Returns ErrNotSupported if the stack is not configured with a link address
// resolver for the specified network protocol.
@@ -1562,6 +1562,11 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address,
return tcpip.ErrNotSupported
}
+ if linkAddr, ok := linkRes.ResolveStaticAddress(addr); ok {
+ onResolve(LinkResolutionResult{LinkAddress: linkAddr, Success: true})
+ return nil
+ }
+
_, _, err := nic.getNeighborLinkAddress(addr, localAddr, linkRes, onResolve)
return err
}
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index c44b3faf7..511bd7e7b 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -4391,3 +4391,51 @@ func TestGetLinkAddressErrors(t *testing.T) {
t.Errorf("got s.GetLinkAddress(%d, '', '', %d, nil) = %s, want = %s", unknownNICID, ipv4.ProtocolNumber, err, tcpip.ErrNotSupported)
}
}
+
+func TestStaticGetLinkAddress(t *testing.T) {
+ const (
+ nicID = 1
+ )
+
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol, ipv6.NewProtocol},
+ })
+ if err := s.CreateNIC(nicID, channel.New(0, 0, "")); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
+
+ tests := []struct {
+ name string
+ proto tcpip.NetworkProtocolNumber
+ addr tcpip.Address
+ expectedLinkAddr tcpip.LinkAddress
+ }{
+ {
+ name: "IPv4",
+ proto: ipv4.ProtocolNumber,
+ addr: header.IPv4Broadcast,
+ expectedLinkAddr: header.EthernetBroadcastAddress,
+ },
+ {
+ name: "IPv6",
+ proto: ipv6.ProtocolNumber,
+ addr: header.IPv6AllNodesMulticastAddress,
+ expectedLinkAddr: header.EthernetAddressFromMulticastIPv6Address(header.IPv6AllNodesMulticastAddress),
+ },
+ }
+
+ for _, test := range tests {
+ t.Run(test.name, func(t *testing.T) {
+ ch := make(chan stack.LinkResolutionResult, 1)
+ if err := s.GetLinkAddress(nicID, test.addr, "", test.proto, func(r stack.LinkResolutionResult) {
+ ch <- r
+ }); err != nil {
+ t.Fatalf("s.GetLinkAddress(%d, %s, '', %d, _): %s", nicID, test.addr, test.proto, err)
+ }
+
+ if diff := cmp.Diff(stack.LinkResolutionResult{LinkAddress: test.expectedLinkAddr, Success: true}, <-ch); diff != "" {
+ t.Fatalf("link resolution result mismatch (-want +got):\n%s", diff)
+ }
+ })
+ }
+}