summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
authorGhanan Gowripalan <ghanan@google.com>2021-02-06 13:23:44 -0800
committergVisor bot <gvisor-bot@google.com>2021-02-06 13:25:28 -0800
commitc5afaf2854679fbb7470f9a615d3c0fbb2af0999 (patch)
treeb25dd9eb5f12eded8d2137f35e6dfa2582dce15c /pkg
parent4943347137dd09cd47b22b5998f8823e0bb04de6 (diff)
Remove (*stack.Stack).FindNetworkEndpoint
The network endpoints only look for other network endpoints of the same kind. Since the network protocols keeps track of all endpoints, go through the protocol to find an endpoint with an address instead of the stack. PiperOrigin-RevId: 356051498
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go26
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go26
-rw-r--r--pkg/tcpip/stack/stack.go18
3 files changed, 40 insertions, 30 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 14cf786d2..7de438fe3 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -130,6 +130,20 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran
return e
}
+func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ for _, e := range p.mu.eps {
+ if addressEndpoint := e.AcquireAssignedAddress(addr, false /* allowTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ return e
+ }
+ }
+
+ return nil
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -347,10 +361,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
+ if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
// Since we rewrote the packet but it is being routed back to us, we
// can safely assume the checksum is valid.
- ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
+ ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */)
return nil
}
}
@@ -449,7 +463,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// The NAT-ed packets may now be destined for us.
locallyDelivered := 0
for pkt := range natPkts {
- ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv4(pkt.NetworkHeader().View()).DestinationAddress())
+ ep := e.protocol.findEndpointWithAddress(header.IPv4(pkt.NetworkHeader().View()).DestinationAddress())
if ep == nil {
// The NAT-ed packet is still destined for some remote node.
continue
@@ -459,7 +473,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
pkts.Remove(pkt)
// Deliver the packet locally.
- ep.(*endpoint).handleLocalPacket(pkt, true)
+ ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */)
locallyDelivered++
}
@@ -550,8 +564,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
dstAddr := h.DestinationAddress()
// Check if the destination is owned by the stack.
- if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil {
- ep.(*endpoint).handlePacket(pkt)
+ if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+ ep.handlePacket(pkt)
return nil
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index c21c587ba..5856c9d3c 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -648,10 +648,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
+ if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
// Since we rewrote the packet but it is being routed back to us, we
// can safely assume the checksum is valid.
- ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
+ ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */)
return nil
}
}
@@ -750,7 +750,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// The NAT-ed packets may now be destined for us.
locallyDelivered := 0
for pkt := range natPkts {
- ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv6(pkt.NetworkHeader().View()).DestinationAddress())
+ ep := e.protocol.findEndpointWithAddress(header.IPv6(pkt.NetworkHeader().View()).DestinationAddress())
if ep == nil {
// The NAT-ed packet is still destined for some remote node.
continue
@@ -760,7 +760,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
pkts.Remove(pkt)
// Deliver the packet locally.
- ep.(*endpoint).handleLocalPacket(pkt, true)
+ ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */)
locallyDelivered++
}
@@ -829,8 +829,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// Check if the destination is owned by the stack.
- if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil {
- ep.(*endpoint).handlePacket(pkt)
+ if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+ ep.handlePacket(pkt)
return nil
}
@@ -1760,6 +1760,20 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran
return e
}
+func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ for _, e := range p.mu.eps {
+ if addressEndpoint := e.AcquireAssignedAddress(addr, false /* allowTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ return e
+ }
+ }
+
+ return nil
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 3bb3f61b2..07cd88a6a 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -2063,24 +2063,6 @@ func generateRandInt64() int64 {
return v
}
-// FindNetworkEndpoint returns the network endpoint for the given address.
-//
-// Returns nil if the address is not associated with any network endpoint.
-func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) NetworkEndpoint {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- for _, nic := range s.nics {
- addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint)
- if addressEndpoint == nil {
- continue
- }
- addressEndpoint.DecRef()
- return nic.getNetworkEndpoint(netProto)
- }
- return nil
-}
-
// FindNICNameFromID returns the name of the NIC for the given NICID.
func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
s.mu.RLock()