/*
 * Copyright (c) 1991-2003 Kyoto University
 * Copyright (c) 2000-2003 NAIST
 * All rights reserved
 */

/* vsegment.c --- generic routine to do viterbi segmentation */

/* $Id: vsegment.c,v 1.6 2003/09/29 06:01:23 ri Exp $ */

/* any segmentatino unit is allowed: segmentation point should be specified by sequence of HMM number to pickup */

#include <sent/stddefs.h>
#include <sent/htk_param.h>
#include <sent/hmm.h>

typedef struct __seg_token__ {
  int last_id;
  int last_end_frame;
  LOGPROB last_end_score;
  struct __seg_token__ *next;
} SEGTOKEN;

LOGPROB				/* returns the acoustic score */
viterbi_segment(HMM *hmm,	/* concatinated sentence HMMs */
	     HTK_Param *param,	/* input parameter vectors */
	     int *endstates,	/* where unit ends in the hmm */
	     int ulen,		/* total unit numin the hmm */
	     int **id_ret,       /* return sequence of result */
	     int **seg_ret,	/* return segmented frame number at each endstates */
	     LOGPROB **uscore_ret,/* return normalized scores for each unit */
	     int *slen_ret)
{
  /* for viterbi */
  LOGPROB *nodescore[2];	/* node buffer */
  SEGTOKEN **tokenp[2];		/* propagating token which holds segment info */
  int *from_node;
  int *u_end, *u_start;	/* the node is an end of the word, or -1 */
  int i, n, t;
  int tl,tn;
  LOGPROB tmpsum;
  A_CELL *ac;
  SEGTOKEN *newtoken, *token, *tmptoken;
  LOGPROB result_score;
  LOGPROB maxscore, minscore;	/* for debug */
  int maxnode;			/* for debug */
  int *id, *seg, slen;
  LOGPROB *uscore;

  /* assume more than 1 units */
  if (ulen < 1) {
    j_printerr("Error: viterbi_segment: no unit?\n");
    return LOG_ZERO;
  }

  /* initialize unit start/end marker */
  u_start = (int *)mymalloc(hmm->len * sizeof(int));
  u_end   = (int *)mymalloc(hmm->len * sizeof(int));
  for (n = 0; n < hmm->len; n++) {
    u_start[n] = -1;
    u_end[n] = -1;
  }
  u_start[0] = 0;
  u_end[endstates[0]] = 0;
  for (i=1;i<ulen;i++) {
    u_start[endstates[i-1]+1] = i;
    u_end[endstates[i]] = i;
  }
#if 0
  for (i=0;i<hmm->len;i++) {
    printf("unit %d: start=%d, end=%d\n", i, u_start[i], u_end[i]);
  }
#endif

  /* initialize node buffers */
  tn = 0;
  tl = 1;
  for (i=0;i<2;i++){
    nodescore[i] = (LOGPROB *)mymalloc(hmm->len * sizeof(LOGPROB));
    tokenp[i] = (SEGTOKEN **)mymalloc(hmm->len * sizeof(SEGTOKEN *));
  }
  for (n = 0; n < hmm->len; n++) {
    nodescore[tn][n] = LOG_ZERO;
    newtoken = (SEGTOKEN *)mymalloc(sizeof(SEGTOKEN));
    newtoken->last_id = -1;
    newtoken->last_end_frame = -1;
    newtoken->last_end_score = 0.0;
    newtoken->next = NULL;
    tokenp[tn][n] = newtoken;
  }
  from_node = (int *)mymalloc(sizeof(int) * hmm->len);
  
  /* first frame: only set initial score */
  /*if (hmm->state[0].is_pseudo_state) {
    j_printerr("Warning: state %d: pseudo state?\n", 0);
    }*/
  nodescore[tn][0] = outprob(0, &(hmm->state[0]), param);

  /* do viterbi for rest frame */
  for (t = 1; t < param->samplenum; t++) {
    i = tl;
    tl = tn;
    tn = i;
    maxscore = LOG_ZERO;
    minscore = 0.0;

    /* clear next scores */
    for (i=0;i<hmm->len;i++) {
      nodescore[tn][i] = LOG_ZERO;
      from_node[i] = -1;
    }

    /* select viterbi path for each node */
    for (n = 0; n < hmm->len; n++) {
      if (nodescore[tl][n] <= LOG_ZERO) continue;
      for (ac = hmm->state[n].ac; ac; ac = ac->next) {
        tmpsum = nodescore[tl][n] + ac->a;
        if (nodescore[tn][ac->arc] < tmpsum) {
          nodescore[tn][ac->arc] = tmpsum;
	  from_node[ac->arc] = n;
	}
      }
    }
    /* propagate token, appending new if path was selected between units */
    for (n = 0; n < hmm->len; n++) {
      if (from_node[n] == -1) {
	tokenp[tn][n] = NULL;
      } else if (nodescore[tn][n] <= LOG_ZERO) {
	tokenp[tn][n] = tokenp[tl][from_node[n]];
      } else {
	if (u_end[from_node[n]] != -1 && u_start[n] != -1
	    && from_node[n] !=  n) {
	  newtoken = (SEGTOKEN *)mymalloc(sizeof(SEGTOKEN));
	  newtoken->last_id = u_end[from_node[n]];
	  newtoken->last_end_frame = t-1;
	  newtoken->last_end_score = nodescore[tl][from_node[n]];
	  newtoken->next = tokenp[tl][from_node[n]];
	  tokenp[tn][n] = newtoken;
	} else {
	  tokenp[tn][n] = tokenp[tl][from_node[n]];
	}
      }
    }
    /* calc outprob to new nodes */
    for (n = 0; n < hmm->len; n++) {
      if (nodescore[tn][n] > LOG_ZERO) {
	if (hmm->state[n].is_pseudo_state) {
	  j_printerr("Warning: state %d: pseudo state?\n", n);
	}
	nodescore[tn][n] += outprob(t, &(hmm->state[n]), param);
      }
      if (nodescore[tn][n] > maxscore) { /* for debug */
	maxscore = nodescore[tn][n];
	maxnode = n;
      }
    }
    
#if 0
    for (i=0;i<ulen;i++) {
      printf("%d: unit %d(%d-%d): begin_frame = %d\n", t - 1, i,
	     (i > 0) ? endstates[i-1]+1 : 0, endstates[i],
	     tokenp[tl][endstates[i]]->last_end_frame + 1);
    }
#endif

    /* printf("t=%3d max=%f n=%d\n",t,maxscore, maxnode); */
    
  }

  result_score = nodescore[tn][hmm->len-1];

  /* parse back the last token to see the trail of best viterbi path */
  /* and store the informations to returning buffer */
  slen = 1;
  for(token = tokenp[tn][hmm->len-1]; token; token = token->next) {
    if (token->last_end_frame == -1) break;
    slen++;
  }
  id = (int *)mymalloc(sizeof(int)*slen);
  seg = (int *)mymalloc(sizeof(int)*slen);
  uscore = (LOGPROB *)mymalloc(sizeof(LOGPROB)*slen);

  id[slen-1] = ulen - 1;
  seg[slen-1] = t - 1;
  uscore[slen-1] = result_score;
  i = slen - 2;
  for(token = tokenp[tn][hmm->len-1]; token; token = token->next) {
    if (i < 0 || token->last_end_frame == -1) break;
    id[i] = token->last_id;
    seg[i] = token->last_end_frame;
    uscore[i] = token->last_end_score;
    i--;
  }

  /* normalize scores by frame */
  for (i=slen-1;i>0;i--) {
    uscore[i] = (uscore[i] - uscore[i-1]) / (seg[i] - seg[i-1]);
  }
  uscore[0] = uscore[0] / (seg[0] + 1);

  /* set return value */
  *id_ret = id;
  *seg_ret = seg;
  *uscore_ret = uscore;
  *slen_ret = slen;

  /* free memory */
  free(u_start);
  free(u_end);
  free(from_node);
  for (n = 0; n < hmm->len; n++) {
    token = tokenp[tn][n];
    while (token) {
      for (i = n + 1; i < hmm->len; i++) {
	if (tokenp[tn][i] == token) {
	  tokenp[tn][i] = NULL;
	} else {
	  tmptoken = tokenp[tn][i];
	  while (tmptoken) {
	    if (tmptoken->next == token) {
	      tmptoken->next = NULL;
	      break;
	    } else {
	      tmptoken = tmptoken->next;
	    }
	  }
	}
      }
      tmptoken = token->next;
      free(token);
      token = tmptoken;
    }
  }
  for (i=0;i<2;i++) {
    free(nodescore[i]);
    free(tokenp[i]);
  }

  return(result_score);

}
