diff options
-rw-r--r-- | pkg/tcpip/network/ipv4/ipv4.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 26 | ||||
-rw-r--r-- | pkg/tcpip/stack/stack.go | 18 |
3 files changed, 40 insertions, 30 deletions
diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 14cf786d2..7de438fe3 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -130,6 +130,20 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran return e } +func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, e := range p.mu.eps { + if addressEndpoint := e.AcquireAssignedAddress(addr, false /* allowTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return e + } + } + + return nil +} + func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() @@ -347,10 +361,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv4(pkt.NetworkHeader().View()) - if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { // Since we rewrote the packet but it is being routed back to us, we // can safely assume the checksum is valid. - ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) return nil } } @@ -449,7 +463,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // The NAT-ed packets may now be destined for us. locallyDelivered := 0 for pkt := range natPkts { - ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv4(pkt.NetworkHeader().View()).DestinationAddress()) + ep := e.protocol.findEndpointWithAddress(header.IPv4(pkt.NetworkHeader().View()).DestinationAddress()) if ep == nil { // The NAT-ed packet is still destined for some remote node. continue @@ -459,7 +473,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pkts.Remove(pkt) // Deliver the packet locally. - ep.(*endpoint).handleLocalPacket(pkt, true) + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) locallyDelivered++ } @@ -550,8 +564,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { dstAddr := h.DestinationAddress() // Check if the destination is owned by the stack. - if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil { - ep.(*endpoint).handlePacket(pkt) + if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + ep.handlePacket(pkt) return nil } diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index c21c587ba..5856c9d3c 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -648,10 +648,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { // Since we rewrote the packet but it is being routed back to us, we // can safely assume the checksum is valid. - ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) return nil } } @@ -750,7 +750,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // The NAT-ed packets may now be destined for us. locallyDelivered := 0 for pkt := range natPkts { - ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv6(pkt.NetworkHeader().View()).DestinationAddress()) + ep := e.protocol.findEndpointWithAddress(header.IPv6(pkt.NetworkHeader().View()).DestinationAddress()) if ep == nil { // The NAT-ed packet is still destined for some remote node. continue @@ -760,7 +760,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pkts.Remove(pkt) // Deliver the packet locally. - ep.(*endpoint).handleLocalPacket(pkt, true) + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) locallyDelivered++ } @@ -829,8 +829,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // Check if the destination is owned by the stack. - if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil { - ep.(*endpoint).handlePacket(pkt) + if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + ep.handlePacket(pkt) return nil } @@ -1760,6 +1760,20 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran return e } +func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, e := range p.mu.eps { + if addressEndpoint := e.AcquireAssignedAddress(addr, false /* allowTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return e + } + } + + return nil +} + func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3bb3f61b2..07cd88a6a 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -2063,24 +2063,6 @@ func generateRandInt64() int64 { return v } -// FindNetworkEndpoint returns the network endpoint for the given address. -// -// Returns nil if the address is not associated with any network endpoint. -func (s *Stack) FindNetworkEndpoint(netProto tcpip.NetworkProtocolNumber, address tcpip.Address) NetworkEndpoint { - s.mu.RLock() - defer s.mu.RUnlock() - - for _, nic := range s.nics { - addressEndpoint := nic.getAddressOrCreateTempInner(netProto, address, false /* createTemp */, NeverPrimaryEndpoint) - if addressEndpoint == nil { - continue - } - addressEndpoint.DecRef() - return nic.getNetworkEndpoint(netProto) - } - return nil -} - // FindNICNameFromID returns the name of the NIC for the given NICID. func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { s.mu.RLock() |