/* $Id: thread.c,v 1.30 2002-05-23 18:20:27 rjkaes Exp $
 *
 * Handles the creation/destruction of the various threads required for
 * processing incoming connections.
 *
 * Copyright (C) 2000 Robert James Kaes (rjkaes@flarenet.com)
 *
 * This program is free software; you can redistribute it and/or modify it
 * under the terms of the GNU General Public License as published by the
 * Free Software Foundation; either version 2, or (at your option) any
 * later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 */

#include "tinyproxy.h"

#include "heap.h"
#include "log.h"
#include "reqs.h"
#include "sock.h"
#include "thread.h"
#include "utils.h"

/*
 * This is the stack frame size used by all the threads. We'll start by
 * setting it to 32 KB.
 */
#define THREAD_STACK_SIZE (1024 * 32)

static int listenfd;
static socklen_t addrlen;

/*
 * Stores the internal data needed for each thread (connection)
 */
struct thread_s {
	pthread_t tid;
	enum { T_EMPTY, T_WAITING, T_CONNECTED } status;
	unsigned int connects;
};

/*
 * A pointer to an array of threads. A certain number of threads are
 * created when the program is started.
 */
static struct thread_s *thread_ptr;
static pthread_mutex_t mlock;

#define ACCEPT_LOCK() do { \
int accept_lock_ret = pthread_mutex_lock(&mlock); \
assert(accept_lock_ret == 0); \
} while (0)
#define ACCEPT_UNLOCK() do { \
int accept_lock_ret = pthread_mutex_unlock(&mlock); \
assert(accept_lock_ret == 0); \
} while (0)

/* Used to override the default statck size. */
static pthread_attr_t thread_attr;

static struct thread_config_s {
	unsigned int maxclients, maxrequestsperchild;
	unsigned int maxspareservers, minspareservers, startservers;
} thread_config;

static int servers_waiting = 0;	/* servers waiting for a connection */
static pthread_mutex_t servers_mutex;

#define SERVER_COUNT_LOCK()   do { \
int servers_mutex_ret = pthread_mutex_lock(&servers_mutex); \
assert(servers_mutex_ret == 0); \
} while (0)
#define SERVER_COUNT_UNLOCK() do { \
int servers_mutex_ret = pthread_mutex_unlock(&servers_mutex); \
assert(servers_mutex_ret == 0); \
} while (0)

#define SERVER_INC() do { \
    SERVER_COUNT_LOCK(); \
    ++servers_waiting; \
    DEBUG2("INC: servers_waiting: %d", servers_waiting); \
    SERVER_COUNT_UNLOCK(); \
} while (0)

#define SERVER_DEC() do { \
    SERVER_COUNT_LOCK(); \
    --servers_waiting; \
    DEBUG2("DEC: servers_waiting: %d", servers_waiting); \
    SERVER_COUNT_UNLOCK(); \
} while (0)

/*
 * Set the configuration values for the various thread related settings.
 */
short int
thread_configure(thread_config_t type, unsigned int val)
{
	switch (type) {
	case THREAD_MAXCLIENTS:
		thread_config.maxclients = val;
		break;
	case THREAD_MAXSPARESERVERS:
		thread_config.maxspareservers = val;
		break;
	case THREAD_MINSPARESERVERS:
		thread_config.minspareservers = val;
		break;
	case THREAD_STARTSERVERS:
		thread_config.startservers = val;
		break;
	case THREAD_MAXREQUESTSPERCHILD:
		thread_config.maxrequestsperchild = val;
		break;
	default:
		DEBUG2("Invalid type (%d)", type);
		return -1;
	}

	return 0;
}

/*
 * This is the main (per thread) loop.
 */
static void *
thread_main(void *arg)
{
	int connfd;
	struct sockaddr *cliaddr;
	socklen_t clilen;
	struct thread_s *ptr;

#ifdef HAVE_PTHREAD_CANCEL
	/* Set the cancelation type to immediate. */
	pthread_setcanceltype(PTHREAD_CANCEL_ASYNCHRONOUS, NULL);
#endif

	ptr = (struct thread_s *) arg;

	cliaddr = safemalloc(addrlen);
	if (!cliaddr)
		return NULL;

	ptr->connects = 0;

	while (!config.quit) {
		ptr->status = T_WAITING;

		clilen = addrlen;

		/*
		 * Check to see if the program is shutting down.
		 */
		if (config.quit)
			break;
		
		ACCEPT_LOCK();
		connfd = accept(listenfd, cliaddr, &clilen);
		ACCEPT_UNLOCK();

		/*
		 * Make sure no error occurred...
		 */
		if (connfd < 0) {
			/*
			 * Accept could return an "error" if it was
			 * interrupted by a signal (like when the program
			 * should be killed. :)
			 */
			if (config.quit)
				break;

			log_message(LOG_ERR, "Accept returned an error (%s) ... retrying.", strerror(errno));
			continue;
		}

		ptr->status = T_CONNECTED;

		SERVER_DEC();

		handle_connection(connfd);

		if (thread_config.maxrequestsperchild != 0) {
			ptr->connects++;

			DEBUG2("%u connections so far...", ptr->connects);

			if (ptr->connects >= thread_config.maxrequestsperchild) {
				log_message(LOG_NOTICE,
					    "Thread has reached MaxRequestsPerChild (%u > %u). Killing thread.",
					    ptr->connects,
					    thread_config.maxrequestsperchild);

				break;
			}
		}

		SERVER_COUNT_LOCK();
		if (servers_waiting > thread_config.maxspareservers) {
			/*
			 * There are too many spare threads, kill ourself
			 * off.
			 */
			log_message(LOG_NOTICE,
				    "Waiting servers (%d) exceeds MaxSpareServers (%d). Killing thread.",
				    servers_waiting, thread_config.maxspareservers);
			SERVER_COUNT_UNLOCK();

			break;
		} else {
			SERVER_COUNT_UNLOCK();
		}

		SERVER_INC();
	}

	ptr->status = T_EMPTY;

	safefree(cliaddr);
	return NULL;
}

