// Copyright 2018 Google Inc.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//     http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// See the License for the specific language governing permissions and
// limitations under the License.

package linux

import (

	ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"

// fileCap is the maximum allowable files for poll & select.
const fileCap = 1024 * 1024

// Masks for "readable", "writable", and "exceptional" events as defined by
// select(2).
const (
	// selectReadEvents is analogous to the Linux kernel's
	// fs/select.c:POLLIN_SET.
	selectReadEvents = waiter.EventIn | waiter.EventHUp | waiter.EventErr

	// selectWriteEvents is analogous to the Linux kernel's
	// fs/select.c:POLLOUT_SET.
	selectWriteEvents = waiter.EventOut | waiter.EventErr

	// selectExceptEvents is analogous to the Linux kernel's
	// fs/select.c:POLLEX_SET.
	selectExceptEvents = waiter.EventPri

func doPoll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (time.Duration, uintptr, error) {
	if uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
		return timeout, 0, syserror.EINVAL

	pfd := make([]syscalls.PollFD, nfds)
	if nfds > 0 {
		if _, err := t.CopyIn(pfdAddr, &pfd); err != nil {
			return timeout, 0, err

	// Compatibility warning: Linux adds POLLHUP and POLLERR just before
	// polling, in fs/select.c:do_pollfd(). Since pfd is copied out after
	// polling, changing event masks here is an application-visible difference.
	// (Linux also doesn't copy out event masks at all, only revents.)
	for i := range pfd {
		pfd[i].Events |= waiter.EventHUp | waiter.EventErr
	remainingTimeout, n, err := syscalls.Poll(t, pfd, timeout)
	err = syserror.ConvertIntr(err, syserror.EINTR)

	// The poll entries are copied out regardless of whether
	// any are set or not. This aligns with the Linux behavior.
	if nfds > 0 && err == nil {
		if _, err := t.CopyOut(pfdAddr, pfd); err != nil {
			return remainingTimeout, 0, err

	return remainingTimeout, n, err

func doSelect(t *kernel.Task, nfds int, readFDs, writeFDs, exceptFDs usermem.Addr, timeout time.Duration) (uintptr, error) {
	if nfds < 0 || uint64(nfds) > t.ThreadGroup().Limits().GetCapped(limits.NumberOfFiles, fileCap) {
		return 0, syserror.EINVAL

	// Capture all the provided input vectors.
	// N.B. This only works on little-endian architectures.
	byteCount := (nfds + 7) / 8
	bitsInLastPartialByte := uint(nfds % 8)
	r := make([]byte, byteCount)
	w := make([]byte, byteCount)
	e := make([]byte, byteCount)

	if readFDs != 0 {
		if _, err := t.CopyIn(readFDs, &r); err != nil {
			return 0, err
		// Mask out bits above nfds.
		if bitsInLastPartialByte != 0 {
			r[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte

	if writeFDs != 0 {
		if _, err := t.CopyIn(writeFDs, &w); err != nil {
			return 0, err
		if bitsInLastPartialByte != 0 {
			w[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte

	if exceptFDs != 0 {
		if _, err := t.CopyIn(exceptFDs, &e); err != nil {
			return 0, err
		if bitsInLastPartialByte != 0 {
			e[byteCount-1] &^= byte(0xff) << bitsInLastPartialByte

	// Count how many FDs are actually being requested so that we can build
	// a PollFD array.
	fdCount := 0
	for i := 0; i < byteCount; i++ {
		v := r[i] | w[i] | e[i]
		for v != 0 {
			v &= (v - 1)

	// Build the PollFD array.
	pfd := make([]syscalls.PollFD, 0, fdCount)
	fd := kdefs.FD(0)
	for i := 0; i < byteCount; i++ {
		rV, wV, eV := r[i], w[i], e[i]
		v := rV | wV | eV
		m := byte(1)
		for j := 0; j < 8; j++ {
			if (v & m) != 0 {
				// Make sure the fd is valid and decrement the reference
				// immediately to ensure we don't leak. Note, another thread
				// might be about to close fd. This is racy, but that's
				// OK. Linux is racy in the same way.
				file := t.FDMap().GetFile(fd)
				if file == nil {
					return 0, syserror.EBADF

				mask := waiter.EventMask(0)
				if (rV & m) != 0 {
					mask |= selectReadEvents

				if (wV & m) != 0 {
					mask |= selectWriteEvents

				if (eV & m) != 0 {
					mask |= selectExceptEvents

				pfd = append(pfd, syscalls.PollFD{
					FD:     fd,
					Events: mask,

			m <<= 1

	// Do the syscall, then count the number of bits set.
	_, _, err := syscalls.Poll(t, pfd, timeout)
	if err != nil {
		return 0, syserror.ConvertIntr(err, syserror.EINTR)

	// r, w, and e are currently event mask bitsets; unset bits corresponding
	// to events that *didn't* occur.
	bitSetCount := uintptr(0)
	for idx := range pfd {
		events := pfd[idx].REvents
		i, j := pfd[idx].FD/8, uint(pfd[idx].FD%8)
		m := byte(1) << j
		if r[i]&m != 0 {
			if (events & selectReadEvents) != 0 {
			} else {
				r[i] &^= m
		if w[i]&m != 0 {
			if (events & selectWriteEvents) != 0 {
			} else {
				w[i] &^= m
		if e[i]&m != 0 {
			if (events & selectExceptEvents) != 0 {
			} else {
				e[i] &^= m

	// Copy updated vectors back.
	if readFDs != 0 {
		if _, err := t.CopyOut(readFDs, r); err != nil {
			return 0, err

	if writeFDs != 0 {
		if _, err := t.CopyOut(writeFDs, w); err != nil {
			return 0, err

	if exceptFDs != 0 {
		if _, err := t.CopyOut(exceptFDs, e); err != nil {
			return 0, err

	return bitSetCount, nil

// timeoutRemaining returns the amount of time remaining for the specified
// timeout or 0 if it has elapsed.
// startNs must be from CLOCK_MONOTONIC.
func timeoutRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration) time.Duration {
	now := t.Kernel().MonotonicClock().Now()
	remaining := timeout - now.Sub(startNs)
	if remaining < 0 {
		remaining = 0
	return remaining

// copyOutTimespecRemaining copies the time remaining in timeout to timespecAddr.
// startNs must be from CLOCK_MONOTONIC.
func copyOutTimespecRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timespecAddr usermem.Addr) error {
	if timeout <= 0 {
		return nil
	remaining := timeoutRemaining(t, startNs, timeout)
	tsRemaining := linux.NsecToTimespec(remaining.Nanoseconds())
	return copyTimespecOut(t, timespecAddr, &tsRemaining)

// copyOutTimevalRemaining copies the time remaining in timeout to timevalAddr.
// startNs must be from CLOCK_MONOTONIC.
func copyOutTimevalRemaining(t *kernel.Task, startNs ktime.Time, timeout time.Duration, timevalAddr usermem.Addr) error {
	if timeout <= 0 {
		return nil
	remaining := timeoutRemaining(t, startNs, timeout)
	tvRemaining := linux.NsecToTimeval(remaining.Nanoseconds())
	return copyTimevalOut(t, timevalAddr, &tvRemaining)

// pollRestartBlock encapsulates the state required to restart poll(2) via
// restart_syscall(2).
type pollRestartBlock struct {
	pfdAddr usermem.Addr
	nfds    uint
	timeout time.Duration

// Restart implements kernel.SyscallRestartBlock.Restart.
func (p *pollRestartBlock) Restart(t *kernel.Task) (uintptr, error) {
	return poll(t, p.pfdAddr, p.nfds, p.timeout)

func poll(t *kernel.Task, pfdAddr usermem.Addr, nfds uint, timeout time.Duration) (uintptr, error) {
	remainingTimeout, n, err := doPoll(t, pfdAddr, nfds, timeout)
	// On an interrupt poll(2) is restarted with the remaining timeout.
	if err == syserror.EINTR {
			pfdAddr: pfdAddr,
			nfds:    nfds,
			timeout: remainingTimeout,
		return 0, kernel.ERESTART_RESTARTBLOCK
	return n, err

// Poll implements linux syscall poll(2).
func Poll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
	pfdAddr := args[0].Pointer()
	nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
	timeout := time.Duration(args[2].Int()) * time.Millisecond
	n, err := poll(t, pfdAddr, nfds, timeout)
	return n, nil, err

// Ppoll implements linux syscall ppoll(2).
func Ppoll(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
	pfdAddr := args[0].Pointer()
	nfds := uint(args[1].Uint()) // poll(2) uses unsigned long.
	timespecAddr := args[2].Pointer()
	maskAddr := args[3].Pointer()
	maskSize := uint(args[4].Uint())

	timeout, err := copyTimespecInToDuration(t, timespecAddr)
	if err != nil {
		return 0, nil, err

	var startNs ktime.Time
	if timeout > 0 {
		startNs = t.Kernel().MonotonicClock().Now()

	if maskAddr != 0 {
		mask, err := copyInSigSet(t, maskAddr, maskSize)
		if err != nil {
			return 0, nil, err

		oldmask := t.SignalMask()

	_, n, err := doPoll(t, pfdAddr, nfds, timeout)
	copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
	// doPoll returns EINTR if interrupted, but ppoll is normally restartable
	// if interrupted by something other than a signal handled by the
	// application (i.e. returns ERESTARTNOHAND). However, if
	// copyOutTimespecRemaining failed, then the restarted ppoll would use the
	// wrong timeout, so the error should be left as EINTR.
	// Note that this means that if err is nil but copyErr is not, copyErr is
	// ignored. This is consistent with Linux.
	if err == syserror.EINTR && copyErr == nil {
		err = kernel.ERESTARTNOHAND
	return n, nil, err

// Select implements linux syscall select(2).
func Select(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
	nfds := int(args[0].Int()) // select(2) uses an int.
	readFDs := args[1].Pointer()
	writeFDs := args[2].Pointer()
	exceptFDs := args[3].Pointer()
	timevalAddr := args[4].Pointer()

	// Use a negative Duration to indicate "no timeout".
	timeout := time.Duration(-1)
	if timevalAddr != 0 {
		timeval, err := copyTimevalIn(t, timevalAddr)
		if err != nil {
			return 0, nil, err
		if timeval.Sec < 0 || timeval.Usec < 0 {
			return 0, nil, syserror.EINVAL
		timeout = time.Duration(timeval.ToNsecCapped())
	startNs := t.Kernel().MonotonicClock().Now()
	n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
	copyErr := copyOutTimevalRemaining(t, startNs, timeout, timevalAddr)
	// See comment in Ppoll.
	if err == syserror.EINTR && copyErr == nil {
		err = kernel.ERESTARTNOHAND
	return n, nil, err

// Pselect implements linux syscall pselect(2).
func Pselect(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
	nfds := int(args[0].Int()) // select(2) uses an int.
	readFDs := args[1].Pointer()
	writeFDs := args[2].Pointer()
	exceptFDs := args[3].Pointer()
	timespecAddr := args[4].Pointer()
	maskWithSizeAddr := args[5].Pointer()

	timeout, err := copyTimespecInToDuration(t, timespecAddr)
	if err != nil {
		return 0, nil, err

	var startNs ktime.Time
	if timeout > 0 {
		startNs = t.Kernel().MonotonicClock().Now()

	if maskWithSizeAddr != 0 {
		maskAddr, size, err := copyInSigSetWithSize(t, maskWithSizeAddr)
		if err != nil {
			return 0, nil, err

		if maskAddr != 0 {
			mask, err := copyInSigSet(t, maskAddr, size)
			if err != nil {
				return 0, nil, err
			oldmask := t.SignalMask()

	n, err := doSelect(t, nfds, readFDs, writeFDs, exceptFDs, timeout)
	copyErr := copyOutTimespecRemaining(t, startNs, timeout, timespecAddr)
	// See comment in Ppoll.
	if err == syserror.EINTR && copyErr == nil {
		err = kernel.ERESTARTNOHAND
	return n, nil, err