summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/unix
diff options
context:
space:
mode:
authorRahat Mahmood <rahat@google.com>2019-06-10 15:16:42 -0700
committerShentubot <shentubot@google.com>2019-06-10 15:17:43 -0700
commita00157cc0e216a9829f2659ce35c856a22aa5ba2 (patch)
treedda4bdb3e03b719661fc263a91ae8aa4e46b5ae3 /pkg/sentry/socket/unix
parent48961d27a8bcc76b3783a7cc4a4a5ebcd5532d25 (diff)
Store more information in the kernel socket table.
Store enough information in the kernel socket table to distinguish between different types of sockets. Previously we were only storing the socket family, but this isn't enough to classify sockets. For example, TCPv4 and UDPv4 sockets are both AF_INET, and ICMP sockets are SOCK_DGRAM sockets with a particular protocol. Instead of creating more sub-tables, flatten the socket table and provide a filtering mechanism based on the socket entry. Also generate and store a socket entry index ("sl" in linux) which allows us to output entries in a stable order from procfs. PiperOrigin-RevId: 252495895
Diffstat (limited to 'pkg/sentry/socket/unix')
-rw-r--r--pkg/sentry/socket/unix/unix.go65
1 files changed, 39 insertions, 26 deletions
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 56ed63e21..b07e8d67b 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -17,6 +17,7 @@
package unix
import (
+ "fmt"
"strings"
"syscall"
@@ -55,22 +56,22 @@ type SocketOperations struct {
refs.AtomicRefCount
socket.SendReceiveTimeout
- ep transport.Endpoint
- isPacket bool
+ ep transport.Endpoint
+ stype linux.SockType
}
// New creates a new unix socket.
-func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File {
+func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File {
dirent := socket.NewDirent(ctx, unixSocketDevice)
defer dirent.DecRef()
- return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true})
+ return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true})
}
// NewWithDirent creates a new unix socket using an existing dirent.
-func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File {
+func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File {
return fs.NewFile(ctx, d, flags, &SocketOperations{
- ep: ep,
- isPacket: isPacket,
+ ep: ep,
+ stype: stype,
})
}
@@ -88,6 +89,18 @@ func (s *SocketOperations) Release() {
s.DecRef()
}
+func (s *SocketOperations) isPacket() bool {
+ switch s.stype {
+ case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ return true
+ case linux.SOCK_STREAM:
+ return false
+ default:
+ // We shouldn't have allowed any other socket types during creation.
+ panic(fmt.Sprintf("Invalid socket type %d", s.stype))
+ }
+}
+
// Endpoint extracts the transport.Endpoint.
func (s *SocketOperations) Endpoint() transport.Endpoint {
return s.ep
@@ -193,7 +206,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- ns := New(t, ep, s.isPacket)
+ ns := New(t, ep, s.stype)
defer ns.DecRef()
if flags&linux.SOCK_NONBLOCK != 0 {
@@ -221,7 +234,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
return 0, nil, 0, syserr.FromError(e)
}
- t.Kernel().RecordSocket(ns, linux.AF_UNIX)
+ t.Kernel().RecordSocket(ns)
return fd, addr, addrLen, nil
}
@@ -487,6 +500,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
waitAll := flags&linux.MSG_WAITALL != 0
+ isPacket := s.isPacket()
// Calculate the number of FDs for which we have space and if we are
// requesting credentials.
@@ -528,8 +542,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
msgFlags |= linux.MSG_CTRUNC
}
- if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
- if s.isPacket && n < int64(r.MsgSize) {
+ if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() {
+ if isPacket && n < int64(r.MsgSize) {
msgFlags |= linux.MSG_TRUNC
}
@@ -570,11 +584,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
total += n
}
- if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
+ if err != nil || !waitAll || isPacket || n >= dst.NumBytes() {
if total > 0 {
err = nil
}
- if s.isPacket && n < int64(r.MsgSize) {
+ if isPacket && n < int64(r.MsgSize) {
msgFlags |= linux.MSG_TRUNC
}
return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
@@ -601,6 +615,12 @@ func (s *SocketOperations) State() uint32 {
return s.ep.State()
}
+// Type implements socket.Socket.Type.
+func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ // Unix domain sockets always have a protocol of 0.
+ return linux.AF_UNIX, s.stype, 0
+}
+
// provider is a unix domain socket provider.
type provider struct{}
@@ -613,21 +633,16 @@ func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs
// Create the endpoint and socket.
var ep transport.Endpoint
- var isPacket bool
switch stype {
case linux.SOCK_DGRAM:
- isPacket = true
ep = transport.NewConnectionless()
- case linux.SOCK_SEQPACKET:
- isPacket = true
- fallthrough
- case linux.SOCK_STREAM:
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
ep = transport.NewConnectioned(stype, t.Kernel())
default:
return nil, syserr.ErrInvalidArgument
}
- return New(t, ep, isPacket), nil
+ return New(t, ep, stype), nil
}
// Pair creates a new pair of AF_UNIX connected sockets.
@@ -637,19 +652,17 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F
return nil, nil, syserr.ErrProtocolNotSupported
}
- var isPacket bool
switch stype {
- case linux.SOCK_STREAM:
- case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
- isPacket = true
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ // Ok
default:
return nil, nil, syserr.ErrInvalidArgument
}
// Create the endpoints and sockets.
ep1, ep2 := transport.NewPair(stype, t.Kernel())
- s1 := New(t, ep1, isPacket)
- s2 := New(t, ep2, isPacket)
+ s1 := New(t, ep1, stype)
+ s2 := New(t, ep2, stype)
return s1, s2, nil
}