summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go8
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go6
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go10
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go12
5 files changed, 32 insertions, 5 deletions
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 28038ce7f..5bc01e3c8 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -29,6 +29,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//pkg/ilist",
+ "//pkg/refs",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/waiter",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 4c913effc..83b50459f 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -145,10 +145,12 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
b.receiver = &queueReceiver{q2}
}
+ q2.IncRef()
a.connected = &connectedEndpoint{
endpoint: b,
writeQueue: q2,
}
+ q1.IncRef()
b.connected = &connectedEndpoint{
endpoint: a,
writeQueue: q1,
@@ -282,12 +284,14 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur
idGenerator: e.idGenerator,
stype: e.stype,
}
+
readQueue := newQueue(ce.WaiterQueue(), ne.Queue, initialLimit)
- writeQueue := newQueue(ne.Queue, ce.WaiterQueue(), initialLimit)
ne.connected = &connectedEndpoint{
endpoint: ce,
writeQueue: readQueue,
}
+
+ writeQueue := newQueue(ne.Queue, ce.WaiterQueue(), initialLimit)
if e.stype == SockStream {
ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
} else {
@@ -297,10 +301,12 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur
select {
case e.acceptedChan <- ne:
// Commit state.
+ writeQueue.IncRef()
connected := &connectedEndpoint{
endpoint: ne,
writeQueue: writeQueue,
}
+ readQueue.IncRef()
if e.stype == SockStream {
returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected)
} else {
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index cd4633106..376e4abb2 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -82,9 +82,13 @@ func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *tc
if r == nil {
return nil, tcpip.ErrConnectionRefused
}
+ q := r.(*queueReceiver).readQueue
+ if !q.TryIncRef() {
+ return nil, tcpip.ErrConnectionRefused
+ }
return &connectedEndpoint{
endpoint: e,
- writeQueue: r.(*queueReceiver).readQueue,
+ writeQueue: q,
}, nil
}
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index 5b4dfab68..72aa409ab 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -17,6 +17,7 @@ package transport
import (
"sync"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/waiter"
)
@@ -25,6 +26,8 @@ import (
//
// +stateify savable
type queue struct {
+ refs.AtomicRefCount
+
ReaderQueue *waiter.Queue
WriterQueue *waiter.Queue
@@ -67,6 +70,13 @@ func (q *queue) Reset() {
q.mu.Unlock()
}
+// DecRef implements RefCounter.DecRef with destructor q.Reset.
+func (q *queue) DecRef() {
+ q.DecRefWithDestructor(q.Reset)
+ // We don't need to notify after resetting because no one cares about
+ // this queue after all references have been dropped.
+}
+
// IsReadable determines if q is currently readable.
func (q *queue) IsReadable() bool {
q.mu.Lock()
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 157133b65..765cca27a 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -381,7 +381,9 @@ func (q *queueReceiver) RecvMaxQueueSize() int64 {
}
// Release implements Receiver.Release.
-func (*queueReceiver) Release() {}
+func (q *queueReceiver) Release() {
+ q.readQueue.DecRef()
+}
// streamQueueReceiver implements Receiver for stream sockets.
//
@@ -694,7 +696,9 @@ func (e *connectedEndpoint) SendMaxQueueSize() int64 {
}
// Release implements ConnectedEndpoint.Release.
-func (*connectedEndpoint) Release() {}
+func (e *connectedEndpoint) Release() {
+ e.writeQueue.DecRef()
+}
// baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless
// unix domain socket Endpoint implementations.
@@ -945,4 +949,6 @@ func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
}
// Release implements BoundEndpoint.Release.
-func (*baseEndpoint) Release() {}
+func (*baseEndpoint) Release() {
+ // Binding a baseEndpoint doesn't take a reference.
+}