/*-
 * Copyright (c) 2005 masashi osakabe
 * All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
 * ARE DISCLAIMED.  IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
 * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
 * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
 * OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
 * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
 * LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
 * OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
 * SUCH DAMAGE.
 *
 * $Id: SSLConnection.cpp,v 1.25 2008/03/01 03:37:10 cvsuser Exp $
 */

#include <sys/types.h>
#include <sys/socket.h>
#include <arpa/inet.h>
#include <netdb.h>
#include <netinet/in.h>
#include <unistd.h>

#include <boost/algorithm/string/predicate.hpp>
#if 0
#include <boost/iostreams/copy.hpp>
#include <boost/iostreams/filter/gzip.hpp>
#include <boost/iostreams/filtering_streambuf.hpp>
#endif
#include <boost/lexical_cast.hpp>
using namespace std;
using namespace boost;

#include <openssl/err.h>
#include <openssl/ssl.h>

#include <sl/sys/print_trace.h>
#include <sl/net/http/date.h>
#include <sl/net/http/http_function.h>

#include "include/si.h"
#include "SSLConnector.h"
#include "SSLConnection.h"
#include "SSLRequest.h"
#include "SSLResponse.h"

#define HOST_NAME_MAX 1024


//
// Constructor/Destructor
//

SSLConnection::SSLConnection(SSLConnector *parent, int sock)
	: _available(false), _reading(false), _thread_loop(true)
{
#ifdef DEBUG
	cout << "SSLConnection::SSLConnection " << hex << this << dec << endl;
#endif
	_parent = parent;
	_accept_socket = sock;

	_request.parent(this);
	_response.parent(this);

	_compress = parent->config().attr("compression") == "true";
	_lookup   = parent->config().attr("enableLookups") == "true";
	_timeout  = parent->getTimeout();

	_ctx = SSL_CTX_new(SSLv2_server_method());
	_ssl = SSL_new(_ctx);

	_reader = new ssl_reader(_ssl);
	_writer = new ssl_writer(_ssl);
	_iostream.reading_method(_reader);
	_iostream.writing_method(_writer);

	_wait_thread = new thread(boost::bind(&SSLConnection::wait, this));
}


SSLConnection::~SSLConnection()
{
#ifdef DEBUG
	cout << "SSLConnection::~SSLConnection " << hex << this << dec << endl;
#endif
	_thread_loop = false;

	this->close();
	_wait_thread->join();
	delete _wait_thread;

	delete _reader;
	delete _writer;

	SSL_free(_ssl);
	SSL_CTX_free(_ctx);
}



//
// Operators.
//

bool SSLConnection::operator!()
{
	return !_available;
}


//
// Member functions.
//
void SSLConnection::certificate_file(const std::string &file)
{
#ifdef DEBUG
	cerr << "SSLConnection::certificate_file : " << file << endl;
#endif
	int ret = SSL_use_certificate_file(_ssl, file.c_str(), SSL_FILETYPE_PEM);
	if (ret != 1) {
		ERR_print_errors_fp(stderr);
		cerr << "Error:SSLConnection::certificate_file" << endl;
	}
}


void SSLConnection::private_key_file(const std::string &file)
{
#ifdef DEBUG
	cerr << "SSLConnection::private_key_file : " << file << endl;
#endif
	int ret = SSL_use_PrivateKey_file(_ssl, file.c_str(), SSL_FILETYPE_PEM);
	if (ret != 1) {
		ERR_print_errors_fp(stderr);
		cerr << "Error:SSLConnection::private_key_file" << endl;
	}
}


void SSLConnection::ssl_pre_init()
{
	int ret = SSL_accept(_ssl);
	if (ret != 1) {
		ERR_print_errors_fp(stderr);
		cerr << "Error :SSLConnection::ssl_pre_init" << endl;
	}
}

void SSLConnection::wait()
{
	while (_thread_loop) {
		fd_set  fds;
		FD_ZERO(&fds);
		FD_SET(_accept_socket, &fds);

		struct timeval t;
		t.tv_sec = 1;
		t.tv_usec= 0;

		if (select(_accept_socket + 1, &fds, NULL, NULL, &t) < 0)
			return;

		if (!FD_ISSET(_accept_socket, &fds))
			continue;

		int s = ::accept(_accept_socket,
						 (struct sockaddr *)&_parent->_addr,
						 &_parent->_addr_length);
		if (s == -1)
			return;

		if (SSL_set_fd(_ssl, s) != 1) {
			cerr << "Error :SSL_set_fd in SSLConnection::wait()" << endl;
			return;
		}

		_parent->removeSpareThread(this);
		_parent->setExecuteThread(this);
		_socket = s;
		_available = true;
		_iostream.reset(_socket);
		_iostream.clear();

		std::string home = _parent->home();
		certificate_file(home +"/"+ _parent->config().attr("certificateFile"));
		private_key_file(home +"/"+ _parent->config().attr("privateKeyFile"));
		ssl_pre_init();

		Connector::handler_t handler = _parent->acceptHandler();
		while (init())
			handler(*this);

		if (available())
			this->close();
		_parent->removeExecuteThread(this);
		_parent->setSpareThread(this);
	}
}

