diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 11 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 25 |
2 files changed, 24 insertions, 12 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 292e51d20..61cbaf688 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -582,17 +582,18 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb // As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via // raw endpoint first. If there are multiple raw endpoints, they all // receive the packet. - foundRaw := false eps.mu.RLock() - for _, rawEP := range eps.rawEndpoints { + // Copy the list of raw endpoints so we can release eps.mu earlier. + rawEPs := make([]RawTransportEndpoint, len(eps.rawEndpoints)) + copy(rawEPs, eps.rawEndpoints) + eps.mu.RUnlock() + for _, rawEP := range rawEPs { // Each endpoint gets its own copy of the packet for the sake // of save/restore. rawEP.HandlePacket(pkt.Clone()) - foundRaw = true } - eps.mu.RUnlock() - return foundRaw + return len(rawEPs) > 0 } // deliverError attempts to deliver the given error to the appropriate transport diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 9c9ccc0ff..a65ea32db 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -84,7 +84,6 @@ type endpoint struct { // Connect(), and is valid only when conneted is true. route *stack.Route `state:"manual"` stats tcpip.TransportEndpointStats `state:"nosave"` - // owner is used to get uid and gid of the packet. owner tcpip.PacketOwner @@ -185,6 +184,8 @@ func (e *endpoint) Close() { func (e *endpoint) ModerateRecvBuf(copied int) {} func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { + e.mu.Lock() + defer e.mu.Unlock() e.owner = owner } @@ -272,14 +273,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp } e.mu.RLock() - defer e.mu.RUnlock() if e.closed { + e.mu.RUnlock() return 0, &tcpip.ErrInvalidEndpointState{} } payloadBytes := make([]byte, p.Len()) if _, err := io.ReadFull(p, payloadBytes); err != nil { + e.mu.RUnlock() return 0, &tcpip.ErrBadBuffer{} } @@ -288,6 +290,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { + e.mu.RUnlock() return 0, &tcpip.ErrInvalidOptionValue{} } dstAddr := ip.DestinationAddress() @@ -309,16 +312,21 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { + e.mu.RUnlock() return 0, &tcpip.ErrDestinationRequired{} } - return e.finishWrite(payloadBytes, e.route) + owner := e.owner + route := e.route + e.mu.RUnlock() + return e.finishWrite(payloadBytes, route, owner) } // The caller provided a destination. Reject destination address if it // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { + e.mu.RUnlock() return 0, &tcpip.ErrNoRoute{} } @@ -326,17 +334,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { + e.mu.RUnlock() return 0, err } - n, err := e.finishWrite(payloadBytes, route) + owner := e.owner + e.mu.RUnlock() + n, err := e.finishWrite(payloadBytes, route, owner) route.Release() return n, err } // finishWrite writes the payload to a route. It resolves the route if -// necessary. It's really just a helper to make defer unnecessary in Write. -func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, tcpip.Error) { +// necessary. +func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route, owner tcpip.PacketOwner) (int64, tcpip.Error) { if e.ops.GetHeaderIncluded() { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), @@ -349,7 +360,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, ReserveHeaderBytes: int(route.MaxHeaderLength()), Data: buffer.View(payloadBytes).ToVectorisedView(), }) - pkt.Owner = e.owner + pkt.Owner = owner if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ Protocol: e.TransProto, TTL: route.DefaultTTL(), |