From a00157cc0e216a9829f2659ce35c856a22aa5ba2 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Mon, 10 Jun 2019 15:16:42 -0700 Subject: Store more information in the kernel socket table. Store enough information in the kernel socket table to distinguish between different types of sockets. Previously we were only storing the socket family, but this isn't enough to classify sockets. For example, TCPv4 and UDPv4 sockets are both AF_INET, and ICMP sockets are SOCK_DGRAM sockets with a particular protocol. Instead of creating more sub-tables, flatten the socket table and provide a filtering mechanism based on the socket entry. Also generate and store a socket entry index ("sl" in linux) which allows us to output entries in a stable order from procfs. PiperOrigin-RevId: 252495895 --- pkg/sentry/socket/epsocket/epsocket.go | 13 +++++-- pkg/sentry/socket/epsocket/provider.go | 2 +- pkg/sentry/socket/hostinet/socket.go | 32 +++++++++++------ pkg/sentry/socket/netlink/provider.go | 2 +- pkg/sentry/socket/netlink/socket.go | 12 ++++++- pkg/sentry/socket/rpcinet/socket.go | 16 +++++++-- pkg/sentry/socket/socket.go | 9 +++-- pkg/sentry/socket/unix/unix.go | 65 ++++++++++++++++++++-------------- 8 files changed, 103 insertions(+), 48 deletions(-) (limited to 'pkg/sentry/socket') diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index e1e29de35..f67451179 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -228,6 +228,7 @@ type SocketOperations struct { family int Endpoint tcpip.Endpoint skType linux.SockType + protocol int // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` @@ -252,7 +253,7 @@ type SocketOperations struct { } // New creates a new endpoint socket. -func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { +func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { return nil, syserr.TranslateNetstackError(err) @@ -266,6 +267,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, family: family, Endpoint: endpoint, skType: skType, + protocol: protocol, }), nil } @@ -550,7 +552,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - ns, err := New(t, s.family, s.skType, wq, ep) + ns, err := New(t, s.family, s.skType, s.protocol, wq, ep) if err != nil { return 0, nil, 0, err } @@ -578,7 +580,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits()) - t.Kernel().RecordSocket(ns, s.family) + t.Kernel().RecordSocket(ns) return fd, addr, addrLen, syserr.FromError(e) } @@ -2324,3 +2326,8 @@ func (s *SocketOperations) State() uint32 { // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. return 0 } + +// Type implements socket.Socket.Type. +func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.skType, s.protocol +} diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index e48a106ea..516582828 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -111,7 +111,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* return nil, syserr.TranslateNetstackError(e) } - return New(t, p.family, stype, wq, ep) + return New(t, p.family, stype, protocol, wq, ep) } // Pair just returns nil sockets (not supported). diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 4517951a0..c62c8d8f1 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -56,15 +56,22 @@ type socketOperations struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout - family int // Read-only. - fd int // must be O_NONBLOCK - queue waiter.Queue + family int // Read-only. + stype linux.SockType // Read-only. + protocol int // Read-only. + fd int // must be O_NONBLOCK + queue waiter.Queue } var _ = socket.Socket(&socketOperations{}) -func newSocketFile(ctx context.Context, family int, fd int, nonblock bool) (*fs.File, *syserr.Error) { - s := &socketOperations{family: family, fd: fd} +func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) { + s := &socketOperations{ + family: family, + stype: stype, + protocol: protocol, + fd: fd, + } if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { return nil, syserr.FromError(err) } @@ -222,7 +229,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr) } - f, err := newSocketFile(t, s.family, fd, flags&syscall.SOCK_NONBLOCK != 0) + f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0) if err != nil { syscall.Close(fd) return 0, nil, 0, err @@ -233,7 +240,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, } kfd, kerr := t.FDMap().NewFDFrom(0, f, fdFlags, t.ThreadGroup().Limits()) - t.Kernel().RecordSocket(f, s.family) + t.Kernel().RecordSocket(f) return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr) } @@ -542,6 +549,11 @@ func (s *socketOperations) State() uint32 { return uint32(info.State) } +// Type implements socket.Socket.Type. +func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.stype, s.protocol +} + type socketProvider struct { family int } @@ -558,7 +570,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto } // Only accept TCP and UDP. - stype := int(stypeflags) & linux.SOCK_TYPE_MASK + stype := stypeflags & linux.SOCK_TYPE_MASK switch stype { case syscall.SOCK_STREAM: switch protocol { @@ -581,11 +593,11 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto // Conservatively ignore all flags specified by the application and add // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0 // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent. - fd, err := syscall.Socket(p.family, stype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) + fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { return nil, syserr.FromError(err) } - return newSocketFile(t, p.family, fd, stypeflags&syscall.SOCK_NONBLOCK != 0) + return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&syscall.SOCK_NONBLOCK != 0) } // Pair implements socket.Provider.Pair. diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 863edc241..5dc103877 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -82,7 +82,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int return nil, err } - s, err := NewSocket(t, p) + s, err := NewSocket(t, stype, p) if err != nil { return nil, err } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 16c79aa33..62659784a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -80,6 +80,10 @@ type Socket struct { // protocol is the netlink protocol implementation. protocol Protocol + // skType is the socket type. This is either SOCK_DGRAM or SOCK_RAW for + // netlink sockets. + skType linux.SockType + // ep is a datagram unix endpoint used to buffer messages sent from the // kernel to userspace. RecvMsg reads messages from this endpoint. ep transport.Endpoint @@ -105,7 +109,7 @@ type Socket struct { var _ socket.Socket = (*Socket)(nil) // NewSocket creates a new Socket. -func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) { +func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) { // Datagram endpoint used to buffer kernel -> user messages. ep := transport.NewConnectionless() @@ -126,6 +130,7 @@ func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) { return &Socket{ ports: t.Kernel().NetlinkPorts(), protocol: protocol, + skType: skType, ep: ep, connection: connection, sendBufferSize: defaultSendBufferSize, @@ -621,3 +626,8 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, func (s *Socket) State() uint32 { return s.ep.State() } + +// Type implements socket.Socket.Type. +func (s *Socket) Type() (family int, skType linux.SockType, protocol int) { + return linux.AF_NETLINK, s.skType, s.protocol.Protocol() +} diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 2d5b5b58f..c22ff1ff0 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -53,7 +53,10 @@ type socketOperations struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout - family int // Read-only. + family int // Read-only. + stype linux.SockType // Read-only. + protocol int // Read-only. + fd uint32 // must be O_NONBLOCK wq *waiter.Queue rpcConn *conn.RPCConnection @@ -86,6 +89,8 @@ func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.S defer dirent.DecRef() return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{ family: family, + stype: skType, + protocol: protocol, wq: &wq, fd: fd, rpcConn: stack.rpcConn, @@ -332,7 +337,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, if err != nil { return 0, nil, 0, syserr.FromError(err) } - t.Kernel().RecordSocket(file, s.family) + t.Kernel().RecordSocket(file) if peerRequested { return fd, payload.Address.Address, payload.Address.Length, nil @@ -835,6 +840,11 @@ func (s *socketOperations) State() uint32 { return 0 } +// Type implements socket.Socket.Type. +func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.stype, s.protocol +} + type socketProvider struct { family int } @@ -876,7 +886,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto return nil, nil } - return newSocketFile(t, s, p.family, stype, 0) + return newSocketFile(t, s, p.family, stype, protocol) } // Pair implements socket.Provider.Pair. diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index f1021ec67..d60944b6b 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -120,6 +120,9 @@ type Socket interface { // State returns the current state of the socket, as represented by Linux in // procfs. The returned state value is protocol-specific. State() uint32 + + // Type returns the family, socket type and protocol of the socket. + Type() (family int, skType linux.SockType, protocol int) } // Provider is the interface implemented by providers of sockets for specific @@ -156,7 +159,7 @@ func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.Fi return nil, err } if s != nil { - t.Kernel().RecordSocket(s, family) + t.Kernel().RecordSocket(s) return s, nil } } @@ -179,8 +182,8 @@ func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.F } if s1 != nil && s2 != nil { k := t.Kernel() - k.RecordSocket(s1, family) - k.RecordSocket(s2, family) + k.RecordSocket(s1) + k.RecordSocket(s2) return s1, s2, nil } } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 56ed63e21..b07e8d67b 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -17,6 +17,7 @@ package unix import ( + "fmt" "strings" "syscall" @@ -55,22 +56,22 @@ type SocketOperations struct { refs.AtomicRefCount socket.SendReceiveTimeout - ep transport.Endpoint - isPacket bool + ep transport.Endpoint + stype linux.SockType } // New creates a new unix socket. -func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File { +func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File { dirent := socket.NewDirent(ctx, unixSocketDevice) defer dirent.DecRef() - return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true}) + return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true}) } // NewWithDirent creates a new unix socket using an existing dirent. -func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File { +func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File { return fs.NewFile(ctx, d, flags, &SocketOperations{ - ep: ep, - isPacket: isPacket, + ep: ep, + stype: stype, }) } @@ -88,6 +89,18 @@ func (s *SocketOperations) Release() { s.DecRef() } +func (s *SocketOperations) isPacket() bool { + switch s.stype { + case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: + return true + case linux.SOCK_STREAM: + return false + default: + // We shouldn't have allowed any other socket types during creation. + panic(fmt.Sprintf("Invalid socket type %d", s.stype)) + } +} + // Endpoint extracts the transport.Endpoint. func (s *SocketOperations) Endpoint() transport.Endpoint { return s.ep @@ -193,7 +206,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - ns := New(t, ep, s.isPacket) + ns := New(t, ep, s.stype) defer ns.DecRef() if flags&linux.SOCK_NONBLOCK != 0 { @@ -221,7 +234,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, nil, 0, syserr.FromError(e) } - t.Kernel().RecordSocket(ns, linux.AF_UNIX) + t.Kernel().RecordSocket(ns) return fd, addr, addrLen, nil } @@ -487,6 +500,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 waitAll := flags&linux.MSG_WAITALL != 0 + isPacket := s.isPacket() // Calculate the number of FDs for which we have space and if we are // requesting credentials. @@ -528,8 +542,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags msgFlags |= linux.MSG_CTRUNC } - if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() { - if s.isPacket && n < int64(r.MsgSize) { + if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() { + if isPacket && n < int64(r.MsgSize) { msgFlags |= linux.MSG_TRUNC } @@ -570,11 +584,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags total += n } - if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() { + if err != nil || !waitAll || isPacket || n >= dst.NumBytes() { if total > 0 { err = nil } - if s.isPacket && n < int64(r.MsgSize) { + if isPacket && n < int64(r.MsgSize) { msgFlags |= linux.MSG_TRUNC } return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) @@ -601,6 +615,12 @@ func (s *SocketOperations) State() uint32 { return s.ep.State() } +// Type implements socket.Socket.Type. +func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { + // Unix domain sockets always have a protocol of 0. + return linux.AF_UNIX, s.stype, 0 +} + // provider is a unix domain socket provider. type provider struct{} @@ -613,21 +633,16 @@ func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs // Create the endpoint and socket. var ep transport.Endpoint - var isPacket bool switch stype { case linux.SOCK_DGRAM: - isPacket = true ep = transport.NewConnectionless() - case linux.SOCK_SEQPACKET: - isPacket = true - fallthrough - case linux.SOCK_STREAM: + case linux.SOCK_SEQPACKET, linux.SOCK_STREAM: ep = transport.NewConnectioned(stype, t.Kernel()) default: return nil, syserr.ErrInvalidArgument } - return New(t, ep, isPacket), nil + return New(t, ep, stype), nil } // Pair creates a new pair of AF_UNIX connected sockets. @@ -637,19 +652,17 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F return nil, nil, syserr.ErrProtocolNotSupported } - var isPacket bool switch stype { - case linux.SOCK_STREAM: - case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: - isPacket = true + case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: + // Ok default: return nil, nil, syserr.ErrInvalidArgument } // Create the endpoints and sockets. ep1, ep2 := transport.NewPair(stype, t.Kernel()) - s1 := New(t, ep1, isPacket) - s2 := New(t, ep2, isPacket) + s1 := New(t, ep1, stype) + s2 := New(t, ep2, stype) return s1, s2, nil } -- cgit v1.2.3