diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/stack/transport_demuxer.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 117 |
2 files changed, 63 insertions, 67 deletions
diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 61cbaf688..e799f9290 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -583,9 +583,14 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb // raw endpoint first. If there are multiple raw endpoints, they all // receive the packet. eps.mu.RLock() - // Copy the list of raw endpoints so we can release eps.mu earlier. - rawEPs := make([]RawTransportEndpoint, len(eps.rawEndpoints)) - copy(rawEPs, eps.rawEndpoints) + // Copy the list of raw endpoints to avoid packet handling under lock. + var rawEPs []RawTransportEndpoint + if n := len(eps.rawEndpoints); n != 0 { + rawEPs = make([]RawTransportEndpoint, n) + if m := copy(rawEPs, eps.rawEndpoints); m != n { + panic(fmt.Sprintf("unexpected copy = %d, want %d", m, n)) + } + } eps.mu.RUnlock() for _, rawEP := range rawEPs { // Each endpoint gets its own copy of the packet for the sake @@ -593,7 +598,7 @@ func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumb rawEP.HandlePacket(pkt.Clone()) } - return len(rawEPs) > 0 + return len(rawEPs) != 0 } // deliverError attempts to deliver the given error to the appropriate transport diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index a65ea32db..fe8e9c751 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -271,83 +271,74 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, tcp if opts.More { return 0, &tcpip.ErrInvalidOptionValue{} } + payloadBytes, route, owner, err := func() ([]byte, *stack.Route, tcpip.PacketOwner, tcpip.Error) { + e.mu.RLock() + defer e.mu.RUnlock() - e.mu.RLock() - - if e.closed { - e.mu.RUnlock() - return 0, &tcpip.ErrInvalidEndpointState{} - } - - payloadBytes := make([]byte, p.Len()) - if _, err := io.ReadFull(p, payloadBytes); err != nil { - e.mu.RUnlock() - return 0, &tcpip.ErrBadBuffer{} - } + if e.closed { + return nil, nil, nil, &tcpip.ErrInvalidEndpointState{} + } - // If this is an unassociated socket and callee provided a nonzero - // destination address, route using that address. - if e.ops.GetHeaderIncluded() { - ip := header.IPv4(payloadBytes) - if !ip.IsValid(len(payloadBytes)) { - e.mu.RUnlock() - return 0, &tcpip.ErrInvalidOptionValue{} + payloadBytes := make([]byte, p.Len()) + if _, err := io.ReadFull(p, payloadBytes); err != nil { + return nil, nil, nil, &tcpip.ErrBadBuffer{} } - dstAddr := ip.DestinationAddress() - // Update dstAddr with the address in the IP header, unless - // opts.To is set (e.g. if sendto specifies a specific - // address). - if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil { - opts.To = &tcpip.FullAddress{ - NIC: 0, // NIC is unset. - Addr: dstAddr, // The address from the payload. - Port: 0, // There are no ports here. + + // If this is an unassociated socket and callee provided a nonzero + // destination address, route using that address. + if e.ops.GetHeaderIncluded() { + ip := header.IPv4(payloadBytes) + if !ip.IsValid(len(payloadBytes)) { + return nil, nil, nil, &tcpip.ErrInvalidOptionValue{} + } + dstAddr := ip.DestinationAddress() + // Update dstAddr with the address in the IP header, unless + // opts.To is set (e.g. if sendto specifies a specific + // address). + if dstAddr != tcpip.Address([]byte{0, 0, 0, 0}) && opts.To == nil { + opts.To = &tcpip.FullAddress{ + NIC: 0, // NIC is unset. + Addr: dstAddr, // The address from the payload. + Port: 0, // There are no ports here. + } } } - } - // Did the user caller provide a destination? If not, use the connected - // destination. - if opts.To == nil { - // If the user doesn't specify a destination, they should have - // connected to another address. - if !e.connected { - e.mu.RUnlock() - return 0, &tcpip.ErrDestinationRequired{} + // Did the user caller provide a destination? If not, use the connected + // destination. + if opts.To == nil { + // If the user doesn't specify a destination, they should have + // connected to another address. + if !e.connected { + return nil, nil, nil, &tcpip.ErrDestinationRequired{} + } + + e.route.Acquire() + + return payloadBytes, e.route, e.owner, nil } - owner := e.owner - route := e.route - e.mu.RUnlock() - return e.finishWrite(payloadBytes, route, owner) - } + // The caller provided a destination. Reject destination address if it + // goes through a different NIC than the endpoint was bound to. + nic := opts.To.NIC + if e.bound && nic != 0 && nic != e.BindNICID { + return nil, nil, nil, &tcpip.ErrNoRoute{} + } - // The caller provided a destination. Reject destination address if it - // goes through a different NIC than the endpoint was bound to. - nic := opts.To.NIC - if e.bound && nic != 0 && nic != e.BindNICID { - e.mu.RUnlock() - return 0, &tcpip.ErrNoRoute{} - } + // Find the route to the destination. If BindAddress is 0, + // FindRoute will choose an appropriate source address. + route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) + if err != nil { + return nil, nil, nil, err + } - // Find the route to the destination. If BindAddress is 0, - // FindRoute will choose an appropriate source address. - route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) + return payloadBytes, route, e.owner, nil + }() if err != nil { - e.mu.RUnlock() return 0, err } + defer route.Release() - owner := e.owner - e.mu.RUnlock() - n, err := e.finishWrite(payloadBytes, route, owner) - route.Release() - return n, err -} - -// finishWrite writes the payload to a route. It resolves the route if -// necessary. -func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route, owner tcpip.PacketOwner) (int64, tcpip.Error) { if e.ops.GetHeaderIncluded() { pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: buffer.View(payloadBytes).ToVectorisedView(), |