/*
 * Copyright (c) 2003 NAIST
 * All rights reserved
 */

/* write_binhmm.c --- write HTK_HMM_INFO to disk in binary format */

/* $Id: write_binhmm.c,v 1.2 2003/12/08 03:44:38 ri Exp $ */

#include <sent/stddefs.h>
#include <sent/htk_param.h>
#include <sent/htk_hmm.h>


/* binary write function with byte swap (assume file is BIG ENDIAN) */
static void
wrt(FILE *fp, void *buf, size_t unitbyte, int unitnum)
{
#ifndef WORDS_BIGENDIAN
  if (unitbyte != 1) {
    swap_bytes((char *)buf, unitbyte, unitnum);
  }
#endif
  if (myfwrite(buf, unitbyte, unitnum, fp) < unitnum) {
    perror("write_binhmm: wrt");
    j_error("write failed\n");
  }
#ifndef WORDS_BIGENDIAN
  if (unitbyte != 1) {
    swap_bytes((char *)buf, unitbyte, unitnum);
  }
#endif
}

static void
wrt_str(FILE *fp, char *str)
{
  static char noname = '\0';
  
  if (str) {
    wrt(fp, str, sizeof(char), strlen(str)+1);
  } else {
    wrt(fp, &noname, sizeof(char), 1);
  }
}


/* write header */
static char *binhmm_header = BINHMM_HEADER;
static void
wt_header(FILE *fp)
{
  wrt_str(fp, binhmm_header);
}


/* write option data */
static void
wt_opt(FILE *fp, HTK_HMM_Options *opt)
{
  wrt(fp, &(opt->stream_info.num), sizeof(short), 1);
  wrt(fp, opt->stream_info.vsize, sizeof(short), 50);
  wrt(fp, &(opt->vec_size), sizeof(short), 1);
  wrt(fp, &(opt->cov_type), sizeof(short), 1);
  wrt(fp, &(opt->dur_type), sizeof(short), 1);
  wrt(fp, &(opt->param_type), sizeof(short), 1);
}

/* write type data */
static void
wt_type(FILE *fp, HTK_HMM_INFO *hmm)
{
  wrt(fp, &(hmm->is_tied_mixture), sizeof(boolean), 1);
  wrt(fp, &(hmm->maxmixturenum), sizeof(int), 1);
}


/* write transition data */
static HTK_HMM_Trans **tr_index;
static unsigned int tr_num;

static int
qsort_tr_index(HTK_HMM_Trans **t1, HTK_HMM_Trans **t2)
{
  if (*t1 > *t2) return 1;
  else if (*t1 < *t2) return -1;
  else return 0;
}

static void
wt_trans(FILE *fp, HTK_HMM_INFO *hmm)
{
  HTK_HMM_Trans *t;
  unsigned int idx;
  int i;

  tr_num = 0;
  for(t = hmm->trstart; t; t = t->next) tr_num++;
  tr_index = (HTK_HMM_Trans **)mymalloc(sizeof(HTK_HMM_Trans *) * tr_num);
  idx = 0;
  for(t = hmm->trstart; t; t = t->next) tr_index[idx++] = t;
  qsort(tr_index, tr_num, sizeof(HTK_HMM_Trans *), (int (*)(const void *, const void *))qsort_tr_index);
  
  wrt(fp, &tr_num, sizeof(unsigned int), 1);
  for (idx = 0; idx < tr_num; idx++) {
    t = tr_index[idx];
    wrt_str(fp, t->name);
    wrt(fp, &(t->statenum), sizeof(short), 1);
    for(i=0;i<t->statenum;i++) {
      wrt(fp, t->a[i], sizeof(PROB), t->statenum);
    }
  }

  j_printf("%d transition maxtix written\n", tr_num);
}

static unsigned int
search_trid(HTK_HMM_Trans *t)
{
  unsigned int left = 0;
  unsigned int right = tr_num - 1;
  unsigned int mid;

  while (left < right) {
    mid = (left + right) / 2;
    if (tr_index[mid] < t) {
      left = mid + 1;
    } else {
      right = mid;
    }
  }
  return(left);
}


/* write variance data */
static HTK_HMM_Var **vr_index;
static unsigned int vr_num;

static int
qsort_vr_index(HTK_HMM_Var **v1, HTK_HMM_Var **v2)
{
  if (*v1 > *v2) return 1;
  else if (*v1 < *v2) return -1;
  else return 0;
}

static void
wt_var(FILE *fp, HTK_HMM_INFO *hmm)
{
  HTK_HMM_Var *v;
  unsigned int idx;

  vr_num = 0;
  for(v = hmm->vrstart; v; v = v->next) vr_num++;
  vr_index = (HTK_HMM_Var **)mymalloc(sizeof(HTK_HMM_Var *) * vr_num);
  idx = 0;
  for(v = hmm->vrstart; v; v = v->next) vr_index[idx++] = v;
  qsort(vr_index, vr_num, sizeof(HTK_HMM_Var *), (int (*)(const void *, const void *))qsort_vr_index);  

  wrt(fp, &vr_num, sizeof(unsigned int), 1);
  for (idx = 0; idx < vr_num; idx++) {
    v = vr_index[idx];
    wrt_str(fp, v->name);
    wrt(fp, &(v->len), sizeof(short), 1);
    wrt(fp, v->vec, sizeof(VECT), v->len);
  }
  j_printf("%d variance written\n", vr_num);
}

