/**********************************************************************
 * elgamal.c                                                August 2005
 *
 * ASYM: An implementation of Asymetric Cryptography in the Linux Kernel
 * Copyright (C) 2005  NTT COMWARE Corporation.
 *
 * This file based in part on code from LVS www.linuxvirtualserver.org
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
 * 02110-1301, USA.
 *
 **********************************************************************/

#ifdef __KERNEL__
#include <linux/types.h>
#include <linux/random.h>
#else
#include "compat.h"
#endif

#include "pk.h"
#include "unsx.h"

#define unsxLen (bitLen / unsx_Bit_Length)
#define RETURN(RET)\
do {\
  ret = RET;\
  goto keygen_done;\
} while(0)
int elgamal_keygen(elgamal_key_t *pub, elgamal_key_t *pri,
                   int bitLen, unsigned genOptions) {
  int i;
  int ret = PK_OK;
  UNSX_NEW(p, unsxLen);
  UNSX_NEW(g, unsxLen);
  UNSX_NEW(ga, unsxLen);
  UNSX_NEW(a, unsxLen);

  if(!p || !g || !ga || !a) {
    RETURN(-ENOMEM);
  }

  unsx_deinit();
  ret = unsx_init(bitLen);
  if(ret < 0) {
    RETURN(ret);
  }

  pub->len = pri->len = unsxLen;
  pub->p = (unsx*)kmalloc(pub->len * sizeof(unsx), GFP_KERNEL);
  pub->g = (unsx*)kmalloc(pub->len * sizeof(unsx), GFP_KERNEL);
  pub->ga = (unsx*)kmalloc(pub->len * sizeof(unsx), GFP_KERNEL);
  pub->a = NULL;

  pri->p = (unsx*)kmalloc(pri->len * sizeof(unsx), GFP_KERNEL);
  pri->g = (unsx*)kmalloc(pri->len * sizeof(unsx), GFP_KERNEL);
  pri->ga = (unsx*)kmalloc(pri->len * sizeof(unsx), GFP_KERNEL);
  pri->a = (unsx*)kmalloc(pri->len * sizeof(unsx), GFP_KERNEL);

  if(!pub->p || !pub->g || !pub->ga ||
		  !pri->p || !pri->g || !pri->ga || !pri->a) {
    RETURN(-ENOMEM);
  }

  unsx_setZeroDual(p, g, unsxLen);
  unsx_setZero(a, unsxLen);

  /* I don't think we need do do this in the Kernel as
   * get_random_bytes() doesn't need seeding (I think) - NTT COMWARE */
#if 0
  {
    time_t x;
    do_gettimeofday(&x);
    srand(x);
  }
#endif

  for (i=0; i<unsxLen; i++) {
    get_random_bytes(&(p[i]), sizeof(p[i]));
    get_random_bytes(&(g[i]), sizeof(g[i]));
    get_random_bytes(&(a[i]), sizeof(a[i]));
  }

  ret = unsx_nextPrime(p, unsxLen);
  if(ret < 0)
    RETURN(ret);

  ret = unsx_isPrimePRP(p,unsxLen);
  if(ret < 0)
    RETURN(ret);
  if (!ret)
    RETURN(PK_KEYGEN_FAILED);

  if (unsx_cmp(p, g, unsxLen) <= 0) {
    ret = unsx_mod(g, g, unsxLen, p, unsxLen);
    if(ret < 0) {
      RETURN(ret);
    }
  }
  if (unsx_cmp(p, a, unsxLen) <= 0) {
    ret = unsx_mod(a, a, unsxLen, p, unsxLen);
    if(ret < 0) {
      RETURN(ret);
    }
  }

  ret = unsx_modPow(ga, g, a, p, unsxLen);
  if(ret < 0) {
    RETURN(ret);
  }

#if 0
  printk(KERN_DEBUG "p:\n"); 
  for (i=0; i<unsx_countArr(p,unsxLen); i++) 
	  printf(KERN_DEBUG "  %.8x\n", p[i]); 
  printk(KERN_DEBUG "g:\n"); 
  for (i=0; i<unsx_countArr(g,unsxLen); i++) 
	  printf(  "KERN_DEBUG %.8x\n", g[i]); 
  printk("KERN_DEBUG ga:\n"); 
  for (i=0; i<unsx_countArr(ga,unsxLen); i++) 
	  printf("KERN_DEBUG   %.8x\n", ga[i]); 
  printk("KERN_DEBUG a:\n"); 
  for (i=0; i<unsx_countArr(a,unsxLen); i++) 
	  printf("KERN_DEBUG   %.8x\n", a[i]); 
#endif

  memcpy(pub->p, p, unsxLen*sizeof(unsx));
  memcpy(pub->g, g, unsxLen*sizeof(unsx));
  memcpy(pub->ga, ga, unsxLen*sizeof(unsx));
  memcpy(pri->p, p, unsxLen*sizeof(unsx));
  memcpy(pri->g, g, unsxLen*sizeof(unsx));
  memcpy(pri->ga, ga, unsxLen*sizeof(unsx));
  memcpy(pri->a, a, unsxLen*sizeof(unsx));

keygen_done:
  if(p)
    UNSX_FREE(p, unsxLen);
  if(g)
    UNSX_FREE(g, unsxLen);
  if(ga)
    UNSX_FREE(ga, unsxLen);
  if(a)
    UNSX_FREE(a, unsxLen);
  if(ret) {
    elgamal_key_destroy_data(pub);
    elgamal_key_destroy_data(pri);
  }

  return ret;
}
#undef unsxLen
#undef RETURN

