summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/tcpip/stack/transport_demuxer.go13
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go117
2 files changed, 63 insertions, 67 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go
index 61cbaf688..e799f9290 100644
--- a/pkg/tcpip/stack/transport_demuxer.go
+++ b/pkg/tcpip/stack/transport_demuxer.go
@@ -583,9 +583,14 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb
// raw endpoint first. If there are multiple raw endpoints, they all
// receive the packet.
eps.mu.RLock()
- // Copy the list of raw endpoints so we can release eps.mu earlier.
- rawEPs := make([]RawTransportEndpoint, len(eps.rawEndpoints))
- copy(rawEPs, eps.rawEndpoints)
+ // Copy the list of raw endpoints to avoid packet handling under lock.
+ var rawEPs []RawTransportEndpoint
+ if n := len(eps.rawEndpoints); n != 0 {
+ rawEPs = make([]RawTransportEndpoint, n)
+ if m := copy(rawEPs, eps.rawEndpoints); m != n {
+ panic(fmt.Sprintf("unexpected copy = %d, want %d", m, n))
+ }
+ }
eps.mu.RUnlock()
for _, rawEP := range rawEPs {
// Each endpoint gets its own copy of the packet for the sake
@@ -593,7 +598,7 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb
rawEP.HandlePacket(pkt.Clone())
}
- return len(rawEPs) > 0
+ return len(rawEPs) != 0
}
// deliverError attempts to deliver the given error to the appropriate transport
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index a65ea32db..fe8e9c751 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -271,83 +271,74 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp
if opts.More {
return 0, &tcpip.ErrInvalidOptionValue{}
}
+ payloadBytes, route, owner, err := func() ([]byte, *stack.Route, tcpip.PacketOwner, tcpip.Error) {
+ e.mu.RLock()
+ defer e.mu.RUnlock()
- e.mu.RLock()
-
- if e.closed {
- e.mu.RUnlock()
- return 0, &tcpip.ErrInvalidEndpointState{}
- }
-
- payloadBytes := make([]byte, p.Len())
- if _, err := io.ReadFull(p, payloadBytes); err != nil {
- e.mu.RUnlock()
- return 0, &tcpip.ErrBadBuffer{}
- }
+ if e.closed {
+ return nil, nil, nil, &tcpip.ErrInvalidEndpointState{}
+ }
- // If this is an unassociated socket and callee provided a nonzero
- // destination address, route using that address.
- if e.ops.GetHeaderIncluded() {
- ip := header.IPv4(payloadBytes)
- if !ip.IsValid(len(payloadBytes)) {
- e.mu.RUnlock()
- return 0, &tcpip.ErrInvalidOptionValue{}
+ payloadBytes := make([]byte, p.Len())
+ if _, err := io.ReadFull(p, payloadBytes); err != nil {
+ return nil, nil, nil, &tcpip.ErrBadBuffer{}
}
- dstAddr := ip.DestinationAddress()
- // Update dstAddr with the address in the IP header, unless
- // opts.To is set (e.g. if sendto specifies a specific
- // address).
- if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil {
- opts.To = &tcpip.FullAddress{
- NIC: 0, // NIC is unset.
- Addr: dstAddr, // The address from the payload.
- Port: 0, // There are no ports here.
+
+ // If this is an unassociated socket and callee provided a nonzero
+ // destination address, route using that address.
+ if e.ops.GetHeaderIncluded() {
+ ip := header.IPv4(payloadBytes)
+ if !ip.IsValid(len(payloadBytes)) {
+ return nil, nil, nil, &tcpip.ErrInvalidOptionValue{}
+ }
+ dstAddr := ip.DestinationAddress()
+ // Update dstAddr with the address in the IP header, unless
+ // opts.To is set (e.g. if sendto specifies a specific
+ // address).
+ if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil {
+ opts.To = &tcpip.FullAddress{
+ NIC: 0, // NIC is unset.
+ Addr: dstAddr, // The address from the payload.
+ Port: 0, // There are no ports here.
+ }
}
}
- }
- // Did the user caller provide a destination? If not, use the connected
- // destination.
- if opts.To == nil {
- // If the user doesn't specify a destination, they should have
- // connected to another address.
- if !e.connected {
- e.mu.RUnlock()
- return 0, &tcpip.ErrDestinationRequired{}
+ // Did the user caller provide a destination? If not, use the connected
+ // destination.
+ if opts.To == nil {
+ // If the user doesn't specify a destination, they should have
+ // connected to another address.
+ if !e.connected {
+ return nil, nil, nil, &tcpip.ErrDestinationRequired{}
+ }
+
+ e.route.Acquire()
+
+ return payloadBytes, e.route, e.owner, nil
}
- owner := e.owner
- route := e.route
- e.mu.RUnlock()
- return e.finishWrite(payloadBytes, route, owner)
- }
+ // The caller provided a destination. Reject destination address if it
+ // goes through a different NIC than the endpoint was bound to.
+ nic := opts.To.NIC
+ if e.bound && nic != 0 && nic != e.BindNICID {
+ return nil, nil, nil, &tcpip.ErrNoRoute{}
+ }
- // The caller provided a destination. Reject destination address if it
- // goes through a different NIC than the endpoint was bound to.
- nic := opts.To.NIC
- if e.bound && nic != 0 && nic != e.BindNICID {
- e.mu.RUnlock()
- return 0, &tcpip.ErrNoRoute{}
- }
+ // Find the route to the destination. If BindAddress is 0,
+ // FindRoute will choose an appropriate source address.
+ route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
+ if err != nil {
+ return nil, nil, nil, err
+ }
- // Find the route to the destination. If BindAddress is 0,
- // FindRoute will choose an appropriate source address.
- route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false)
+ return payloadBytes, route, e.owner, nil
+ }()
if err != nil {
- e.mu.RUnlock()
return 0, err
}
+ defer route.Release()
- owner := e.owner
- e.mu.RUnlock()
- n, err := e.finishWrite(payloadBytes, route, owner)
- route.Release()
- return n, err
-}
-
-// finishWrite writes the payload to a route. It resolves the route if
-// necessary.
-func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route, owner tcpip.PacketOwner) (int64, tcpip.Error) {
if e.ops.GetHeaderIncluded() {
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
Data: buffer.View(payloadBytes).ToVectorisedView(),