/**********************************************************************
 * handshake.c                                              August 2005
 *
 * KSSLD: An implementation of SSL/TLS in the Linux Kernel
 * Copyright (C) 2005  NTT COMWARE Corporation.
 *
 * This file based in part on code from LVS www.linuxvirtualserver.org
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 *
 **********************************************************************/

#include <asm/types.h>
#include <linux/random.h>
#include <asm/scatterlist.h>
#include <linux/string.h>

#include <net/tcp.h>

#include "css.h"
#include "daemon.h"
#include "handshake.h"
#include "message.h"
#include "prf.h"
#include "record.h"
#include "kssl_asym.h"
#include "kssl_alloc.h"
#include "log.h"
#include "session.h"

#include "pk.h"

#include "types/handshake_t.h"

static const u8 finished_sender_client_ssl3[4] = { 0x43, 0x4c, 0x4e, 0x54 };
static const u8 finished_sender_server_ssl3[4] = { 0x53, 0x52, 0x56, 0x52 };
static const u8 * finished_sender_client_tls = "client finished";
static const u8 * finished_sender_server_tls = "server finished";

#define PAD1_LEN 48
static const u8 pad1[PAD1_LEN] =  { 
        0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36,
        0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36,
        0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36,
        0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36,
        0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36,
        0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36
};

#define PAD2_LEN 48
static const u8 pad2[PAD2_LEN] =  { 
        0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c,
        0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c,
        0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c,
        0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c,
        0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c,
        0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c
};


static int server_hello_send(kssl_record_t *cr, int reuse);
static int certificate_send(kssl_record_t *cr, struct iovec *cert, int reuse);
static int server_hello_done_send(kssl_record_t *cr, int reuse);
static int finished_send(kssl_record_t *cr, int reuse);

int kssl_handshake_digest_update(kssl_record_t *cr);


static void 
random_from_buf(random_t *random, const u8 *buf,
		size_t len) 
{
	/* Random is a fixed buffer, so there is no leading size field */
	if (len > RANDOM_NLEN)
		len = RANDOM_NLEN;
	if (len)
		memset(random, 0, len);
	memcpy(random->random_bytes+RANDOM_NLEN-len, buf, len);
}


static int 
session_id_from_buf(session_id_t *sid, const u8 *buf, size_t buf_len)
{
	if (buf_len < 1)
		return -EINVAL;

	sid->len = *buf;
	if (!sid->len)
		return 1;

	if (buf_len < 1 + sid->len)
		return -EINVAL;

	sid->id = kssl_kmalloc(sid->len, GFP_KERNEL);
	if (!sid->id)
		return -EINVAL;

	memcpy(sid->id, buf + 1, sid->len);

	return sid->len + 1;
}


static int 
cipher_suites_from_buf(client_hello_t *ch, const u8 *buf, size_t buf_len)
{
	if (buf_len < 2)
		return -EINVAL;

	ch->cipher_suites.len = (*buf << 8) + *(buf + 1);
	if (!ch->cipher_suites.len)
		return 2;

	if (buf_len < 2 + ch->cipher_suites.len)
		return -EINVAL;

	ch->cipher_suites.cs = kssl_kmalloc(ch->cipher_suites.len, GFP_KERNEL);
	if (!ch->cipher_suites.cs)
		return -EINVAL;

	memcpy(ch->cipher_suites.cs, buf + 2, ch->cipher_suites.len);

	return 2 + ch->cipher_suites.len;
}


static int 
compression_methods_from_buf(client_hello_t *ch, const u8 *buf, size_t buf_len)
{
	if (buf_len < 1)
		return -EINVAL;

	ch->compression_methods.len = *buf;
	if (!ch->compression_methods.len)
		return 1;

	if (buf_len < 1 + ch->compression_methods.len)
		return -EINVAL;

	ch->compression_methods.cm = kssl_kmalloc(ch->compression_methods.len, 
			GFP_KERNEL);
	if (!ch->compression_methods.cm)
		return -EINVAL;

	memcpy(ch->compression_methods.cm, buf + 1, 
			ch->compression_methods.len);

	return ch->compression_methods.len + 1;
}


static void 
client_hello_destroy_data(client_hello_t *ch)
{
	KSSL_DEBUG(12, "client_hello_destroy_data: enter\n");

	if (ch->session_id.id) {
		kssl_kfree(ch->session_id.id);
		ch->session_id.id = NULL;
	}
	if (ch->cipher_suites.cs) {
		kssl_kfree(ch->cipher_suites.cs);
		ch->cipher_suites.cs = NULL;
	}
	if (ch->compression_methods.cm) {
		kssl_kfree(ch->compression_methods.cm);
		ch->session_id.id = NULL;
	}
}


