diff options
author | Ghanan Gowripalan <ghanan@google.com> | 2021-02-06 13:23:44 -0800 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-02-06 13:25:28 -0800 |
commit | c5afaf2854679fbb7470f9a615d3c0fbb2af0999 (patch) | |
tree | b25dd9eb5f12eded8d2137f35e6dfa2582dce15c | |
parent | 4943347137dd09cd47b22b5998f8823e0bb04de6 (diff) |
Remove (*stack.Stack).FindNetworkEndpoint
The network endpoints only look for other network endpoints of the
same kind. Since the network protocols keeps track of all endpoints,
go through the protocol to find an endpoint with an address instead
of the stack.
PiperOrigin-RevId: 356051498
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 18 |
3 files changed, 40 insertions, 30 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 14cf786d2..7de438fe3 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -130,6 +130,20 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran return e } +func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, e := range p.mu.eps { + if addressEndpoint := e.AcquireAssignedAddress(addr, false /* allowTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return e + } + } + + return nil +} + func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() @@ -347,10 +361,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) - if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { // Since we rewrote the packet but it is being routed back to us, we // can safely assume the checksum is valid. - ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) return nil } } @@ -449,7 +463,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // The NAT-ed packets may now be destined for us. locallyDelivered := 0 for pkt := range natPkts { - ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv4(pkt.NetworkHeader().View()).DestinationAddress()) + ep := e.protocol.findEndpointWithAddress(header.IPv4(pkt.NetworkHeader().View()).DestinationAddress()) if ep == nil { // The NAT-ed packet is still destined for some remote node. continue @@ -459,7 +473,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pkts.Remove(pkt) // Deliver the packet locally. - ep.(*endpoint).handleLocalPacket(pkt, true) + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) locallyDelivered++ } @@ -550,8 +564,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil { - ep.(*endpoint).handlePacket(pkt) + if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + ep.handlePacket(pkt) return nil } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index c21c587ba..5856c9d3c 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -648,10 +648,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { // Since we rewrote the packet but it is being routed back to us, we // can safely assume the checksum is valid. - ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) return nil } } @@ -750,7 +750,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // The NAT-ed packets may now be destined for us. locallyDelivered := 0 for pkt := range natPkts { - ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv6(pkt.NetworkHeader().View()).DestinationAddress()) + ep := e.protocol.findEndpointWithAddress(header.IPv6(pkt.NetworkHeader().View()).DestinationAddress()) if ep == nil { // The NAT-ed packet is still destined for some remote node. continue @@ -760,7 +760,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pkts.Remove(pkt) // Deliver the packet locally. - ep.(*endpoint).handleLocalPacket(pkt, true) + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) locallyDelivered++ } @@ -829,8 +829,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // Check if the destination is owned by the stack. - if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil { - ep.(*endpoint).handlePacket(pkt) + if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + ep.handlePacket(pkt) return nil } @@ -1760,6 +1760,20 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran return e } +func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, e := range p.mu.eps { + if addressEndpoint := e.AcquireAssignedAddress(addr, false /* allowTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return e + } + } + + return nil +} + func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3bb3f61b2..07cd88a6a 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -2063,24 +2063,6 @@ func generateRandInt64() int64 { return v } -// FindNetworkEndpoint returns the network endpoint for the given address. -// -// Returns nil if the address is not associated with any network endpoint. -func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) NetworkEndpoint { - s.mu.RLock() - defer s.mu.RUnlock() - - for _, nic := range s.nics { - addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint) - if addressEndpoint == nil { - continue - } - addressEndpoint.DecRef() - return nic.getNetworkEndpoint(netProto) - } - return nil -} - // FindNICNameFromID returns the name of the NIC for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { s.mu.RLock() |