summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authorRahat Mahmood <rahat@google.com>2019-06-10 15:16:42 -0700
committerShentubot <shentubot@google.com>2019-06-10 15:17:43 -0700
commita00157cc0e216a9829f2659ce35c856a22aa5ba2 (patch)
treedda4bdb3e03b719661fc263a91ae8aa4e46b5ae3
parent48961d27a8bcc76b3783a7cc4a4a5ebcd5532d25 (diff)
Store more information in the kernel socket table.
Store enough information in the kernel socket table to distinguish between different types of sockets. Previously we were only storing the socket family, but this isn't enough to classify sockets. For example, TCPv4 and UDPv4 sockets are both AF_INET, and ICMP sockets are SOCK_DGRAM sockets with a particular protocol. Instead of creating more sub-tables, flatten the socket table and provide a filtering mechanism based on the socket entry. Also generate and store a socket entry index ("sl" in linux) which allows us to output entries in a stable order from procfs. PiperOrigin-RevId: 252495895
-rw-r--r--pkg/sentry/fs/host/socket.go4
-rw-r--r--pkg/sentry/fs/proc/BUILD1
-rw-r--r--pkg/sentry/fs/proc/net.go14
-rw-r--r--pkg/sentry/kernel/BUILD13
-rw-r--r--pkg/sentry/kernel/kernel.go55
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go13
-rw-r--r--pkg/sentry/socket/epsocket/provider.go2
-rw-r--r--pkg/sentry/socket/hostinet/socket.go32
-rw-r--r--pkg/sentry/socket/netlink/provider.go2
-rw-r--r--pkg/sentry/socket/netlink/socket.go12
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go16
-rw-r--r--pkg/sentry/socket/socket.go9
-rw-r--r--pkg/sentry/socket/unix/unix.go65
13 files changed, 152 insertions, 86 deletions
diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go
index 6423ad938..305eea718 100644
--- a/pkg/sentry/fs/host/socket.go
+++ b/pkg/sentry/fs/host/socket.go
@@ -164,7 +164,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F
ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
- return unixsocket.NewWithDirent(ctx, d, ep, e.stype != linux.SOCK_STREAM, flags), nil
+ return unixsocket.NewWithDirent(ctx, d, ep, e.stype, flags), nil
}
// newSocket allocates a new unix socket with host endpoint.
@@ -196,7 +196,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error)
ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e)
- return unixsocket.New(ctx, ep, e.stype != linux.SOCK_STREAM), nil
+ return unixsocket.New(ctx, ep, e.stype), nil
}
// Send implements transport.ConnectedEndpoint.Send.
diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD
index d19c360e0..1728fe0b5 100644
--- a/pkg/sentry/fs/proc/BUILD
+++ b/pkg/sentry/fs/proc/BUILD
@@ -45,6 +45,7 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/limits",
"//pkg/sentry/mm",
+ "//pkg/sentry/socket",
"//pkg/sentry/socket/rpcinet",
"//pkg/sentry/socket/unix",
"//pkg/sentry/socket/unix/transport",
diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go
index 3daaa962c..034950158 100644
--- a/pkg/sentry/fs/proc/net.go
+++ b/pkg/sentry/fs/proc/net.go
@@ -27,6 +27,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs"
"gvisor.googlesource.com/gvisor/pkg/sentry/inet"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
)
@@ -213,17 +214,18 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s
fmt.Fprintf(&buf, "Num RefCount Protocol Flags Type St Inode Path\n")
// Entries
- for _, sref := range n.k.ListSockets(linux.AF_UNIX) {
- s := sref.Get()
+ for _, se := range n.k.ListSockets() {
+ s := se.Sock.Get()
if s == nil {
- log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", sref)
+ log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock)
continue
}
sfile := s.(*fs.File)
- sops, ok := sfile.FileOperations.(*unix.SocketOperations)
- if !ok {
- panic(fmt.Sprintf("Found non-unix socket file in unix socket table: %+v", sfile))
+ if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX {
+ // Not a unix socket.
+ continue
}
+ sops := sfile.FileOperations.(*unix.SocketOperations)
addr, err := sops.Endpoint().GetLocalAddress()
if err != nil {
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index 99a2fd964..04e375910 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -64,6 +64,18 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "socket_list",
+ out = "socket_list.go",
+ package = "kernel",
+ prefix = "socket",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*SocketEntry",
+ "Linker": "*SocketEntry",
+ },
+)
+
proto_library(
name = "uncaught_signal_proto",
srcs = ["uncaught_signal.proto"],
@@ -104,6 +116,7 @@ go_library(
"sessions.go",
"signal.go",
"signal_handlers.go",
+ "socket_list.go",
"syscalls.go",
"syscalls_state.go",
"syslog.go",
diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go
index 85d73ace2..f253a81d9 100644
--- a/pkg/sentry/kernel/kernel.go
+++ b/pkg/sentry/kernel/kernel.go
@@ -182,9 +182,13 @@ type Kernel struct {
// danglingEndpoints is used to save / restore tcpip.DanglingEndpoints.
danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"`
- // socketTable is used to track all sockets on the system. Protected by
+ // sockets is the list of all network sockets the system. Protected by
// extMu.
- socketTable map[int]map[*refs.WeakRef]struct{}
+ sockets socketList
+
+ // nextSocketEntry is the next entry number to use in sockets. Protected
+ // by extMu.
+ nextSocketEntry uint64
// deviceRegistry is used to save/restore device.SimpleDevices.
deviceRegistry struct{} `state:".(*device.Registry)"`
@@ -283,7 +287,6 @@ func (k *Kernel) Init(args InitKernelArgs) error {
k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic}
k.futexes = futex.NewManager()
k.netlinkPorts = port.New()
- k.socketTable = make(map[int]map[*refs.WeakRef]struct{})
return nil
}
@@ -1137,51 +1140,43 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) {
})
}
-// socketEntry represents a socket recorded in Kernel.socketTable. It implements
+// SocketEntry represents a socket recorded in Kernel.sockets. It implements
// refs.WeakRefUser for sockets stored in the socket table.
//
// +stateify savable
-type socketEntry struct {
- k *Kernel
- sock *refs.WeakRef
- family int
+type SocketEntry struct {
+ socketEntry
+ k *Kernel
+ Sock *refs.WeakRef
+ ID uint64 // Socket table entry number.
}
// WeakRefGone implements refs.WeakRefUser.WeakRefGone.
-func (s *socketEntry) WeakRefGone() {
+func (s *SocketEntry) WeakRefGone() {
s.k.extMu.Lock()
- // k.socketTable is guaranteed to point to a valid socket table for s.family
- // at this point, since we made sure of the fact when we created this
- // socketEntry, and we never delete socket tables.
- delete(s.k.socketTable[s.family], s.sock)
+ s.k.sockets.Remove(s)
s.k.extMu.Unlock()
}
// RecordSocket adds a socket to the system-wide socket table for tracking.
//
// Precondition: Caller must hold a reference to sock.
-func (k *Kernel) RecordSocket(sock *fs.File, family int) {
+func (k *Kernel) RecordSocket(sock *fs.File) {
k.extMu.Lock()
- table, ok := k.socketTable[family]
- if !ok {
- table = make(map[*refs.WeakRef]struct{})
- k.socketTable[family] = table
- }
- se := socketEntry{k: k, family: family}
- se.sock = refs.NewWeakRef(sock, &se)
- table[se.sock] = struct{}{}
+ id := k.nextSocketEntry
+ k.nextSocketEntry++
+ s := &SocketEntry{k: k, ID: id}
+ s.Sock = refs.NewWeakRef(sock, s)
+ k.sockets.PushBack(s)
k.extMu.Unlock()
}
-// ListSockets returns a snapshot of all sockets of a given family.
-func (k *Kernel) ListSockets(family int) []*refs.WeakRef {
+// ListSockets returns a snapshot of all sockets.
+func (k *Kernel) ListSockets() []*SocketEntry {
k.extMu.Lock()
- socks := []*refs.WeakRef{}
- if table, ok := k.socketTable[family]; ok {
- socks = make([]*refs.WeakRef, 0, len(table))
- for s := range table {
- socks = append(socks, s)
- }
+ var socks []*SocketEntry
+ for s := k.sockets.Front(); s != nil; s = s.Next() {
+ socks = append(socks, s)
}
k.extMu.Unlock()
return socks
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index e1e29de35..f67451179 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -228,6 +228,7 @@ type SocketOperations struct {
family int
Endpoint tcpip.Endpoint
skType linux.SockType
+ protocol int
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
@@ -252,7 +253,7 @@ type SocketOperations struct {
}
// New creates a new endpoint socket.
-func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
+func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
if skType == linux.SOCK_STREAM {
if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil {
return nil, syserr.TranslateNetstackError(err)
@@ -266,6 +267,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue,
family: family,
Endpoint: endpoint,
skType: skType,
+ protocol: protocol,
}), nil
}
@@ -550,7 +552,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- ns, err := New(t, s.family, s.skType, wq, ep)
+ ns, err := New(t, s.family, s.skType, s.protocol, wq, ep)
if err != nil {
return 0, nil, 0, err
}
@@ -578,7 +580,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits())
- t.Kernel().RecordSocket(ns, s.family)
+ t.Kernel().RecordSocket(ns)
return fd, addr, addrLen, syserr.FromError(e)
}
@@ -2324,3 +2326,8 @@ func (s *SocketOperations) State() uint32 {
// TODO(b/112063468): Export states for UDP, ICMP, and raw sockets.
return 0
}
+
+// Type implements socket.Socket.Type.
+func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ return s.family, s.skType, s.protocol
+}
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
index e48a106ea..516582828 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -111,7 +111,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
return nil, syserr.TranslateNetstackError(e)
}
- return New(t, p.family, stype, wq, ep)
+ return New(t, p.family, stype, protocol, wq, ep)
}
// Pair just returns nil sockets (not supported).
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 4517951a0..c62c8d8f1 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -56,15 +56,22 @@ type socketOperations struct {
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
socket.SendReceiveTimeout
- family int // Read-only.
- fd int // must be O_NONBLOCK
- queue waiter.Queue
+ family int // Read-only.
+ stype linux.SockType // Read-only.
+ protocol int // Read-only.
+ fd int // must be O_NONBLOCK
+ queue waiter.Queue
}
var _ = socket.Socket(&socketOperations{})
-func newSocketFile(ctx context.Context, family int, fd int, nonblock bool) (*fs.File, *syserr.Error) {
- s := &socketOperations{family: family, fd: fd}
+func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) {
+ s := &socketOperations{
+ family: family,
+ stype: stype,
+ protocol: protocol,
+ fd: fd,
+ }
if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
return nil, syserr.FromError(err)
}
@@ -222,7 +229,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
}
- f, err := newSocketFile(t, s.family, fd, flags&syscall.SOCK_NONBLOCK != 0)
+ f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0)
if err != nil {
syscall.Close(fd)
return 0, nil, 0, err
@@ -233,7 +240,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
}
kfd, kerr := t.FDMap().NewFDFrom(0, f, fdFlags, t.ThreadGroup().Limits())
- t.Kernel().RecordSocket(f, s.family)
+ t.Kernel().RecordSocket(f)
return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr)
}
@@ -542,6 +549,11 @@ func (s *socketOperations) State() uint32 {
return uint32(info.State)
}
+// Type implements socket.Socket.Type.
+func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ return s.family, s.stype, s.protocol
+}
+
type socketProvider struct {
family int
}
@@ -558,7 +570,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto
}
// Only accept TCP and UDP.
- stype := int(stypeflags) & linux.SOCK_TYPE_MASK
+ stype := stypeflags & linux.SOCK_TYPE_MASK
switch stype {
case syscall.SOCK_STREAM:
switch protocol {
@@ -581,11 +593,11 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto
// Conservatively ignore all flags specified by the application and add
// SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0
// to simplify the syscall filters, since 0 and IPPROTO_* are equivalent.
- fd, err := syscall.Socket(p.family, stype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
if err != nil {
return nil, syserr.FromError(err)
}
- return newSocketFile(t, p.family, fd, stypeflags&syscall.SOCK_NONBLOCK != 0)
+ return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&syscall.SOCK_NONBLOCK != 0)
}
// Pair implements socket.Provider.Pair.
diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go
index 863edc241..5dc103877 100644
--- a/pkg/sentry/socket/netlink/provider.go
+++ b/pkg/sentry/socket/netlink/provider.go
@@ -82,7 +82,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int
return nil, err
}
- s, err := NewSocket(t, p)
+ s, err := NewSocket(t, stype, p)
if err != nil {
return nil, err
}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 16c79aa33..62659784a 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -80,6 +80,10 @@ type Socket struct {
// protocol is the netlink protocol implementation.
protocol Protocol
+ // skType is the socket type. This is either SOCK_DGRAM or SOCK_RAW for
+ // netlink sockets.
+ skType linux.SockType
+
// ep is a datagram unix endpoint used to buffer messages sent from the
// kernel to userspace. RecvMsg reads messages from this endpoint.
ep transport.Endpoint
@@ -105,7 +109,7 @@ type Socket struct {
var _ socket.Socket = (*Socket)(nil)
// NewSocket creates a new Socket.
-func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) {
+func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) {
// Datagram endpoint used to buffer kernel -> user messages.
ep := transport.NewConnectionless()
@@ -126,6 +130,7 @@ func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) {
return &Socket{
ports: t.Kernel().NetlinkPorts(),
protocol: protocol,
+ skType: skType,
ep: ep,
connection: connection,
sendBufferSize: defaultSendBufferSize,
@@ -621,3 +626,8 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence,
func (s *Socket) State() uint32 {
return s.ep.State()
}
+
+// Type implements socket.Socket.Type.
+func (s *Socket) Type() (family int, skType linux.SockType, protocol int) {
+ return linux.AF_NETLINK, s.skType, s.protocol.Protocol()
+}
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index 2d5b5b58f..c22ff1ff0 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -53,7 +53,10 @@ type socketOperations struct {
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
socket.SendReceiveTimeout
- family int // Read-only.
+ family int // Read-only.
+ stype linux.SockType // Read-only.
+ protocol int // Read-only.
+
fd uint32 // must be O_NONBLOCK
wq *waiter.Queue
rpcConn *conn.RPCConnection
@@ -86,6 +89,8 @@ func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.S
defer dirent.DecRef()
return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{
family: family,
+ stype: skType,
+ protocol: protocol,
wq: &wq,
fd: fd,
rpcConn: stack.rpcConn,
@@ -332,7 +337,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
if err != nil {
return 0, nil, 0, syserr.FromError(err)
}
- t.Kernel().RecordSocket(file, s.family)
+ t.Kernel().RecordSocket(file)
if peerRequested {
return fd, payload.Address.Address, payload.Address.Length, nil
@@ -835,6 +840,11 @@ func (s *socketOperations) State() uint32 {
return 0
}
+// Type implements socket.Socket.Type.
+func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ return s.family, s.stype, s.protocol
+}
+
type socketProvider struct {
family int
}
@@ -876,7 +886,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto
return nil, nil
}
- return newSocketFile(t, s, p.family, stype, 0)
+ return newSocketFile(t, s, p.family, stype, protocol)
}
// Pair implements socket.Provider.Pair.
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index f1021ec67..d60944b6b 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -120,6 +120,9 @@ type Socket interface {
// State returns the current state of the socket, as represented by Linux in
// procfs. The returned state value is protocol-specific.
State() uint32
+
+ // Type returns the family, socket type and protocol of the socket.
+ Type() (family int, skType linux.SockType, protocol int)
}
// Provider is the interface implemented by providers of sockets for specific
@@ -156,7 +159,7 @@ func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.Fi
return nil, err
}
if s != nil {
- t.Kernel().RecordSocket(s, family)
+ t.Kernel().RecordSocket(s)
return s, nil
}
}
@@ -179,8 +182,8 @@ func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.F
}
if s1 != nil && s2 != nil {
k := t.Kernel()
- k.RecordSocket(s1, family)
- k.RecordSocket(s2, family)
+ k.RecordSocket(s1)
+ k.RecordSocket(s2)
return s1, s2, nil
}
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 56ed63e21..b07e8d67b 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -17,6 +17,7 @@
package unix
import (
+ "fmt"
"strings"
"syscall"
@@ -55,22 +56,22 @@ type SocketOperations struct {
refs.AtomicRefCount
socket.SendReceiveTimeout
- ep transport.Endpoint
- isPacket bool
+ ep transport.Endpoint
+ stype linux.SockType
}
// New creates a new unix socket.
-func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File {
+func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File {
dirent := socket.NewDirent(ctx, unixSocketDevice)
defer dirent.DecRef()
- return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true})
+ return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true})
}
// NewWithDirent creates a new unix socket using an existing dirent.
-func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File {
+func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File {
return fs.NewFile(ctx, d, flags, &SocketOperations{
- ep: ep,
- isPacket: isPacket,
+ ep: ep,
+ stype: stype,
})
}
@@ -88,6 +89,18 @@ func (s *SocketOperations) Release() {
s.DecRef()
}
+func (s *SocketOperations) isPacket() bool {
+ switch s.stype {
+ case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ return true
+ case linux.SOCK_STREAM:
+ return false
+ default:
+ // We shouldn't have allowed any other socket types during creation.
+ panic(fmt.Sprintf("Invalid socket type %d", s.stype))
+ }
+}
+
// Endpoint extracts the transport.Endpoint.
func (s *SocketOperations) Endpoint() transport.Endpoint {
return s.ep
@@ -193,7 +206,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- ns := New(t, ep, s.isPacket)
+ ns := New(t, ep, s.stype)
defer ns.DecRef()
if flags&linux.SOCK_NONBLOCK != 0 {
@@ -221,7 +234,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
return 0, nil, 0, syserr.FromError(e)
}
- t.Kernel().RecordSocket(ns, linux.AF_UNIX)
+ t.Kernel().RecordSocket(ns)
return fd, addr, addrLen, nil
}
@@ -487,6 +500,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
waitAll := flags&linux.MSG_WAITALL != 0
+ isPacket := s.isPacket()
// Calculate the number of FDs for which we have space and if we are
// requesting credentials.
@@ -528,8 +542,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
msgFlags |= linux.MSG_CTRUNC
}
- if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
- if s.isPacket && n < int64(r.MsgSize) {
+ if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() {
+ if isPacket && n < int64(r.MsgSize) {
msgFlags |= linux.MSG_TRUNC
}
@@ -570,11 +584,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
total += n
}
- if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
+ if err != nil || !waitAll || isPacket || n >= dst.NumBytes() {
if total > 0 {
err = nil
}
- if s.isPacket && n < int64(r.MsgSize) {
+ if isPacket && n < int64(r.MsgSize) {
msgFlags |= linux.MSG_TRUNC
}
return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
@@ -601,6 +615,12 @@ func (s *SocketOperations) State() uint32 {
return s.ep.State()
}
+// Type implements socket.Socket.Type.
+func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ // Unix domain sockets always have a protocol of 0.
+ return linux.AF_UNIX, s.stype, 0
+}
+
// provider is a unix domain socket provider.
type provider struct{}
@@ -613,21 +633,16 @@ func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs
// Create the endpoint and socket.
var ep transport.Endpoint
- var isPacket bool
switch stype {
case linux.SOCK_DGRAM:
- isPacket = true
ep = transport.NewConnectionless()
- case linux.SOCK_SEQPACKET:
- isPacket = true
- fallthrough
- case linux.SOCK_STREAM:
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
ep = transport.NewConnectioned(stype, t.Kernel())
default:
return nil, syserr.ErrInvalidArgument
}
- return New(t, ep, isPacket), nil
+ return New(t, ep, stype), nil
}
// Pair creates a new pair of AF_UNIX connected sockets.
@@ -637,19 +652,17 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F
return nil, nil, syserr.ErrProtocolNotSupported
}
- var isPacket bool
switch stype {
- case linux.SOCK_STREAM:
- case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
- isPacket = true
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ // Ok
default:
return nil, nil, syserr.ErrInvalidArgument
}
// Create the endpoints and sockets.
ep1, ep2 := transport.NewPair(stype, t.Kernel())
- s1 := New(t, ep1, isPacket)
- s2 := New(t, ep2, isPacket)
+ s1 := New(t, ep1, stype)
+ s2 := New(t, ep2, stype)
return s1, s2, nil
}