/*
 * Copyright (c) 1991-2004 Kyoto University
 * Copyright (c) 2000-2004 NAIST
 * All rights reserved
 */

/* mkwhmm.c --- make word(phrase) HMM from HTK_HMM_INFO */

/* $Id: mkwhmm.c,v 1.6 2004/03/23 03:00:16 ri Exp $ */

#include <sent/stddefs.h>
#include <sent/hmm.h>

/* calculate state length length */
static int
totalstatelen(HMM_Logical **hdseq, int hdseqlen, boolean *has_sp, HTK_HMM_INFO *hmminfo)
{
  int i, len;

  len = 0;
  for (i=0;i<hdseqlen;i++) {
    len += hmm_logical_state_num(hdseq[i]) - 2;
    if (has_sp[i]) {
      if (hmminfo->sp == NULL) j_error("Error: no hmminfo->sp!!\n");
      len += hmm_logical_state_num(hmminfo->sp) - 2;
    }
  }
  return(len+2);
}

/* add arc with prob a to state */
static void
add_arc(HMM_STATE *state, int arc, LOGPROB a)
{
  A_CELL *atmp;

  atmp = (A_CELL *)mymalloc(sizeof(A_CELL));
  atmp->a = a;
  atmp->arc = arc;
  atmp->next = state->ac;
  state->ac = atmp;
}

