/**********************************************************************
 * TCPS (TCP Splicing Module)
 * info.c: manage tcps information.
 *
 * Copyright (C) 2006  Hirotaka Sasaki <hiro1967@mti.biglobe.ne.jp>
 *
 * This code is based on ip_vs_conn.c 1.28.2.1
 *
 * IPVS         An implementation of the IP virtual server support for the
 *              LINUX operating system.  IPVS is now implemented as a module
 *              over the Netfilter framework. IPVS can be used to build a
 *              high-performance and highly available server based on a
 *              cluster of servers.
 *
 * Authors:     Wensong Zhang <wensong@linuxvirtualserver.org>
 *              Peter Kese <peter.kese@ijs.si>
 *              Julian Anastasov <ja@ssi.bg>
 *
 * The IPVS code for kernel 2.2 was done by Wensong Zhang and Peter Kese,
 * with changes/fixes from Julian Anastasov, Lars Marowsky-Bree, Horms
 * and others. Many code here is taken from IP MASQ code of kernel 2.2.
 *
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 *
 **********************************************************************/

#include <linux/config.h>
#include <linux/module.h>
#include <linux/types.h>
#include <linux/kernel.h>
#include <linux/fs.h>
#include <linux/sysctl.h>
#include <linux/proc_fs.h>
#include <linux/list.h>
#include <linux/spinlock.h>
#include <linux/interrupt.h>
#include <linux/errno.h>
#include <linux/timer.h>
#include <linux/vmalloc.h>
#include <linux/skbuff.h>
#include <linux/ip.h>
#include <linux/jhash.h>
#include <linux/random.h>
#include <asm/atomic.h>
#include <asm/uaccess.h>
#include <net/tcp.h>
#include <net/udp.h>
#include <net/icmp.h>
#include <net/ip.h>
#include <net/sock.h>

#include "tcps.h"
#include "tcps_compat.h"

/*
 *  Fine locking granularity for big tcps info hash table
 */
#define IT_LOCKARRAY_BITS  4
#define IT_LOCKARRAY_SIZE  (1<<IT_LOCKARRAY_BITS)
#define IT_LOCKARRAY_MASK  (IT_LOCKARRAY_SIZE-1)

/*
 *  tcps info hash table
 */
static struct list_head *tcps_info_tab;

/* SLAB cache for tcps info */
static kmem_cache_t *tcps_info_cachep;

/* counter for tcps info entries */
static atomic_t tcps_info_count = ATOMIC_INIT(0);

struct tcps_info_aligned_lock {
	rwlock_t	l;
} __attribute__((__aligned__(SMP_CACHE_BYTES)));

/* lock array for tcps info table */
struct tcps_info_aligned_lock
__tcps_infotbl_lock_array[IT_LOCKARRAY_SIZE] __cacheline_aligned;

static inline void it_read_lock(unsigned key)
{
	read_lock(&__tcps_infotbl_lock_array[key&IT_LOCKARRAY_MASK].l);
}

static inline void it_read_unlock(unsigned key)
{
	read_unlock(&__tcps_infotbl_lock_array[key&IT_LOCKARRAY_MASK].l);
}

static inline void it_write_lock(unsigned key)
{
	write_lock(&__tcps_infotbl_lock_array[key&IT_LOCKARRAY_MASK].l);
}

static inline void it_write_unlock(unsigned key)
{
	write_unlock(&__tcps_infotbl_lock_array[key&IT_LOCKARRAY_MASK].l);
}

static inline void it_read_lock_bh(unsigned key)
{
	read_lock_bh(&__tcps_infotbl_lock_array[key&IT_LOCKARRAY_MASK].l);
}

static inline void it_read_unlock_bh(unsigned key)
{
	read_unlock_bh(&__tcps_infotbl_lock_array[key&IT_LOCKARRAY_MASK].l);
}

static inline void it_write_lock_bh(unsigned key)
{
	write_lock_bh(&__tcps_infotbl_lock_array[key&IT_LOCKARRAY_MASK].l);
}

static inline void it_write_unlock_bh(unsigned key)
{
	write_unlock_bh(&__tcps_infotbl_lock_array[key&IT_LOCKARRAY_MASK].l);
}

static const char *
tcps_info_state_name(u32 state)
{
	if (state & TCPS_INFO_S_EXPIRED) {
		return "EXPIRED";
	} else if (state & TCPS_INFO_S_RST) {
		return "RST";
	} else if (state & TCPS_INFO_S_FIN) {
		return "FIN";
	} else if (state & TCPS_INFO_S_SENT) {
		return "SENT";
	} else if (state & TCPS_INFO_S_SYN) {
		return "SYN";
	} else if (state & TCPS_INFO_S_NEW) {
		return "NEW";
	}

	return "Unknown";
}

