/* tinyproxy - A fast light-weight HTTP proxy
 * Copyright (C) 2000, 2002 Robert James Kaes <rjkaes@users.sourceforge.net>
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License along
 * with this program; if not, write to the Free Software Foundation, Inc.,
 * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
 */

/* This system handles Access Control for use of this daemon. A list of
 * domains, or IP addresses (including IP blocks) are stored in a list
 * which is then used to compare incoming connections.
 */

#include "tinyproxy.h"

#include "acl.h"
#include "heap.h"
#include "log.h"
#include "network.h"
#include "sock.h"
#include "vector.h"

/* Define how long an IPv6 address is in bytes (128 bits, 16 bytes) */
#define IPV6_LEN 16

enum acl_type
{
  ACL_STRING,
  ACL_NUMERIC,
};

/*
 * Hold the information about a particular access control.  We store
 * whether it's an ALLOW or DENY entry, and also whether it's a string
 * entry (like a domain name) or an IP entry.
 */
struct acl_s
{
  acl_access_t access;
  enum acl_type type;
  union
  {
    char *string;
    struct
    {
      unsigned char octet[IPV6_LEN];
      unsigned char mask[IPV6_LEN];
    } ip;
  } address;
};

/*
 * All the access lists are stored in a vector.
 */
static vector_t access_list = NULL;


/*
 * Fills in the netmask array given a numeric value.
 *
 * Returns:
 *   0 on success
 *  -1 on failure (invalid mask value)
 *
 */
inline static int
fill_netmask_array (char *bitmask_string, unsigned char array[],
                    unsigned int len)
{
  unsigned int i;
  long int mask;
  char *endptr;

  errno = 0;                    /* to distinguish success/failure after call */
  mask = strtol (bitmask_string, &endptr, 10);

  /* check for various conversion errors */
  if ((errno == ERANGE && (mask == LONG_MIN || mask == LONG_MAX))
      || (errno != 0 && mask == 0) || (endptr == bitmask_string))
    return -1;

  /* valid range for a bit mask */
  if (mask < 0 || mask > (8 * len))
    return -1;

  /* we have a valid range to fill in the array */
  for (i = 0; i != len; ++i)
    {
      if (mask >= 8)
        {
          array[i] = 0xff;
          mask -= 8;
        }
      else if (mask > 0)
        {
          array[i] = (unsigned char) (0xff << (8 - mask));
          mask = 0;
        }
      else
        {
          array[i] = 0;
        }
    }

  return 0;
}


/*
 * Inserts a new access control into the list. The function will figure out
 * whether the location is an IP address (with optional netmask) or a
 * domain name.
 *
 * Returns:
 *    -1 on failure
 *     0 otherwise.
 */
int
insert_acl (char *location, acl_access_t access_type)
{
  struct acl_s acl;
  int ret;
  char *p, ip_dst[IPV6_LEN];

  assert (location != NULL);

  /*
   * If the access list has not been set up, create it.
   */
  if (!access_list)
    {
      access_list = vector_create ();
      if (!access_list)
        {
          log_message (LOG_ERR, "Unable to allocate memory for access list");
          return -1;
        }
    }

  /*
   * Start populating the access control structure.
   */
  memset (&acl, 0, sizeof (struct acl_s));
  acl.access = access_type;

  /*
   * Check for a valid IP address (the simplest case) first.
   */
  if (full_inet_pton (location, ip_dst) > 0)
    {
      acl.type = ACL_NUMERIC;
      memcpy (acl.address.ip.octet, ip_dst, IPV6_LEN);
      memset (acl.address.ip.mask, 0xff, IPV6_LEN);
    }
  else
    {
      /*
       * At this point we're either a hostname or an
       * IP address with a slash.
       */
      p = strchr (location, '/');
      if (p != NULL)
        {
          /*
           * We have a slash, so it's intended to be an
           * IP address with mask
           */
          *p = '\0';
          if (full_inet_pton (location, ip_dst) <= 0)
            return -1;

          acl.type = ACL_NUMERIC;
          memcpy (acl.address.ip.octet, ip_dst, IPV6_LEN);

          if (fill_netmask_array (p + 1, &(acl.address.ip.mask[0]), IPV6_LEN)
              < 0)
            return -1;
        }
      else
        {
          /* In all likelihood a string */
          acl.type = ACL_STRING;
          acl.address.string = safestrdup (location);
          if (!acl.address.string)
            return -1;
        }
    }

  /*
   * Add the entry and then clean up.
   */
  ret = vector_append (access_list, &acl, sizeof (struct acl_s));
  safefree (acl.address.string);
  return ret;
}