static int 
client_hello_from_buf_tls(client_hello_t *ch, u8 *buf, size_t buf_len) 
{
	size_t len;
	int status = -EINVAL;

	ch->session_id.id = NULL;
	ch->cipher_suites.cs = NULL;
	ch->compression_methods.cm = NULL;

	len = 2 + RANDOM_NLEN;
	if (buf_len < len) {
		KSSL_DEBUG(3, "client_hello_from_buf_tls: buffer too short\n");
		goto error;
	}

	protocol_version_from_buf(&(ch->client_version), buf);
	random_from_buf(&(ch->random), buf+2, RANDOM_NLEN);

	status  = session_id_from_buf(&(ch->session_id), buf + len, 
				buf_len - len);
	if (status < 0) {
		KSSL_DEBUG(6, "client_hello_from_buf_tls: "
				"error reading session id\n");
		goto error;
	}
	len += status;

	status = cipher_suites_from_buf(ch, buf + len, buf_len - len);
	if (status < 0) {
		KSSL_DEBUG(6, "client_hello_from_buf_tls: "
				"error reading cipher suites\n");
		goto error;
	}
	len += status;

	status = compression_methods_from_buf(ch, buf + len, buf_len - len);
	if (status < 0) {
		KSSL_DEBUG(6, "client_hello_from_buf_tls: "
				"error reading compression methods\n");
		goto error;
	}
	len += status;

	status = len;
error:
	if (status < 0)
		client_hello_destroy_data(ch);
	return status;
}


static int 
client_hello_from_buf_ssl2(client_hello_t *ch, const u8 *buf,
		size_t buf_len, kssl_record_t *cr) 
{
	size_t challenge_len;
	const u8 *cs_in;
	cipher_suite_t *cs_out;
	size_t i;
	size_t nocs;

	ch->session_id.id = NULL;
	ch->cipher_suites.cs = NULL;

	/* There are no Compression Methods in SSLv2 */
	ch->compression_methods.cm = NULL;
	ch->compression_methods.len = 0;

	if (buf_len < 9 + RANDOM_NLEN) {
		KSSL_DEBUG(3, "client_hello_from_buf_ssl2: too short 1\n");
		goto error;
	}

	/* Version and lengths come first */
	protocol_version_from_buf(&(ch->client_version), buf+1);

	/* Cipher Suites must be a multiple of 3 bytes long in SSL2 */
	ch->cipher_suites.len = (*(buf+3) << 8) | *(buf+4);
	if (ch->cipher_suites.len % 3) {
		KSSL_DEBUG(3, "client_hello_from_buf_ssl2: "
				"ciphers buffer is the wrong length\n");
		goto error;
	}

	ch->session_id.len = (*(buf+5) << 8) | *(buf+6);

	/* Random (challenge) must be at least 16 bytes and not
	 * more than RANDOM_NLEN bytes long */
	challenge_len = (*(buf+7) << 8) | *(buf+8);
	if (challenge_len < 16 || challenge_len > RANDOM_NLEN) {
		KSSL_DEBUG(3, "client_hello_from_buf_ssl2: "
				"random buffer is the wrong length\n");
		goto error;
	}

	if (buf_len < 9 + ch->cipher_suites.len + ch->session_id.len
			+ challenge_len) {
		KSSL_DEBUG(3, "client_hello_from_buf_ssl2: too short 2\n");
		goto error;
	}

	ch->session_id.id = kssl_kmalloc(ch->session_id.len, GFP_KERNEL);
	if (!ch->session_id.id) {
		KSSL_DEBUG(6, "client_hello_from_buf_ssl2: kssl_kmalloc\n");
		goto error;
	}
	memcpy(ch->session_id.id, buf+9+ch->session_id.len, 
			ch->session_id.len);

	random_from_buf(&(ch->random), 
			buf+9+ch->cipher_suites.len+ch->session_id.len,
			challenge_len);

	nocs = ch->cipher_suites.len / 3;
	ch->cipher_suites.len = 0;
	cs_out = kssl_kmalloc(nocs * 2, GFP_KERNEL);
	ch->cipher_suites.cs = cs_out;
	if (!ch->cipher_suites.cs) {
		goto error;
	}
	cs_in = buf + 9;
	for(i = 0; i < nocs ; i++) {
		if (!*cs_in) { /* Only Copy SSL3/TLS Ciphers */
			cipher_suite_cpy(cs_out, cs_in+1);
			KSSL_DEBUG(12, "client_hello_from_buf_ssl2: "
					"added   SSLv3 cipher "
					"cs=0x00%02x%02x\n", 
					cs_out->cs[0], cs_out->cs[1]);
			cs_out++;
			ch->cipher_suites.len += 2;
		}
		else {
			KSSL_DEBUG(12, "client_hello_from_buf_ssl2: "
					"ignored SSLv2 cipher "
					"cs=0x%02x%02x%02x\n", 
					cs_in[0], cs_in[1], cs_in[2]);
		}
		cs_in += 3;
	}

	return buf_len;
error:
	client_hello_destroy_data(ch);
	return -EINVAL;
}


static void client_key_exchange_destroy_data(client_key_exchange_t *cke)
{
	/* XXX: Only deals with RSA */
	cke->exchange_keys.rsa.encrypted_data = NULL;
	cke->exchange_keys.rsa.len = 0;
}


