// Copyright 2018 The gVisor Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef GVISOR_TEST_SYSCALLS_SOCKET_TEST_UTIL_H_
#define GVISOR_TEST_SYSCALLS_SOCKET_TEST_UTIL_H_

#include <errno.h>
#include <netinet/ip.h>
#include <netinet/ip_icmp.h>
#include <netinet/udp.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <sys/un.h>

#include <functional>
#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "gtest/gtest.h"
#include "absl/strings/str_format.h"
#include "test/util/file_descriptor.h"
#include "test/util/posix_error.h"
#include "test/util/test_util.h"

namespace gvisor {
namespace testing {

// Wrapper for socket(2) that returns a FileDescriptor.
inline PosixErrorOr<FileDescriptor> Socket(int family, int type, int protocol) {
  int fd = socket(family, type, protocol);
  MaybeSave();
  if (fd < 0) {
    return PosixError(
        errno, absl::StrFormat("socket(%d, %d, %d)", family, type, protocol));
  }
  return FileDescriptor(fd);
}

// Wrapper for accept(2) that returns a FileDescriptor.
inline PosixErrorOr<FileDescriptor> Accept(int sockfd, sockaddr* addr,
                                           socklen_t* addrlen) {
  int fd = RetryEINTR(accept)(sockfd, addr, addrlen);
  MaybeSave();
  if (fd < 0) {
    return PosixError(
        errno, absl::StrFormat("accept(%d, %p, %p)", sockfd, addr, addrlen));
  }
  return FileDescriptor(fd);
}

// Wrapper for accept4(2) that returns a FileDescriptor.
inline PosixErrorOr<FileDescriptor> Accept4(int sockfd, sockaddr* addr,
                                            socklen_t* addrlen, int flags) {
  int fd = RetryEINTR(accept4)(sockfd, addr, addrlen, flags);
  MaybeSave();
  if (fd < 0) {
    return PosixError(errno, absl::StrFormat("accept4(%d, %p, %p, %#x)", sockfd,
                                             addr, addrlen, flags));
  }
  return FileDescriptor(fd);
}

inline ssize_t SendFd(int fd, void* buf, size_t count, int flags) {
  return internal::ApplyFileIoSyscall(
      [&](size_t completed) {
        return sendto(fd, static_cast<char*>(buf) + completed,
                      count - completed, flags, nullptr, 0);
      },
      count);
}

PosixErrorOr<struct sockaddr_un> UniqueUnixAddr(bool abstract, int domain);

// A Creator<T> is a function that attempts to create and return a new T. (This
// is copy/pasted from cloud/gvisor/api/sandbox_util.h and is just duplicated
// here for clarity.)
template <typename T>
using Creator = std::function<PosixErrorOr<std::unique_ptr<T>>()>;

// A SocketPair represents a pair of socket file descriptors owned by the
// SocketPair.
class SocketPair {
 public:
  virtual ~SocketPair() = default;

  virtual int first_fd() const = 0;
  virtual int second_fd() const = 0;
  virtual int release_first_fd() = 0;
  virtual int release_second_fd() = 0;
  virtual const struct sockaddr* first_addr() const = 0;
  virtual const struct sockaddr* second_addr() const = 0;
  virtual size_t first_addr_size() const = 0;
  virtual size_t second_addr_size() const = 0;
  virtual size_t first_addr_len() const = 0;
  virtual size_t second_addr_len() const = 0;
};

// A FDSocketPair is a SocketPair that consists of only a pair of file
// descriptors.
class FDSocketPair : public SocketPair {
 public:
  FDSocketPair(int first_fd, int second_fd)
      : first_(first_fd), second_(second_fd) {}
  FDSocketPair(std::unique_ptr<FileDescriptor> first_fd,
               std::unique_ptr<FileDescriptor> second_fd)
      : first_(first_fd->release()), second_(second_fd->release()) {}