bool SSLConnection::init()
{
#ifdef DEBUG
	cerr << "SSLConnection::init Enter" << endl;
#endif
	if (!available())
		return false;

	// ॢȤ30ä
	// ɤ߹².
	_iostream.timeout(30);
	_iostream.rlength(-1);

	int ret = sl::net::http::read_message(_iostream, _real_msg);

	if (ret <= 0) {
		_real_msg.cleanup();
		return false;
	}

	server_info(_socket);
	remote_info(_socket);
	local_info(_socket);

	int length = sl::net::http::content_length(_real_msg);
#ifdef DEBUG
	cerr << "CONTENT-LENGTH:" << length << std::endl;
#endif
	if (length > 0)
		_iostream.rlength(length);

	_request.set(_real_msg);

#ifdef DEBUG
	cerr << "SSLConnection::init Exit" << endl;
#endif
	return true;
}

void SSLConnection::flushHeader()
{
#ifdef DEBUG
	cout << "SSLConnection::flashHeader Enter reading:" << _reading << endl;
#endif
	// Server إå
	_response.header("Server", si::server());
	_response.header("Date", sl::net::http::date::current_time());

	// Connection إå
	_chunked = false;
	if (_request.protocol() == "HTTP/1.1" &&
		equals(_request.header("Connection"), "Keep-Alive", is_iequal()))
	{
		_response.header("Connection", "Keep-Alive");

		if (!_response.containsHeader("Content-Length")) {
			_response.header("Transfer-Encoding", "chunked");
			_chunked = true;
		}
	} else
		_response.header("Connection", "close");

	sl::net::http::write_headers(_iostream, _response.get());
}


void SSLConnection::flushBody()
{
	if (_response.get().msg_body().empty())
		return;

	if (_chunked) {
		std::string chunk = _response.get().msg_body();

		std::ostringstream ss;
		ss << std::hex << chunk.length() << std::dec;

		chunk.insert(0, "\r\n", 2);
		chunk.insert(0, ss.str());
		chunk.append("\r\n", 2);

		_response.get().msg_body(chunk);
	}

	sl::net::http::write_body(_iostream, _response.get());
}

void SSLConnection::flush()
{
	if (!available())
		return;

	sl::net::http::read_unnecessary(_socket, 0);

	if (!equals(_response.getHeader("Connection"), "Keep-Alive", is_iequal()))
		this->close();

	if (_chunked) {
		_response.get().msg_body("\r\n0\r\n\r\n");
		sl::net::http::write_body(_iostream, _response.get());
	}

	_request.cleanup();
	_response.cleanup();
}


void SSLConnection::close()
{
	if (available()) {
	   _available = false;
		SSL_shutdown(_ssl);
	}
}


string SSLConnection::targetHost()
{
	std::string host = _request.header("Host");
	if (host.empty()) {
		try {
			std::string port = boost::lexical_cast<std::string>(serverPort());
			host = serverHost() + ":" + port;
		} catch(...) { }
	}
	return host;
}


Request &SSLConnection::request()
{
	return _request;
}


Response &SSLConnection::response()
{
	return _response;
}


bool SSLConnection::available() const
{ 
	return _available;
}


string SSLConnection::localHost() const
{
	return _local_name;
}


string SSLConnection::localAddr() const
{
	return _local_addr;
}


int SSLConnection::localPort() const
{
	return _local_port;
}


string SSLConnection::remoteHost() const
{
	return _remote_name;
}


string SSLConnection::remoteAddr() const
{
	return _remote_addr;
}


int SSLConnection::remotePort() const
{
	return _remote_port;
}


string SSLConnection::serverHost() const
{
	return _server_name;
}


string SSLConnection::serverAddr() const
{
	return _server_addr;
}


int SSLConnection::serverPort() const
{
	return _server_port;
}


//
// --- private --------------------------------------------------------------
//

void SSLConnection::gzip_encode(std::string& content)
{
#if 0
	namespace io = boost::iostreams;

	stringstream is(content);
	stringstream os;

	io::filtering_streambuf<io::input> filter;
	filter.push(io::gzip_compressor());
	filter.push(is);

	io::copy(filter, os);
	content.clear();
	content = os.str();
#endif
}


bool SSLConnection::remote_info(int sock)
{
	_remote_port = 0;

	struct sockaddr_in	addr;
	socklen_t			len = sizeof(addr);

	if (getpeername(sock, (struct sockaddr *)&addr, &len) < 0)
		return false;

	_remote_addr = inet_ntoa(addr.sin_addr);
	_remote_port = ntohs(addr.sin_port);

	if (_lookup) {
		struct hostent *host = gethostbyaddr((char *)&addr.sin_addr.s_addr,
											 sizeof(addr.sin_addr), AF_INET);
		if (!host)
			return false;

		_remote_name = host->h_name;
	} else {
		_remote_name = _remote_addr;
	}

	return true;
}


bool SSLConnection::server_info(int /* sock */)
{
	char buffer[HOST_NAME_MAX];
	if (gethostname(buffer, HOST_NAME_MAX) < 0)
		return false;

	_server_name = std::string(buffer);
	_server_port = _parent->getPort();

	return true;
}

bool SSLConnection::local_info(int sock)
{
	_local_port = 0;

	struct sockaddr_in	addr;
	socklen_t			len = sizeof(addr);

	if (getsockname(sock, (struct sockaddr *)&addr, &len) < 0)
		return false;

	_local_addr = inet_ntoa(addr.sin_addr);
	_local_port = ntohs(addr.sin_port);

	if (_lookup) {
		struct hostent *host = gethostbyaddr((char *)&addr.sin_addr.s_addr,
											 sizeof(addr.sin_addr), AF_INET);
		if (!host)
			return false;
		_local_name = host->h_name;
	} else
		_local_name = _local_addr;
	return true;
}