static int
tcps_info_getinfo(char *buffer, char **start, off_t offset, int length)
{
	off_t pos = 0;
	int idx, len = 0;
	char temp[50];
	struct tcps_info *tcpsi;
	struct list_head *l, *e;

	pos = 40;
	if (pos > offset) {
		len += sprintf(buffer+len, "%-39s\n",
			       "Idx   SockAdr  InfoAdr  State   Expires");
	}

	for(idx = 0; idx < TCPS_CONN_TAB_SIZE; idx++) {
		/*
		 *	Lock is actually only need in next loop
		 *	we are called from uspace: must stop bh.
		 */
		it_read_lock_bh(idx);

		l = &tcps_info_tab[idx];
		for (e = l->next; e != l; e = e->next) {
			tcpsi = list_entry(e, struct tcps_info, list);
			pos += 40;
			if (pos <= offset)
				continue;

			sprintf(temp, "%05d %08X %08X %-7s %7lu",
				idx, (unsigned int)tcpsi->sk, (unsigned int)tcpsi,
				tcps_info_state_name(tcpsi->state),
				(tcpsi->timer.expires-jiffies)/HZ);

			len += sprintf(buffer + len, "%-39s\n", temp);

			if (pos >= offset + length) {
				it_read_unlock_bh(idx);
				goto done;
			}
		}
		it_read_unlock_bh(idx);
	}

  done:
	*start = buffer + len - (pos - offset);  /* Start of wanted data */
	len = pos - offset;
	if (len > length)
		len = length;
	if (len < 0)
		len = 0;
	return len;
}

static int
tcps_info_tab_getinfo(char *buffer, char **start, off_t offset, int length)
{
	off_t pos = 0;
	int idx, len = 0;
	char temp[30];
	struct list_head *l, *e;
	int entry_count, total_count;

	pos = 22;
	if (pos > offset) {
		len += sprintf(buffer+len, "%-21s\n", "Index      Count     ");
	}

	total_count = 0;
	for(idx = 0; idx < TCPS_CONN_TAB_SIZE; idx++) {

		pos += 22;
		if (pos <= offset) continue;

		entry_count = 0;
		it_read_lock_bh(idx);
		l = &tcps_info_tab[idx];
		for (e = l->next; e != l; e = e->next) entry_count++;
		it_read_unlock_bh(idx);

		total_count += entry_count;
		sprintf(temp, "%010d %010d", idx, entry_count);
		len += sprintf(buffer + len, "%-21s\n", temp);
		if (pos >= offset + length) {
			goto done;
		}
	}

  done:
	*start = buffer + len - (pos - offset);  /* Start of wanted data */
	len = pos - offset;
	if (len > length)
		len = length;
	if (len < 0)
		len = 0;
	return len;
}

#ifdef TCPS_INFO_USE_JHASH

/* seed value for tcps info hash */
static unsigned int tcps_info_hash_sval;

static void
tcps_info_hash_init(void)
{
	/* calculate the seed value for tcps info hash */
	get_random_bytes(&tcps_info_hash_sval, sizeof(tcps_info_hash_sval));
}

/*
 *	Returns hash value for tcps info entry
 */
static inline unsigned long
tcps_info_hashkey(struct sock *sk, struct tcps_info *tcpsi)
{
	unsigned long __sk = (unsigned long)sk;
	unsigned long __tcpsi = (unsigned long)tcpsi;

	return jhash_2words(__sk, __tcpsi, tcps_info_hash_sval) & TCPS_INFO_TAB_MASK;
}

#else /* not use jhash */

static void
tcps_info_hash_init(void)
{
	return;
}

/*
 *	Returns hash value for tcps info entry
 */
static inline unsigned long
tcps_info_hashkey(struct sock *sk, struct tcps_info *tcpsi)
{

/*
	unsigned long tmp = (unsigned long)sk >> L1_CACHE_SHIFT;
	tmp += ((unsigned long)tcpsi >> L1_CACHE_SHIFT);
	tmp += (tmp >> TCPS_INFO_TAB_BITS);
	return tmp & TCPS_INFO_TAB_MASK;
*/
	unsigned long __sk = (unsigned long)sk;
	unsigned long __tcpsi = (unsigned long)tcpsi;

	return ((__sk>>L1_CACHE_SHIFT)^(__tcpsi>>L1_CACHE_SHIFT))
	       & TCPS_INFO_TAB_MASK;
}

#endif /* TCPS_INFO_USE_JHASH */