  int first_fd() const override { return first_.get(); }
  int second_fd() const override { return second_.get(); }
  int release_first_fd() override { return first_.release(); }
  int release_second_fd() override { return second_.release(); }
  const struct sockaddr* first_addr() const override { return nullptr; }
  const struct sockaddr* second_addr() const override { return nullptr; }
  size_t first_addr_size() const override { return 0; }
  size_t second_addr_size() const override { return 0; }
  size_t first_addr_len() const override { return 0; }
  size_t second_addr_len() const override { return 0; }

 private:
  FileDescriptor first_;
  FileDescriptor second_;
};

// CalculateUnixSockAddrLen calculates the length returned by recvfrom(2) and
// recvmsg(2) for Unix sockets.
size_t CalculateUnixSockAddrLen(const char* sun_path);

// A AddrFDSocketPair is a SocketPair that consists of a pair of file
// descriptors in addition to a pair of socket addresses.
class AddrFDSocketPair : public SocketPair {
 public:
  AddrFDSocketPair(int first_fd, int second_fd,
                   const struct sockaddr_un& first_address,
                   const struct sockaddr_un& second_address)
      : first_(first_fd),
        second_(second_fd),
        first_addr_(to_storage(first_address)),
        second_addr_(to_storage(second_address)),
        first_len_(CalculateUnixSockAddrLen(first_address.sun_path)),
        second_len_(CalculateUnixSockAddrLen(second_address.sun_path)),
        first_size_(sizeof(first_address)),
        second_size_(sizeof(second_address)) {}

  AddrFDSocketPair(int first_fd, int second_fd,
                   const struct sockaddr_in& first_address,
                   const struct sockaddr_in& second_address)
      : first_(first_fd),
        second_(second_fd),
        first_addr_(to_storage(first_address)),
        second_addr_(to_storage(second_address)),
        first_len_(sizeof(first_address)),
        second_len_(sizeof(second_address)),
        first_size_(sizeof(first_address)),
        second_size_(sizeof(second_address)) {}

  AddrFDSocketPair(int first_fd, int second_fd,
                   const struct sockaddr_in6& first_address,
                   const struct sockaddr_in6& second_address)
      : first_(first_fd),
        second_(second_fd),
        first_addr_(to_storage(first_address)),
        second_addr_(to_storage(second_address)),
        first_len_(sizeof(first_address)),
        second_len_(sizeof(second_address)),
        first_size_(sizeof(first_address)),
        second_size_(sizeof(second_address)) {}

  int first_fd() const override { return first_.get(); }
  int second_fd() const override { return second_.get(); }
  int release_first_fd() override { return first_.release(); }
  int release_second_fd() override { return second_.release(); }
  const struct sockaddr* first_addr() const override {
    return reinterpret_cast<const struct sockaddr*>(&first_addr_);
  }
  const struct sockaddr* second_addr() const override {
    return reinterpret_cast<const struct sockaddr*>(&second_addr_);
  }
  size_t first_addr_size() const override { return first_size_; }
  size_t second_addr_size() const override { return second_size_; }
  size_t first_addr_len() const override { return first_len_; }
  size_t second_addr_len() const override { return second_len_; }

 private:
  // to_storage coverts a sockaddr_* to a sockaddr_storage.
  static struct sockaddr_storage to_storage(const sockaddr_un& addr);
  static struct sockaddr_storage to_storage(const sockaddr_in& addr);
  static struct sockaddr_storage to_storage(const sockaddr_in6& addr);

