summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/network/ipv4/ipv4.go26
-rw-r--r--pkg/tcpip/network/ipv6/ipv6.go26
-rw-r--r--pkg/tcpip/stack/stack.go18
3 files changed, 40 insertions, 30 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go
index 14cf786d2..7de438fe3 100644
--- a/pkg/tcpip/network/ipv4/ipv4.go
+++ b/pkg/tcpip/network/ipv4/ipv4.go
@@ -130,6 +130,20 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran
return e
}
+func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ for _, e := range p.mu.eps {
+ if addressEndpoint := e.AcquireAssignedAddress(addr, false /* allowTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ return e
+ }
+ }
+
+ return nil
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
@@ -347,10 +361,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv4(pkt.NetworkHeader().View())
- if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
+ if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
// Since we rewrote the packet but it is being routed back to us, we
// can safely assume the checksum is valid.
- ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
+ ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */)
return nil
}
}
@@ -449,7 +463,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// The NAT-ed packets may now be destined for us.
locallyDelivered := 0
for pkt := range natPkts {
- ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv4(pkt.NetworkHeader().View()).DestinationAddress())
+ ep := e.protocol.findEndpointWithAddress(header.IPv4(pkt.NetworkHeader().View()).DestinationAddress())
if ep == nil {
// The NAT-ed packet is still destined for some remote node.
continue
@@ -459,7 +473,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
pkts.Remove(pkt)
// Deliver the packet locally.
- ep.(*endpoint).handleLocalPacket(pkt, true)
+ ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */)
locallyDelivered++
}
@@ -550,8 +564,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
dstAddr := h.DestinationAddress()
// Check if the destination is owned by the stack.
- if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil {
- ep.(*endpoint).handlePacket(pkt)
+ if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+ ep.handlePacket(pkt)
return nil
}
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go
index c21c587ba..5856c9d3c 100644
--- a/pkg/tcpip/network/ipv6/ipv6.go
+++ b/pkg/tcpip/network/ipv6/ipv6.go
@@ -648,10 +648,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw
// short circuits broadcasts before they are sent out to other hosts.
if pkt.NatDone {
netHeader := header.IPv6(pkt.NetworkHeader().View())
- if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil {
+ if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil {
// Since we rewrote the packet but it is being routed back to us, we
// can safely assume the checksum is valid.
- ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */)
+ ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */)
return nil
}
}
@@ -750,7 +750,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
// The NAT-ed packets may now be destined for us.
locallyDelivered := 0
for pkt := range natPkts {
- ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv6(pkt.NetworkHeader().View()).DestinationAddress())
+ ep := e.protocol.findEndpointWithAddress(header.IPv6(pkt.NetworkHeader().View()).DestinationAddress())
if ep == nil {
// The NAT-ed packet is still destined for some remote node.
continue
@@ -760,7 +760,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe
pkts.Remove(pkt)
// Deliver the packet locally.
- ep.(*endpoint).handleLocalPacket(pkt, true)
+ ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */)
locallyDelivered++
}
@@ -829,8 +829,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error {
// Check if the destination is owned by the stack.
- if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil {
- ep.(*endpoint).handlePacket(pkt)
+ if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil {
+ ep.handlePacket(pkt)
return nil
}
@@ -1760,6 +1760,20 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran
return e
}
+func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint {
+ p.mu.RLock()
+ defer p.mu.RUnlock()
+
+ for _, e := range p.mu.eps {
+ if addressEndpoint := e.AcquireAssignedAddress(addr, false /* allowTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil {
+ addressEndpoint.DecRef()
+ return e
+ }
+ }
+
+ return nil
+}
+
func (p *protocol) forgetEndpoint(nicID tcpip.NICID) {
p.mu.Lock()
defer p.mu.Unlock()
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 3bb3f61b2..07cd88a6a 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -2063,24 +2063,6 @@ func generateRandInt64() int64 {
return v
}
-// FindNetworkEndpoint returns the network endpoint for the given address.
-//
-// Returns nil if the address is not associated with any network endpoint.
-func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) NetworkEndpoint {
- s.mu.RLock()
- defer s.mu.RUnlock()
-
- for _, nic := range s.nics {
- addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint)
- if addressEndpoint == nil {
- continue
- }
- addressEndpoint.DecRef()
- return nic.getNetworkEndpoint(netProto)
- }
- return nil
-}
-
// FindNICNameFromID returns the name of the NIC for the given NICID.
func (s *Stack) FindNICNameFromID(id tcpip.NICID) string {
s.mu.RLock()