summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/unix
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/unix')
-rw-r--r--pkg/sentry/socket/unix/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go9
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go5
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go10
-rw-r--r--pkg/sentry/socket/unix/unix.go11
5 files changed, 22 insertions, 14 deletions
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index 5c3cdef6a..7b546c04d 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -62,7 +62,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/syserr",
- "//pkg/syserror",
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 33f9aeb06..b3f0cf563 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -129,9 +129,9 @@ func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProv
stype: stype,
}
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
- ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
return ep
}
@@ -406,14 +406,15 @@ func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error {
// Accept accepts a new connection.
func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) {
e.Lock()
- defer e.Unlock()
if !e.Listening() {
+ e.Unlock()
return nil, syserr.ErrInvalidEndpointState
}
select {
case ne := <-e.acceptedChan:
+ e.Unlock()
if peerAddr != nil {
ne.Lock()
c := ne.connected
@@ -429,6 +430,7 @@ func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *s
return ne, nil
default:
+ e.Unlock()
// Nothing left.
return nil, syserr.ErrWouldBlock
}
@@ -517,3 +519,6 @@ func (e *connectionedEndpoint) OnSetSendBufferSize(v int64) (newSz int64) {
}
return v
}
+
+// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters.
+func (e *connectionedEndpoint) WakeupWriters() {}
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 61338728a..61311718e 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -44,9 +44,9 @@ func NewConnectionless(ctx context.Context) Endpoint {
q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: defaultBufferSize}
q.InitRefs()
ep.receiver = &queueReceiver{readQueue: &q}
+ ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
ep.ops.SetSendBufferSize(defaultBufferSize, false /* notify */)
ep.ops.SetReceiveBufferSize(defaultBufferSize, false /* notify */)
- ep.ops.InitHandler(ep, &stackHandler{}, getSendBufferLimits, getReceiveBufferLimits)
return ep
}
@@ -227,3 +227,6 @@ func (e *connectionlessEndpoint) OnSetSendBufferSize(v int64) (newSz int64) {
}
return v
}
+
+// WakeupWriters implements tcpip.SocketOptionsHandler.WakeupWriters.
+func (e *connectionlessEndpoint) WakeupWriters() {}
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index e4de44498..188ad3bd9 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -59,12 +59,14 @@ func (q *queue) Close() {
// q.WriterQueue.Notify(waiter.WritableEvents)
func (q *queue) Reset(ctx context.Context) {
q.mu.Lock()
- for cur := q.dataList.Front(); cur != nil; cur = cur.Next() {
- cur.Release(ctx)
- }
+ dataList := q.dataList
q.dataList.Reset()
q.used = 0
q.mu.Unlock()
+
+ for cur := dataList.Front(); cur != nil; cur = cur.Next() {
+ cur.Release(ctx)
+ }
}
// DecRef implements RefCounter.DecRef.
@@ -133,7 +135,7 @@ func (q *queue) Enqueue(ctx context.Context, data [][]byte, c ControlMessages, f
free := q.limit - q.used
if l > free && truncate {
- if free == 0 {
+ if free <= 0 {
// Message can't fit right now.
q.mu.Unlock()
return 0, false, syserr.ErrWouldBlock
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 8ccdadae9..e9e482017 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -38,7 +38,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
@@ -494,7 +493,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
}
n, err := src.CopyInTo(t, &w)
- if err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ if err != linuxerr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
return int(n), syserr.FromError(err)
}
@@ -514,13 +513,13 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b
n, err = src.CopyInTo(t, &w)
total += n
- if err != syserror.ErrWouldBlock {
+ if err != linuxerr.ErrWouldBlock {
break
}
if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
if linuxerr.Equals(linuxerr.ETIMEDOUT, err) {
- err = syserror.ErrWouldBlock
+ err = linuxerr.ErrWouldBlock
}
break
}
@@ -648,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
var total int64
- if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait {
+ if n, err := doRead(); err != linuxerr.ErrWouldBlock || dontWait {
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
@@ -683,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
defer s.EventUnregister(&e)
for {
- if n, err := doRead(); err != syserror.ErrWouldBlock {
+ if n, err := doRead(); err != linuxerr.ErrWouldBlock {
var from linux.SockAddr
var fromLen uint32
if r.From != nil {