// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <arpa/inet.h>
#include <netinet/in.h>
#include <sys/sendfile.h>
#include <sys/socket.h>
#include <unistd.h>

#include <iostream>
#include <vector>

#include "gtest/gtest.h"
#include "absl/strings/string_view.h"
#include "test/syscalls/linux/socket_test_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
#include "test/util/thread_util.h"

namespace gvisor {
namespace testing {
namespace {

class SendFileTest : public ::testing::TestWithParam<int> {
 protected:
  PosixErrorOr<std::tuple<int, int>> Sockets() {
    // Bind a server socket.
    int family = GetParam();
    struct sockaddr server_addr = {};
    switch (family) {
      case AF_INET: {
        struct sockaddr_in *server_addr_in =
            reinterpret_cast<struct sockaddr_in *>(&server_addr);
        server_addr_in->sin_family = family;
        server_addr_in->sin_addr.s_addr = INADDR_ANY;
        break;
      }
      case AF_UNIX: {
        struct sockaddr_un *server_addr_un =
            reinterpret_cast<struct sockaddr_un *>(&server_addr);
        server_addr_un->sun_family = family;
        server_addr_un->sun_path[0] = '\0';
        break;
      }
      default:
        return PosixError(EINVAL);
    }
    int server = socket(family, SOCK_STREAM, 0);
    if (bind(server, &server_addr, sizeof(server_addr)) < 0) {
      return PosixError(errno);
    }
    if (listen(server, 1) < 0) {
      close(server);
      return PosixError(errno);
    }

    // Fetch the address; both are anonymous.
    socklen_t length = sizeof(server_addr);
    if (getsockname(server, &server_addr, &length) < 0) {
      close(server);
      return PosixError(errno);
    }

    // Connect the client.
    int client = socket(family, SOCK_STREAM, 0);
    if (connect(client, &server_addr, length) < 0) {
      close(server);
      close(client);
      return PosixError(errno);
    }

    // Accept on the server.
    int server_client = accept(server, nullptr, 0);
    if (server_client < 0) {
      close(server);
      close(client);
      return PosixError(errno);
    }
    close(server);
    return std::make_tuple(client, server_client);
  }
};

// Sends large file to exercise the path that read and writes data multiple
// times, esp. when more data is read than can be written.
TEST_P(SendFileTest, SendMultiple) {
  std::vector<char> data(5 * 1024 * 1024);
  RandomizeBuffer(data.data(), data.size());

  // Create temp files.
  const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
      GetAbsoluteTestTmpdir(), absl::string_view(data.data(), data.size()),
      TempPath::kDefaultFileMode));
  const TempPath out_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());

  // Create sockets.
  std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets());
  const FileDescriptor server(std::get<0>(fds));
  FileDescriptor client(std::get<1>(fds));  // non-const, reset is used.

  // Thread that reads data from socket and dumps to a file.
  ScopedThread th([&] {
    FileDescriptor outf =
        ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_WRONLY));

    // Read until socket is closed.
    char buf[10240];
    for (int cnt = 0;; cnt++) {
      int r = RetryEINTR(read)(server.get(), buf, sizeof(buf));
      // We cannot afford to save on every read() call.
      if (cnt % 1000 == 0) {
        ASSERT_THAT(r, SyscallSucceeds());
      } else {
        const DisableSave ds;
        ASSERT_THAT(r, SyscallSucceeds());
      }
      if (r == 0) {
        // EOF
        break;
      }
      int w = RetryEINTR(write)(outf.get(), buf, r);
      // We cannot afford to save on every write() call.
      if (cnt % 1010 == 0) {
        ASSERT_THAT(w, SyscallSucceedsWithValue(r));
      } else {
        const DisableSave ds;
        ASSERT_THAT(w, SyscallSucceedsWithValue(r));
      }
    }
  });

  // Open the input file as read only.
  const FileDescriptor inf =
      ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));

  int cnt = 0;
  for (size_t sent = 0; sent < data.size(); cnt++) {
    const size_t remain = data.size() - sent;
    std::cout << "sendfile, size=" << data.size() << ", sent=" << sent
              << ", remain=" << remain;

    // Send data and verify that sendfile returns the correct value.
    int res = sendfile(client.get(), inf.get(), nullptr, remain);
    // We cannot afford to save on every sendfile() call.
    if (cnt % 120 == 0) {
      MaybeSave();
    }
    if (res == 0) {
      // EOF
      break;
    }
    if (res > 0) {
      sent += res;
    } else {
      ASSERT_TRUE(errno == EINTR || errno == EAGAIN) << "errno=" << errno;
    }
  }

  // Close socket to stop thread.
  client.reset();
  th.Join();

  // Verify that the output file has the correct data.
  const FileDescriptor outf =
      ASSERT_NO_ERRNO_AND_VALUE(Open(out_file.path(), O_RDONLY));
  std::vector<char> actual(data.size(), '\0');
  ASSERT_THAT(RetryEINTR(read)(outf.get(), actual.data(), actual.size()),
              SyscallSucceedsWithValue(actual.size()));
  ASSERT_EQ(memcmp(data.data(), actual.data(), data.size()), 0);
}

TEST_P(SendFileTest, Shutdown) {
  // Create a socket.
  std::tuple<int, int> fds = ASSERT_NO_ERRNO_AND_VALUE(Sockets());
  const FileDescriptor client(std::get<0>(fds));
  FileDescriptor server(std::get<1>(fds));  // non-const, released below.

  // If this is a TCP socket, then turn off linger.
  if (GetParam() == AF_INET) {
    struct linger sl;
    sl.l_onoff = 1;
    sl.l_linger = 0;
    ASSERT_THAT(
        setsockopt(server.get(), SOL_SOCKET, SO_LINGER, &sl, sizeof(sl)),
        SyscallSucceeds());
  }

  // Create a 1m file with random data.
  std::vector<char> data(1024 * 1024);
  RandomizeBuffer(data.data(), data.size());
  const TempPath in_file = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
      GetAbsoluteTestTmpdir(), absl::string_view(data.data(), data.size()),
      TempPath::kDefaultFileMode));
  const FileDescriptor inf =
      ASSERT_NO_ERRNO_AND_VALUE(Open(in_file.path(), O_RDONLY));

  // Read some data, then shutdown the socket. We don't actually care about
  // checking the contents (other tests do that), so we just re-use the same
  // buffer as above.
  ScopedThread t([&]() {
    int done = 0;
    while (done < data.size()) {
      int n = read(server.get(), data.data(), data.size());
      ASSERT_THAT(n, SyscallSucceeds());
      done += n;
    }
    // Close the server side socket.
    ASSERT_THAT(close(server.release()), SyscallSucceeds());
  });

  // Continuously stream from the file to the socket. Note we do not assert
  // that a specific amount of data has been written at any time, just that some
  // data is written. Eventually, we should get a connection reset error.
  while (1) {
    off_t offset = 0;  // Always read from the start.
    int n = sendfile(client.get(), inf.get(), &offset, data.size());
    EXPECT_THAT(n, AnyOf(SyscallFailsWithErrno(ECONNRESET),
                         SyscallFailsWithErrno(EPIPE), SyscallSucceeds()));
    if (n <= 0) {
      break;
    }
  }
}

INSTANTIATE_TEST_SUITE_P(AddressFamily, SendFileTest,
                         ::testing::Values(AF_UNIX, AF_INET));

}  // namespace
}  // namespace testing
}  // namespace gvisor