static int client_key_exchange_from_buf(kssl_record_t *cr,
		client_key_exchange_t *cke, u8 *buf, size_t buf_len) 
{
	KSSL_DEBUG(12, "client_key_exchange_from_buf: enter\n");

	/* XXX: Only RSA is supported */
	if (cr->conn->conn_state.kes.kea != kssl_kea_rsa) {
		return -EINVAL;
	}

	/* Assume major version is 3 and 
	 *          * minor version is 0 (SSL3) or 1 (TLS1).
	 *                   * This should already have been checked */
	cke->exchange_keys.rsa.encrypted_data = buf;
	if (cr->conn->conn_state.version.minor) {
		cke->exchange_keys.rsa.len = cr->msg->data.handshake.length-2;
		cke->exchange_keys.rsa.offset = 2;
	}
	else {
		cke->exchange_keys.rsa.len = cr->msg->data.handshake.length;
		cke->exchange_keys.rsa.offset = 0;
	}

	return cr->msg->data.handshake.length;
}


void kssl_handshake_destroy_data(handshake_t *hs)
{
	KSSL_DEBUG(12, "kssl_handshake_destroy_data: enter: %u\n",
			hs->msg_type);

	switch (hs->msg_type) {
		case ht_client_hello:
			client_hello_destroy_data(&(hs->body.client_hello));
			return;
		case ht_client_key_exchange:
			client_key_exchange_destroy_data(
					&(hs->body.client_key_exchange));
			return;
		case ht_certificate:
		case ht_hello_request:
		case ht_server_hello:
		case ht_server_key_exchange:
		case ht_certificate_request:
		case ht_server_hello_done:
		case ht_certificate_verify:
		case ht_finished:
		case ht_last:
		default:
			return;
	}

	return;
}


static int 
finished_from_buf(kssl_record_t *cr, finished_t *finished, 
		u8 *buf, size_t buf_len) 
{
	memset(finished, 0, sizeof(finished_t));

	if (buf_len == 12) {
		memcpy(finished->tls.verify_data, buf, 12);
	}
	else {
		memcpy(finished->ssl3.md5_hash, buf, 16);
		memcpy(finished->ssl3.sha1_hash, buf+16, 20);
	}

	return buf_len;
}


static void 
finished_to_buf(const finished_t *finished, u8 *buf, size_t buf_len)
{
	if (buf_len == 12) {
		memcpy(buf, finished->tls.verify_data, 12);
	}
	else {
		memcpy(buf, finished->ssl3.md5_hash, 16);
		memcpy(buf+16, finished->ssl3.sha1_hash, 20);
	}
}


int 
kssl_handshake_body_tls(kssl_record_t *cr) 
{
	handshake_t *hs;
	u8 *buf;
	int status = -EINVAL;
	const char *op_str = "none";

	hs = &(cr->msg->data.handshake);

	KSSL_DEBUG(12, "kssl_handshake_body_tls enter msg.type=%d\n",
			hs->msg_type);

	/* XXX: Probably don't need to copy the buffer in most cases */
	buf = cr->iov->iov_base + HANDSHAKE_HEAD_NLEN + cr->offset;

	/* N:B. functions in here must consume buf */
	switch (hs->msg_type) {
		case ht_client_hello:
			op_str = "client_hello_from_buf_tls";
			status = client_hello_from_buf_tls(
					&(hs->body.client_hello),
					buf, hs->length);
			break;
		case ht_client_key_exchange:
			op_str = "client_key_exchange_from_buf";
			status = client_key_exchange_from_buf(cr,
					&(hs->body.client_key_exchange),
					buf, hs->length);
			break;
		case ht_finished:
			op_str = "finished_from_buf";
			status = finished_from_buf(cr, &(hs->body.finished),
					buf, hs->length);
			break;
		case ht_hello_request:
		case ht_server_hello:
		case ht_certificate:
		case ht_server_key_exchange:
		case ht_certificate_request:
		case ht_server_hello_done:
		case ht_certificate_verify:
		case ht_last:
		default:
			op_str = "unknown/unsuported";
			break;
	}


	if (status < 0) {
		KSSL_DEBUG(6, "kssl_handshake_body_tls: %s\n", op_str);
		return status;
	}

	return status;
}


int 
kssl_handshake_body_ssl2(kssl_record_t *cr) 
{
	handshake_t *hs;
	u8 *buf;
	int status = -EINVAL;

	hs = &(cr->msg->data.handshake);

	/* Verify Content Type */
	/* We only know how to do handshake messages in ssl2 format */
	if (hs->msg_type != ht_client_hello) {
		KSSL_DEBUG(3, "kssl_handshake_body_ssl2: "
				"not a client hello message\n");
		return -EINVAL;
	}

	/* XXX: Probably don't need to copy the buffer in most cases */
	buf = (u8 *)kssl_kmalloc(hs->length, GFP_KERNEL);
	if (!buf) {
		return -EINVAL;
	}
	kssl_record_vec_cpy(cr, buf, hs->length, cr->offset);

	status = client_hello_from_buf_ssl2(&(hs->body.client_hello),
					buf, hs->length, cr);
	if (status < 0) {
		KSSL_DEBUG(6, "kssl_handshake_body_ssl2: "
				"client_hello_from_buf_ssl2\n");
	}

	kssl_kfree(buf);
	return status;
}


