diff options
author | Chris Kuiper <ckuiper@google.com> | 2019-09-24 13:18:19 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2019-09-24 13:21:20 -0700 |
commit | 6704d625ef2726a54078f4179d0eaa2c820cc401 (patch) | |
tree | f087fd9889d41a23d9ee31be6573811644f4ac68 | |
parent | 91abeb1dbc83d7c55fdbbdeec07c1c6c41d80d7a (diff) |
Return only primary addresses in Stack.NICInfo()
Non-primary addresses are used for endpoints created to accept multicast and
broadcast packets, as well as "helper" endpoints (0.0.0.0) that allow sending
packets when no proper address has been assigned yet (e.g., for DHCP). These
addresses are not real addresses from a user point of view and should not be
part of the NICInfo() value. Also see b/127321246 for more info.
This switches NICInfo() to call a new NIC.PrimaryAddresses() function. To still
allow an option to get all addresses (mostly for testing) I added
Stack.GetAllAddresses() and NIC.AllAddresses().
In addition, the return value for GetMainNICAddress() was changed for the case
where the NIC has no primary address. Instead of returning an error here,
it now returns an empty AddressWithPrefix() value. The rational for this
change is that it is a valid case for a NIC to have no primary addresses.
Lastly, I refactored the code based on the new additions.
PiperOrigin-RevId: 270971764
-rw-r--r-- | pkg/tcpip/stack/nic.go | 65 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack_test.go | 64 |
3 files changed, 94 insertions, 69 deletions
diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index a719058b4..0e8a23f00 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -148,37 +148,6 @@ func (n *NIC) setSpoofing(enable bool) { n.mu.Unlock() } -func (n *NIC) getMainNICAddress(protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { - n.mu.RLock() - defer n.mu.RUnlock() - - var r *referencedNetworkEndpoint - - // Check for a primary endpoint. - if list, ok := n.primary[protocol]; ok { - for e := list.Front(); e != nil; e = e.Next() { - ref := e.(*referencedNetworkEndpoint) - if ref.getKind() == permanent && ref.tryIncRef() { - r = ref - break - } - } - - } - - if r == nil { - return tcpip.AddressWithPrefix{}, tcpip.ErrNoLinkAddress - } - - addressWithPrefix := tcpip.AddressWithPrefix{ - Address: r.ep.ID().LocalAddress, - PrefixLen: r.ep.PrefixLen(), - } - r.decRef() - - return addressWithPrefix, nil -} - // primaryEndpoint returns the primary endpoint of n for the given network // protocol. func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedNetworkEndpoint { @@ -398,10 +367,12 @@ func (n *NIC) AddAddress(protocolAddress tcpip.ProtocolAddress, peb PrimaryEndpo return err } -// Addresses returns the addresses associated with this NIC. -func (n *NIC) Addresses() []tcpip.ProtocolAddress { +// AllAddresses returns all addresses (primary and non-primary) associated with +// this NIC. +func (n *NIC) AllAddresses() []tcpip.ProtocolAddress { n.mu.RLock() defer n.mu.RUnlock() + addrs := make([]tcpip.ProtocolAddress, 0, len(n.endpoints)) for nid, ref := range n.endpoints { // Don't include expired or tempory endpoints to avoid confusion and @@ -421,6 +392,34 @@ func (n *NIC) Addresses() []tcpip.ProtocolAddress { return addrs } +// PrimaryAddresses returns the primary addresses associated with this NIC. +func (n *NIC) PrimaryAddresses() []tcpip.ProtocolAddress { + n.mu.RLock() + defer n.mu.RUnlock() + + var addrs []tcpip.ProtocolAddress + for proto, list := range n.primary { + for e := list.Front(); e != nil; e = e.Next() { + ref := e.(*referencedNetworkEndpoint) + // Don't include expired or tempory endpoints to avoid confusion and + // prevent the caller from using those. + switch ref.getKind() { + case permanentExpired, temporary: + continue + } + + addrs = append(addrs, tcpip.ProtocolAddress{ + Protocol: proto, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: ref.ep.ID().LocalAddress, + PrefixLen: ref.ep.PrefixLen(), + }, + }) + } + } + return addrs +} + // AddAddressRange adds a range of addresses to n, so that it starts accepting // packets targeted at the given addresses and network protocol. The range is // given by a subnet address, and all addresses contained in the subnet are diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 1fe21b68e..f7ba3cb0f 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -738,7 +738,7 @@ func (s *Stack) NICInfo() map[tcpip.NICID]NICInfo { nics[id] = NICInfo{ Name: nic.name, LinkAddress: nic.linkEP.LinkAddress(), - ProtocolAddresses: nic.Addresses(), + ProtocolAddresses: nic.PrimaryAddresses(), Flags: flags, MTU: nic.linkEP.MTU(), Stats: nic.stats, @@ -845,19 +845,37 @@ func (s *Stack) RemoveAddress(id tcpip.NICID, addr tcpip.Address) *tcpip.Error { return tcpip.ErrUnknownNICID } -// GetMainNICAddress returns the first primary address (and the subnet that -// contains it) for the given NIC and protocol. Returns an arbitrary endpoint's -// address if no primary addresses exist. Returns an error if the NIC doesn't -// exist or has no endpoints. +// AllAddresses returns a map of NICIDs to their protocol addresses (primary +// and non-primary). +func (s *Stack) AllAddresses() map[tcpip.NICID][]tcpip.ProtocolAddress { + s.mu.RLock() + defer s.mu.RUnlock() + + nics := make(map[tcpip.NICID][]tcpip.ProtocolAddress) + for id, nic := range s.nics { + nics[id] = nic.AllAddresses() + } + return nics +} + +// GetMainNICAddress returns the first primary address and prefix for the given +// NIC and protocol. Returns an error if the NIC doesn't exist and an empty +// value if the NIC doesn't have a primary address for the given protocol. func (s *Stack) GetMainNICAddress(id tcpip.NICID, protocol tcpip.NetworkProtocolNumber) (tcpip.AddressWithPrefix, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() - if nic, ok := s.nics[id]; ok { - return nic.getMainNICAddress(protocol) + nic, ok := s.nics[id] + if !ok { + return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID } - return tcpip.AddressWithPrefix{}, tcpip.ErrUnknownNICID + for _, a := range nic.PrimaryAddresses() { + if a.Protocol == protocol { + return a.AddressWithPrefix, nil + } + } + return tcpip.AddressWithPrefix{}, nil } func (s *Stack) getRefEP(nic *NIC, localAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (ref *referencedNetworkEndpoint) { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 0c26c9911..7aa10bce9 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -930,10 +930,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) } // Sending a packet works. testSendTo(t, s, dstAddr, ep, nil) @@ -945,10 +945,10 @@ func TestSpoofingWithAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != localAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) } // Sending a packet using the route works. testSend(t, r, ep, nil) @@ -992,10 +992,10 @@ func TestSpoofingNoAddress(t *testing.T) { t.Fatal("FindRoute failed:", err) } if r.LocalAddress != nonExistentLocalAddr { - t.Errorf("Route has wrong local address: got %v, wanted %v", r.LocalAddress, nonExistentLocalAddr) + t.Errorf("Route has wrong local address: got %s, want %s", r.LocalAddress, nonExistentLocalAddr) } if r.RemoteAddress != dstAddr { - t.Errorf("Route has wrong remote address: got %v, wanted %v", r.RemoteAddress, dstAddr) + t.Errorf("Route has wrong remote address: got %s, want %s", r.RemoteAddress, dstAddr) } // Sending a packet works. // FIXME(b/139841518):Spoofing doesn't work if there is no primary address. @@ -1400,20 +1400,20 @@ func TestGetMainNICAddressAddPrimaryNonPrimary(t *testing.T) { // Check that GetMainNICAddress returns an address if at least // one primary address was added. In that case make sure the // address/prefixLen matches what we added. + gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } if len(primaryAddrAdded) == 0 { - // No primary addresses present, expect an error. - if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, wanted = %s", err, tcpip.ErrNoLinkAddress) + // No primary addresses present. + if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { + t.Fatalf("GetMainNICAddress: got addr = %s, want = %s", gotAddr, wantAddr) } } else { - // At least one primary address was added, expect a valid - // address and prefixLen. - gotAddressWithPefix, err := s.GetMainNICAddress(1, fakeNetNumber) - if err != nil { - t.Fatal("GetMainNICAddress failed:", err) - } - if _, ok := primaryAddrAdded[gotAddressWithPefix]; !ok { - t.Fatalf("GetMainNICAddress: got addressWithPrefix = %v, wanted any in {%v}", gotAddressWithPefix, primaryAddrAdded) + // At least one primary address was added, verify the returned + // address is in the list of primary addresses we added. + if _, ok := primaryAddrAdded[gotAddr]; !ok { + t.Fatalf("GetMainNICAddress: got = %s, want any in {%v}", gotAddr, primaryAddrAdded) } } }) @@ -1452,19 +1452,25 @@ func TestGetMainNICAddressAddRemove(t *testing.T) { } // Check that we get the right initial address and prefix length. - if gotAddressWithPrefix, err := s.GetMainNICAddress(1, fakeNetNumber); err != nil { + gotAddr, err := s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { t.Fatal("GetMainNICAddress failed:", err) - } else if gotAddressWithPrefix != protocolAddress.AddressWithPrefix { - t.Fatalf("got GetMainNICAddress = %+v, want = %+v", gotAddressWithPrefix, protocolAddress.AddressWithPrefix) + } + if wantAddr := protocolAddress.AddressWithPrefix; gotAddr != wantAddr { + t.Fatalf("got s.GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) } if err := s.RemoveAddress(1, protocolAddress.AddressWithPrefix.Address); err != nil { t.Fatal("RemoveAddress failed:", err) } - // Check that we get an error after removal. - if _, err := s.GetMainNICAddress(1, fakeNetNumber); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got s.GetMainNICAddress(...) = %v, want = %s", err, tcpip.ErrNoLinkAddress) + // Check that we get no address after removal. + gotAddr, err = s.GetMainNICAddress(1, fakeNetNumber) + if err != nil { + t.Fatal("GetMainNICAddress failed:", err) + } + if wantAddr := (tcpip.AddressWithPrefix{}); gotAddr != wantAddr { + t.Fatalf("got GetMainNICAddress(...) = %s, want = %s", gotAddr, wantAddr) } }) } @@ -1479,8 +1485,10 @@ func (g *addressGenerator) next(addrLen int) tcpip.Address { } func verifyAddresses(t *testing.T, expectedAddresses, gotAddresses []tcpip.ProtocolAddress) { + t.Helper() + if len(gotAddresses) != len(expectedAddresses) { - t.Fatalf("got len(addresses) = %d, wanted = %d", len(gotAddresses), len(expectedAddresses)) + t.Fatalf("got len(addresses) = %d, want = %d", len(gotAddresses), len(expectedAddresses)) } sort.Slice(gotAddresses, func(i, j int) bool { @@ -1519,7 +1527,7 @@ func TestAddAddress(t *testing.T) { }) } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } @@ -1551,7 +1559,7 @@ func TestAddProtocolAddress(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } @@ -1580,7 +1588,7 @@ func TestAddAddressWithOptions(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } @@ -1615,7 +1623,7 @@ func TestAddProtocolAddressWithOptions(t *testing.T) { } } - gotAddresses := s.NICInfo()[nicid].ProtocolAddresses + gotAddresses := s.AllAddresses()[nicid] verifyAddresses(t, expectedAddresses, gotAddresses) } |