From 2b46e2d19e642ed8a0a58ea1c62a38b53e7dfb2d Mon Sep 17 00:00:00 2001 From: Ghanan Gowripalan Date: Tue, 14 Sep 2021 13:35:21 -0700 Subject: Defer mutex unlocking PiperOrigin-RevId: 396670516 --- pkg/tcpip/transport/raw/endpoint.go | 158 ++++++++++++++++++------------------ 1 file changed, 77 insertions(+), 81 deletions(-) diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 4a5858bdd..264d29c7a 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -534,96 +534,92 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) { // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { - e.mu.RLock() - e.rcvMu.Lock() + notifyReadableEvents := func() bool { + e.mu.RLock() + defer e.mu.RUnlock() + e.rcvMu.Lock() + defer e.rcvMu.Unlock() + + // Drop the packet if our buffer is currently full or if this is an unassociated + // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only + // See: https://man7.org/linux/man-pages/man7/raw.7.html + // + // An IPPROTO_RAW socket is send only. If you really want to receive + // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. + // Note that packet sockets don't reassemble IP fragments, unlike raw + // sockets. + if e.rcvClosed || !e.associated { + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ClosedReceiver.Increment() + return false + } - // Drop the packet if our buffer is currently full or if this is an unassociated - // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only - // See: https://man7.org/linux/man-pages/man7/raw.7.html - // - // An IPPROTO_RAW socket is send only. If you really want to receive - // all IP packets, use a packet(7) socket with the ETH_P_IP protocol. - // Note that packet sockets don't reassemble IP fragments, unlike raw - // sockets. - if e.rcvClosed || !e.associated { - e.rcvMu.Unlock() - e.mu.RUnlock() - e.stack.Stats().DroppedPackets.Increment() - e.stats.ReceiveErrors.ClosedReceiver.Increment() - return - } + rcvBufSize := e.ops.GetReceiveBufferSize() + if e.frozen || e.rcvBufSize >= int(rcvBufSize) { + e.stack.Stats().DroppedPackets.Increment() + e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() + return false + } - rcvBufSize := e.ops.GetReceiveBufferSize() - if e.frozen || e.rcvBufSize >= int(rcvBufSize) { - e.rcvMu.Unlock() - e.mu.RUnlock() - e.stack.Stats().DroppedPackets.Increment() - e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() - return - } + if e.bound { + // If bound to a NIC, only accept data for that NIC. + if e.BindNICID != 0 && e.BindNICID != pkt.NICID { + return false + } - if e.bound { - // If bound to a NIC, only accept data for that NIC. - if e.BindNICID != 0 && e.BindNICID != pkt.NICID { - e.rcvMu.Unlock() - e.mu.RUnlock() - return + // If bound to an address, only accept data for that address. + if e.BindAddr != "" && e.BindAddr != pkt.Network().DestinationAddress() { + return false + } } - // If bound to an address, only accept data for that address. - if e.BindAddr != "" && e.BindAddr != pkt.Network().DestinationAddress() { - e.rcvMu.Unlock() - e.mu.RUnlock() - return + + srcAddr := pkt.Network().SourceAddress() + // If connected, only accept packets from the remote address we + // connected to. + if e.connected && e.route.RemoteAddress() != srcAddr { + return false } - } - srcAddr := pkt.Network().SourceAddress() - // If connected, only accept packets from the remote address we - // connected to. - if e.connected && e.route.RemoteAddress() != srcAddr { - e.rcvMu.Unlock() - e.mu.RUnlock() - return - } + wasEmpty := e.rcvBufSize == 0 - wasEmpty := e.rcvBufSize == 0 + // Push new packet into receive list and increment the buffer size. + packet := &rawPacket{ + senderAddr: tcpip.FullAddress{ + NIC: pkt.NICID, + Addr: srcAddr, + }, + } - // Push new packet into receive list and increment the buffer size. - packet := &rawPacket{ - senderAddr: tcpip.FullAddress{ - NIC: pkt.NICID, - Addr: srcAddr, - }, - } + // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. + // We copy headers' underlying bytes because pkt.*Header may point to + // the middle of a slice, and another struct may point to the "outer" + // slice. Save/restore doesn't support overlapping slices and will fail. + // + // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports + // overlapping slices. + var combinedVV buffer.VectorisedView + if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { + network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() + headers := make(buffer.View, 0, len(network)+len(transport)) + headers = append(headers, network...) + headers = append(headers, transport...) + combinedVV = headers.ToVectorisedView() + } else { + combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() + } + combinedVV.Append(pkt.Data().ExtractVV()) + packet.data = combinedVV + packet.receivedAt = e.stack.Clock().Now() - // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not. - // We copy headers' underlying bytes because pkt.*Header may point to - // the middle of a slice, and another struct may point to the "outer" - // slice. Save/restore doesn't support overlapping slices and will fail. - // - // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports - // overlapping slices. - var combinedVV buffer.VectorisedView - if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber { - network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View() - headers := make(buffer.View, 0, len(network)+len(transport)) - headers = append(headers, network...) - headers = append(headers, transport...) - combinedVV = headers.ToVectorisedView() - } else { - combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView() - } - combinedVV.Append(pkt.Data().ExtractVV()) - packet.data = combinedVV - packet.receivedAt = e.stack.Clock().Now() - - e.rcvList.PushBack(packet) - e.rcvBufSize += packet.data.Size() - e.rcvMu.Unlock() - e.mu.RUnlock() - e.stats.PacketsReceived.Increment() - // Notify waiters that there's data to be read. - if wasEmpty { + e.rcvList.PushBack(packet) + e.rcvBufSize += packet.data.Size() + e.stats.PacketsReceived.Increment() + + // Notify waiters that there is data to be read now. + return wasEmpty + }() + + if notifyReadableEvents { e.waiterQueue.Notify(waiter.ReadableEvents) } } -- cgit v1.2.3