summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go158
1 files changed, 77 insertions, 81 deletions
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 4a5858bdd..264d29c7a 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -534,96 +534,92 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
// HandlePacket implements stack.RawTransportEndpoint.HandlePacket.
func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) {
- e.mu.RLock()
- e.rcvMu.Lock()
+ notifyReadableEvents := func() bool {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
+ e.rcvMu.Lock()
+ defer e.rcvMu.Unlock()
+
+ // Drop the packet if our buffer is currently full or if this is an unassociated
+ // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
+ // See: https://man7.org/linux/man-pages/man7/raw.7.html
+ //
+ // An IPPROTO_RAW socket is send only. If you really want to receive
+ // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
+ // Note that packet sockets don't reassemble IP fragments, unlike raw
+ // sockets.
+ if e.rcvClosed || !e.associated {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ClosedReceiver.Increment()
+ return false
+ }
- // Drop the packet if our buffer is currently full or if this is an unassociated
- // endpoint (i.e endpoint created w/ IPPROTO_RAW). Such endpoints are send only
- // See: https://man7.org/linux/man-pages/man7/raw.7.html
- //
- // An IPPROTO_RAW socket is send only. If you really want to receive
- // all IP packets, use a packet(7) socket with the ETH_P_IP protocol.
- // Note that packet sockets don't reassemble IP fragments, unlike raw
- // sockets.
- if e.rcvClosed || !e.associated {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stack.Stats().DroppedPackets.Increment()
- e.stats.ReceiveErrors.ClosedReceiver.Increment()
- return
- }
+ rcvBufSize := e.ops.GetReceiveBufferSize()
+ if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
+ e.stack.Stats().DroppedPackets.Increment()
+ e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
+ return false
+ }
- rcvBufSize := e.ops.GetReceiveBufferSize()
- if e.frozen || e.rcvBufSize >= int(rcvBufSize) {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stack.Stats().DroppedPackets.Increment()
- e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment()
- return
- }
+ if e.bound {
+ // If bound to a NIC, only accept data for that NIC.
+ if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
+ return false
+ }
- if e.bound {
- // If bound to a NIC, only accept data for that NIC.
- if e.BindNICID != 0 && e.BindNICID != pkt.NICID {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
+ // If bound to an address, only accept data for that address.
+ if e.BindAddr != "" && e.BindAddr != pkt.Network().DestinationAddress() {
+ return false
+ }
}
- // If bound to an address, only accept data for that address.
- if e.BindAddr != "" && e.BindAddr != pkt.Network().DestinationAddress() {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
+
+ srcAddr := pkt.Network().SourceAddress()
+ // If connected, only accept packets from the remote address we
+ // connected to.
+ if e.connected && e.route.RemoteAddress() != srcAddr {
+ return false
}
- }
- srcAddr := pkt.Network().SourceAddress()
- // If connected, only accept packets from the remote address we
- // connected to.
- if e.connected && e.route.RemoteAddress() != srcAddr {
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- return
- }
+ wasEmpty := e.rcvBufSize == 0
- wasEmpty := e.rcvBufSize == 0
+ // Push new packet into receive list and increment the buffer size.
+ packet := &rawPacket{
+ senderAddr: tcpip.FullAddress{
+ NIC: pkt.NICID,
+ Addr: srcAddr,
+ },
+ }
- // Push new packet into receive list and increment the buffer size.
- packet := &rawPacket{
- senderAddr: tcpip.FullAddress{
- NIC: pkt.NICID,
- Addr: srcAddr,
- },
- }
+ // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
+ // We copy headers' underlying bytes because pkt.*Header may point to
+ // the middle of a slice, and another struct may point to the "outer"
+ // slice. Save/restore doesn't support overlapping slices and will fail.
+ //
+ // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports
+ // overlapping slices.
+ var combinedVV buffer.VectorisedView
+ if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
+ network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
+ headers := make(buffer.View, 0, len(network)+len(transport))
+ headers = append(headers, network...)
+ headers = append(headers, transport...)
+ combinedVV = headers.ToVectorisedView()
+ } else {
+ combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
+ }
+ combinedVV.Append(pkt.Data().ExtractVV())
+ packet.data = combinedVV
+ packet.receivedAt = e.stack.Clock().Now()
- // Raw IPv4 endpoints return the IP header, but IPv6 endpoints do not.
- // We copy headers' underlying bytes because pkt.*Header may point to
- // the middle of a slice, and another struct may point to the "outer"
- // slice. Save/restore doesn't support overlapping slices and will fail.
- //
- // TODO(https://gvisor.dev/issue/6517): Avoid the copy once S/R supports
- // overlapping slices.
- var combinedVV buffer.VectorisedView
- if e.TransportEndpointInfo.NetProto == header.IPv4ProtocolNumber {
- network, transport := pkt.NetworkHeader().View(), pkt.TransportHeader().View()
- headers := make(buffer.View, 0, len(network)+len(transport))
- headers = append(headers, network...)
- headers = append(headers, transport...)
- combinedVV = headers.ToVectorisedView()
- } else {
- combinedVV = append(buffer.View(nil), pkt.TransportHeader().View()...).ToVectorisedView()
- }
- combinedVV.Append(pkt.Data().ExtractVV())
- packet.data = combinedVV
- packet.receivedAt = e.stack.Clock().Now()
-
- e.rcvList.PushBack(packet)
- e.rcvBufSize += packet.data.Size()
- e.rcvMu.Unlock()
- e.mu.RUnlock()
- e.stats.PacketsReceived.Increment()
- // Notify waiters that there's data to be read.
- if wasEmpty {
+ e.rcvList.PushBack(packet)
+ e.rcvBufSize += packet.data.Size()
+ e.stats.PacketsReceived.Increment()
+
+ // Notify waiters that there is data to be read now.
+ return wasEmpty
+ }()
+
+ if notifyReadableEvents {
e.waiterQueue.Notify(waiter.ReadableEvents)
}
}