/*
 * Create the initial pool of threads.
 */
short int
thread_pool_create(void)
{
	unsigned int i;
	int pthread_ret;
#if 0
	pthread_mutexattr_t mutexattr;
#endif

	/*
	 * Initialize thread_attr to contain a non-default stack size
	 * because the default on some OS's is too small. Also, make sure
	 * we're using a detached creation method so all resources are
	 * reclaimed when the thread exits.
	 */
	pthread_attr_init(&thread_attr);
	pthread_attr_setdetachstate(&thread_attr, PTHREAD_CREATE_DETACHED);
	pthread_attr_setstacksize(&thread_attr, THREAD_STACK_SIZE);

#if 0
	pthread_mutexattr_settype(&mutexattr, PTHREAD_MUTEX_ERRORCHECK);
#endif
	pthread_mutex_init(&mlock, NULL);
	pthread_mutex_init(&servers_mutex, NULL);

	if (thread_config.maxclients == 0) {
		log_message(LOG_ERR,
			    "thread_pool_create: \"MaxClients\" must be greater than zero.");
		return -1;
	}
	if (thread_config.startservers == 0) {
		log_message(LOG_ERR,
			    "thread_pool_create: \"StartServers\" must be greater than zero.");
		return -1;
	}

	thread_ptr = safecalloc((size_t) thread_config.maxclients,
				sizeof(struct thread_s));
	if (!thread_ptr) {
		log_message(LOG_ERR, "Could not allocate memory for threads.");
		return -1;
	}

	if (thread_config.startservers > thread_config.maxclients) {
		log_message(LOG_WARNING,
			    "Can not start more than \"MaxClients\" servers. Starting %u servers instead.",
			    thread_config.maxclients);
		thread_config.startservers = thread_config.maxclients;
	}

	for (i = 0; i < thread_config.maxclients; i++) {
		thread_ptr[i].status = T_EMPTY;
		thread_ptr[i].connects = 0;
	}

	for (i = 0; i < thread_config.startservers; i++) {
		DEBUG2("Trying to create thread %d of %d", i + 1, thread_config.startservers);
		thread_ptr[i].status = T_WAITING;
		pthread_ret = pthread_create(&thread_ptr[i].tid, &thread_attr,
					     &thread_main, &thread_ptr[i]);
		if (pthread_ret != 0) {
			log_message(LOG_WARNING,
				    "Could not create thread number %d of %d: %s",
				    i, thread_config.startservers,
				    strerror(pthread_ret));
			return -1;
		} else {
			log_message(LOG_INFO,
				    "Creating thread number %d of %d ...",
				    i + 1, thread_config.startservers);

			SERVER_INC();
		}
	}

	log_message(LOG_INFO, "Finished creating all threads.");

	return 0;
}

/*
 * Keep the proper number of servers running. This is the birth of the
 * servers. It monitors this at least once a second.
 */
void
thread_main_loop(void)
{
	int i;
	int pthread_ret;

	while (1) {
		if (config.quit)
			return;

		/* If there are not enough spare servers, create more */
		SERVER_COUNT_LOCK();
		if (servers_waiting < thread_config.minspareservers) {
			log_message(LOG_NOTICE,
				    "Waiting servers (%d) is less than MinSpareServers (%d). Creating new thread.",
				    servers_waiting, thread_config.minspareservers);

			SERVER_COUNT_UNLOCK();

			for (i = 0; i < thread_config.maxclients; i++) {
				if (thread_ptr[i].status == T_EMPTY) {
					thread_ptr[i].status = T_WAITING;
					pthread_ret = pthread_create(&thread_ptr[i].tid,
								     &thread_attr,
								     &thread_main,
								     &thread_ptr[i]);
					if (pthread_ret != 0) {
						log_message(LOG_NOTICE,
							    "Could not create thread: %s",
							    strerror(pthread_ret));

						thread_ptr[i].status = T_EMPTY;
						break;
					}

					SERVER_INC();

					break;
				}
			}
		} else {
			SERVER_COUNT_UNLOCK();
		}

		sleep(5);

		/* Handle log rotation if it was requested */
		if (log_rotation_request) {
			rotate_log_files();
			log_rotation_request = FALSE;
		}
	}
}

/*
 * Go through all the non-empty threads and cancel them.
 */
#ifdef HAVE_PTHREAD_CANCEL
void
thread_kill_threads(void)
{
	int i;
	
	for (i = 0; i < thread_config.maxclients; i++) {
		if (thread_ptr[i].status != T_EMPTY)
			pthread_cancel(thread_ptr[i].tid);
	}
}
#else
void
thread_kill_threads(void)
{
}
#endif

int
thread_listening_sock(uint16_t port)
{
	listenfd = listen_sock(port, &addrlen);
	return listenfd;
}

void
thread_close_sock(void)
{
	close(listenfd);
}