int
tcps_info_hash(struct tcps_info *tcpsi)
{
	unsigned hash;

	if (tcpsi->flags & TCPS_INFO_F_HASHED) {
		TCPS_ERR("tcps_info_hash(): request for already hashed, "
			 "called from %p\n", __builtin_return_address(0));
		return 0;
	}

	/* Hash by sock address and tcpsi address */
	hash = tcps_info_hashkey(tcpsi->sk, tcpsi);

	it_write_lock_bh(hash);

	list_add(&tcpsi->list, &tcps_info_tab[hash]);
	tcpsi->flags |= TCPS_INFO_F_HASHED;
	atomic_inc(&tcpsi->refcnt);

	it_write_unlock_bh(hash);

	return 1;
}

int
tcps_info_unhash(struct tcps_info *tcpsi)
{
	unsigned hash;

	if (!(tcpsi->flags & TCPS_INFO_F_HASHED)) {
		TCPS_ERR("tcps_info_unhash(): request for unhash flagged, "
			 "called from %p\n", __builtin_return_address(0));
		return 0;
	}

	/* Hash by sock address and tcpsi address */
	hash = tcps_info_hashkey(tcpsi->sk, tcpsi);

	it_write_lock_bh(hash);

	list_del(&tcpsi->list);
	tcpsi->flags &= ~TCPS_INFO_F_HASHED;
	atomic_dec(&tcpsi->refcnt);

	it_write_unlock_bh(hash);

	return 1;
}

struct tcps_info *
tcps_info_get(struct sock *sk)
{
	struct tcps_info *tcpsi, *found;
	unsigned int hash;
	struct list_head *l, *e;

	if (TCPS_SK_TCPSI(sk) == NULL) {
		return NULL;
	}

	hash = tcps_info_hashkey(sk, TCPS_SK_TCPSI(sk));
	l = &tcps_info_tab[hash];
	found = NULL;

	it_read_lock_bh(hash);
	for (e = l->next; e != l; e = e->next) {
		tcpsi = list_entry(e, struct tcps_info, list);
		if (TCPS_SK_TCPSI(sk) == tcpsi && sk == tcpsi->sk) {
			if (tcpsi->state & TCPS_INFO_S_EXPIRED) {
				continue;
			}
			atomic_inc(&tcpsi->refcnt);
			found = tcpsi;
			break;
		}
	}
	it_read_unlock_bh(hash);
	return found;
}

void
tcps_info_set_state(struct tcps_info *tcpsi, u32 state)
{
	tcpsi->state = state;
	if (state & TCPS_INFO_S_EXPIRED) {
		tcpsi->timeout = TCPS_INFO_TIMEOUT_EXPIRED;
	} else if (state & TCPS_INFO_S_RST) {
		tcpsi->timeout = TCPS_INFO_TIMEOUT_RST;
	} else if (state & TCPS_INFO_S_FIN) {
		tcpsi->timeout = TCPS_INFO_TIMEOUT_FIN;
	} else if (state & TCPS_INFO_S_SENT) {
		tcpsi->timeout = TCPS_INFO_TIMEOUT;
	} else {
		tcpsi->timeout = TCPS_INFO_TIMEOUT_SYN;
	}
}

void
tcps_info_put(struct sock *sk)
{
	struct tcps_info *tcpsi = TCPS_SK_TCPSI(sk);

	mod_timer(&tcpsi->timer, jiffies + tcpsi->timeout);
	__tcps_info_put(sk);
}

static void
tcps_info_expire(unsigned long data)
{
	struct sock *sk = (struct sock *)data;
	struct tcps_info *tcpsi;

	tcpsi = tcps_info_get(sk);
	if (!tcpsi) return;

	if (!tcps_info_unhash(tcpsi)) {
		__tcps_info_put(sk);
		return;
	}

	spin_lock_bh(&tcpsi->lock);
	tcps_info_set_state(tcpsi, tcpsi->state|TCPS_INFO_S_EXPIRED);
	spin_unlock_bh(&tcpsi->lock);

	if (atomic_read(&tcpsi->refcnt) == 1) {
		if (timer_pending(&tcpsi->timer))
			del_timer(&tcpsi->timer);

		TCPS_DBG("tcps_info_expire: tcpsi=%p\n", tcpsi);
		__tcps_info_put(sk);
		tcps_info_free(tcpsi);
		return;
	}

	tcps_info_hash(tcpsi);

	TCPS_DBG("tcps_info_expire: delayed: tcpsi=%p refcnt-1=%d\n",
		 tcpsi, atomic_read(&tcpsi->refcnt)-1);

	tcps_info_put(sk);
}

static void
tcps_info_expire_now(struct tcps_info *tcpsi)
{
	if (del_timer(&tcpsi->timer))
		mod_timer(&tcpsi->timer, jiffies);
}

/*
 *  Create a new tcps info entry and hash it into the tcps_info_tab.
 */