int client_hello_process(kssl_record_t *cr, alert_t *alert) 
{
	cipher_suite_t *cs = NULL;
	cipher_suite_t session_cs;
	client_hello_t *ch;
	kssl_session_t *session = NULL;
	int status = -EINVAL;
	size_t i;

	KSSL_DEBUG(12, "client_hello_process: enter\n");

	ch = &(cr->msg->data.handshake.body.client_hello);

	/* Select Version */
	if (kssl_version_verify(&(ch->client_version)) < 0) {
		KSSL_DEBUG(3, "client_hello_process: invalid version\n");
		goto leave;
	}
	memcpy(&(cr->conn->conn_state.version), &(ch->client_version), 
			sizeof(protocol_version_t));

	session = kssl_session_str_find(ch->session_id.id, ch->session_id.len);
	if (session) {
		memcpy(&session_cs, &session->cs, sizeof(session_cs));
		kssl_session_put(session);
		cs =  cipher_suite_nfind(ch->cipher_suites.cs, &session_cs,
				ch->cipher_suites.len);
		if (!cs) {
			KSSL_DEBUG(3, "client_hello_process: "
					"client did not offer session cipher "
					"suite for resumed session\n");
			goto leave;
		}
		cs =  kssl_daemon_cipher_suite_find(cr->conn->daemon,
				&session_cs);
		if (!cs) {
			KSSL_DEBUG(3, "client_hello_process: "
					"server no longer has session cipher "
					"suite for resumed session\n");
			goto leave;
		}
		KSSL_DEBUG(9, "client_hello_process: Found ID: %08x %p\n", 
				session->id, cs);
	}
	else {
		cs =  kssl_daemon_cipher_suite_find_list(cr->conn->daemon,
				ch->cipher_suites.cs, ch->cipher_suites.len);
		if (!cs) {
			KSSL_DEBUG(3, "client_hello_process: "
					"no valid cipher suite\n");
			goto leave;
		}
	}

	/* Choose Compression Method */
	/* If the handshake came in an SSL2 record then
	 * no copmpression methods are supplied, cm_null is assumed
	 */
	if (ch->compression_methods.len) {
		for(i = 0; i < ch->compression_methods.len; i++) {
			/* Only accept cm_null */
			if (ch->compression_methods.cm[i] == cm_null)
				break;
		}
		if (i == ch->compression_methods.len) {
			KSSL_DEBUG(3, "client_hello_process: "
					"no valid compression methods\n");
			goto leave;
		}
	}

	if (kssl_conn_set_cipher_pending(cr->conn, &(ch->random), 
				cs, cm_null) < 0) {
		KSSL_DEBUG(6, "client_hello_process: "
				"kssl_conn_set_cipher_pending\n");
		goto leave;
	}

	if (!session) {
		cr->conn->sess_state.id = kssl_session_add(
				cr->conn->ssl_sock->sk->saddr, 
				cr->conn->ssl_sock->sk->sport, cs, cm_null);
	}
	else {
		cr->conn->sess_state.id = session->id;
	}

	status = server_hello_send(cr, 0);
	if (status < 0) {
		KSSL_DEBUG(6, "client_hello_process: server_hello_send\n");
		goto leave;
	}

	if (!session) {
		kssl_daemon_get_read(cr->conn->daemon);
		if (cr->conn->daemon->key.type != kssl_key_type_rsa) {
			KSSL_DEBUG(3, "client_hello_process: "
					"no certificate / non RSA key\n");
			kssl_daemon_put_read(cr->conn->daemon);
			status = -EEXIST;
			goto leave;
		}
		status = certificate_send(cr, &(cr->conn->daemon->cert.cert), 
				0);
		kssl_daemon_put_read(cr->conn->daemon);
		if (status < 0) {
			KSSL_DEBUG(6, "client_hello_process: "
					"certificate_send\n");
			goto leave;
		}
	
		status = server_hello_done_send(cr, 0);
		if (status < 0) {
			KSSL_DEBUG(6, "client_hello_process: "
					"server_hello_done_send\n");
			goto leave;
		}
	}
	else {
		status = kssl_conn_keying_material_init_from_master_secret(
				cr->conn, session->master_secret);
		if (status < 0) {
			KSSL_DEBUG(6, "client_key_exchange_process: "
			"kssl_conn_keying_material_init_from_master_secret\n");
			goto leave;
		}

		status = kssl_change_cipher_spec_send(cr, 0);
		if (status < 0)
			goto leave;
	
		KSSL_NOTICE(3, "ISSL010: Send(ssl): CHANGE_CIPHER_SPEC (to client)\n");

		/* Call this after a change cipher spec is sent */
		kssl_conn_activate_sec_param_out(cr->conn);
	
		status = finished_send(cr, 0);
		if (status < 0) 
			goto leave;
	}

	status = 0;
leave:
	if (status < 0) {
		alert->level = al_fatal;
		alert->description = ad_handshake_failure;
	}

	return status;
}


