summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/control/control.go10
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go25
-rw-r--r--pkg/sentry/socket/hostinet/socket.go9
-rw-r--r--pkg/sentry/socket/netlink/socket.go3
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go8
-rw-r--r--pkg/sentry/socket/socket.go3
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go4
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go30
-rw-r--r--pkg/sentry/socket/unix/unix.go17
9 files changed, 55 insertions, 54 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index b646dc258..4f4a20dfe 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
- "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
@@ -63,7 +62,7 @@ type RightsFiles []*fs.File
func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
files := make(RightsFiles, 0, len(fds))
for _, fd := range fds {
- file, _ := t.FDMap().GetDescriptor(kdefs.FD(fd))
+ file := t.GetFile(fd)
if file == nil {
files.Release()
return nil, syserror.EBADF
@@ -109,7 +108,9 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32
files, trunc := rights.Files(t, max)
fds := make([]int32, 0, len(files))
for i := 0; i < max && len(files) > 0; i++ {
- fd, err := t.FDMap().NewFDFrom(0, files[0], kernel.FDFlags{cloexec}, t.ThreadGroup().Limits())
+ fd, err := t.NewFDFrom(0, files[0], kernel.FDFlags{
+ CloseOnExec: cloexec,
+ })
files[0].DecRef()
files = files[1:]
if err != nil {
@@ -315,8 +316,7 @@ func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte {
// Parse parses a raw socket control message into portable objects.
func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) {
var (
- fds linux.ControlMessageRights
-
+ fds linux.ControlMessageRights
haveCreds bool
creds linux.ControlMessageCredentials
)
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 2a38e370a..9d1bcfd41 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -40,7 +40,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
@@ -286,14 +285,14 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
// GetAddress reads an sockaddr struct from the given address and converts it
// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6
// addresses.
-func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
+func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
return tcpip.FullAddress{}, syserr.ErrInvalidArgument
}
family := usermem.ByteOrder.Uint16(addr)
- if family != uint16(sfamily) {
+ if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) {
return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
}
@@ -318,7 +317,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
case linux.AF_INET:
var a linux.SockAddrInet
if len(addr) < sockAddrInetSize {
- return tcpip.FullAddress{}, syserr.ErrBadAddress
+ return tcpip.FullAddress{}, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
@@ -331,7 +330,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
case linux.AF_INET6:
var a linux.SockAddrInet6
if len(addr) < sockAddrInet6Size {
- return tcpip.FullAddress{}, syserr.ErrBadAddress
+ return tcpip.FullAddress{}, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
@@ -344,6 +343,9 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
}
return out, nil
+ case linux.AF_UNSPEC:
+ return tcpip.FullAddress{}, nil
+
default:
return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
}
@@ -466,7 +468,7 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr)
+ addr, err := GetAddress(s.family, sockaddr, false /* strict */)
if err != nil {
return err
}
@@ -499,7 +501,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr)
+ addr, err := GetAddress(s.family, sockaddr, true /* strict */)
if err != nil {
return err
}
@@ -537,7 +539,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *wait
// Accept implements the linux syscall accept(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
// Issue the accept request to get the new endpoint.
ep, wq, terr := s.Endpoint.Accept()
if terr != nil {
@@ -575,10 +577,9 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- fdFlags := kernel.FDFlags{
+ fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
- }
- fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits())
+ })
t.Kernel().RecordSocket(ns)
@@ -1924,7 +1925,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, err := GetAddress(s.family, to)
+ addrBuf, err := GetAddress(s.family, to, true /* strict */)
if err != nil {
return 0, err
}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index c63f3aacf..7f69406b7 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
@@ -190,7 +189,7 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
}
// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
var peerAddr []byte
var peerAddrlen uint32
var peerAddrPtr *byte
@@ -236,11 +235,11 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
defer f.DecRef()
- fdFlags := kernel.FDFlags{
+ kfd, kerr := t.NewFDFrom(0, f, kernel.FDFlags{
CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
- }
- kfd, kerr := t.FDMap().NewFDFrom(0, f, fdFlags, t.ThreadGroup().Limits())
+ })
t.Kernel().RecordSocket(f)
+
return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr)
}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index ecc1e2d53..f3d6c1e9b 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -27,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink/port"
@@ -272,7 +271,7 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr
}
// Accept implements socket.Socket.Accept.
-func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
// Netlink sockets never support accept.
return 0, nil, 0, syserr.ErrNotSupported
}
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index cc7b964ea..ccaaddbfc 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
@@ -286,7 +285,7 @@ func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultP
}
// Accept implements socket.Socket.Accept.
-func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
payload, se := rpcAccept(t, s.fd, peerRequested)
// Check if we need to block.
@@ -336,10 +335,9 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
})
defer file.DecRef()
- fdFlags := kernel.FDFlags{
+ fd, err := t.NewFDFrom(0, file, kernel.FDFlags{
CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
- }
- fd, err := t.FDMap().NewFDFrom(0, file, fdFlags, t.ThreadGroup().Limits())
+ })
if err != nil {
return 0, nil, 0, syserr.FromError(err)
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 933120f34..0efa58a58 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -27,7 +27,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usermem"
@@ -53,7 +52,7 @@ type Socket interface {
// Accept implements the accept4(2) linux syscall.
// Returns fd, real peer address length and error. Real peer address
// length is only set if len(peer) > 0.
- Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error)
+ Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error)
// Bind implements the bind(2) linux syscall.
Bind(t *kernel.Task, sockaddr []byte) *syserr.Error
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index e4c416233..73d2df15d 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -143,7 +143,9 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E
}
q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
+ q1.EnableLeakCheck("transport.queue")
q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
+ q2.EnableLeakCheck("transport.queue")
if stype == linux.SOCK_STREAM {
a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
@@ -294,12 +296,14 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
}
readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
+ readQueue.EnableLeakCheck("transport.queue")
ne.connected = &connectedEndpoint{
endpoint: ce,
writeQueue: readQueue,
}
writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
+ writeQueue.EnableLeakCheck("transport.queue")
if e.stype == linux.SOCK_STREAM {
ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
} else {
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index e987519f0..c7f7c5b16 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -41,7 +41,9 @@ var (
// NewConnectionless creates a new unbound dgram endpoint.
func NewConnectionless(ctx context.Context) Endpoint {
ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}}
- ep.receiver = &queueReceiver{readQueue: &queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}}
+ q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
+ q.EnableLeakCheck("transport.queue")
+ ep.receiver = &queueReceiver{readQueue: &q}
return ep
}
@@ -52,29 +54,24 @@ func (e *connectionlessEndpoint) isBound() bool {
// Close puts the endpoint in a closed state and frees all resources associated
// with it.
-//
-// The socket will be a fresh state after a call to close and may be reused.
-// That is, close may be used to "unbind" or "disconnect" the socket in error
-// paths.
func (e *connectionlessEndpoint) Close() {
e.Lock()
- var r Receiver
- if e.Connected() {
- e.receiver.CloseRecv()
- r = e.receiver
- e.receiver = nil
-
+ if e.connected != nil {
e.connected.Release()
e.connected = nil
}
+
if e.isBound() {
e.path = ""
}
+
+ e.receiver.CloseRecv()
+ r := e.receiver
+ e.receiver = nil
e.Unlock()
- if r != nil {
- r.CloseNotify()
- r.Release()
- }
+
+ r.CloseNotify()
+ r.Release()
}
// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
@@ -137,6 +134,9 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi
}
e.Lock()
+ if e.connected != nil {
+ e.connected.Release()
+ }
e.connected = connected
e.Unlock()
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 6190de0c5..637168714 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -28,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/control"
@@ -69,10 +68,13 @@ func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType)
// NewWithDirent creates a new unix socket using an existing dirent.
func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File {
- return fs.NewFile(ctx, d, flags, &SocketOperations{
+ s := SocketOperations{
ep: ep,
stype: stype,
- })
+ }
+ s.EnableLeakCheck("unix.SocketOperations")
+
+ return fs.NewFile(ctx, d, flags, &s)
}
// DecRef implements RefCounter.DecRef.
@@ -108,7 +110,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr)
+ addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr, true /* strict */)
if err != nil {
return "", err
}
@@ -191,7 +193,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *
// Accept implements the linux syscall accept(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
// Issue the accept request to get the new endpoint.
ep, err := s.ep.Accept()
if err != nil {
@@ -226,10 +228,9 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- fdFlags := kernel.FDFlags{
+ fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
- }
- fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits())
+ })
if e != nil {
return 0, nil, 0, syserr.FromError(e)
}