#define RETURN(RET)\
do {\
  ret = RET;\
  goto elgamal_encrypt_ret;\
} while(0);
int elgamal_encrypt(char *out, int *outLen,
                           const elgamal_key_t *pub,
                           const char *in, int inLen,
                           unsigned options) {
  char *inBuf = (char*)kmalloc(UNSX_LENGTH * sizeof(unsx), GFP_KERNEL);
  UNSX_NEW(K, UNSX_LENGTH);
  UNSX_NEW(Cy, UNSX_LENGTH);
  UNSX_NEW(Cd, UNSX_LENGTH);
  UNSX_NEW(IN, UNSX_LENGTH);
  int i,k, ret=PK_OK;
  unsx t;

  if(!inBuf || !K || !Cy || !Cd || !IN) {
    RETURN(-ENOMEM);
  }

  if (pub->p == NULL) {
    RETURN(PK_INVALID_KEY);
  }
  
  memset(inBuf, 0, UNSX_LENGTH * sizeof(unsx));

  k = 0;
  inBuf[k++] = 0x00;
  inBuf[k++] = 0x02;
  for (i=0; i<UNSX_LENGTH*sizeof(unsx)-3-inLen; i++) {
    while (inBuf[k] == 0) {
      inBuf[k] = 0;
      get_random_bytes(&(inBuf[k]), 1);
   }
    k++;
  }
  memcpy(&inBuf[k+1], in, inLen);
  
  unsx_setZero(IN, UNSX_LENGTH);
  for (i=0; i<UNSX_LENGTH*sizeof(unsx); i+=4) {
    IN[UNSX_LENGTH-1-(i/4)] = ((inBuf[i  ]&0xff)<<24)
                             |((inBuf[i+1]&0xff)<<16)
                             |((inBuf[i+2]&0xff)<< 8)
                             |((inBuf[i+3]&0xff)    );
  }

  unsx_setZero(K, UNSX_LENGTH);
  for (i=0; i<pub->len-1; i++) {
    get_random_bytes(&(K[i]), sizeof(K[i]));
    K[i] *= 0x00010001;
  }

  if (unsx_cmp(K, pub->p, pub->len) >= 0) {
    ret = unsx_mod(K, K, UNSX_LENGTH, pub->p, pub->len);
    if(ret < 0)
      RETURN(ret);
  }
  ret = unsx_modPow(Cy, pub->g, K, pub->p, pub->len);  /* Cy = g^K mod p     */
  if(ret < 0)
    RETURN(ret);
  ret = unsx_modPow(Cd, pub->ga, K, pub->p, pub->len); /* Cd = (g^a)^k mod p */
  if(ret < 0)
    RETURN(ret);
  ret = unsx_modMult(Cd, Cd, IN, pub->p, pub->len);    /* Cd =               */
  if(ret < 0)                                          /* M * (g^a)^k mod p  */
    RETURN(ret);

  for (k=0,i=0; i<UNSX_LENGTH; i++) {
    t = Cy[UNSX_LENGTH-1-i];
    out[k  ] = t; t >>= 8;
    out[k+1] = t; t >>= 8;
    out[k+2] = t; t >>= 8;
    out[k+3] = t;
    k+=4;
  }

  for (i=0; i<UNSX_LENGTH; i++) {
    t = Cd[UNSX_LENGTH-1-i];
    out[k  ] = t; t >>= 8;
    out[k+1] = t; t >>= 8;
    out[k+2] = t; t >>= 8;
    out[k+3] = t;
    k+=4;
  }
  *outLen = 2 * UNSX_LENGTH * sizeof(unsx);

elgamal_encrypt_ret:
  if(IN)
    UNSX_FREE(IN, UNSX_LENGTH);
  if(Cy)
    UNSX_FREE(Cy, UNSX_LENGTH);
  if(Cd)
    UNSX_FREE(Cd, UNSX_LENGTH);
  if(K)
    UNSX_FREE(K, UNSX_LENGTH);
  if(inBuf) {
    memset(inBuf, 0, UNSX_LENGTH * sizeof(unsx));
    kfree(inBuf); inBuf = NULL;
  }

  return ret;
}
#undef RETURN