  FileDescriptor first_;
  FileDescriptor second_;
  const struct sockaddr_storage first_addr_;
  const struct sockaddr_storage second_addr_;
  const size_t first_len_;
  const size_t second_len_;
  const size_t first_size_;
  const size_t second_size_;
};

// SyscallSocketPairCreator returns a Creator<SocketPair> that obtains file
// descriptors by invoking the socketpair() syscall.
Creator<SocketPair> SyscallSocketPairCreator(int domain, int type,
                                             int protocol);

// SyscallSocketCreator returns a Creator<FileDescriptor> that obtains a file
// descriptor by invoking the socket() syscall.
Creator<FileDescriptor> SyscallSocketCreator(int domain, int type,
                                             int protocol);

// FilesystemBidirectionalBindSocketPairCreator returns a Creator<SocketPair>
// that obtains file descriptors by invoking the bind() and connect() syscalls
// on filesystem paths. Only works for DGRAM sockets.
Creator<SocketPair> FilesystemBidirectionalBindSocketPairCreator(int domain,
                                                                 int type,
                                                                 int protocol);

// AbstractBidirectionalBindSocketPairCreator returns a Creator<SocketPair> that
// obtains file descriptors by invoking the bind() and connect() syscalls on
// abstract namespace paths. Only works for DGRAM sockets.
Creator<SocketPair> AbstractBidirectionalBindSocketPairCreator(int domain,
                                                               int type,
                                                               int protocol);

// SocketpairGoferSocketPairCreator returns a Creator<SocketPair> that
// obtains file descriptors by connect() syscalls on two sockets with socketpair
// gofer paths.
Creator<SocketPair> SocketpairGoferSocketPairCreator(int domain, int type,
                                                     int protocol);

// SocketpairGoferFileSocketPairCreator returns a Creator<SocketPair> that
// obtains file descriptors by open() syscalls on socketpair gofer paths.
Creator<SocketPair> SocketpairGoferFileSocketPairCreator(int flags);

// FilesystemAcceptBindSocketPairCreator returns a Creator<SocketPair> that
// obtains file descriptors by invoking the accept() and bind() syscalls on
// a filesystem path. Only works for STREAM and SEQPACKET sockets.
Creator<SocketPair> FilesystemAcceptBindSocketPairCreator(int domain, int type,
                                                          int protocol);

// AbstractAcceptBindSocketPairCreator returns a Creator<SocketPair> that
// obtains file descriptors by invoking the accept() and bind() syscalls on a
// abstract namespace path. Only works for STREAM and SEQPACKET sockets.
Creator<SocketPair> AbstractAcceptBindSocketPairCreator(int domain, int type,
                                                        int protocol);

// FilesystemUnboundSocketPairCreator returns a Creator<SocketPair> that obtains
// file descriptors by invoking the socket() syscall and generates a filesystem
// path for binding.
Creator<SocketPair> FilesystemUnboundSocketPairCreator(int domain, int type,
                                                       int protocol);

// AbstractUnboundSocketPairCreator returns a Creator<SocketPair> that obtains
// file descriptors by invoking the socket() syscall and generates an abstract
// path for binding.
Creator<SocketPair> AbstractUnboundSocketPairCreator(int domain, int type,
                                                     int protocol);

// TCPAcceptBindSocketPairCreator returns a Creator<SocketPair> that obtains
// file descriptors by invoking the accept() and bind() syscalls on TCP sockets.
Creator<SocketPair> TCPAcceptBindSocketPairCreator(int domain, int type,
                                                   int protocol,
                                                   bool dual_stack);

// TCPAcceptBindPersistentListenerSocketPairCreator is like
// TCPAcceptBindSocketPairCreator, except it uses the same listening socket to
// create all SocketPairs.
Creator<SocketPair> TCPAcceptBindPersistentListenerSocketPairCreator(
    int domain, int type, int protocol, bool dual_stack);

// UDPBidirectionalBindSocketPairCreator returns a Creator<SocketPair> that
// obtains file descriptors by invoking the bind() and connect() syscalls on UDP
// sockets.
Creator<SocketPair> UDPBidirectionalBindSocketPairCreator(int domain, int type,
                                                          int protocol,
                                                          bool dual_stack);

// UDPUnboundSocketPairCreator returns a Creator<SocketPair> that obtains file
// descriptors by creating UDP sockets.
Creator<SocketPair> UDPUnboundSocketPairCreator(int domain, int type,
                                                int protocol, bool dual_stack);

// UnboundSocketCreator returns a Creator<FileDescriptor> that obtains a file
// descriptor by creating a socket.
Creator<FileDescriptor> UnboundSocketCreator(int domain, int type,
                                             int protocol);

// A SocketPairKind couples a human-readable description of a socket pair with
// a function that creates such a socket pair.
struct SocketPairKind {
  std::string description;
  int domain;
  int type;
  int protocol;
  Creator<SocketPair> creator;