int client_key_exchange_process(kssl_record_t *cr, alert_t *alert) 
{
	size_t byte_len = 48; /* Encrypted data is always 48 bytes */
	int status = -ENOMEM;
	u8 *dec = NULL;

	KSSL_DEBUG(12, "client_key_exchange_process enter\n");

	/* Decrypt Pre Master Secret */
	dec = (u8 *) kssl_kmalloc(byte_len, GFP_KERNEL);
	if (!dec) {
		KSSL_DEBUG(6, "client_key_exchange_process: kmalloc\n");
		goto leave;
	}

	status = kssl_asym_decrypt(cr, dec, byte_len);
	if (status < 0) {
		KSSL_DEBUG(6, "client_key_exchange_process: "
				"kssl_asym_decrypt\n");
		goto leave;
	}

	status = kssl_conn_keying_material_init(cr->conn, dec, byte_len);
	if (status < 0) {
		KSSL_DEBUG(6, "client_key_exchange_process: "
				"key material generation failed\n");
		goto leave;
	}

	status = 0;
leave:
	if (dec)
		kssl_kfree(dec);
	if (status < 0) {
		alert->level = al_fatal;
		alert->description = ad_handshake_failure;
	}
	return status;

}


static int __finished_generate_ssl3(finished_t *finished,
		u8 *master_secret, const u8 *sender,
		struct crypto_tfm *hs_digest_md5, 
		struct crypto_tfm *hs_digest_sha1)
{
	struct scatterlist sg[3];
	u8 *buf;

	buf = (u8*)kssl_kmalloc(PAD1_LEN, GFP_KERNEL);
	if (!buf) {
		KSSL_DEBUG(6, "__finished_generate_ssl3: kssl_kmalloc\n");
		return -ENOMEM;
	}

	memcpy(buf, sender, 4);
	sg[0].page = virt_to_page(buf);
	sg[0].offset = ((long) (buf) & ~PAGE_MASK);
	sg[0].length = 4;

	sg[1].page = virt_to_page(master_secret);
	sg[1].offset = ((long) (master_secret) & ~PAGE_MASK);
	sg[1].length = MASTER_SECRET_LEN;

	crypto_digest_update(hs_digest_md5, sg, 2);
	crypto_digest_update(hs_digest_sha1, sg, 2);

	memcpy(buf, pad1, PAD1_LEN);
	sg[0].page = virt_to_page(buf);
	sg[0].offset = ((long) (buf) & ~PAGE_MASK);
	sg[0].length = 48;

	crypto_digest_update(hs_digest_md5, sg, 1);

	sg[0].length = 40;

	crypto_digest_update(hs_digest_sha1, sg, 1);

	crypto_digest_final(hs_digest_md5, finished->ssl3.md5_hash);
	crypto_digest_final(hs_digest_sha1, finished->ssl3.sha1_hash);

	crypto_digest_init(hs_digest_md5);
	crypto_digest_init(hs_digest_sha1);

	sg[0].page = virt_to_page(master_secret);
	sg[0].offset = ((long) (master_secret) & ~PAGE_MASK);
	sg[0].length = MASTER_SECRET_LEN;

	memcpy(buf, pad2, PAD2_LEN);
	sg[1].page = virt_to_page(buf);
	sg[1].offset = ((long) (buf) & ~PAGE_MASK);
	sg[1].length = 48;

	sg[2].page = virt_to_page(finished->ssl3.md5_hash);
	sg[2].offset = ((long) (finished->ssl3.md5_hash) & ~PAGE_MASK);
	sg[2].length = 16;

	crypto_digest_update(hs_digest_md5, sg, 3);

	sg[1].length = 40;

	sg[2].page = virt_to_page(finished->ssl3.sha1_hash);
	sg[2].offset = ((long) (finished->ssl3.sha1_hash) & ~PAGE_MASK);
	sg[2].length = 20;

	crypto_digest_update(hs_digest_sha1, sg, 3);

	crypto_digest_final(hs_digest_md5, finished->ssl3.md5_hash);
	crypto_digest_final(hs_digest_sha1, finished->ssl3.sha1_hash);

	if (buf)
		kssl_kfree(buf);
	return 0;
}


static int __finished_generate_tls(finished_t *finished,
		u8 *master_secret, const u8 *label,
		struct crypto_tfm *hs_digest_md5, 
		struct crypto_tfm *hs_digest_sha1) 
{ 
	u8 *result;
	u8 result_md5[16];
	u8 result_sha1[20];

	crypto_digest_final(hs_digest_md5, result_md5);
	crypto_digest_final(hs_digest_sha1, result_sha1);

	result = kssl_prf_tls(master_secret, MASTER_SECRET_LEN,
			label, strlen(label), result_md5, 16, 
			result_sha1, 20, 12);
	if (!result) {
		return -ENOMEM;
	}

	memcpy(finished->tls.verify_data, result, 12);
	kssl_kfree(result);

	return 0;
}

#define HOST_TYPE_SERVER 1
#define HOST_TYPE_CLIENT 2

static int finished_generate(finished_t *finished,
		kssl_record_t *cr, int host_type,
		opaque_t *master_secret,
		struct crypto_tfm *hs_digest_md5, 
		struct crypto_tfm *hs_digest_sha1) 
{ 
	/* Major version 3 and minor version of 0 or 1 
	 * should already have been checked */
	if (cr->conn->conn_state.version.minor) {
		return(__finished_generate_tls(finished, master_secret,
				host_type == HOST_TYPE_SERVER ?
				finished_sender_server_tls :
				finished_sender_client_tls,
				hs_digest_md5, hs_digest_sha1));
	}
	return(__finished_generate_ssl3(finished, master_secret,
			host_type == HOST_TYPE_SERVER ?
			finished_sender_server_ssl3 :
			finished_sender_client_ssl3,
			hs_digest_md5, hs_digest_sha1));
}

