/**********************************************************************
 * session.c                                                August 2005
 *
 * KSSLD: An implementation of SSL/TLS in the Linux Kernel
 * Copyright (C) 2005  NTT COMWARE Corporation.
 *
 * This file based in part on code from LVS www.linuxvirtualserver.org
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 *
 **********************************************************************/

#define __KERNEL_SYSCALLS__

#include <linux/config.h>
#include <linux/module.h>
#include <linux/kernel.h>
#include <linux/sched.h>
#include <linux/random.h>
#include <linux/jhash.h>
#include <linux/interrupt.h>
#include <asm/softirq.h>
#include <linux/proc_fs.h>

#include "types/base_t.h"
#include "types/security_parameters_t.h"
#include "types/cipher_suite_t.h"

#include "util.h"
#include "session.h"
#include "log.h"


static unsigned int kssl_session_rnd;

#define KSSL_SESSION_TAB_BITS  6
#define KSSL_SESSION_TAB_SIZE (1 << (KSSL_SESSION_TAB_BITS - 1))
#define KSSL_SESSION_TAB_MASK (KSSL_SESSION_TAB_SIZE - 1)

#define KSSL_SESSION_TIMEOUT   (2*60*HZ)

static struct list_head *kssl_session_tab = NULL;
static rwlock_t kssl_session_tab_lock;

static void
kssl_session_expire(unsigned long data);

static kssl_session_t *
kssl_session_find(kssl_session_id_t id);


static inline void
kssl_session_get(kssl_session_t *session) 
{
	atomic_inc(&session->users);
}


static inline void
__kssl_session_put(kssl_session_t *session) 
{
	atomic_dec(&session->users);
}


void
kssl_session_put(kssl_session_t *session) 
{
	mod_timer(&session->timer, jiffies+KSSL_SESSION_TIMEOUT);
	__kssl_session_put(session);
}


static inline void
kssl_session_hash(kssl_session_t *session)
{
	write_lock(&kssl_session_tab_lock);
	list_add(&session->list, &kssl_session_tab[session->id &
			                        KSSL_SESSION_TAB_MASK]);
	write_unlock(&kssl_session_tab_lock);
}


static inline void
kssl_session_unhash(kssl_session_t *session)
{
	write_lock(&kssl_session_tab_lock);
	if (likely(!list_empty(&session->list)))
		list_del_init(&session->list);
	write_unlock(&kssl_session_tab_lock);
}


static kssl_session_t *
__kssl_session_create(void)
{
	kssl_session_t *session;

	session = (kssl_session_t *)kmalloc(sizeof(kssl_session_t),
			GFP_KERNEL);
	if (!session)
		return NULL;

	memset(session, 0, sizeof(kssl_session_t));

	INIT_LIST_HEAD(&(session->list));	
	kssl_session_get(session);
	
	init_timer(&session->timer);
	session->timer.data = (unsigned long)session;
	session->timer.function = kssl_session_expire;

	return session;
}


static kssl_session_t *
kssl_session_create(kssl_session_id_t id, cipher_suite_t *cs,
		compression_method_t cm) 
{
	kssl_session_t *session;

	session = __kssl_session_create();
	if (!session)
		return NULL;

	session->id = id;
	memcpy(&session->cs, cs, sizeof(cipher_suite_t));
	session->cm = cm;

	return session;
}


void
kssl_session_invalidate(kssl_session_id_t id)
{
	kssl_session_t *session;

	KSSL_DEBUG(12, "kssl_session_invalidate: enter\n");

	session = kssl_session_find(id);
        if (!session)
  		return;

	memset(session->master_secret, 0, MASTER_SECRET_LEN);
	session->id = 0;

	kssl_session_put(session);
}


static int
kssl_session_valid(kssl_session_t *session)
{
	return session->id != 0;
}


static void
kssl_session_destroy(kssl_session_t *session)
{
	kssl_session_unhash(session);
	memset(session, 0, sizeof(kssl_session_t));
	kfree(session);
}


static kssl_session_t *
kssl_session_find(kssl_session_id_t id)
{
	kssl_session_t *session;
	struct list_head *c_list;
	struct list_head *t_list;

	read_lock(&kssl_session_tab_lock);
	list_for_each_safe(c_list, t_list, &kssl_session_tab[id & 
			KSSL_SESSION_TAB_MASK]) {
		session = list_entry(c_list, kssl_session_t, list);
		if (session->id == id) {
			kssl_session_get(session);
			read_unlock(&kssl_session_tab_lock);
			return session;
		}
	}

	read_unlock(&kssl_session_tab_lock);
	return NULL;
}


static unsigned int
kssl_session_key(u32 addr, u16 port)
{
	u32 rnd;

	get_random_bytes(&rnd, sizeof(rnd));

	return jhash_3words(addr, port, rnd, kssl_session_rnd);
}


kssl_session_id_t 
kssl_session_add(u32 addr, u16 port, cipher_suite_t *cs,
		compression_method_t cm)
{
	kssl_session_t *session;
	kssl_session_id_t id;

	session = kssl_session_create(kssl_session_key(addr, port), cs, cm);
	if (!session) 
		return 0;

	/* Fudge because 0 is actually a valid id */
	if (unlikely(!session->id))
		session->id = KSSL_SESSION_TAB_SIZE;

	id = session->id;

	kssl_session_hash(session);
	kssl_session_put(session);

	return id;
}


int 
kssl_session_set_master_secret(kssl_session_id_t id,
		opaque_t master_secret[MASTER_SECRET_LEN])
{
	kssl_session_t *session;

	session = kssl_session_find(id);
	if (!session)
		return -EINVAL;

	if (kssl_session_valid(session))
		memcpy(session->master_secret, master_secret, 
				MASTER_SECRET_LEN);
	kssl_session_put(session);

	return 0;
}