  // Create creates a socket pair of this kind.
  PosixErrorOr<std::unique_ptr<SocketPair>> Create() const { return creator(); }
};

// A SocketKind couples a human-readable description of a socket with
// a function that creates such a socket.
struct SocketKind {
  std::string description;
  int domain;
  int type;
  int protocol;
  Creator<FileDescriptor> creator;

  // Create creates a socket pair of this kind.
  PosixErrorOr<std::unique_ptr<FileDescriptor>> Create() const {
    return creator();
  }
};

// A ReversedSocketPair wraps another SocketPair but flips the first and second
// file descriptors. ReversedSocketPair is used to test socket pairs that
// should be symmetric.
class ReversedSocketPair : public SocketPair {
 public:
  explicit ReversedSocketPair(std::unique_ptr<SocketPair> base)
      : base_(std::move(base)) {}

  int first_fd() const override { return base_->second_fd(); }
  int second_fd() const override { return base_->first_fd(); }
  int release_first_fd() override { return base_->release_second_fd(); }
  int release_second_fd() override { return base_->release_first_fd(); }
  const struct sockaddr* first_addr() const override {
    return base_->second_addr();
  }
  const struct sockaddr* second_addr() const override {
    return base_->first_addr();
  }
  size_t first_addr_size() const override { return base_->second_addr_size(); }
  size_t second_addr_size() const override { return base_->first_addr_size(); }
  size_t first_addr_len() const override { return base_->second_addr_len(); }
  size_t second_addr_len() const override { return base_->first_addr_len(); }

 private:
  std::unique_ptr<SocketPair> base_;
};

// Reversed returns a SocketPairKind that represents SocketPairs created by
// flipping the file descriptors provided by another SocketPair.
SocketPairKind Reversed(SocketPairKind const& base);

// IncludeReversals returns a vector<SocketPairKind> that returns all
// SocketPairKinds in `vec` as well as all SocketPairKinds obtained by flipping
// the file descriptors provided by the kinds in `vec`.
std::vector<SocketPairKind> IncludeReversals(std::vector<SocketPairKind> vec);

// A Middleware is a function wraps a SocketPairKind.
using Middleware = std::function<SocketPairKind(SocketPairKind)>;

// Reversed returns a SocketPairKind that represents SocketPairs created by
// flipping the file descriptors provided by another SocketPair.
template <typename T>
Middleware SetSockOpt(int level, int optname, T* value) {
  return [=](SocketPairKind const& base) {
    auto const& creator = base.creator;
    return SocketPairKind{
        absl::StrCat("setsockopt(", level, ", ", optname, ", ", *value, ") ",
                     base.description),
        base.domain, base.type, base.protocol,
        [creator, level, optname,
         value]() -> PosixErrorOr<std::unique_ptr<SocketPair>> {
          ASSIGN_OR_RETURN_ERRNO(auto creator_value, creator());
          if (creator_value->first_fd() >= 0) {
            RETURN_ERROR_IF_SYSCALL_FAIL(setsockopt(
                creator_value->first_fd(), level, optname, value, sizeof(T)));
          }
          if (creator_value->second_fd() >= 0) {
            RETURN_ERROR_IF_SYSCALL_FAIL(setsockopt(
                creator_value->second_fd(), level, optname, value, sizeof(T)));
          }
          return creator_value;
        }};
  };
}

constexpr int kSockOptOn = 1;
constexpr int kSockOptOff = 0;

// NoOp returns the same SocketPairKind that it is passed.
SocketPairKind NoOp(SocketPairKind const& base);

// TransferTest tests that data can be send back and fourth between two
// specified FDs. Note that calls to this function should be wrapped in
// ASSERT_NO_FATAL_FAILURE().
void TransferTest(int fd1, int fd2);

// Fills [buf, buf+len) with random bytes.
void RandomizeBuffer(char* buf, size_t len);

// Base test fixture for tests that operate on pairs of connected sockets.
class SocketPairTest : public ::testing::TestWithParam<SocketPairKind> {
 protected:
  SocketPairTest() {
    // gUnit uses printf, so so will we.
    printf("Testing with %s\n", GetParam().description.c_str());
    fflush(stdout);
  }

