From 81f4829d1195276d037f8bd23a2ef69e88f5ae6c Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Wed, 20 Mar 2019 14:30:00 -0700 Subject: Record sockets created during accept(2) for all families. Track new sockets created during accept(2) in the socket table for all families. Previously we were only doing this for unix domain sockets. PiperOrigin-RevId: 239475550 Change-Id: I16f009f24a06245bfd1d72ffd2175200f837c6ac --- pkg/sentry/socket/hostinet/socket.go | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) (limited to 'pkg/sentry/socket/hostinet') diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 2c54e8de2..a0a8a3220 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -53,14 +53,15 @@ type socketOperations struct { fsutil.FileNoMMap `state:"nosave"` socket.SendReceiveTimeout - fd int // must be O_NONBLOCK - queue waiter.Queue + family int // Read-only. + fd int // must be O_NONBLOCK + queue waiter.Queue } var _ = socket.Socket(&socketOperations{}) -func newSocketFile(ctx context.Context, fd int, nonblock bool) (*fs.File, *syserr.Error) { - s := &socketOperations{fd: fd} +func newSocketFile(ctx context.Context, family int, fd int, nonblock bool) (*fs.File, *syserr.Error) { + s := &socketOperations{family: family, fd: fd} if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { return nil, syserr.FromError(err) } @@ -218,7 +219,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr) } - f, err := newSocketFile(t, fd, flags&syscall.SOCK_NONBLOCK != 0) + f, err := newSocketFile(t, s.family, fd, flags&syscall.SOCK_NONBLOCK != 0) if err != nil { syscall.Close(fd) return 0, nil, 0, err @@ -229,6 +230,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, } kfd, kerr := t.FDMap().NewFDFrom(0, f, fdFlags, t.ThreadGroup().Limits()) + t.Kernel().RecordSocket(f, s.family) return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr) } @@ -552,7 +554,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p if err != nil { return nil, syserr.FromError(err) } - return newSocketFile(t, fd, stypeflags&syscall.SOCK_NONBLOCK != 0) + return newSocketFile(t, p.family, fd, stypeflags&syscall.SOCK_NONBLOCK != 0) } // Pair implements socket.Provider.Pair. -- cgit v1.2.3