static unsigned int
search_vid(HTK_HMM_Var *v)
{
  unsigned int left = 0;
  unsigned int right = vr_num - 1;
  unsigned int mid;

  while (left < right) {
    mid = (left + right) / 2;
    if (vr_index[mid] < v) {
      left = mid + 1;
    } else {
      right = mid;
    }
  }
  return(left);
}


/* write density data */
static HTK_HMM_Dens **dens_index;
static unsigned int dens_num;

static int
qsort_dens_index(HTK_HMM_Dens **d1, HTK_HMM_Dens **d2)
{
  if (*d1 > *d2) return 1;
  else if (*d1 < *d2) return -1;
  else return 0;
}

static void
wt_dens(FILE *fp, HTK_HMM_INFO *hmm)
{
  HTK_HMM_Dens *d;
  unsigned int idx;
  unsigned int vid;

  dens_num = hmm->totalmixnum;
  dens_index = (HTK_HMM_Dens **)mymalloc(sizeof(HTK_HMM_Dens *) * dens_num);
  idx = 0;
  for(d = hmm->dnstart; d; d = d->next) dens_index[idx++] = d;
  qsort(dens_index, dens_num, sizeof(HTK_HMM_Dens *), (int (*)(const void *, const void *))qsort_dens_index);
  
  wrt(fp, &dens_num, sizeof(unsigned int), 1);
  for (idx = 0; idx < dens_num; idx++) {
    d = dens_index[idx];
    wrt_str(fp, d->name);
    wrt(fp, &(d->meanlen), sizeof(short), 1);
    wrt(fp, d->mean, sizeof(VECT), d->meanlen);
    vid = search_vid(d->var);
    /* for debug */
    if (d->var != vr_index[vid]) j_error("index not match!!! dens\n");
    wrt(fp, &vid, sizeof(unsigned int), 1);
    wrt(fp, &(d->gconst), sizeof(LOGPROB), 1);
  }
  j_printf("%d gaussian densities written\n", dens_num);
}

static unsigned int
search_did(HTK_HMM_Dens *d)
{
  unsigned int left = 0;
  unsigned int right = dens_num - 1;
  unsigned int mid;

  while (left < right) {
    mid = (left + right) / 2;
    if (dens_index[mid] < d) {
      left = mid + 1;
    } else {
      right = mid;
    }
  }
  return(left);
}


/* write tmix data */
static GCODEBOOK **tm_index;
static unsigned int tm_num;
static unsigned int tm_idx;

static void
tmix_list_callback(void *p)
{
  GCODEBOOK *tm;
  tm = p;
  tm_index[tm_idx++] = tm;
}

static int
qsort_tm_index(GCODEBOOK **tm1, GCODEBOOK **tm2)
{
  if (*tm1 > *tm2) return 1;
  else if (*tm1 < *tm2) return -1;
  else return 0;
}

static void
wt_tmix(FILE *fp, HTK_HMM_INFO *hmm)
{
  GCODEBOOK *tm;
  unsigned int idx;
  unsigned int did;
  int i;

  tm_num = hmm->codebooknum;
  tm_index = (GCODEBOOK **)mymalloc(sizeof(GCODEBOOK *) * tm_num);
  tm_idx = 0;
  aptree_traverse_and_do(hmm->codebook_root, tmix_list_callback);
  qsort(tm_index, tm_num, sizeof(GCODEBOOK *), (int (*)(const void *, const void *))qsort_tm_index);  

  wrt(fp, &tm_num, sizeof(unsigned int), 1);
  for (idx = 0; idx < tm_num; idx++) {
    tm = tm_index[idx];
    wrt_str(fp, tm->name);
    wrt(fp, &(tm->num), sizeof(int), 1);
    for(i=0;i<tm->num;i++) {
      if (tm->d[i] == NULL) {
	did = dens_num;
      } else {
	did = search_did(tm->d[i]);
	/* for debug */
	if (tm->d[i] != dens_index[did]) j_error("index not match!!! dens\n");
      }
      wrt(fp, &did, sizeof(unsigned int), 1);
    }
  }
  j_printf("%d tied-mixture codebooks written\n", tm_num);
}

static unsigned int
search_tmid(GCODEBOOK *tm)
{
  unsigned int left = 0;
  unsigned int right = tm_num - 1;
  unsigned int mid;

  while (left < right) {
    mid = (left + right) / 2;
    if (tm_index[mid] < tm) {
      left = mid + 1;
    } else {
      right = mid;
    }
  }
  return(left);
}


