/**********************************************************************
 * prf.c                                                    August 2005
 *
 * KSSLD: An implementation of SSL/TLS in the Linux Kernel
 * Copyright (C) 2005  NTT COMWARE Corporation.
 *
 * This file based in part on code from LVS www.linuxvirtualserver.org
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 *
 **********************************************************************/

#include <linux/kernel.h>
#include <linux/net.h>
#include <linux/version.h>
#include <linux/crypto.h>
#include <linux/slab.h>

#include <asm/scatterlist.h>

#include "prf.h"
#include "kssl_alloc.h"
#include "log.h"

static inline void xor_8(u8 *a, u8 *b, u32 count) 
{
	while(count-- > 0) {
		*a++ ^= *b++;
	}
}


static inline void xor_32(u32 *a, u32 *b, u32 count) 
{
	while(count > 0) {
		if(count < 4) {
			xor_8((u8 *)a, (u8 *)b, count);
			break;
		}
		*a++ ^= *b++;
		count -= 4;
	}
}


static u8 *kssl_prf_tls_hash(u8 *secret, u32 secret_len,
		const u8 *label, u32 label_len,
		const u8 *client_random, u32 client_random_len,
		const u8 *server_random, u32 server_random_len,
		u32 *output_len, const u8 *hash)
{
	struct scatterlist sg[4];
	struct crypto_tfm *tfm = NULL;
	u8 *ax = NULL;
	u8 *ax_1 = NULL;
	u8 *result = NULL;
	u8 *output = NULL;
	u32 result_len;
	u32 digest_len;
	u32 sg_no = 0;

	tfm = crypto_alloc_tfm(hash, 0);
	if (!tfm) {
		KSSL_DEBUG(6, "kssl_prf_tls_hash: "
				"failed to load transform for %s\n", hash);
		goto leave;
	}

	digest_len = crypto_tfm_alg_digestsize(tfm);

	ax = (u8 *)kssl_kmalloc(digest_len, GFP_KERNEL);
	if(!ax)
		goto leave;

	ax_1 = (u8 *)kssl_kmalloc(digest_len, GFP_KERNEL);
	if(!ax_1)
		goto leave;

	*output_len = ((*output_len + digest_len -1 )/ digest_len) *
		digest_len;
	output = (u8 *)kssl_kmalloc(*output_len, GFP_KERNEL);
	if(!output)
		goto leave;

	sg[sg_no].page = virt_to_page(ax);
	sg[sg_no].offset = ((long) ax & ~PAGE_MASK);
	sg[sg_no].length = digest_len;
	sg_no++;

	if(label && label_len) {
		sg[sg_no].page = virt_to_page(label);
		sg[sg_no].offset = ((long) label & ~PAGE_MASK);
		sg[sg_no].length = label_len;
		sg_no++;
	}

	if(client_random && client_random_len) {
		sg[sg_no].page = virt_to_page(client_random);
		sg[sg_no].offset = ((long) client_random & ~PAGE_MASK);
		sg[sg_no].length = client_random_len;
		sg_no++;
	}

	if(server_random && server_random_len) {
		sg[sg_no].page = virt_to_page(server_random);
		sg[sg_no].offset = ((long)server_random & ~PAGE_MASK);
		sg[sg_no].length = server_random_len;
		sg_no++;
	}

	crypto_hmac(tfm, secret, &secret_len, sg+1, sg_no-1, ax);

	result_len = 0;
	result = output;
	while(result_len < *output_len) {
		crypto_hmac(tfm, secret, &secret_len, sg, sg_no, result);

		crypto_hmac(tfm, secret, &secret_len, sg, 1, ax_1);
		memcpy(ax, ax_1, digest_len);

		result_len += digest_len;
		result += digest_len;
	}
	
leave:
	if(tfm)
		crypto_free_tfm(tfm);
	if(ax)
		kssl_kfree(ax);
	if(ax_1)
		kssl_kfree(ax_1);
	return(output);
}


