diff options
-rw-r--r-- | pkg/sentry/socket/unix/transport/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectioned.go | 8 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectionless.go | 6 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/queue.go | 10 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/unix.go | 12 |
5 files changed, 32 insertions, 5 deletions
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 28038ce7f..5bc01e3c8 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -29,6 +29,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//pkg/ilist", + "//pkg/refs", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/waiter", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 4c913effc..83b50459f 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -145,10 +145,12 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { b.receiver = &queueReceiver{q2} } + q2.IncRef() a.connected = &connectedEndpoint{ endpoint: b, writeQueue: q2, } + q1.IncRef() b.connected = &connectedEndpoint{ endpoint: a, writeQueue: q1, @@ -282,12 +284,14 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur idGenerator: e.idGenerator, stype: e.stype, } + readQueue := newQueue(ce.WaiterQueue(), ne.Queue, initialLimit) - writeQueue := newQueue(ne.Queue, ce.WaiterQueue(), initialLimit) ne.connected = &connectedEndpoint{ endpoint: ce, writeQueue: readQueue, } + + writeQueue := newQueue(ne.Queue, ce.WaiterQueue(), initialLimit) if e.stype == SockStream { ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} } else { @@ -297,10 +301,12 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur select { case e.acceptedChan <- ne: // Commit state. + writeQueue.IncRef() connected := &connectedEndpoint{ endpoint: ne, writeQueue: writeQueue, } + readQueue.IncRef() if e.stype == SockStream { returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected) } else { diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index cd4633106..376e4abb2 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -82,9 +82,13 @@ func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *tc if r == nil { return nil, tcpip.ErrConnectionRefused } + q := r.(*queueReceiver).readQueue + if !q.TryIncRef() { + return nil, tcpip.ErrConnectionRefused + } return &connectedEndpoint{ endpoint: e, - writeQueue: r.(*queueReceiver).readQueue, + writeQueue: q, }, nil } diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index 5b4dfab68..72aa409ab 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -17,6 +17,7 @@ package transport import ( "sync" + "gvisor.googlesource.com/gvisor/pkg/refs" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -25,6 +26,8 @@ import ( // // +stateify savable type queue struct { + refs.AtomicRefCount + ReaderQueue *waiter.Queue WriterQueue *waiter.Queue @@ -67,6 +70,13 @@ func (q *queue) Reset() { q.mu.Unlock() } +// DecRef implements RefCounter.DecRef with destructor q.Reset. +func (q *queue) DecRef() { + q.DecRefWithDestructor(q.Reset) + // We don't need to notify after resetting because no one cares about + // this queue after all references have been dropped. +} + // IsReadable determines if q is currently readable. func (q *queue) IsReadable() bool { q.mu.Lock() diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 157133b65..765cca27a 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -381,7 +381,9 @@ func (q *queueReceiver) RecvMaxQueueSize() int64 { } // Release implements Receiver.Release. -func (*queueReceiver) Release() {} +func (q *queueReceiver) Release() { + q.readQueue.DecRef() +} // streamQueueReceiver implements Receiver for stream sockets. // @@ -694,7 +696,9 @@ func (e *connectedEndpoint) SendMaxQueueSize() int64 { } // Release implements ConnectedEndpoint.Release. -func (*connectedEndpoint) Release() {} +func (e *connectedEndpoint) Release() { + e.writeQueue.DecRef() +} // baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless // unix domain socket Endpoint implementations. @@ -945,4 +949,6 @@ func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } // Release implements BoundEndpoint.Release. -func (*baseEndpoint) Release() {} +func (*baseEndpoint) Release() { + // Binding a baseEndpoint doesn't take a reference. +} |