diff options
Diffstat (limited to 'pkg/tcpip/network/ipv6')
-rw-r--r-- | pkg/tcpip/network/ipv6/ipv6.go | 26 |
1 files changed, 20 insertions, 6 deletions
diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index c21c587ba..5856c9d3c 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -648,10 +648,10 @@ func (e *endpoint) WritePacket(r *stack.Route, gso *stack.GSO, params stack.Netw // short circuits broadcasts before they are sent out to other hosts. if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) - if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); ep != nil { + if ep := e.protocol.findEndpointWithAddress(netHeader.DestinationAddress()); ep != nil { // Since we rewrote the packet but it is being routed back to us, we // can safely assume the checksum is valid. - ep.(*endpoint).handleLocalPacket(pkt, true /* canSkipRXChecksum */) + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) return nil } } @@ -750,7 +750,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe // The NAT-ed packets may now be destined for us. locallyDelivered := 0 for pkt := range natPkts { - ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, header.IPv6(pkt.NetworkHeader().View()).DestinationAddress()) + ep := e.protocol.findEndpointWithAddress(header.IPv6(pkt.NetworkHeader().View()).DestinationAddress()) if ep == nil { // The NAT-ed packet is still destined for some remote node. continue @@ -760,7 +760,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe pkts.Remove(pkt) // Deliver the packet locally. - ep.(*endpoint).handleLocalPacket(pkt, true) + ep.handleLocalPacket(pkt, true /* canSkipRXChecksum */) locallyDelivered++ } @@ -829,8 +829,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) tcpip.Error { // Check if the destination is owned by the stack. - if ep := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, dstAddr); ep != nil { - ep.(*endpoint).handlePacket(pkt) + if ep := e.protocol.findEndpointWithAddress(dstAddr); ep != nil { + ep.handlePacket(pkt) return nil } @@ -1760,6 +1760,20 @@ func (p *protocol) NewEndpoint(nic stack.NetworkInterface, dispatcher stack.Tran return e } +func (p *protocol) findEndpointWithAddress(addr tcpip.Address) *endpoint { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, e := range p.mu.eps { + if addressEndpoint := e.AcquireAssignedAddress(addr, false /* allowTemp */, stack.NeverPrimaryEndpoint); addressEndpoint != nil { + addressEndpoint.DecRef() + return e + } + } + + return nil +} + func (p *protocol) forgetEndpoint(nicID tcpip.NICID) { p.mu.Lock() defer p.mu.Unlock() |