u8 *kssl_prf_tls(u8 *secret, u32 secret_len,
		const u8 *label, u32 label_len,
		const u8 *client_random, u32 client_random_len,
		const u8 *server_random, u32 server_random_len,
		u32 output_len)
{
	u8 *md5_result = NULL;
	u32 md5_result_len;
	u8 *sha1_result = NULL;
	u32 sha1_result_len;
	u32 s_len;
	u32 s_offset;
	u8 *label_cpy;

	/* XXX: Is this neccessary ??? */
	label_cpy = kssl_kmalloc(label_len, GFP_KERNEL);
	if (!label_cpy)
		return NULL;
	memcpy(label_cpy, label, label_len);

	KSSL_DEBUG(12, "kssl_prf_tls enter: output_len=%d\n", output_len);

#if 0 /* DEBUG */
	asym_print_char(KERN_DEBUG "secret", secret, secret_len);
	asym_print_char(KERN_DEBUG "label", label, label_len);
	asym_print_char(KERN_DEBUG "client_random", client_random, 
			client_random_len);
	asym_print_char(KERN_DEBUG "server_random", server_random, 
			server_random_len);
#endif

	s_offset = s_len = (secret_len >> 1 );
	if(secret_len & 0x1) {
		s_len++;
	}

	md5_result_len = output_len;
	md5_result = kssl_prf_tls_hash(secret, s_len, label_cpy, label_len,
			client_random, client_random_len, 
			server_random, server_random_len,
			&md5_result_len, "md5");
	if(!md5_result) {
		KSSL_DEBUG(6, "kssl_prf_tls: md5 hash failed\n");
		goto leave;
	}

	sha1_result_len = output_len;
	sha1_result = kssl_prf_tls_hash(secret+s_offset, s_len, 
			label_cpy, label_len, client_random, client_random_len, 
			server_random, server_random_len,
			&sha1_result_len, "sha1");
	if(!sha1_result) {
		KSSL_DEBUG(6, "kssl_prf_tls: sha1 hash sha1\n");
		goto leave;
	}

	xor_32((u32 *)sha1_result, (u32 *)md5_result, output_len);
	if(sha1_result_len > output_len)
		memset(sha1_result+output_len, 0, sha1_result_len-output_len);

leave:
	if (label_cpy)
		kssl_kfree(label_cpy);
	if(md5_result)
		kssl_kfree(md5_result);
	return(sha1_result);
}


static void __kssl_prf_ssl3_hash(const u8 *secret, u32 secret_len,
		const u8 *label, u32 label_len,
		const u8 *random_a, u32 random_a_len,
		const u8 *random_b, u32 random_b_len,
		u8 *sha1_buf, u8 *output, 
		struct crypto_tfm *tfm_md5, struct crypto_tfm *tfm_sha1)
{
	struct scatterlist sg[4];
	u32 sg_no = 0;

	if(label && label_len) {
		sg[sg_no].page = virt_to_page(label);
		sg[sg_no].offset = ((long) label & ~PAGE_MASK);
		sg[sg_no].length = label_len;
		sg_no++;
	}

	if(secret && secret_len) {
		sg[sg_no].page = virt_to_page(secret);
		sg[sg_no].offset = ((long) secret & ~PAGE_MASK);
		sg[sg_no].length = secret_len;
		sg_no++;
	}

	if(random_a && random_a_len) {
		sg[sg_no].page = virt_to_page(random_a);
		sg[sg_no].offset = ((long)random_a & ~PAGE_MASK);
		sg[sg_no].length = random_a_len;
		sg_no++;
	}

	if(random_b && random_b_len) {
		sg[sg_no].page = virt_to_page(random_b);
		sg[sg_no].offset = ((long)random_b & ~PAGE_MASK);
		sg[sg_no].length = random_b_len;
		sg_no++;
	}

	crypto_digest_init(tfm_sha1);
	crypto_digest_update(tfm_sha1, sg, sg_no);
	crypto_digest_final(tfm_sha1, sha1_buf);

	sg_no = 0;

	if(secret && secret_len) {
		sg[sg_no].page = virt_to_page(secret);
		sg[sg_no].offset = ((long) secret & ~PAGE_MASK);
		sg[sg_no].length = secret_len;
		sg_no++;
	}

	sg[sg_no].page = virt_to_page(sha1_buf);
	sg[sg_no].offset = ((long) sha1_buf & ~PAGE_MASK);
	sg[sg_no].length = crypto_tfm_alg_digestsize(tfm_sha1);
	sg_no++;

	crypto_digest_init(tfm_md5);
	crypto_digest_update(tfm_md5, sg, sg_no);
	crypto_digest_final(tfm_md5, output);
	
	return;
}