#define RETURN(RET)\
do {\
  ret = RET;\
  goto elgamal_decrypt_ret;\
} while(0)
int elgamal_decrypt(char *out, int *outLen,
                           const elgamal_key_t *pri,
                           const char *in, int inLen,
                           unsigned options) {
  char *tmpBuf = (char*)kmalloc(UNSX_LENGTH * sizeof(unsx), GFP_KERNEL);
  UNSX_NEW(Cy, UNSX_LENGTH);
  UNSX_NEW(Cd, UNSX_LENGTH);
  int i, k=0, ret=PK_OK;
  unsx t;

  if(!tmpBuf || !Cy || !Cd) {
    RETURN(-ENOMEM);
  }

  if (pri->a == NULL) {
    RETURN(PK_INVALID_KEY);
  }
  
  unsx_setZero(Cy, UNSX_LENGTH);
  for (i=0; i<inLen/2; i+=4) {
    Cy[UNSX_LENGTH-1-(i/4)] = ((in[k+3]&0xff)<<24)
                             |((in[k+2]&0xff)<<16)
                             |((in[k+1]&0xff)<< 8)
                             |((in[k  ]&0xff)    );
    k+=4;
  }

  unsx_setZero(Cd, UNSX_LENGTH);
  for (i=0; i<inLen/2; i+=4) {
    Cd[UNSX_LENGTH-1-(i/4)] = ((in[k+3]&0xff)<<24)
                             |((in[k+2]&0xff)<<16)
                             |((in[k+1]&0xff)<< 8)
                             |((in[k  ]&0xff)    );
    k+=4;
  }

  ret = unsx_modPow(Cy, Cy, pri->a, pri->p, pri->len); /* Cy = Cy ^ a       */
  if(ret < 0)
    RETURN(ret);
  ret = unsx_modInv(Cy, Cy, pri->p, pri->len);         /* Cy = 1 / (Cy ^ a) */
  if(ret < 0)
    RETURN(ret);
  ret = unsx_modMult(Cd, Cd, Cy, pri->p, pri->len);    /* M = Cd =          */
  if(ret < 0)                                          /*  Cd / (Cy ^ a)    */
    RETURN(ret);

  for (i=0; i<UNSX_LENGTH; i++) {
    t = Cd[UNSX_LENGTH-1-i];
    tmpBuf[4*i  ] = (t >> 24);
    tmpBuf[4*i+1] = (t >> 16);
    tmpBuf[4*i+2] = (t >>  8);
    tmpBuf[4*i+3] = t;
  }

  if ((tmpBuf[0]&0xff) != 0 ||(tmpBuf[1]&0xff) != 0x02) {
    RETURN(PK_UNWRAP_FAILED);
  }
  for (i=2; i<UNSX_LENGTH*sizeof(unsx); i++) {
    if (tmpBuf[i] == 0) {
      break;
    }
  }

  memcpy(out, &tmpBuf[i+1], UNSX_LENGTH*sizeof(unsx)-1-i);

  *outLen = UNSX_LENGTH*sizeof(unsx) -1 -i;

elgamal_decrypt_ret:
  if(Cy)
    UNSX_FREE(Cy, UNSX_LENGTH);
  if(Cd)
    UNSX_FREE(Cd, UNSX_LENGTH);
  if(tmpBuf) {
    memset(tmpBuf, 0, UNSX_LENGTH * sizeof(unsx));
    kfree(tmpBuf); 
  }

  return ret;
}
#undef RETURN