static void __exit
kssl_session_expire_now(kssl_session_t *session)
{
	mod_timer(&session->timer, jiffies+2);
}


static int __exit 
kssl_expire_all_now(void)
{
	int i;
	int session_count = 0;
	kssl_session_t *session;
	struct list_head *c_list;
	struct list_head *t_list;

	write_lock(&kssl_session_tab_lock);
	for (i = 0; i < KSSL_SESSION_TAB_SIZE; i++) {
		list_for_each_safe(c_list, t_list, &kssl_session_tab[i]) {
			session = list_entry(c_list, kssl_session_t, 
					list);
			session_count++;
			kssl_session_expire_now(session);
		}
	}
	write_unlock(&kssl_session_tab_lock);

	return session_count;
}


static void
kssl_session_expire(unsigned long data)
{
	kssl_session_t *session;

	session = (kssl_session_t *)data;

	kssl_session_get(session);
	kssl_session_unhash(session);

	if (likely(atomic_read(&session->users) == 1)) {
		if (timer_pending(&session->timer))
			del_timer(&session->timer);
		kssl_session_destroy(session);
	}
	else {
		kssl_session_hash(session);
		__kssl_session_put(session);
	}

	return;
}


#define KSSLD_MASTER_SECRET_STR_LEN ((MASTER_SECRET_LEN * 2) + 1)
#define KSSLD_GET_SESSION_LINE_LEN  (KSSLD_MASTER_SECRET_STR_LEN + 25)

static int 
kssld_get_session(char *buf, char **start, off_t offset, int length)
{
	int pos = 0;
	int len;
	int i;
	char tmp[KSSLD_GET_SESSION_LINE_LEN];
	char master_secret[KSSLD_MASTER_SECRET_STR_LEN];
	kssl_session_t *session;
	struct list_head *c_list;
	struct list_head *t_list;

	*start = buf;

	if (length == 0)
		return 0;

	read_lock(&kssl_session_tab_lock);

	for (i = 0; i < KSSL_SESSION_TAB_SIZE; i++) {
		list_for_each_safe(c_list, t_list, &kssl_session_tab[i]) {
			if (pos >= length) 
				goto leave;

			if (offset > KSSLD_GET_SESSION_LINE_LEN) {
				offset -= KSSLD_GET_SESSION_LINE_LEN;
				continue;
			}

			session = list_entry(c_list, kssl_session_t, list);

			kssl_hexdump(session->master_secret, master_secret,
					MASTER_SECRET_LEN);


			len = sprintf(tmp, "%08x %08x %02x%02x %s", i,
					session->id, session->cs.cs[0],
					session->cs.cs[1], master_secret);
			memset(tmp+len, ' ', KSSLD_GET_SESSION_LINE_LEN - len);
			*(tmp + KSSLD_GET_SESSION_LINE_LEN - 1) = '\n';
			len = KSSLD_GET_SESSION_LINE_LEN > length - pos ? 
				length - pos : KSSLD_GET_SESSION_LINE_LEN;
			memcpy(buf + pos, tmp + offset, len - offset);
			pos += len - offset;
			offset = 0;
		}
	}

leave:
	read_unlock(&kssl_session_tab_lock);

	return pos;
}

/* Should be KSSL_SESSION_ID_LEN - sizeof(kssl_session_id_t) bytes long.
 * Longer is ok, but the extra bytes won't be used.*/
#define KSSL_SESSION_PAD "its just 28 bytes of crap..."

/* buf must be at least KSSL_SESSION_ID_LEN+1 bytes long */
void
kssl_session_str_generate(kssl_session_id_t id, u8 *buf)
{
	*buf++ = KSSL_SESSION_ID_LEN;
	id = htonl(id);
	memcpy(buf, &id, sizeof(id));
	memcpy(buf + sizeof(id), KSSL_SESSION_PAD, 
			KSSL_SESSION_ID_LEN - sizeof(id));
}


/* buf should not include the leading length byte */
kssl_session_t *
kssl_session_str_find(u8 *buf, size_t buf_len) 
{
	kssl_session_id_t id;

	if (buf_len != KSSL_SESSION_ID_LEN) 
		return NULL;
	if (memcmp(buf + sizeof(id), KSSL_SESSION_PAD, 
				KSSL_SESSION_ID_LEN-sizeof(id)))
		return NULL;
	memcpy(&id, buf, sizeof(id));
	return kssl_session_find(ntohl(id));
}


int __init 
kssl_session_init(void)
{
	int i;

	get_random_bytes(&kssl_session_rnd, sizeof(kssl_session_rnd));

	kssl_session_tab = kmalloc(KSSL_SESSION_TAB_SIZE *
			sizeof(struct list_head), GFP_KERNEL);
	if (!kssl_session_tab)
		return -ENOMEM;

	for (i = 0; i < KSSL_SESSION_TAB_SIZE; i++)
		INIT_LIST_HEAD(&kssl_session_tab[i]);

	rwlock_init(&kssl_session_tab_lock);

	if (! proc_net_create("kssld_session", 0, kssld_get_session) ) {
		kfree(kssl_session_tab);
		kssl_session_tab = NULL;
		return -EINVAL;
	}

	return 0;
}


void __exit 
kssl_session_cleanup(void)
{
	proc_net_remove("kssld_session");
	

	if (!kssl_session_tab)
		return;

	while(kssl_expire_all_now())
		kssl_jsleep(2);
	kfree(kssl_session_tab);
}

