diff options
author | gVisor bot <gvisor-bot@google.com> | 2020-12-28 22:05:49 +0000 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2020-12-28 22:05:49 +0000 |
commit | 5c21c7c3bd1552f4d5f87ef588fc213e2a2278ef (patch) | |
tree | b62b3f2c71f46e145c15d7740262f7d59c91c87f /pkg/tcpip/transport/raw | |
parent | b0f23fb7e0cf908622bc6b8c90e2819de6de6ccb (diff) | |
parent | 3ff7324dfa7c096a50b628189d5c3f2d4d5ec2f6 (diff) |
Merge release-20201208.0-89-g3ff7324df (automated)
Diffstat (limited to 'pkg/tcpip/transport/raw')
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 44 |
1 files changed, 12 insertions, 32 deletions
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index eee3f11c1..7befcfc9b 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -261,15 +261,14 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } e.mu.RLock() + defer e.mu.RUnlock() if e.closed { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidEndpointState } payloadBytes, err := p.FullPayload() if err != nil { - e.mu.RUnlock() return 0, nil, err } @@ -278,7 +277,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c if e.ops.GetHeaderIncluded() { ip := header.IPv4(payloadBytes) if !ip.IsValid(len(payloadBytes)) { - e.mu.RUnlock() return 0, nil, tcpip.ErrInvalidOptionValue } dstAddr := ip.DestinationAddress() @@ -300,39 +298,16 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // If the user doesn't specify a destination, they should have // connected to another address. if !e.connected { - e.mu.RUnlock() return 0, nil, tcpip.ErrDestinationRequired } - if e.route.IsResolutionRequired() { - savedRoute := e.route - // Promote lock to exclusive if using a shared route, - // given that it may need to change in finishWrite. - e.mu.RUnlock() - e.mu.Lock() - - // Make sure that the route didn't change during the - // time we didn't hold the lock. - if !e.connected || savedRoute != e.route { - e.mu.Unlock() - return 0, nil, tcpip.ErrInvalidEndpointState - } - - n, ch, err := e.finishWrite(payloadBytes, savedRoute) - e.mu.Unlock() - return n, ch, err - } - - n, ch, err := e.finishWrite(payloadBytes, e.route) - e.mu.RUnlock() - return n, ch, err + return e.finishWrite(payloadBytes, e.route) } // The caller provided a destination. Reject destination address if it // goes through a different NIC than the endpoint was bound to. nic := opts.To.NIC if e.bound && nic != 0 && nic != e.BindNICID { - e.mu.RUnlock() return 0, nil, tcpip.ErrNoRoute } @@ -340,13 +315,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c // FindRoute will choose an appropriate source address. route, err := e.stack.FindRoute(nic, e.BindAddr, opts.To.Addr, e.NetProto, false) if err != nil { - e.mu.RUnlock() return 0, nil, err } n, ch, err := e.finishWrite(payloadBytes, route) route.Release() - e.mu.RUnlock() return n, ch, err } @@ -404,7 +377,7 @@ func (*endpoint) Disconnect() *tcpip.Error { func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { // Raw sockets do not support connecting to a IPv4 address on a IPv6 endpoint. if e.TransportEndpointInfo.NetProto == header.IPv6ProtocolNumber && len(addr.Addr) != header.IPv6AddressSize { - return tcpip.ErrInvalidOptionValue + return tcpip.ErrAddressFamilyNotSupported } e.mu.Lock() @@ -435,11 +408,11 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { if err != nil { return err } - defer route.Release() if e.associated { // Re-register the endpoint with the appropriate NIC. if err := e.stack.RegisterRawTransportEndpoint(addr.NIC, e.NetProto, e.TransProto, e); err != nil { + route.Release() return err } e.stack.UnregisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e) @@ -447,7 +420,7 @@ func (e *endpoint) Connect(addr tcpip.FullAddress) *tcpip.Error { } // Save the route we've connected via. - e.route = route.Clone() + e.route = route e.connected = true return nil @@ -620,6 +593,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { + e.mu.RLock() e.rcvMu.Lock() // Drop the packet if our buffer is currently full or if this is an unassociated @@ -632,6 +606,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // sockets. if e.rcvClosed || !e.associated { e.rcvMu.Unlock() + e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ClosedReceiver.Increment() return @@ -639,6 +614,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if e.rcvBufSize >= e.rcvBufSizeMax { e.rcvMu.Unlock() + e.mu.RUnlock() e.stack.Stats().DroppedPackets.Increment() e.stats.ReceiveErrors.ReceiveBufferOverflow.Increment() return @@ -650,11 +626,13 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // If bound to a NIC, only accept data for that NIC. if e.BindNICID != 0 && e.BindNICID != pkt.NICID { e.rcvMu.Unlock() + e.mu.RUnlock() return } // If bound to an address, only accept data for that address. if e.BindAddr != "" && e.BindAddr != remoteAddr { e.rcvMu.Unlock() + e.mu.RUnlock() return } } @@ -663,6 +641,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { // connected to. if e.connected && e.route.RemoteAddress != remoteAddr { e.rcvMu.Unlock() + e.mu.RUnlock() return } @@ -697,6 +676,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.rcvList.PushBack(packet) e.rcvBufSize += packet.data.Size() e.rcvMu.Unlock() + e.mu.RUnlock() e.stats.PacketsReceived.Increment() // Notify waiters that there's data to be read. if wasEmpty { |