u8 *kssl_prf_ssl3_export_md5(const u8 *secret, u32 secret_len,
		const u8 *client_random, u32 client_random_len,
		const u8 *server_random, u32 server_random_len,
		u32 *output_len)
{
	struct scatterlist sg[3];
	struct crypto_tfm *tfm = NULL;
	u8 *buf = NULL;
	u32 len;
	u32 sg_no = 0;

	tfm = crypto_alloc_tfm("md5", 0);
	if (!tfm) {
		KSSL_DEBUG(6, "failed to load transform for md5\n");
		return(NULL);
	}

	*output_len = ((*output_len + crypto_tfm_alg_digestsize(tfm) - 1) /
			crypto_tfm_alg_digestsize(tfm)) * 
		crypto_tfm_alg_digestsize(tfm);

	buf = (u8 *)kssl_kmalloc(*output_len, GFP_KERNEL);
	if (!buf)
		goto leave;

	if(secret && secret_len) {
		sg[sg_no].page = virt_to_page(secret);
		sg[sg_no].offset = ((long) secret & ~PAGE_MASK);
		sg[sg_no].length = secret_len;
		sg_no++;
	}
	
	if(client_random && client_random_len) {
		sg[sg_no].page = virt_to_page(client_random);
		sg[sg_no].offset = ((long)client_random & ~PAGE_MASK);
		sg[sg_no].length = client_random_len;
		sg_no++;
	}


	if(server_random && server_random_len) {
		sg[sg_no].page = virt_to_page(server_random);
		sg[sg_no].offset = ((long)server_random & ~PAGE_MASK);
		sg[sg_no].length = server_random_len;
		sg_no++;
	}

	crypto_digest_init(tfm);
	crypto_digest_update(tfm, sg, sg_no);
	crypto_digest_final(tfm, buf);

	/* XXX: Is this needed ? */
	for(len = crypto_tfm_alg_digestsize(tfm); len < *output_len ; 
			len += crypto_tfm_alg_digestsize(tfm)) {
		memcpy(buf + len, buf, crypto_tfm_alg_digestsize(tfm));
	}
	
leave:
	crypto_free_tfm(tfm);
	return(buf);
}


/* Ouptut will be 3*16 = 48 bytes (MASTER_SECRET_LEN) */
u8 *kssl_prf_ssl3(const u8 *secret, u32 secret_len,
		const u8 *random_a, u32 random_a_len,
		const u8 *random_b, u32 random_b_len, u32 *output_len)
{
	int status = -ENOMEM;
	struct crypto_tfm *tfm_md5 = NULL;
	struct crypto_tfm *tfm_sha1 = NULL;
	u8 *sha1_buf = NULL;
	u8 label[26];
	u32 len;
	char c;
	u8 *output = NULL;

	tfm_md5 = crypto_alloc_tfm("md5", 0);
	if (!tfm_md5) {
		KSSL_DEBUG(6, "kssl_prf_ssl3_hash: "
				"failed to load transform for md5\n");
		goto leave;
	}

	tfm_sha1 = crypto_alloc_tfm("sha1", 0);
	if (!tfm_sha1) {
		KSSL_DEBUG(6, "kssl_prf_ssl3_hash: "
				"failed to load transform for sha1\n");
		goto leave;
	}

	sha1_buf = (u8 *)kssl_kmalloc(crypto_tfm_alg_digestsize(tfm_sha1),
			GFP_KERNEL);
	if (!sha1_buf) {
		KSSL_DEBUG(6, "kssl_prf_ssl3_hash: "
				"failed to allocate memory\n");
		goto leave;
	}

	*output_len = ((*output_len + crypto_tfm_alg_digestsize(tfm_md5) - 1) /
		crypto_tfm_alg_digestsize(tfm_md5)) * 
		crypto_tfm_alg_digestsize(tfm_md5);

	output = (u8 *)kssl_kmalloc(*output_len, GFP_KERNEL);
	if (!output) {
		goto leave;
	}

	c = 'A';
	for(len = 0; len < *output_len ; 
			len += crypto_tfm_alg_digestsize(tfm_md5)) {

		memset(label, c, c - 'A' + 1);

		__kssl_prf_ssl3_hash(secret, secret_len, label, c - 'A' + 1,
				random_a, random_a_len, random_b, random_b_len, 
				sha1_buf, output + len, tfm_md5, tfm_sha1);

		c++;
		if(c == 'Z') {
			status = -EINVAL;
			goto leave;
		}
	}

	status = 0;
leave:
	if(tfm_md5)
		crypto_free_tfm(tfm_md5);
	if(tfm_sha1)
		crypto_free_tfm(tfm_sha1);
	if(sha1_buf)
		kssl_kfree(sha1_buf);
	if(output && status) {
		kssl_kfree(output);
		output = NULL;
	}
	return(output);
}
