summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go11
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go25
2 files changed, 24 insertions, 12 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 292e51d20..61cbaf688 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -582,17 +582,18 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb
// As in net/ipv4/ip_input.c:ip_local_deliver, attempt to deliver via
// raw endpoint first. If there are multiple raw endpoints, they all
// receive the packet.
- foundRaw := false
eps.mu.RLock()
- for _, rawEP := range eps.rawEndpoints {
+ // Copy the list of raw endpoints so we can release eps.mu earlier.
+ rawEPs := make([]RawTransportEndpoint, len(eps.rawEndpoints))
+ copy(rawEPs, eps.rawEndpoints)
+ eps.mu.RUnlock()
+ for _, rawEP := range rawEPs {
// Each endpoint gets its own copy of the packet for the sake
// of save/restore.
rawEP.HandlePacket(pkt.Clone())
- foundRaw = true
}
- eps.mu.RUnlock()
- return foundRaw
+ return len(rawEPs) > 0
}
// deliverError attempts to deliver the given error to the appropriate transport
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 9c9ccc0ff..a65ea32db 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -84,7 +84,6 @@ type endpoint struct {
// Connect(), and is valid only when conneted is true.
route *stack.Route `state:"manual"`
stats tcpip.TransportEndpointStats `state:"nosave"`
-
// owner is used to get uid and gid of the packet.
owner tcpip.PacketOwner
@@ -185,6 +184,8 @@ func (e *endpoint) Close() {
func (e *endpoint) ModerateRecvBuf(copied int) {}
func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
+ e.mu.Lock()
+ defer e.mu.Unlock()
e.owner = owner
}
@@ -272,14 +273,15 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
}
e.mu.RLock()
- defer e.mu.RUnlock()
if e.closed {
+ e.mu.RUnlock()
return 0, &tcpip.ErrInvalidEndpointState{}
}
payloadBytes := make([]byte, p.Len())
if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ e.mu.RUnlock()
return 0, &tcpip.ErrBadBuffer{}
}
@@ -288,6 +290,7 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
if e.ops.GetHeaderIncluded() {
ip := header.IPv4(payloadBytes)
if !ip.IsValid(len(payloadBytes)) {
+ e.mu.RUnlock()
return 0, &tcpip.ErrInvalidOptionValue{}
}
dstAddr := ip.DestinationAddress()
@@ -309,16 +312,21 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
// If the user doesn't specify a destination, they should have
// connected to another address.
if !e.connected {
+ e.mu.RUnlock()
return 0, &tcpip.ErrDestinationRequired{}
}
- return e.finishWrite(payloadBytes, e.route)
+ owner := e.owner
+ route := e.route
+ e.mu.RUnlock()
+ return e.finishWrite(payloadBytes, route, owner)
}
// The caller provided a destination. Reject destination address if it
// goes through a different NIC than the endpoint was bound to.
nic := opts.To.NIC
if e.bound && nic != 0 && nic != e.BindNICID {
+ e.mu.RUnlock()
return 0, &tcpip.ErrNoRoute{}
}
@@ -326,17 +334,20 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
// FindRoute will choose an appropriate source address.
route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
if err != nil {
+ e.mu.RUnlock()
return 0, err
}
- n, err := e.finishWrite(payloadBytes, route)
+ owner := e.owner
+ e.mu.RUnlock()
+ n, err := e.finishWrite(payloadBytes, route, owner)
route.Release()
return n, err
}
// finishWrite writes the payload to a route. It resolves the route if
-// necessary. It's really just a helper to make defer unnecessary in Write.
-func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, tcpip.Error) {
+// necessary.
+func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route, owner tcpip.PacketOwner) (int64, tcpip.Error) {
if e.ops.GetHeaderIncluded() {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(payloadBytes).ToVectorisedView(),
@@ -349,7 +360,7 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
ReserveHeaderBytes: int(route.MaxHeaderLength()),
Data: buffer.View(payloadBytes).ToVectorisedView(),
})
- pkt.Owner = e.owner
+ pkt.Owner = owner
if err := route.WritePacket(nil /* gso */, stack.NetworkHeaderParams{
Protocol: e.TransProto,
TTL: route.DefaultTTL(),