  PosixErrorOr<std::unique_ptr<SocketPair>> NewSocketPair() const {
    return GetParam().Create();
  }
};

// Base test fixture for tests that operate on simple Sockets.
class SimpleSocketTest : public ::testing::TestWithParam<SocketKind> {
 protected:
  SimpleSocketTest() {
    // gUnit uses printf, so so will we.
    printf("Testing with %s\n", GetParam().description.c_str());
  }

  PosixErrorOr<std::unique_ptr<FileDescriptor>> NewSocket() const {
    return GetParam().Create();
  }
};

SocketKind SimpleSocket(int fam, int type, int proto);

// Send a buffer of size 'size' to sockets->first_fd(), returning the result of
// sendmsg.
//
// If reader, read from second_fd() until size bytes have been read.
ssize_t SendLargeSendMsg(const std::unique_ptr<SocketPair>& sockets,
                         size_t size, bool reader);

// Initializes the given buffer with random data.
void RandomizeBuffer(char* ptr, size_t len);

enum class AddressFamily { kIpv4 = 1, kIpv6 = 2, kDualStack = 3 };
enum class SocketType { kUdp = 1, kTcp = 2 };

// Returns a PosixError or a port that is available. If 0 is specified as the
// port it will bind port 0 (and allow the kernel to select any free port).
// Otherwise, it will try to bind the specified port and validate that it can be
// used for the requested family and socket type. The final option is
// reuse_addr. This specifies whether SO_REUSEADDR should be applied before a
// bind(2) attempt. SO_REUSEADDR means that sockets in TIME_WAIT states or other
// bound UDP sockets would not cause an error on bind(2). This option should be
// set if subsequent calls to bind on the returned port will also use
// SO_REUSEADDR.
//
// Note: That this test will attempt to bind the ANY address for the respective
// protocol.
PosixErrorOr<int> PortAvailable(int port, AddressFamily family, SocketType type,
                                bool reuse_addr);

// FreeAvailablePort is used to return a port that was obtained by using
// the PortAvailable helper with port 0.
PosixError FreeAvailablePort(int port);

// SendMsg converts a buffer to an iovec and adds it to msg before sending it.
PosixErrorOr<int> SendMsg(int sock, msghdr* msg, char buf[], int buf_size);

// RecvNoData checks that no data is receivable on sock.
void RecvNoData(int sock);

// Base test fixture for tests that apply to all kinds of pairs of connected
// sockets.
using AllSocketPairTest = SocketPairTest;

struct TestAddress {
  std::string description;
  sockaddr_storage addr;
  socklen_t addr_len;

  int family() const { return addr.ss_family; }
  explicit TestAddress(std::string description = "")
      : description(std::move(description)), addr(), addr_len() {}
};

constexpr char kMulticastAddress[] = "224.0.2.1";
constexpr char kBroadcastAddress[] = "255.255.255.255";

TestAddress V4Any();
TestAddress V4Broadcast();
TestAddress V4Loopback();
TestAddress V4MappedAny();
TestAddress V4MappedLoopback();
TestAddress V4Multicast();
TestAddress V6Any();
TestAddress V6Loopback();

// Compute the internet checksum of an IP header.
uint16_t IPChecksum(struct iphdr ip);

// Compute the internet checksum of a UDP header.
uint16_t UDPChecksum(struct iphdr iphdr, struct udphdr udphdr,
                     const char* payload, ssize_t payload_len);

// Compute the internet checksum of an ICMP header.
uint16_t ICMPChecksum(struct icmphdr icmphdr, const char* payload,
                      ssize_t payload_len);

namespace internal {
PosixErrorOr<int> TryPortAvailable(int port, AddressFamily family,
                                   SocketType type, bool reuse_addr);
}  // namespace internal

}  // namespace testing
}  // namespace gvisor

#endif  // GVISOR_TEST_SYSCALLS_SOCKET_TEST_UTIL_H_