/*
 * This function is called whenever a "string" access control is found in
 * the ACL.  From here we do both a text based string comparison, along with
 * a reverse name lookup comparison of the IP addresses.
 *
 * Return: 0 if host is denied
 *         1 if host is allowed
 *        -1 if no tests match, so skip
 */
static int
acl_string_processing (struct acl_s *acl,
                       const char *ip_address, const char *string_address)
{
  int match;
  struct addrinfo hints, *res, *ressave;
  size_t test_length, match_length;
  char ipbuf[512];

  assert (acl && acl->type == ACL_STRING);
  assert (ip_address && strlen (ip_address) > 0);
  assert (string_address && strlen (string_address) > 0);

  /*
   * If the first character of the ACL string is a period, we need to
   * do a string based test only; otherwise, we can do a reverse
   * lookup test as well.
   */
  if (acl->address.string[0] != '.')
    {
      memset (&hints, 0, sizeof (struct addrinfo));
      hints.ai_family = AF_UNSPEC;
      hints.ai_socktype = SOCK_STREAM;
      if (getaddrinfo (acl->address.string, NULL, &hints, &res) != 0)
        goto STRING_TEST;

      ressave = res;

      match = FALSE;
      do
        {
          get_ip_string (res->ai_addr, ipbuf, sizeof (ipbuf));
          if (strcmp (ip_address, ipbuf) == 0)
            {
              match = TRUE;
              break;
            }
        }
      while ((res = res->ai_next) != NULL);

      freeaddrinfo (ressave);

      if (match)
        {
          if (acl->access == ACL_DENY)
            return 0;
          else
            return 1;
        }
    }

STRING_TEST:
  test_length = strlen (string_address);
  match_length = strlen (acl->address.string);

  /*
   * If the string length is shorter than AC string, return a -1 so
   * that the "driver" will skip onto the next control in the list.
   */
  if (test_length < match_length)
    return -1;

  if (strcasecmp
      (string_address + (test_length - match_length),
       acl->address.string) == 0)
    {
      if (acl->access == ACL_DENY)
        return 0;
      else
        return 1;
    }

  /* Indicate that no tests succeeded, so skip to next control. */
  return -1;
}

/*
 * Compare the supplied numeric IP address with the supplied ACL structure.
 *
 * Return:
 *   1  IP address is allowed
 *   0  IP address is denied
 *  -1  neither allowed nor denied.
 */
static int
check_numeric_acl (const struct acl_s *acl, const char *ip)
{
  uint8_t addr[IPV6_LEN], x, y;
  int i;

  assert (acl && acl->type == ACL_NUMERIC);
  assert (ip && strlen (ip) > 0);

  if (full_inet_pton (ip, &addr) <= 0)
    return -1;

  for (i = 0; i != IPV6_LEN; ++i)
    {
      x = addr[i] & acl->address.ip.mask[i];
      y = acl->address.ip.octet[i] & acl->address.ip.mask[i];

      /* If x and y don't match, the IP addresses don't match */
      if (x != y)
        return 0;
    }

  /* The addresses match, return the permission */
  return (acl->access == ACL_ALLOW);
}

/*
 * Checks whether a connection is allowed.
 *
 * Returns:
 *     1 if allowed
 *     0 if denied
 */
int
check_acl (const char *ip, const char *host)
{
  struct acl_s *acl;
  int perm = 0;
  size_t i;

  assert (ip != NULL);
  assert (host != NULL);

  /*
   * If there is no access list allow everything.
   */
  if (!access_list)
    return 1;

  for (i = 0; i != (size_t)vector_length (access_list); ++i)
    {
      acl = (struct acl_s *)vector_getentry (access_list, i, NULL);
      switch (acl->type)
        {
        case ACL_STRING:
          perm = acl_string_processing (acl, ip, host);
          break;

        case ACL_NUMERIC:
          if (ip[0] == '\0')
            continue;
          perm = check_numeric_acl (acl, ip);
          break;
        }

      /*
       * Check the return value too see if the IP address is
       * allowed or denied.
       */
      if (perm == 0)
        break;
      else if (perm == 1)
        return perm;
    }

  /*
   * Deny all connections by default.
   */
  log_message (LOG_NOTICE, "Unauthorized connection from \"%s\" [%s].",
               host, ip);
  return 0;
}