#define RETURN(RET)\
do {\
  ret = RET;\
  goto elgamal_sign_ret;\
} while(0)
int elgamal_sign(char *out, int *outLen,
                        const elgamal_key_t *pri,
                        const char *in, int inLen,
                        unsigned options) {
  char *inBuf = (char*)kmalloc(UNSX_LENGTH * sizeof(unsx), GFP_KERNEL);
  UNSX_NEW(K, UNSX_LENGTH);
  UNSX_NEW(pMinus1, UNSX_LENGTH);
  UNSX_NEW(kr, UNSX_LENGTH);
  UNSX_NEW(ah, UNSX_LENGTH);
  UNSX_NEW(r, UNSX_LENGTH);
  UNSX_NEW(s, 1+UNSX_LENGTH);
  UNSX_NEW(IN, UNSX_LENGTH);
  int i,k, ret=PK_OK;
  unsx t;

  if (!inBuf || !K || !pMinus1 || !kr || !ah || !r || !s || !IN ||
		  pri->a == NULL) {
    RETURN(PK_INVALID_KEY);
  }
  
  memset(inBuf, 0, UNSX_LENGTH * sizeof(unsx));

  k = 0;
  inBuf[k++] = 0x00;
  inBuf[k++] = 0x01;
  for (i=0; i<UNSX_LENGTH*sizeof(unsx)-3-inLen; i++) {
    while (inBuf[k] == 0)
      inBuf[k] = 0xff;
    k++;
  }
  memcpy(&inBuf[k+1], in, inLen);

  unsx_setZero(IN, UNSX_LENGTH);
  for (i=0; i<UNSX_LENGTH*sizeof(unsx); i+=4) {
    IN[UNSX_LENGTH-1-(i/4)] = ((inBuf[i  ]&0xff)<<24)
                             |((inBuf[i+1]&0xff)<<16)
                             |((inBuf[i+2]&0xff)<< 8)
                             |((inBuf[i+3]&0xff)    );
  }


  unsx_dec(pMinus1, pri->p, pri->len);  /* pMinus1 = p - 1 */
  unsx_setZero(K, UNSX_LENGTH);

  for (i=0; i<pri->len; i++) {
    get_random_bytes(&(K[i]), sizeof(K[i]));
    K[i] *= 0x00010001;
    get_random_bytes(&(s[i]), sizeof(s[i]));
    s[i] *= 0x00010001;
  }

  if (unsx_cmp(K, pMinus1, pri->len) >= 0) {
    ret = unsx_mod(K, K, UNSX_LENGTH, pMinus1, pri->len);
    if(ret < 0) {
      RETURN(ret);
    }
  }

  ret = unsx_gcd(s, K, pMinus1, pri->len);
  if(ret < 0) {
    RETURN(ret);
  }

  while (!unsx_isOne(s,pri->len)) {
    unsx_inc(K, K, pri->len);
    ret = unsx_gcd(s, K, pMinus1, pri->len);
    if(ret < 0) {
      RETURN(ret);
    }
  }

  ret = unsx_modPow( r,    pri->g,   K, pri->p,  pri->len);
  if(ret < 0) {
    RETURN(ret);
  }
  ret = unsx_modMult(kr,   K,        r, pMinus1, pri->len);
  if(ret < 0) {
    RETURN(ret);
  }
  ret = unsx_modMult(ah,   pri->a,   IN, pMinus1, pri->len);
  if(ret < 0) {
    RETURN(ret);
  }

  s[pri->len] = unsx_add(s, ah, kr, pri->len);
  if (unsx_cmp(s, pMinus1, pri->len) > 0  || s[pri->len] != 0) {
    ret = unsx_mod(s, s, pri->len+1, pMinus1, pri->len);
    if(ret < 0) {
      RETURN(ret);
    }
  } else {
    unsx_set(s, s, pri->len);
  }

  for (k=0,i=0; i<UNSX_LENGTH; i++) {
    t = r[UNSX_LENGTH-1-i];
    out[k+3] = t; t >>= 8;
    out[k+2] = t; t >>= 8;
    out[k+1] = t; t >>= 8;
    out[k  ] = t;
    k+=4;
  }

  for (i=0; i<UNSX_LENGTH; i++) {
    t = s[UNSX_LENGTH-1-i];
    out[k+3] = t; t >>= 8;
    out[k+2] = t; t >>= 8;
    out[k+1] = t; t >>= 8;
    out[k  ] = t;
    k+=4;
  }
  *outLen = 2 * UNSX_LENGTH * sizeof(unsx);

elgamal_sign_ret:
  if(K)
    UNSX_FREE(K, UNSX_LENGTH);
  if(pMinus1)
    UNSX_FREE(pMinus1, UNSX_LENGTH);
  if(kr)
    UNSX_FREE(kr, UNSX_LENGTH);
  if(ah)
    UNSX_FREE(ah, UNSX_LENGTH);
  if(r)
    UNSX_FREE(r, UNSX_LENGTH);
  if(s)
    UNSX_FREE(s, 1+UNSX_LENGTH);
  if(IN)
    UNSX_FREE(IN, UNSX_LENGTH);
  if(inBuf) {
    memset(inBuf, 0, UNSX_LENGTH * sizeof(unsx));
    kfree(inBuf); inBuf = NULL;
  }

  return ret;
}
#undef RETURN