/* write state data */
static HTK_HMM_State **st_index;
static unsigned int st_num;

static int
qsort_st_index(HTK_HMM_State **s1, HTK_HMM_State **s2)
{
  if (*s1 > *s2) return 1;
  else if (*s1 < *s2) return -1;
  else return 0;
}

static void
wt_state(FILE *fp, HTK_HMM_INFO *hmm)
{
  HTK_HMM_State *s;
  unsigned int idx;
  unsigned int did;
  int i;
  short dummy;

  st_num = hmm->totalstatenum;
  st_index = (HTK_HMM_State **)mymalloc(sizeof(HTK_HMM_State *) * st_num);
  idx = 0;
  for(s = hmm->ststart; s; s = s->next) st_index[idx++] = s;
  qsort(st_index, st_num, sizeof(HTK_HMM_State *), (int (*)(const void *, const void *))qsort_st_index);
  
  wrt(fp, &st_num, sizeof(unsigned int), 1);
  for (idx = 0; idx < st_num; idx++) {
    s = st_index[idx];
    wrt_str(fp, s->name);
    if (hmm->is_tied_mixture) {
      /* try tmix */
      did = search_tmid((GCODEBOOK *)(s->b));
      if ((GCODEBOOK *)s->b == tm_index[did]) {
	/* tmix */
	dummy = -1;
	wrt(fp, &dummy, sizeof(short), 1);
	wrt(fp, &did, sizeof(unsigned int), 1);
      } else {
	/* tmix failed -> normal mixture */
	wrt(fp, &(s->mix_num), sizeof(short), 1);
	for (i=0;i<s->mix_num;i++) {
	  if (s->b[i] == NULL) {
	    did = dens_num;
	  } else {
	    did = search_did(s->b[i]);
	    if (s->b[i] != dens_index[did]) {
	      j_error("index not match!!!");
	    }
	  }
	  wrt(fp, &did, sizeof(unsigned int), 1);
	}
      }
    } else {			/* not tied mixture */
      wrt(fp, &(s->mix_num), sizeof(short), 1);
      for (i=0;i<s->mix_num;i++) {
	if (s->b[i] == NULL) {
	  did = dens_num;
	} else {
	  did = search_did(s->b[i]);
	  if (s->b[i] != dens_index[did]) {
	    j_error("index not match!!!");
	  }
	}
	wrt(fp, &did, sizeof(unsigned int), 1);
      }
    }
    wrt(fp, s->bweight, sizeof(PROB), s->mix_num);
  }
  j_printf("%d states written\n", st_num);
}

static unsigned int
search_stid(HTK_HMM_State *s)
{
  unsigned int left = 0;
  unsigned int right = st_num - 1;
  unsigned int mid;

  while (left < right) {
    mid = (left + right) / 2;
    if (st_index[mid] < s) {
      left = mid + 1;
    } else {
      right = mid;
    }
  }
  return(left);
}


/* write toplevel model data */
static void
wt_data(FILE *fp, HTK_HMM_INFO *hmm)
{
  HTK_HMM_Data *d;
  unsigned int md_num;
  unsigned int sid, tid;
  int i;

  md_num = hmm->totalhmmnum;

  wrt(fp, &(md_num), sizeof(unsigned int), 1);
  for(d = hmm->start; d; d = d->next) {
    wrt_str(fp, d->name);
    wrt(fp, &(d->state_num), sizeof(short), 1);
    for (i=0;i<d->state_num;i++) {
      if (d->s[i] != NULL) {
	sid = search_stid(d->s[i]);
	/* for debug */
	if (d->s[i] != st_index[sid]) j_error("index not match!!! data state\n");
      } else {
	sid = hmm->totalstatenum + 1; /* error value */
      }
      wrt(fp, &sid, sizeof(unsigned int), 1);
    }
    tid = search_trid(d->tr);
    /* for debug */
    if (d->tr != tr_index[tid]) j_error("index not match!!! data trans\n");
    wrt(fp, &tid, sizeof(unsigned int), 1);
  }
  j_printf("%d HMM model definition written\n", md_num);
}



boolean
write_binhmm(FILE *fp, HTK_HMM_INFO *hmm)
{
  /* write header */
  wt_header(fp);
  
  /* write option data */
  wt_opt(fp, &(hmm->opt));

  /* write type data */
  wt_type(fp, hmm);

  /* write transition data */
  wt_trans(fp, hmm);

  /* write variance data */
  wt_var(fp, hmm);

  /* write density data */
  wt_dens(fp, hmm);

  /* write tmix data */
  if (hmm->is_tied_mixture) {
    wt_tmix(fp, hmm);
  }

  /* write state data */
  wt_state(fp, hmm);

  /* write model data */
  wt_data(fp, hmm);

  /* free pointer->index work area */
  free(tr_index);
  free(vr_index);
  free(dens_index);
  if (hmm->is_tied_mixture) free(tm_index);
  free(st_index);

  return (TRUE);
}