int finished_process(kssl_record_t *cr, alert_t *alert,
		struct crypto_tfm *hs_digest_md5, 
		struct crypto_tfm *hs_digest_sha1) 
{
	finished_t finished;
	int status = 0;

	KSSL_DEBUG(12, "finished_process enter\n"); 

	if ((status = finished_generate(&finished, cr, HOST_TYPE_CLIENT,
			cr->conn->sec_param_in_act->master_secret,
			hs_digest_md5, hs_digest_sha1))) {
		KSSL_DEBUG(6, "finished_process: finished_generate\n");
		goto leave;
	}

	/* Verify record */

	/* Major version 3 and minor version of 0 or 1 
	 * should already have been checked */
	/* plaintext should have been checked to make sure that
	 * it is sufficient length */
	if (cr->conn->conn_state.version.minor) {
		if (memcmp(cr->msg->data.handshake.body.finished.tls.verify_data, 
				finished.tls.verify_data, 12)) {
			KSSL_DEBUG(6, "finished_process: missmatch (tls)\n");
			status = -ENOMEM;
			goto leave;
		}
	}
	else {
		if (memcmp(cr->msg->data.handshake.body.finished.ssl3.md5_hash, 
				finished.ssl3.md5_hash, 16) ||
			memcmp(cr->msg->data.handshake.body.finished.ssl3.sha1_hash, 
					finished.ssl3.sha1_hash, 20)) {
			KSSL_DEBUG(6, "finished_process: missmatch (ssl3)\n");
			status = -ENOMEM;
			goto leave;
		}
	}

	/* The session may have been resumed in which case the
	 * both the outgoing and incoming security parameters 
	 * will already have been activated. */

	if (cr->conn->sec_param_out_act == cr->conn->sec_param_in_act)
		goto leave;

	status = kssl_change_cipher_spec_send(cr, 0);
	if (status < 0)
		goto leave;

	KSSL_NOTICE(3, "ISSL010: Send(ssl): CHANGE_CIPHER_SPEC (to client)\n");

	/* Call this after a change cipher spec is sent */
	kssl_conn_activate_sec_param_out(cr->conn);

	status = finished_send(cr, 0);
leave:
	if (status < 0) {
		alert->level = al_fatal;
		alert->description = ad_handshake_failure;
	}
	return status;

}


static int server_hello_write(kssl_record_t *cr, void *data)
{
	u8 *buf;

	KSSL_DEBUG(12, "server_hello_write enter\n");

	buf = cr->iov->iov_base;

	/* Record Head */
	kssl_record_head_set(cr, ct_handshake, 
			SERVER_HELLO_NLEN(KSSL_SESSION_ID_LEN));
	buf += TLS_HEAD_NLEN;

	/* Handshake Head */
	handshake_head_parts_to_buf(
			SERVER_HELLO_BODY_NLEN(KSSL_SESSION_ID_LEN), 
			ht_server_hello, buf);
	buf += HANDSHAKE_HEAD_NLEN;

	/* Server Hello */
	protocol_version_to_buf(&(cr->conn->conn_state.version), buf);
	buf += PROTOCOL_VERSION_NLEN;

	memcpy(buf, cr->conn->sec_param_out_pend->server_random, RANDOM_NLEN);
	buf += RANDOM_NLEN;

	kssl_session_str_generate(cr->conn->sess_state.id, buf);
	buf += KSSL_SESSION_ID_LEN + 1;

	cipher_suite_cpy(buf, &(cr->conn->conn_state.cs));
	buf += CIPHER_SUITE_NLEN;

	*buf = cr->conn->sec_param_in_pend->compression_algorithm;

	/* Fill out handshake structure */
	if (cr->msg)
		kssl_message_destroy(cr->msg);
	cr->msg = kssl_message_create(ct_handshake);
	if (!cr->msg) {
		return -ENOMEM;
	}
	handshake_head_from_buf(&(cr->msg->data.handshake), 
			cr->iov->iov_base + TLS_HEAD_NLEN);

	/* Update digests */
	kssl_handshake_digest_update(cr);

	return 0;
}


static int server_hello_send(kssl_record_t *cr, int reuse)
{
	KSSL_NOTICE(3, "ISSL003: Send(handshake): SERVER_HELLO\n");
	return(kssl_record_build_send(cr, server_hello_write,
			SERVER_HELLO_NLEN(KSSL_SESSION_ID_LEN) + TLS_HEAD_NLEN,
			NULL, reuse));
}