#define RETURN(RET)\
do {\
  ret = RET;\
  goto elgamal_verify_ret;\
} while(0)
int elgamal_verify(const char *hash, int hashLen, const elgamal_key_t *pub,
                          const char *in, int inLen, unsigned options) 
{
  char *hashBuf = (char*)kmalloc(UNSX_LENGTH * sizeof(unsx), GFP_KERNEL);
  UNSX_NEW(r, UNSX_LENGTH);
  UNSX_NEW(s, UNSX_LENGTH);
  UNSX_NEW(H, UNSX_LENGTH);
  UNSX_NEW(v1, UNSX_LENGTH);
  UNSX_NEW(v2, UNSX_LENGTH);
  int i, k=0, ret=PK_OK;

  if(!hashBuf || !r || !s || !H || !v1 || !v2) {
    RETURN(-ENOMEM);
  }

  if (pub->p == NULL) {
    RETURN(PK_INVALID_KEY);
  }
  
  unsx_setZero(r, UNSX_LENGTH);
  for (i=0; i<inLen/2; i+=4) {
    r[UNSX_LENGTH-1-(i/4)] = ((in[k  ]&0xff)<<24)
                            |((in[k+1]&0xff)<<16)
                            |((in[k+2]&0xff)<< 8)
                            |((in[k+3]&0xff)    );
    k+=4;
  }

  unsx_setZero(s, UNSX_LENGTH);
  for (i=0; i<inLen/2; i+=4) {
    s[UNSX_LENGTH-1-(i/4)] = ((in[k  ]&0xff)<<24)
                            |((in[k+1]&0xff)<<16)
                            |((in[k+2]&0xff)<< 8)
                            |((in[k+3]&0xff)    );
    k+=4;
  }

  memset(hashBuf, 0, UNSX_LENGTH * sizeof(unsx));

  k = 0;
  hashBuf[k++] = 0x00;
  hashBuf[k++] = 0x01;
  for (i=0; i<UNSX_LENGTH*sizeof(unsx)-3-hashLen; i++) {
    while (hashBuf[k] == 0)
      hashBuf[k] = 0xff;
    k++;
  }
  memcpy(&hashBuf[k+1], hash, hashLen);

  unsx_setZero(H, UNSX_LENGTH);
  for (i=0; i<UNSX_LENGTH*sizeof(unsx); i+=4) {
    H[UNSX_LENGTH-1-(i/4)] = ((hashBuf[i  ]&0xff)<<24)
                            |((hashBuf[i+1]&0xff)<<16)
                            |((hashBuf[i+2]&0xff)<< 8)
                            |((hashBuf[i+3]&0xff)    );
  }

  ret = unsx_modPow(v2, pub->g, s, pub->p, pub->len); /* v2 = g^s mod p     */
  if(ret < 0) {
    RETURN(ret);
  }
  ret = unsx_modPow(H, pub->ga, H, pub->p, pub->len); /* H = (g^a)^h mod p  */
  if(ret < 0) {
    RETURN(ret);
  }
  ret = unsx_modPow(v1, r, r, pub->p, pub->len);      /* v1 = r^r mod p     */
  if(ret < 0) {
    RETURN(ret);
  }
  ret = unsx_modMult(v1, H, v1, pub->p, pub->len);    /* v1                 */
  if(ret < 0) {                                       /* = g^ah * r^r mod p */
    RETURN(ret);
  }

  if (unsx_cmp(v1, v2, pub->len) != 0) {
    RETURN(PK_UNWRAP_FAILED);
  }

elgamal_verify_ret:
  if(r)
    UNSX_FREE(r, UNSX_LENGTH);
  if(s)
    UNSX_FREE(s, UNSX_LENGTH);
  if(H)
    UNSX_FREE(H, UNSX_LENGTH);
  if(v1)
    UNSX_FREE(v1, UNSX_LENGTH);
  if(v2)
    UNSX_FREE(v2, UNSX_LENGTH);
  if(hashBuf) {
    memset(hashBuf, 0, UNSX_LENGTH * sizeof(unsx));
    kfree(hashBuf); hashBuf = NULL;
  }
  return ret;
}
