/* Copyright (c) 1991-2002 Doshita Lab. Speech Group, Kyoto University */
/* Copyright (c) 2000-2002 Speech and Acoustics Processing Lab., NAIST */
/*   All rights reserved   */

/* gprune_beam.c --- calculate probability of Gaussian densities */
/*                   with Gaussian pruning (beam) */

/* $Id: gprune_beam.c,v 1.2 2002/09/11 22:01:50 ri Exp $ */

#include <sent/stddefs.h>
#include <sent/htk_hmm.h>
#include <sent/htk_param.h>
#include <sent/hmm.h>
#include <sent/gprune.h>
#include "globalvars.h"

/*

  best_mixtures_on_last_frame[]
  
  dim:  0 1 2 3 4 .... veclen-1    -> sum up
 ================================
  thres
 --------------------------------
  mix1  | |
  mix2  | |
  mix3  v v
  ...
  mixN  
 ================================
         \_\_ vecprob[0],vecprob[1]

  algorithm 1:
	 
  foreach dim {
     foreach all_mixtures in best_mixtures_on_last_frame {
        compute score
     }
     threshold = the current lowest score + beam_width?
     foreach rest_mixtures {
        if (already marked as pruned at previous dim) {
	   skip
	}
	compute score
        if (score < threshold) {
	   mark as pruned
	   skip
	}
	if (score > threshold) {
	   update threshold
	}
     }
  }

  algorithm 2:

  foreach all_mixtures in best_mixtures_on_last_frame {
     foreach dim {
       compute score
       if (threshold[dim] < score) update
     }
     threshold[dim] += beam_width
  }
  foreach rest_mixtures {
     foreach dim {
        compute score
	if (score < threshold[dim]) skip this mixture
	update thres
     }
  }
     
*/

/* pruning threshold */
static LOGPROB *dimthres;	/* threshold for each dimension (inversed) */
static int dimthres_num;	/* veclen */

static boolean *mixcalced;	/* mark which Gaussian has been computed */

/* clear dimthres */
static void
clear_dimthres()
{
  int i;
  for(i=0;i<dimthres_num;i++) dimthres[i] = 0.0;
}

/* set beam dimthres */
static void
set_dimthres()
{
  int i;
  for(i=0;i<dimthres_num;i++) dimthres[i] += TMBEAMWIDTH;
}

/* calculate probability while setting max values to dimthres */
static LOGPROB
compute_g_beam_updating(HTK_HMM_Dens *binfo)
{
  VECT tmp, x;
  VECT *mean;
  VECT *var;
  VECT *th = dimthres;
  VECT *vec = OP_vec;
  short veclen = OP_veclen;

  if (binfo == NULL) return(LOG_ZERO);
  mean = binfo->mean;
  var = binfo->var->vec;

  tmp = 0.0;
  for (; veclen > 0; veclen--) {
    x = *(vec++) - *(mean++);
    tmp += x * x / *(var++);
    if ( *th < tmp) *th = tmp;
    th++;
  }
  return((tmp + binfo->gconst) / -2.0);
}
/* calculate probability with pruning by dimthres thresholds */
static LOGPROB
compute_g_beam_pruning(HTK_HMM_Dens *binfo)
{
  VECT tmp, x;
  VECT *mean;
  VECT *var;
  VECT *th = dimthres;
  VECT *vec = OP_vec;
  short veclen = OP_veclen;

  if (binfo == NULL) return(LOG_ZERO);
  mean = binfo->mean;
  var = binfo->var->vec;

  tmp = 0.0;
  for (; veclen > 0; veclen--) {
    x = *(vec++) - *(mean++);
    tmp += x * x / *(var++);
    if ( tmp > *(th++)) {
      return LOG_ZERO;
    }
  }
  return((tmp + binfo->gconst) / -2.0);
}


/* init */
boolean
gprune_beam_init()
{
  int i;
  /* maximum Gaussian set size = maximum mixture size */
  OP_calced_maxnum = OP_hmminfo->maxmixturenum;
  OP_calced_score = (LOGPROB *)mymalloc(sizeof(LOGPROB) * OP_gprune_num);
  OP_calced_id = (int *)mymalloc(sizeof(int) * OP_gprune_num);
  mixcalced = (boolean *)mymalloc(sizeof(int) * OP_calced_maxnum);
  for(i=0;i<OP_calced_maxnum;i++) mixcalced[i] = FALSE;
  dimthres_num = OP_hmminfo->opt.vec_size;
  dimthres = (LOGPROB *)mymalloc(sizeof(LOGPROB) * dimthres_num);

  return TRUE;
}

/* compute a set of Gaussians with safe pruning */
void
gprune_beam(HTK_HMM_Dens **g, int gnum, int *last_id)
{
  int i, j, num = 0;
  LOGPROB score, thres;

  if (last_id != NULL) {	/* compute them first to form thresholds */
    /* 1. clear dimthres */
    clear_dimthres();
    /* 2. calculate first $OP_gprune_num and set initial thresholds */
    for (j=0; j<OP_gprune_num; j++) {
      i = last_id[j];
      score = compute_g_beam_updating(g[i]);
      num = cache_push(i, score, num);
      mixcalced[i] = TRUE;      /* mark them as calculated */
    }
    /* 3. set pruning thresholds for each dimension */
    set_dimthres();

    /* 4. calculate the rest with pruning*/
    for (i = 0; i < gnum; i++) {
      /* skip calced ones in 1. */
      if (mixcalced[i]) {
        mixcalced[i] = FALSE;
        continue;
      }
      /* compute with safe pruning */
      score = compute_g_beam_pruning(g[i]);
      if (score > LOG_ZERO) {
	num = cache_push(i, score, num);
      }
    }
  } else {			/* in case the last_id not available */
    /* at the first 0 frame */
    /* calculate with safe pruning */
    thres = LOG_ZERO;
    for (i = 0; i < gnum; i++) {
      if (num < OP_gprune_num) {
	score = compute_g_base(g[i]);
      } else {
	score = compute_g_safe(g[i], thres);
	if (score <= thres) continue;
      }
      num = cache_push(i, score, num);
      thres = OP_calced_score[num-1];
    }
  }
  OP_calced_num = num;
}