static int certificate_write(kssl_record_t *cr, void *data)
{
	u8 *buf;
	struct iovec *cert;
	u32 tmp_len;

	KSSL_DEBUG(12, "certificate_write enter\n");

	cert = (struct iovec *)data;
	buf = cr->iov->iov_base;

	/* Record Head */
	kssl_record_head_set(cr, ct_handshake, cert->iov_len +
			HANDSHAKE_HEAD_NLEN + CERTIFICATE_HEAD_NLEN);
	buf += TLS_HEAD_NLEN;

	/* Handshake Head */
	handshake_head_parts_to_buf(cert->iov_len + CERTIFICATE_HEAD_NLEN,
			ht_certificate, buf);
	buf += HANDSHAKE_HEAD_NLEN;

	/* Certificate List Head */
	tmp_len = htonl(cert->iov_len + CERTIFICATE_CERT_HEAD_NLEN);
	memcpy(buf, ((u8 *)&tmp_len) + 1, 3);
	buf += CERTIFICATE_LIST_HEAD_NLEN;

	/* Certificate Head */
	tmp_len = htonl(cert->iov_len);
	memcpy(buf, ((u8 *)&tmp_len) + 1, 3);
	buf += CERTIFICATE_CERT_HEAD_NLEN;

	/* Certificate */
	memcpy(buf, cert->iov_base, cert->iov_len);

	/* Fill out handshake structure */
	if (cr->msg)
		kssl_message_destroy(cr->msg);
	cr->msg = kssl_message_create(ct_handshake);
	if (!cr->msg) {
		KSSL_DEBUG(6, "certificate_write: kssl_message_create\n");
		return -ENOMEM;
	}
	handshake_head_from_buf(&(cr->msg->data.handshake), 
			cr->iov->iov_base + TLS_HEAD_NLEN);

	/* Update digests */
	kssl_handshake_digest_update(cr);

	return 0;
}

static int certificate_send(kssl_record_t *cr, struct iovec *cert, int reuse)
{
	KSSL_NOTICE(3, "ISSL004: Send(handshake): CERTIFICATE\n");
	return kssl_record_build_send(cr, certificate_write, 
				cert->iov_len + TLS_HEAD_NLEN + 
				HANDSHAKE_HEAD_NLEN + 
				CERTIFICATE_HEAD_NLEN, cert, reuse);
}


static int server_hello_done_write(kssl_record_t *cr, void *data)
{
	u8 *buf;

	KSSL_DEBUG(12, "server_hello_done_write enter\n");

	buf = cr->iov->iov_base;

	/* Record Head */
	kssl_record_head_set(cr, ct_handshake, HANDSHAKE_HEAD_NLEN);

        /* Handshake Head */
	handshake_head_parts_to_buf(0, ht_server_hello_done, 
			buf + TLS_HEAD_NLEN);

	/* Server Hello Done has no body */

	/* Fill out handshake structure */
	if (cr->msg)
		kssl_message_destroy(cr->msg);
	cr->msg = kssl_message_create(ct_handshake);
	if (!cr->msg) {
		return -ENOMEM;
	}
	handshake_head_from_buf(&(cr->msg->data.handshake), 
			buf + TLS_HEAD_NLEN);

	/* Update digests */
	kssl_handshake_digest_update(cr);

	return 0;
}


static int server_hello_done_send(kssl_record_t *cr, int reuse) 
{
	KSSL_NOTICE(3, "ISSL005: Send(handshake): SERVER_DONE\n");
	return(kssl_record_build_send(cr, server_hello_done_write,
			HANDSHAKE_HEAD_NLEN + TLS_HEAD_NLEN, NULL, reuse));
}


static int finished_write(kssl_record_t *cr, void *data)
{
	u8 *buf;
	finished_t *finished;
	u16 len;

	KSSL_DEBUG(12, "finished_write enter\n");

	finished = (finished_t *)data;
	buf = cr->iov->iov_base;

	len = cr->conn->conn_state.version.minor ? 12 : 16 + 20;

	/* Record Head */
	kssl_record_head_set(cr, ct_handshake, HANDSHAKE_HEAD_NLEN + len);

	/* Handshake Head */
	handshake_head_parts_to_buf(len, ht_finished,  buf + TLS_HEAD_NLEN);

	/* Finished content */
	finished_to_buf(finished, buf + TLS_HEAD_NLEN + HANDSHAKE_HEAD_NLEN, 
			len);

	/* Fill out handshake structure */
	if (cr->msg)
		kssl_message_destroy(cr->msg);
	cr->msg = kssl_message_create(ct_handshake);
	if (!cr->msg) {
		return -ENOMEM;
	}
	handshake_head_from_buf(&(cr->msg->data.handshake), 
			buf + TLS_HEAD_NLEN);

	/* Update digests */
	kssl_handshake_digest_update(cr);

	return 0;
}