/* make word(phrase) HMM from HTK_HMM_INFO */
/* LM prob will be assigned for cross-word arcs */
/* new HMM is malloced and returned */
HMM *
new_make_word_hmm_with_lm(HTK_HMM_INFO *hmminfo, HMM_Logical **hdseq, int hdseqlen, boolean *has_sp, LOGPROB *lscore)
{
  HMM *new;
  int i,j,n;
  int afrom, ato;
  LOGPROB logprob;
  HTK_HMM_Trans *tr;
  int state_num;

  new = (HMM *)mymalloc(sizeof(HMM));
  new->len = totalstatelen(hdseq, hdseqlen, has_sp, hmminfo);
  new->state = (HMM_STATE *)mymalloc(sizeof(HMM_STATE) * new->len);
  for (i=0;i<new->len;i++) {
    new->state[i].ac = NULL;
    new->state[i].is_pseudo_state = FALSE;
    new->state[i].out.state = NULL;
    new->state[i].out.cdset = NULL;
  }

  /* assign state outprob info  */
  n = 1;			/* skip first state */
  for (i = 0; i < hdseqlen; i++) {
    if (hdseq[i]->is_pseudo) {
      for (j = 1; j < hdseq[i]->body.pseudo->state_num - 1; j++) {
	new->state[n].is_pseudo_state = TRUE;
	new->state[n].out.cdset = &(hdseq[i]->body.pseudo->stateset[j]);
	n++;
      }
    } else {
      for (j = 1; j < hdseq[i]->body.defined->state_num - 1; j++) {
	new->state[n].is_pseudo_state = FALSE;
	new->state[n].out.state = hdseq[i]->body.defined->s[j];
	n++;
      }
    }
    if (has_sp[i]) {
      /* append sp at the end of the phone */
      if (hmminfo->sp->is_pseudo) {
	for (j = 1; j < hmm_logical_state_num(hmminfo->sp) - 1; j++) {
	  new->state[n].is_pseudo_state = TRUE;
	  new->state[n].out.cdset = &(hmminfo->sp->body.pseudo->stateset[j]);
	  n++;
	}
      } else {
	for (j = 1; j < hmm_logical_state_num(hmminfo->sp) - 1; j++) {
	  new->state[n].is_pseudo_state = FALSE;
	  new->state[n].out.state = hmminfo->sp->body.defined->s[j];
	  n++;
	}
      }
    }
  }
  
  /* make transition arcs */
  /* initial state check */
/* 
 *   for (i=0;i<hdseq[0]->def->state_num;i++) {
 *     if (i != 1 && (hdseq[0]->def->tr->a[0][i]) != LOG_ZERO) {
 *	 j_printerr("initial state contains more than 1 arc.\n");
 *     }
 *   }
 */
  {
    int *out_from, *out_from_next;
    LOGPROB *out_a, *out_a_next;
    int out_num_prev, out_num_next;
    out_from = (int *)mymalloc(sizeof(int) * new->len);
    out_from_next = (int *)mymalloc(sizeof(int) * new->len);
    out_a = (LOGPROB *)mymalloc(sizeof(LOGPROB) * new->len);
    out_a_next = (LOGPROB *)mymalloc(sizeof(LOGPROB) * new->len);

    n = 0;			/* n points to previous state */

    out_from[0] = 0;
    out_a[0] = 0.0;
    out_num_prev = 1;
    for (i = 0; i < hdseqlen; i++) {
      state_num = hmm_logical_state_num(hdseq[i]);
      tr = hmm_logical_trans(hdseq[i]);
      out_num_next = 0;
      /* arc from initial state */
      for (ato = 1; ato < state_num; ato++) {
	logprob = tr->a[0][ato];
	if (logprob != LOG_ZERO) {
	  /* expand arc */
	  if (ato == state_num-1) {
	    /* from initial to final ... register all previously registered arcs for next expansion */
	    if (lscore != NULL) logprob += lscore[i];
	    for(j=0;j<out_num_prev;j++) {
	      out_from_next[out_num_next] = out_from[j];
	      out_a_next[out_num_next] = out_a[j] + logprob;
	      out_num_next++;
	    }
	  } else {
	    for(j=0;j<out_num_prev;j++) {
	      add_arc(&(new->state[out_from[j]]), n + ato,
		      out_a[j] + logprob);
	    }
	  }
	}
      }
      /* arc from output state */
      for(afrom = 1; afrom < state_num - 1; afrom++) {
	for (ato = 1; ato < state_num; ato++) {
	  logprob = tr->a[afrom][ato];
	  if (logprob != LOG_ZERO) {
	    if (ato == state_num - 1) {
	      /* from output state to final ... register the arc for next expansion */
	      if (lscore != NULL) logprob += lscore[i];
	      out_from_next[out_num_next] = n+afrom;
	      out_a_next[out_num_next++] = logprob;
	    } else {
	      add_arc(&(new->state[n+afrom]), n + ato, logprob);
	    }
	  }
	}
      }
      n += state_num - 2;
      for(j=0;j<out_num_next;j++) {
	out_from[j] = out_from_next[j];
	out_a[j] = out_a_next[j];
      }
      out_num_prev = out_num_next;

      /* inter-word short pause handling */
      if (has_sp[i]) {
      
	out_num_next = 0;

	/* arc from initial state */
	for (ato = 1; ato < hmm_logical_state_num(hmminfo->sp); ato++) {
	  logprob = hmm_logical_trans(hmminfo->sp)->a[0][ato];
	  if (logprob != LOG_ZERO) {
	    /* to control short pause insertion, transition probability toward
	       the word-end short pause will be given a penalty */
	    logprob += hmminfo->iwsp_penalty;
	    /* expand arc */
	    if (ato == hmm_logical_state_num(hmminfo->sp)-1) {
	      /* from initial to final ... register all previously registered arcs for next expansion */
	      for(j=0;j<out_num_prev;j++) {
		out_from_next[out_num_next] = out_from[j];
		out_a_next[out_num_next] = out_a[j] + logprob;
		out_num_next++;
	      }
	    } else {
	      for(j=0;j<out_num_prev;j++) {
		add_arc(&(new->state[out_from[j]]), n + ato,
			out_a[j] + logprob);
	      }
	    }
	  }
	}
	/* if short pause model doesn't have a model skip transition, also add it */
	if (hmm_logical_trans(hmminfo->sp)->a[0][hmm_logical_state_num(hmminfo->sp)-1] == LOG_ZERO) {
	  /* to make insertion sp model to have no effect on the original path,
	     the skip transition probability should be 0.0 (=100%) */
	  logprob = 0.0;
	  for(j=0; j<out_num_prev; j++) {
	    out_from_next[out_num_next] = out_from[j];
	    out_a_next[out_num_next] = out_a[j] + logprob;
	    out_num_next++;
	  }
	}
	/* arc from output state */
	for(afrom = 1; afrom < hmm_logical_state_num(hmminfo->sp) - 1; afrom++) {
	  for (ato = 1; ato < hmm_logical_state_num(hmminfo->sp); ato++) {
	    logprob = hmm_logical_trans(hmminfo->sp)->a[afrom][ato];
	    if (logprob != LOG_ZERO) {
	      if (ato == hmm_logical_state_num(hmminfo->sp) - 1) {
		/* from output state to final ... register the arc for next expansion */
		out_from_next[out_num_next] = n+afrom;
		out_a_next[out_num_next++] = logprob;
	      } else {
		add_arc(&(new->state[n+afrom]), n + ato, logprob);
	      }
	    }
	  }
	}
	n += hmm_logical_state_num(hmminfo->sp) - 2;
	for(j=0;j<out_num_next;j++) {
	  out_from[j] = out_from_next[j];
	  out_a[j] = out_a_next[j];
	}
	out_num_prev = out_num_next;
      }
    }
      
    
    for(j=0;j<out_num_prev;j++) {
      add_arc(&(new->state[out_from[j]]), new->len-1, out_a[j]);
    }
    free(out_from);
    free(out_from_next);
    free(out_a);
    free(out_a_next);
  }

  return (new);
}

/* make word(phrase) HMM from HTK_HMM_INFO with no LM */
HMM *
new_make_word_hmm(HTK_HMM_INFO *hmminfo, HMM_Logical **hdseq, int hdseqlen, boolean *has_sp)
{
  return(new_make_word_hmm_with_lm(hmminfo, hdseq, hdseqlen, has_sp, NULL));
}

/* free HMM */
void
free_hmm(HMM *d)
{
  A_CELL *ac, *atmp;
  int i;

  for (i=0;i<d->len;i++) {
    ac = d->state[i].ac;
    while (ac) {
      atmp = ac->next;
      free(ac);
      ac = atmp;
    }
  }
  free(d->state);
  free(d);
}

