summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/unix
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/unix')
-rw-r--r--pkg/sentry/socket/unix/unix.go65
1 files changed, 39 insertions, 26 deletions
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 56ed63e21..b07e8d67b 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -17,6 +17,7 @@
package unix
import (
+ "fmt"
"strings"
"syscall"
@@ -55,22 +56,22 @@ type SocketOperations struct {
refs.AtomicRefCount
socket.SendReceiveTimeout
- ep transport.Endpoint
- isPacket bool
+ ep transport.Endpoint
+ stype linux.SockType
}
// New creates a new unix socket.
-func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File {
+func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File {
dirent := socket.NewDirent(ctx, unixSocketDevice)
defer dirent.DecRef()
- return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true})
+ return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true})
}
// NewWithDirent creates a new unix socket using an existing dirent.
-func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File {
+func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File {
return fs.NewFile(ctx, d, flags, &SocketOperations{
- ep: ep,
- isPacket: isPacket,
+ ep: ep,
+ stype: stype,
})
}
@@ -88,6 +89,18 @@ func (s *SocketOperations) Release() {
s.DecRef()
}
+func (s *SocketOperations) isPacket() bool {
+ switch s.stype {
+ case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ return true
+ case linux.SOCK_STREAM:
+ return false
+ default:
+ // We shouldn't have allowed any other socket types during creation.
+ panic(fmt.Sprintf("Invalid socket type %d", s.stype))
+ }
+}
+
// Endpoint extracts the transport.Endpoint.
func (s *SocketOperations) Endpoint() transport.Endpoint {
return s.ep
@@ -193,7 +206,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- ns := New(t, ep, s.isPacket)
+ ns := New(t, ep, s.stype)
defer ns.DecRef()
if flags&linux.SOCK_NONBLOCK != 0 {
@@ -221,7 +234,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
return 0, nil, 0, syserr.FromError(e)
}
- t.Kernel().RecordSocket(ns, linux.AF_UNIX)
+ t.Kernel().RecordSocket(ns)
return fd, addr, addrLen, nil
}
@@ -487,6 +500,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
waitAll := flags&linux.MSG_WAITALL != 0
+ isPacket := s.isPacket()
// Calculate the number of FDs for which we have space and if we are
// requesting credentials.
@@ -528,8 +542,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
msgFlags |= linux.MSG_CTRUNC
}
- if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
- if s.isPacket && n < int64(r.MsgSize) {
+ if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() {
+ if isPacket && n < int64(r.MsgSize) {
msgFlags |= linux.MSG_TRUNC
}
@@ -570,11 +584,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
total += n
}
- if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
+ if err != nil || !waitAll || isPacket || n >= dst.NumBytes() {
if total > 0 {
err = nil
}
- if s.isPacket && n < int64(r.MsgSize) {
+ if isPacket && n < int64(r.MsgSize) {
msgFlags |= linux.MSG_TRUNC
}
return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
@@ -601,6 +615,12 @@ func (s *SocketOperations) State() uint32 {
return s.ep.State()
}
+// Type implements socket.Socket.Type.
+func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ // Unix domain sockets always have a protocol of 0.
+ return linux.AF_UNIX, s.stype, 0
+}
+
// provider is a unix domain socket provider.
type provider struct{}
@@ -613,21 +633,16 @@ func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs
// Create the endpoint and socket.
var ep transport.Endpoint
- var isPacket bool
switch stype {
case linux.SOCK_DGRAM:
- isPacket = true
ep = transport.NewConnectionless()
- case linux.SOCK_SEQPACKET:
- isPacket = true
- fallthrough
- case linux.SOCK_STREAM:
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
ep = transport.NewConnectioned(stype, t.Kernel())
default:
return nil, syserr.ErrInvalidArgument
}
- return New(t, ep, isPacket), nil
+ return New(t, ep, stype), nil
}
// Pair creates a new pair of AF_UNIX connected sockets.
@@ -637,19 +652,17 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F
return nil, nil, syserr.ErrProtocolNotSupported
}
- var isPacket bool
switch stype {
- case linux.SOCK_STREAM:
- case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
- isPacket = true
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ // Ok
default:
return nil, nil, syserr.ErrInvalidArgument
}
// Create the endpoints and sockets.
ep1, ep2 := transport.NewPair(stype, t.Kernel())
- s1 := New(t, ep1, isPacket)
- s2 := New(t, ep2, isPacket)
+ s1 := New(t, ep1, stype)
+ s2 := New(t, ep2, stype)
return s1, s2, nil
}