struct tcps_info *
tcps_info_new(struct sock *sk)
{
	struct tcps_info *tcpsi;

	tcpsi = kmem_cache_alloc(tcps_info_cachep, GFP_ATOMIC);
	if (tcpsi == NULL) {
		TCPS_INF("tcps_info_new: no memory available");
		return ERR_PTR(-ENOMEM);
	}

	memset(tcpsi, 0, sizeof(*tcpsi));
	INIT_LIST_HEAD(&tcpsi->list);
	spin_lock_init(&tcpsi->lock);

	init_timer(&tcpsi->timer);
	tcpsi->timer.data     = (unsigned long)sk;
	tcpsi->timer.function = tcps_info_expire;

	TCPS_SK_TCPSI(sk)    = tcpsi;
	tcpsi->sk            = sk;
	sock_hold(sk);

	atomic_inc(&tcps_info_count);

	/* Set its state and timeout */
	tcps_info_set_state(tcpsi, TCPS_INFO_S_NEW);

	/*
	 * Set the entry is referenced by the current thread before hashing
	 * it in the table.
	 */
	atomic_set(&tcpsi->refcnt, 1);

	tcps_info_hash(tcpsi);

	TCPS_DBG("tcps_info_new: sk=%p tcpsi=%p\n", sk, tcpsi);
	return tcpsi;
}

void
tcps_info_free(struct tcps_info *tcpsi)
{
	struct sock *sk;

	if (!tcpsi) return;

	spin_lock_bh(&tcpsi->lock);

	sk = tcpsi->sk;
	TCPS_SK_TCPSI(sk) = NULL;
	tcpsi->sk = NULL;

	spin_unlock_bh(&tcpsi->lock);

	sock_put(sk);
	kmem_cache_free(tcps_info_cachep, tcpsi);
	atomic_dec(&tcps_info_count);

	TCPS_DBG("tcps_info_free: sk=%p tcpsi=%p info_count=%d\n",
		 sk, tcpsi, atomic_read(&tcps_info_count));
}

/*
 *  Flush all tcps info entries in the tcps_info_tab.
 *
 *  note: You need to delete netfilter hooks before calling this function.
 */
static void tcps_info_flush(void)
{
	int idx;
	struct tcps_info *tcpsi;
	struct list_head *l,*e;
	unsigned long tcps_info_flush_count = atomic_read(&tcps_info_count);

flush_again:
	for (idx = 0; idx < TCPS_INFO_TAB_SIZE; idx++) {
		/*
		 *  Lock is actually needed in this loop.
		 */
		it_write_lock_bh(idx);

		l = &tcps_info_tab[idx];
		for (e=l->next; e!=l; e=e->next) {
			tcpsi = list_entry(e, struct tcps_info, list);
			TCPS_DBG("tcps_info_flush: tcpsi=%p\n", tcpsi);
			tcps_info_expire_now(tcpsi);
		}
		it_write_unlock_bh(idx);
	}

	if (atomic_read(&tcps_info_count) != 0) {
		schedule();
		goto flush_again;
	}

	TCPS_INF("tcps_info_flush: %lu entries flushed\n", tcps_info_flush_count);
}

int
tcps_info_init(void)
{
	int idx;

	/*
	 * Allocate the tcps info hash table and initialize its list heads
	 */
	tcps_info_tab = vmalloc(TCPS_INFO_TAB_SIZE*sizeof(struct list_head));
	if (!tcps_info_tab)
		return -ENOMEM;

	/* Allocate tcps_info slab cache */
	tcps_info_cachep = kmem_cache_create("tcps_info",
					      sizeof(struct tcps_info), 0,
					      SLAB_HWCACHE_ALIGN, NULL, NULL);
	if (!tcps_info_cachep) {
		vfree(tcps_info_tab);
		return -ENOMEM;
	}

	TCPS_INF("Info hash table configured "
		  "(size=%d, memory=%ldKbytes)\n",
		  TCPS_INFO_TAB_SIZE,
		  (long)(TCPS_INFO_TAB_SIZE*sizeof(struct list_head))/1024);
	TCPS_DBG("Each info entry needs %d bytes at least\n",
		 sizeof(struct tcps_info));

	for (idx = 0; idx < TCPS_INFO_TAB_SIZE; idx++) {
		INIT_LIST_HEAD(&tcps_info_tab[idx]);
	}

	for (idx = 0; idx < IT_LOCKARRAY_SIZE; idx++)  {
		__tcps_infotbl_lock_array[idx].l = RW_LOCK_UNLOCKED;
	}

	proc_net_create("tcps_info", 0, tcps_info_getinfo);
	proc_net_create("tcps_info_tab", 0, tcps_info_tab_getinfo);

	tcps_info_hash_init();

	return 0;
}

void
tcps_info_fini(void)
{
	/* Flush all the tcps info entries */
	tcps_info_flush();

	proc_net_remove("tcps_info");
	proc_net_remove("tcps_info_tab");
	kmem_cache_destroy(tcps_info_cachep);
	vfree(tcps_info_tab);
}
