/*
 * Copyright (c) 2003-2005 RIKEN Japan, All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY RIKEN AND CONTRIBUTORS ``AS IS'' AND ANY
 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL RIKEN OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGE.
 */

/* $SATELLITE: satellite4/modules/bps/command/wgtload.cpp,v 1.8 2005/02/22 07:40:22 ninja Exp $ */

#ifdef HAVE_CONFIG_H
# include "config.h"
#endif
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "libbps.h"

/************************************************
 *						*
 *	Back Propergation Simulator(BPS)	*
 *	      subroutine package		*
 *	  coded 	in Sep.1  1990		*
 *	  last modified	in Nov.15 1990		*
 *	  coded by	K.Kuroda		*
 *						*
 *************************************************
 *						*
 *	filename wgtload.c			*
 *	   BPS WGTLOAD command			*
 *						*
 ************************************************/

#ifdef __cplusplus
extern "C" {
#endif

int      lay_num;
int      wgt_number[HBUFSIZE];

/************************************************
  bias check
  append : 0
  no append : 1
  ************************************************/
int
isbias_r(int laynum)
{
  if (CheckBias(laynum))
    return (0);
  else
    return (1);
}


/************************************************
  count weight number
  ************************************************/
int
wnum1(int unum)
{
  int   i, j, k, wgt_cnt;

  wgt_cnt = 0;

  for (i = 1; i < bps_cont.NumOfLayer; i++) {
    for (j = 1; j <= bps_cont.NumOfCell[i]; j++) {
      if ((lay_num == (i - 1)) && (unum == j))
	return (wgt_cnt);
      if (CheckBias(i))
	wgt_cnt++;
      for (k = 1; k <= bps_cont.NumOfCell[i - 1]; k++)
	wgt_cnt++;
    }
  }
  return 0;
}


/************************************************
  count weight number
  ************************************************/
void
wnum2(int unum)
{
  int   i, j, k, l, wgt_cnt;

  l = 0;
  wgt_cnt = 0;

  for (i = 1; i < bps_cont.NumOfLayer; i++) {
    for (j = 1; j <= bps_cont.NumOfCell[i]; j++) {
      if ((lay_num == (i - 1)) && (unum == 0))
	wgt_number[l++] = wgt_cnt;
      if (CheckBias(i))
	wgt_cnt++;
      for (k = 1; k <= bps_cont.NumOfCell[i - 1]; k++) {
	if ((lay_num == (i - 1)) && (unum == k))
	  wgt_number[l++] = wgt_cnt;
	wgt_cnt++;
      }
    }
  }
}


/************************************************
  count weight number
  ************************************************/
int
wnum3(int unum1, int unum2)
{
  int             i, j, k, wgt_cnt;

  wgt_cnt = 0;

  for (i = 1; i < bps_cont.NumOfLayer; i++) {
    for (j = 1; j <= bps_cont.NumOfCell[i]; j++) {
      if ((lay_num == (i - 1)) && (unum1 == 0) && (unum2 == j))
	return (wgt_cnt);
      if (CheckBias(i))
	wgt_cnt++;
      for (k = 1; k <= bps_cont.NumOfCell[i - 1]; k++) {
	if ((lay_num == (i - 1)) && (unum1 == k) && (unum2 == j))
	  return (wgt_cnt);
	wgt_cnt++;
      }
    }
  }
  return 0;
}


/************************************************
  load data from weight file
  ************************************************/
DLLEXPORT int mod_bps_wgtload()
{
  int     nargs;
  int     idx[10];
  int     p_int6;
  int     i, j, k, iterat_int, unum1, unum2, wgt_num;
  int     bias_flg1, bias_flg2;
  char    iterat_char, unit_num1, unit_num2;
	char    *str1, *str2, *str3;
  float   *dum;
  Buffer  *Data;
  Header  header;

  rebps();

  GetStructureParameters();
  GetLearningParameters();
  SetNumOfLink();

  /* Get Arguments from SATELLITE Language */
  nargs       =      GetArgNum();

	str1        =      GetString(0);
  if(str1 == NULL)
		return 2; /* illegal parameter */
  iterat_char =     *str1;

  iterat_int  = (int)GetScalar(0);
  lay_num     = (int)GetScalar(1);

	str2        =      GetString(2);
  if(str2 == NULL)
		return 2; /* illegal parameter */
  unit_num1   =     *str2;

  unum1       = (int)GetScalar(2);

	str3        =      GetString(3);
  if(str3 == NULL)
		return 2; /* illegal parameter */
  unit_num2   =     *str3;

  unum2       = (int)GetScalar(3);
  strcpy(bps_cont.bps_sp.FileName, GetString(4));
  p_int6      = (int)GetScalar(5);

  if (nargs == 6) {
    if (p_int6 == 1) {
      bias_flg1 = 1;
      bias_flg2 = 0;
    } else
    return 133; /* Unknown Flag */
  } else {
    bias_flg1 = 0;
    bias_flg2 = 1;
  }
  if ((lay_num < -1) || (lay_num > (bps_cont.NumOfLayer - 1)))
    return 134; /* Illigal Layer No. */

  /*    Data = buff_alloc();  */
  /*    block_size();         */

  /*    syscom.buff_leng = HBUFSIZE;  */

  /* may be does not require. take */
  /*  write_syscom(); */

  switch (bps_cont.WgtStorMode) {
  case BPS_STOREMODE_APPEND:
    switch (iterat_char) {
    case 'x':
    case 'X':
      if ((unit_num1 == 'y') || (unit_num1 == 'Y')) {
	      if ((unum2 < 1) || (unum2 > bps_cont.NumOfCell[lay_num + 1]))
	        return 130; /* Illigal Unit No. */

	      if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
		      return 11; /* Illigal File No. */

	      idx[0] = bps_cont.NumOfCell[lay_num] + CheckBias(lay_num + 1) * bias_flg1;
	      idx[1] = header.index[0] - 1;

	      wgt_num = wnum1(unum2) + CheckBias(lay_num + 1) * bias_flg2;
	      message_2d_before(idx);

	      Data = AllocBuffer(IndexSize(2, idx));
	      for (i = 0; i < idx[1]; i++) {
	        dum = (float*)LoadData(bps_cont.bps_sp.FileName, i, &header);
	        for (j = 0; j < idx[0]; j++) {
	          *(Data+j*idx[1]+i) = (double)dum[j + wgt_num];
	        }
	        FreeData(dum);
	        message_2d_after(idx[0] - 1 , "WEIGHT");
	      }
	      ReturnSnapshot(Data, 2, idx);
	      efree(Data);
	      printf("\n");

      } else {	/* iterate = X, APPEND, unix_num1 != Y  */
	      if ((unum1 < isbias_r(lay_num + 1)) || (unum1 > bps_cont.NumOfCell[lay_num]))
	        return 130; /* Illigal Unit No. */
	      if ((unit_num2 == 'y') || (unit_num2 == 'Y')) {
      	  if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
			      return 11; /* Illigal File No. */

	        idx[0] = bps_cont.NumOfCell[lay_num + 1];
	        idx[1] = header.index[0] - 1;
	        wnum2(unum1);

	        message_2d_before(idx);

	        Data = AllocBuffer(IndexSize(2, idx));
	        for (i = 0; i < idx[1]; i++) {
	          dum = (float*)LoadData(bps_cont.bps_sp.FileName, i, &header);
	          for (j = 0; j < idx[0]; j++){
	            *(Data+j*idx[1]+i) = (double)dum[wgt_number[j]];
	          }
	          FreeData(dum);
	        }
	        message_2d_after(idx[1] - 1 , "WEIGHT");
	        ReturnSnapshot(Data, 2, idx);
	        efree(Data);
	        printf("\n");

      	} else {/* iterate = X, APPEND, unit_num1 != Y,  unit_num2 != Y */
	        if ((unum2 < 1) || (unum2 > bps_cont.NumOfCell[lay_num + 1]))
	          return 130; /* Illigal Unit No. */
	        if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
			      return 11; /* Illigal File No. */

	        wgt_num = wnum3(unum1, unum2);
	        idx[0] = header.index[0] - 1;
	        Data = AllocBuffer(IndexSize(1, idx));
	        for (i = 0; i < idx[0]; i++) {
	          dum = (float*)LoadData(bps_cont.bps_sp.FileName, i, &header);

	          Data[i] = (double)dum[wgt_num];
	          FreeData(dum);
	        }
	        ReturnSnapshot(Data, 1, idx);
	        efree(Data);
	        message_1d_after("WEIGHT");
	      }
      }
      break;

    case 'y':
    case 'Y':
      if ((unit_num1 == 'x') || (unit_num1 == 'X')) {
	      if ((unum2 < 1) || (unum2 > bps_cont.NumOfCell[lay_num + 1]))
	        return 130; /* Illigal Unit No. */
	      if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
					return 11; /* Illigal File No. */

	      idx[0] = header.index[0] - 1;
	      idx[1] = bps_cont.NumOfCell[lay_num] + CheckBias(lay_num + 1) * bias_flg1;
	      wgt_num = wnum1(unum2) + CheckBias(lay_num + 1) * bias_flg2;

	      message_2d_before(idx);

	      Data = AllocBuffer(IndexSize(2, idx));
	      for (i = 0; i < idx[0]; i++) {
	        dum = (float*)LoadData(bps_cont.bps_sp.FileName, i, &header);
	        for (j = 0; j < idx[1]; j++)
	          *(Data+j+i*idx[1]) = (double)dum[j + wgt_num];
	        FreeData(dum);
	        message_2d_after(i, "WEIGHT");
	      }
	      ReturnSnapshot(Data, 2, idx);
	      efree(Data);
	      printf("\n");

      } else {
	      if ((unum1 < isbias_r(lay_num + 1)) || (unum1 > bps_cont.NumOfCell[lay_num]))
	        return 130; /* Illigal Unit No. */
	      if ((unit_num2 != 'x') && (unit_num2 != 'X'))
	        return 130; /* Illigal Unit No. */
	      if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
					return 11; /* Illigal File No. */

	      idx[0] = header.index[0] - 1;
	      idx[1] = bps_cont.NumOfCell[lay_num + 1];
	      wnum2(unum1);

	      Data = AllocBuffer(IndexSize(2, idx));
	      message_2d_before(idx);
	      for (i = 0; i < idx[0]; i++) {
	        dum = (float*)LoadData(bps_cont.bps_sp.FileName, i, &header);
	        for (j = 0; j < idx[1]; j++)
	          *(Data+j+i*idx[1]) = (double)dum[wgt_number[j]];
	        FreeData(dum);
	        message_2d_after(i, "WEIGHT");
	      }
	      ReturnSnapshot(Data, 2, idx);
	      efree(Data);
	      printf("\n");
      }
      break;

    default:			/* iterat != Y/N , APPEND */
      if (iterat_int < 0)
	      return 131; /* Illigal History No. */
      switch (unit_num1) {
      case 'x':
      case 'X':
	      if ((unit_num2 == 'y') || (unit_num2 == 'Y')) {

	        if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
						return 11; /* Illigal File No. */
	        if (iterat_int == 0)
	          iterat_int = header.index[0] - 1;
	        if (iterat_int > header.index[0])
	          return 131; /* Illigal History No. */

	        idx[0] = bps_cont.NumOfCell[lay_num + 1];
	        idx[1] = bps_cont.NumOfCell[lay_num] + CheckBias(lay_num + 1) * bias_flg1;
	        wgt_num = wnum1(1) + CheckBias(lay_num + 1) * bias_flg2;

	        message_2d_before(idx);

	        dum = (float*)LoadData(bps_cont.bps_sp.FileName, iterat_int - 1, &header);
	        Data = AllocBuffer(IndexSize(2, idx));
	        k = wgt_num;
		    
	        for (i = 0; i < idx[0]; i++) {
	          for (j = 0; j < idx[1]; j++)
	            *(Data+j+i*idx[1]) = (double)dum[k++];
	          k += CheckBias(lay_num + 1) * bias_flg2;
	          message_2d_after(i, "WEIGHT");
	          FreeData(dum);
	        }
	        ReturnSnapshot(Data, 2, idx);
	        efree(Data);
	        printf("\n");

	      } else {		/* iterat != Y/N , APPEND , unit_num1 = X  */
	        if ((unum2 < 1) || (unum2 > bps_cont.NumOfCell[lay_num + 1]))
	          return 130; /* Illigal Unit No. */
	        if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
						return 11; /* Illigal File No. */
	        if (iterat_int == 0)
	          iterat_int = header.index[0] - 1;
	        if (iterat_int > header.index[0])
	          return 131; /* Illigal History No. */

	        wgt_num = wnum1(unum2) + CheckBias(lay_num + 1) * bias_flg2;
		    
	        idx[0] = bps_cont.NumOfCell[lay_num] + CheckBias(lay_num + 1) * bias_flg1;
	        dum = (float*)LoadData(bps_cont.bps_sp.FileName, iterat_int - 1, &header);
	        Data = AllocBuffer(IndexSize(1, idx));
		    
	        for (i = 0; i < idx[0]; i++) {
	          Data[i] = (double)dum[i + wgt_num]; 
	        }
	        message_1d_after("WEIGHT");
	        ReturnSnapshot(Data, 1, idx);
	        FreeData(dum);
	        efree(Data);
	      }
	      break;

      case 'y':			/* iterat != Y/N , unit_num1 == Y , APPEND */
      case 'Y':
	      if ((unit_num2 != 'x') && (unit_num2 != 'X'))
	        return 130; /* Illigal Unit No. */
	      if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
					return 11; /* Illigal File No. */
	      if (iterat_int == 0)
	        iterat_int = header.index[0] - 1;
	      if (iterat_int > header.index[0])
	        return 131; /* Illigal History No. */

      	idx[0] = bps_cont.NumOfCell[lay_num] + CheckBias(lay_num + 1) * bias_flg1;
	      idx[1] = bps_cont.NumOfCell[lay_num + 1];
	      wgt_num = wnum1(1) + CheckBias(lay_num + 1) * bias_flg2;

      	message_2d_before(idx);

      	Data = AllocBuffer(IndexSize(2, idx));
	      k = wgt_num;
	      dum = (float*)LoadData(bps_cont.bps_sp.FileName, iterat_int - 1, &header);
	      for (i = 0; i < idx[1]; i++) {
	        for (j = 0; j < idx[0]; j++)
	          *(Data+j*idx[1]+i) = dum[k++];
	        k += CheckBias(lay_num + 1) * bias_flg2;
	      }
	      message_2d_after(idx[0] - 1, "WEIGHT");
	      ReturnSnapshot(Data, 2, idx);
	      FreeData(dum);
	      efree(Data);
	      printf("\n");
	      break;

      default:			/* iterat != Y/N , unit_num1 != X/Y , APPEND */
	      if ((unum1 < isbias_r(lay_num + 1)) || (unum1 > bps_cont.NumOfCell[lay_num]))
	        return 130; /* Illigal Unit No. */
	      if ((unit_num2 == 'x') || (unit_num2 == 'X')) {
	        if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
						return 11; /* Illigal File No. */
	        if (iterat_int == 0)
	          iterat_int = header.index[0] - 1;
	        if (iterat_int > header.index[0])
	          return 131; /* Illigal History No. */

	        idx[0] = bps_cont.NumOfCell[lay_num + 1];
	        wnum2(unum1);

	        Data = AllocBuffer(IndexSize(1, idx));
	        dum = (float*)LoadData(bps_cont.bps_sp.FileName, iterat_int - 1, &header);

	        for (i = 0; i < idx[0]; i++)
	          Data[i] = (double)dum[wgt_number[i]]; 

	        message_1d_after("WEIGHT");
	        ReturnSnapshot(Data, 1, idx);
	        efree(Data);
	        FreeData(dum);

	      } else {
	        if ((unum2 < 1) || (unum2 > bps_cont.NumOfCell[lay_num + 1]))
	          return 130; /* Illigal Unit No. */
	        if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
						return 11; /* Illigal File No. */
	        if (iterat_int == 0)
	          iterat_int = header.index[0] - 1;
	        if (iterat_int > header.index[0])
	          return 131; /* Illigal History No. */

	        wgt_num = wnum3(unum1, unum2);

	        idx[0] = 1;
	        dum = (float*)LoadData(bps_cont.bps_sp.FileName, iterat_int - 1, &header);
	        Data = AllocBuffer(IndexSize(1, idx));

	        Data[0] = (double)dum[wgt_num];
	        ReturnSnapshot(Data, 1, idx);
	        message_1d_after("WEIGHT");
	        FreeData(dum);
	        efree(Data);
	      }
      }
      break;
    }
    break;

  case BPS_STOREMODE_OVERWRITE:
    if ((iterat_int != 1) && (iterat_int != 2))
      return 131; /* Illigal History No. */
    switch (unit_num1) {
    case 'x':
    case 'X':
      if ((unit_num2 == 'y') || (unit_num2 == 'Y')) {
	      if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
					return 11; /* Illigal File No. */

	      idx[0] = bps_cont.NumOfCell[lay_num + 1];
	      idx[1] = bps_cont.NumOfCell[lay_num] + CheckBias(lay_num + 1) * bias_flg1;
	      wgt_num = wnum1(1) + CheckBias(lay_num + 1) * bias_flg2;

	      message_2d_before(idx);
	      k = wgt_num;
	      Data = AllocBuffer(IndexSize(2, idx));
	      for (i = 0; i < idx[0]; i++) {
	        dum = (float*)LoadData(bps_cont.bps_sp.FileName, iterat_int - 1, &header);
	        for (j = 0; j < idx[1]; j++)
	          *(Data+i*idx[1]+j) = (double)dum[k++];
	        k += CheckBias(lay_num + 1) * bias_flg2;
	        FreeData(dum);
	        message_2d_after(i, "WEIGHT");
	      }
	      ReturnSnapshot(Data, 2, idx);
	      efree(Data);
	      printf("\n");

      } else {
	      if ((unum2 < 1) && (unum2 > bps_cont.NumOfCell[lay_num + 1]))
	        return 130; /* Illigal Unit No. */
	      if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
					return 11; /* Illigal File No. */

	      wgt_num = wnum1(unum2) + CheckBias(lay_num + 1) * bias_flg2;

	      dum = (float*)LoadData(bps_cont.bps_sp.FileName, iterat_int - 1, &header);
	      idx[0] = bps_cont.NumOfCell[lay_num] + CheckBias(lay_num + 1) * bias_flg1;
	      Data = AllocBuffer(IndexSize(1, idx));
	      for (i = 0; i < idx[0]; i++)
	        Data[i] = (double)dum[i + wgt_num];

      	message_1d_after("WEIGHT");
	      ReturnSnapshot(Data, 1, idx);
	      FreeData(dum);
	      efree(Data);
      }
      break;

    case 'y':
    case 'Y':			/* unit_num1 == Y , Overwrite */
      if ((unit_num2 != 'x') && (unit_num2 != 'X'))
	      return 130; /* Illigal Unit No. */
      if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
				return 11; /* Illigal File No. */

      idx[0] = bps_cont.NumOfCell[lay_num] + CheckBias(lay_num + 1) * bias_flg1;
      idx[1] = bps_cont.NumOfCell[lay_num + 1];
      wgt_num = wnum1(1) + CheckBias(lay_num + 1) * bias_flg2;

      message_2d_before(idx);

      dum = (float*)LoadData(bps_cont.bps_sp.FileName, iterat_int - 1, &header);
      Data = AllocBuffer(IndexSize(2, idx));
      k = wgt_num;
      for (i = 0; i < idx[1]; i++) {
	      for (j = 0; j < idx[0]; j++)
	        *(Data+j*idx[1]+i) = (double)dum[k++];
	      k += CheckBias(lay_num + 1) * bias_flg2;
      }
      message_2d_after(idx[0] - 1 , "WEIGHT");

      ReturnSnapshot(Data, 2, idx);
      efree(Data);
      FreeData(dum);
      printf("\n");
      break;

    default:			/* unit_num1 != X/Y , Overwrite  */
      if ((unum1 < isbias_r(lay_num + 1)) || (unum1 > bps_cont.NumOfCell[lay_num]))
	      return 130; /* Illigal Unit No. */
      if ((unit_num2 == 'x') || (unit_num2 == 'X')) {
	      if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
					return 11; /* Illigal File No. */

	      idx[0] = bps_cont.NumOfCell[lay_num + 1];

	      wnum2(unum1);

	      dum = (float*)LoadData(bps_cont.bps_sp.FileName, iterat_int - 1, &header);
	      Data = AllocBuffer(IndexSize(1, idx));
	      for (i = 0; i < idx[0]; i++)
	        Data[i] = (double)dum[wgt_number[i]];

	      message_1d_after("WEIGHT");
	      ReturnSnapshot(Data, 1, idx);
	      FreeData(dum);
	      efree(Data);
      } else {
	      if ((unum2 < 1) || (unum2 > bps_cont.NumOfCell[lay_num + 1]))
	        return 130; /* Illigal Unit No. */

	      if (LoadHeader(bps_cont.bps_sp.FileName, &header) == -1)
					return 11; /* Illigal File No. */
	      wgt_num = wnum3(unum1, unum2);

	      dum = (float*)LoadData(bps_cont.bps_sp.FileName, iterat_int - 1, &header);
	      idx[0] = 1;
	      Data = AllocBuffer(IndexSize(1, idx));
	      Data[0] = (double)dum[wgt_num];

	      message_1d_after("WEIGHT");
	      ReturnSnapshot(Data, 1, idx);
	      FreeData(dum);
	      efree(Data);
      }
    }
    break;

  default:
    return 202; /* Unknown Store Mode */
  }

  wrbps();
  return 0;
}

#ifdef __cplusplus
}
#endif