From f7114e0a27db788af512af8678595954996bcd7f Mon Sep 17 00:00:00 2001 From: Chris Kuiper Date: Fri, 16 Aug 2019 15:57:46 -0700 Subject: Add subnet checking to NIC.findEndpoint and consolidate with NIC.getRef This adds the same logic to NIC.findEndpoint that is already done in NIC.getRef. Since this makes the two functions very similar they were combined into one with the originals being wrappers. PiperOrigin-RevId: 263864708 --- pkg/tcpip/stack/nic.go | 123 ++++++++++++++++++------------------------ pkg/tcpip/stack/stack_test.go | 42 +++++++++++++++ pkg/tcpip/tcpip.go | 5 ++ 3 files changed, 99 insertions(+), 71 deletions(-) (limited to 'pkg') diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dc28dc970..4ef85bdfb 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -186,41 +186,73 @@ func (n *NIC) primaryEndpoint(protocol tcpip.NetworkProtocolNumber) *referencedN return nil } +func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { + return n.getRefOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, n.promiscuous) +} + // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) *referencedNetworkEndpoint { + return n.getRefOrCreateTemp(protocol, address, peb, n.spoofing) +} + +// getRefEpOrCreateTemp returns the referenced network endpoint for the given +// protocol and address. If none exists a temporary one may be created if +// requested. +func (n *NIC) getRefOrCreateTemp(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior, allowTemp bool) *referencedNetworkEndpoint { id := NetworkEndpointID{address} n.mu.RLock() - ref := n.endpoints[id] - if ref != nil && !ref.tryIncRef() { - ref = nil + + if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { + n.mu.RUnlock() + return ref + } + + // The address was not found, create a temporary one if requested by the + // caller or if the address is found in the NIC's subnets. + createTempEP := allowTemp + if !createTempEP { + for _, sn := range n.subnets { + if sn.Contains(address) { + createTempEP = true + break + } + } } - spoofing := n.spoofing + n.mu.RUnlock() - if ref != nil || !spoofing { - return ref + if !createTempEP { + return nil } // Try again with the lock in exclusive mode. If we still can't get the // endpoint, create a new "temporary" endpoint. It will only exist while // there's a route through it. n.mu.Lock() - ref = n.endpoints[id] - if ref == nil || !ref.tryIncRef() { - if netProto, ok := n.stack.networkProtocols[protocol]; ok { - ref, _ = n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: address, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, peb, true) - if ref != nil { - ref.holdsInsertRef = false - } - } + if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { + n.mu.Unlock() + return ref + } + + netProto, ok := n.stack.networkProtocols[protocol] + if !ok { + n.mu.Unlock() + return nil } + + ref, _ := n.addAddressLocked(tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: netProto.DefaultPrefixLen(), + }, + }, peb, true) + + if ref != nil { + ref.holdsInsertRef = false + } + n.mu.Unlock() return ref } @@ -553,57 +585,6 @@ func (n *NIC) DeliverNetworkPacket(linkEP LinkEndpoint, remote, _ tcpip.LinkAddr n.stack.stats.IP.InvalidAddressesReceived.Increment() } -func (n *NIC) getRef(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address) *referencedNetworkEndpoint { - id := NetworkEndpointID{dst} - - n.mu.RLock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.RUnlock() - return ref - } - - promiscuous := n.promiscuous - // Check if the packet is for a subnet this NIC cares about. - if !promiscuous { - for _, sn := range n.subnets { - if sn.Contains(dst) { - promiscuous = true - break - } - } - } - n.mu.RUnlock() - if promiscuous { - // Try again with the lock in exclusive mode. If we still can't - // get the endpoint, create a new "temporary" one. It will only - // exist while there's a route through it. - n.mu.Lock() - if ref, ok := n.endpoints[id]; ok && ref.tryIncRef() { - n.mu.Unlock() - return ref - } - netProto, ok := n.stack.networkProtocols[protocol] - if !ok { - n.mu.Unlock() - return nil - } - ref, err := n.addAddressLocked(tcpip.ProtocolAddress{ - Protocol: protocol, - AddressWithPrefix: tcpip.AddressWithPrefix{ - Address: dst, - PrefixLen: netProto.DefaultPrefixLen(), - }, - }, CanBePrimaryEndpoint, true) - n.mu.Unlock() - if err == nil { - ref.holdsInsertRef = false - return ref - } - } - - return nil -} - // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, netHeader buffer.View, vv buffer.VectorisedView) { diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 1ab9c575b..966d03007 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -830,6 +830,48 @@ func TestSubnetAcceptsMatchingPacket(t *testing.T) { } } +// Set the subnet, then check that CheckLocalAddress returns the correct NIC. +func TestCheckLocalAddressForSubnet(t *testing.T) { + const nicID tcpip.NICID = 1 + s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) + + id, _ := channel.New(10, defaultMTU, "") + if err := s.CreateNIC(nicID, id); err != nil { + t.Fatal("CreateNIC failed:", err) + } + + s.SetRouteTable([]tcpip.Route{ + {"\x00", "\x00", "\x00", nicID}, // default route + }) + + subnet, err := tcpip.NewSubnet(tcpip.Address("\xa0"), tcpip.AddressMask("\xf0")) + + if err != nil { + t.Fatal("NewSubnet failed:", err) + } + if err := s.AddSubnet(nicID, fakeNetNumber, subnet); err != nil { + t.Fatal("AddSubnet failed:", err) + } + + // Loop over all subnet addresses and check them. + numOfAddresses := (1 << uint((8 - subnet.Prefix()))) + if numOfAddresses < 1 || numOfAddresses > 255 { + t.Fatalf("got numOfAddresses = %d, want = [1 .. 255] (subnet=%s)", numOfAddresses, subnet) + } + addr := []byte(subnet.ID()) + for i := 0; i < numOfAddresses; i++ { + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != nicID { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, nicID) + } + addr[0]++ + } + + // Trying the next address should fail since it is outside the subnet range. + if gotNicID := s.CheckLocalAddress(0, fakeNetNumber, tcpip.Address(addr)); gotNicID != 0 { + t.Errorf("got CheckLocalAddress(0, %d, %s) = %d, want = %d", fakeNetNumber, tcpip.Address(addr), gotNicID, 0) + } +} + // Set destination outside the subnet, then check it doesn't get delivered. func TestSubnetRejectsNonmatchingPacket(t *testing.T) { s := stack.New([]string{"fakeNet"}, nil, stack.Options{}) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f664673b3..ba2dd85b8 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -168,6 +168,11 @@ func NewSubnet(a Address, m AddressMask) (Subnet, error) { return Subnet{a, m}, nil } +// String implements Stringer. +func (s Subnet) String() string { + return fmt.Sprintf("%s/%d", s.ID(), s.Prefix()) +} + // Contains returns true iff the address is of the same length and matches the // subnet address and mask. func (s *Subnet) Contains(a Address) bool { -- cgit v1.2.3