static int finished_send(kssl_record_t *cr, int reuse) 
{
	struct crypto_tfm *hs_digest_sha1 = NULL;
	struct crypto_tfm *hs_digest_md5 = NULL;
	int status = -ENOMEM;
	finished_t finished;

	KSSL_DEBUG(12, "finished_send: enter\n");

	if (kssl_conn_cpy_digest(cr->conn, &hs_digest_md5, &hs_digest_sha1)) {
		KSSL_DEBUG(6, "finished_send: digest copy failed\n");
		return status;
	}

	if ((status = finished_generate(&finished, cr, HOST_TYPE_SERVER,
			cr->conn->sec_param_out_act->master_secret,
			hs_digest_md5, hs_digest_sha1))) {
		KSSL_DEBUG(6, "finished_send: generate failed\n");
		goto leave;
	}

	/* Major version 3 and minor version of 0 or 1 
	 * should already have been checked */
	status = kssl_record_build_send(cr, finished_write,
			HANDSHAKE_HEAD_NLEN + TLS_HEAD_NLEN + 
			(cr->conn->conn_state.version.minor ? 12 : 16 + 20),
			&finished, reuse);

	KSSL_NOTICE(3, "ISSL008: Send(handshake): SERVER_FINISHED\n");
	KSSL_NOTICE(3, "ISSL011: Info(session): Client session established (session ID=%x\n", cr->conn->sess_state.id);

leave:
	if (hs_digest_md5)
		crypto_free_tfm(hs_digest_md5);
	if (hs_digest_sha1)
		crypto_free_tfm(hs_digest_sha1);
	return status;
}


int kssl_handshake_digest_update(kssl_record_t *cr) 
{
	struct scatterlist *sg;
	size_t total_len;
	size_t nvec;
	size_t len;
	size_t offset;
	int status = -EINVAL;

	KSSL_DEBUG(12, "kssl_handshake_digest_update: enter\n");

#if 0
	{ /* DEBUG */
		struct crypto_tfm *hs_digest_md5 = NULL;
		struct crypto_tfm *hs_digest_sha1 = NULL;
		u8 result_md5[16];
		u8 result_sha1[20];

		if ((status = kssl_conn_cpy_digest(cr->conn, &hs_digest_md5, 
				&hs_digest_sha1))) {
			return status;
		}
		
		crypto_digest_final(hs_digest_md5, result_md5);
		crypto_digest_final(hs_digest_sha1, result_sha1);

		asym_print_char(KERN_DEBUG "digest_md5", result_md5, 16);
		asym_print_char(KERN_DEBUG "digest_sha1", result_sha1, 20);

		crypto_free_tfm(hs_digest_md5); hs_digest_md5 = NULL;
		crypto_free_tfm(hs_digest_sha1); hs_digest_sha1 = NULL;
	}
#endif

	if (cr->record.head.type == ct_ssl2) {
		len = HANDSHAKE_HEAD_NLEN + cr->msg->data.handshake.length + 2;
		offset = 2;
	}
	else {
		len = HANDSHAKE_HEAD_NLEN + cr->msg->data.handshake.length;
		offset = TLS_HEAD_NLEN;
	}
	status = kssl_record_to_sg(cr, offset, len, &sg, &nvec, &total_len);
	if (status < 0)
		return status;

#if 0
	{ /* DEBUG */
		u32 i;
		u8 tmp;
		for(i = 0; i < nvec; i++) {
			tmp = kmap(sg[i].page) + sg[i].offset;
			asym_print_char(KERN_DEBUG "hs", tmp, 
					sg[i].length);
		}
	}
#endif

	crypto_digest_update(cr->conn->conn_state.hs_digest_md5, sg, nvec);
	crypto_digest_update(cr->conn->conn_state.hs_digest_sha1, sg, nvec);

	kssl_kfree(sg);

	return 0;
}


int kssl_handshake_process(kssl_record_t *cr, alert_t *alert) 
{
	int status = -EINVAL;
	struct crypto_tfm *hs_digest_md5 = NULL;
	struct crypto_tfm *hs_digest_sha1 = NULL;

	KSSL_DEBUG(12, "kssl_handshake_process: enter\n");

	if (cr->msg->data.handshake.msg_type == ht_finished) {
		if ((status = kssl_conn_cpy_digest(cr->conn, &hs_digest_md5, 
				&hs_digest_sha1))) {
			goto leave;
		}
	}
	/* RFC 2246 7.4.9: 
	 * Also, Hello Request messages are omitted from handshake hashes. */
	if (cr->msg->data.handshake.msg_type != ht_hello_request) {
		if ((status = kssl_handshake_digest_update(cr))) {
			goto leave;
		}
	}

	switch (cr->msg->data.handshake.msg_type) {
		case ht_hello_request:
			break;
		case ht_client_hello:
			KSSL_NOTICE(3, "ISSL002: Recv(handshake): CLIENT_HELLO\n");
			status = client_hello_process(cr, alert);
			break;
		case ht_client_key_exchange:
			KSSL_NOTICE(3, "ISSL006: Recv(handshake): CLIENT_KEY_EXCHANGE\n");
			status = client_key_exchange_process(cr, alert);
			break;
		case ht_finished:
			KSSL_NOTICE(3, "ISSL007: Recv(handshake): CLIENT_FINISHED\n");
			status = finished_process(cr, alert, hs_digest_md5,
					hs_digest_sha1);
			break;
		case ht_server_hello:
		case ht_certificate:
		case ht_server_key_exchange:
		case ht_certificate_request:
		case ht_server_hello_done:
		case ht_certificate_verify:
			break;
		case ht_last:
		default:
			goto leave;
	}

	if (status >= 0)
		kssl_record_destroy(cr);

leave:
	if (hs_digest_md5)
		crypto_free_tfm(hs_digest_md5);
	if (hs_digest_sha1)
		crypto_free_tfm(hs_digest_sha1);
	return status;
}

