/*
 * Copyright (c) 2003-2005 RIKEN Japan, All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY RIKEN AND CONTRIBUTORS ``AS IS'' AND ANY
 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL RIKEN OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGE.
 */

/* $SATELLITE: satellite4/modules/bps/command/errfunc.cpp,v 1.5 2005/02/21 11:53:13 ninja Exp $ */

#ifdef HAVE_CONFIG_H
# include "config.h"
#endif
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "libbps.h"

/************************************************
*                                               *
*       Back Propergation Simurator(BPS)        *
*             subroutine package                *
*               Ver 				*
*         coded         in Nov.16 1989          *
*         modified by   K.Kuroda                *
*                                               *
*************************************************
*                                               *
*       filename errfunc.c                      *
*                                               *
************************************************/

#ifdef __cplusplus
extern "C" {
#endif

static int    wgt_hist_no;
static int    lay_no1, unit_from1, unit_to1, dpoint1;
static int    lay_no2, unit_from2, unit_to2, dpoint2;
static float  min_wgt1, max_wgt1;
static float  min_wgt2, max_wgt2;

Buffer       *err_data;


/************************************************
  set parameters
  ************************************************/
int
set_param_errfunc()
{
  wgt_hist_no =   (int)GetScalar(0);

  lay_no1     =   (int)GetScalar(1);
  unit_from1  =   (int)GetScalar(2);
  unit_to1    =   (int)GetScalar(3);
  min_wgt1    = (float)GetScalar(4);
  max_wgt1    = (float)GetScalar(5);
  dpoint1     =   (int)GetScalar(6);

  lay_no2     =   (int)GetScalar(7);
  unit_from2  =   (int)GetScalar(8);
  unit_to2    =   (int)GetScalar(9);
  min_wgt2    = (float)GetScalar(10);
  max_wgt2    = (float)GetScalar(11);
  dpoint2     =   (int)GetScalar(12);

  if (wgt_hist_no <= 0)
    return 17; /* Illigal Buffer No. */
  if (bps_cont.MaxLearnCount < wgt_hist_no)
    return 25; /* Illigal History No. */

  if ((lay_no1 < 0) || ((bps_cont.NumOfLayer-1) < lay_no1))
    return 27; /* Illigal Layer No. */
  if ((unit_from1 < 1) || (bps_cont.NumOfCell[lay_no1] < unit_from1))
    return 21; /* Illigal Unit No. */
  if ((unit_to1 < 1) || (bps_cont.NumOfCell[lay_no1+1] < unit_to1))
    return 21; /* Illigal Unit No. */
  if (dpoint1 < 1)
    return 57; /* Illigal Parameter 8 */

  if ((lay_no2 < 0) || ((bps_cont.NumOfLayer-1) < lay_no2))
    return 20; /* Illigal Layer No. */
  if ((unit_from2 < 1) || (bps_cont.NumOfCell[lay_no2] < unit_from2))
    return 21; /* Illigal Unit No. */
  if ((unit_to2 < 1) || (bps_cont.NumOfCell[lay_no2+1] < unit_to2))
    return 21; /* Illigal Unit No. */
  if (dpoint2 < 1)
    return 63; /* Illigal Parameter 14 */

	return 0;
}


/************************************************
  clean OutCellErr
  ************************************************/
void
clean_outcell()
{
  int  unit;
#if 0  
  for (unit = 0; unit < NumOfCell[NumOfLayer - 1]; unit++);
  OutCellErr[unit] = 0.0; /* ??? */
#else /* ʤΤǤ */
  for (unit = 0; unit < bps_cont.NumOfCell[bps_cont.NumOfLayer - 1]; unit++)
    bps_cont.OutCellErr[unit] = 0.0;
#endif
}


/************************************************
  calculate error function and
  store data
  ************************************************/
int
cal_err()
{
  double   weight1, weight2, delta1, delta2;
  int      i, j, ptrn;
  int      idx[10];
  
  idx[0] = dpoint1;
  idx[1] = dpoint2;

  printf("  ### Snapshot : 2d_Dimension ");
  for (i = 0; i < 2; i++)
    printf("[%d]", idx[i]);
  printf(" ###\n");

  printf("  ### ERROR FUNCTION ==> 2D Snapshot ###\n");
  
  err_data = AllocBuffer(IndexSize(2, idx));
  if (err_data == NULL)
    return 23; /* Can't Allocate To Array */

  delta1 = (max_wgt1 - min_wgt1) / (dpoint1 - 1);
  delta2 = (max_wgt2 - min_wgt2) / (dpoint2 - 1);
  weight1 = min_wgt1;
  
  for (i = 0; i < dpoint1; i++) {
    SetWeight(bps_cont.BPNet[lay_no1][unit_from1].CellNode,
	      bps_cont.BPNet[lay_no1+1][unit_to1].CellNode, weight1);
    weight2 = min_wgt2;
    for (j = 0; j < dpoint2; j++) {
      SetWeight(bps_cont.BPNet[lay_no2][unit_from2].CellNode,
		bps_cont.BPNet[lay_no2+1][unit_to2].CellNode, weight2);
      *(err_data+i*dpoint2+j) = (double)0.0;
      for (ptrn = 0; ptrn < bps_cont.NumOfPtrn; ptrn++) {
	*(err_data+i*dpoint2+j) +=
	  (double)forward_learn(bps_cont.InputData[ptrn], bps_cont.TeachData[ptrn]);
      }
      weight2 += delta2;
    }
    weight1 += delta1;
    
    printf("     ERROR FUNCTION ==> INDEX[%d] \r", i);
    fflush(stdout);
  }
  printf("\n");
  ReturnSnapshot(err_data, 2, idx);

  efree(err_data);
	return 0;
}


/***********************************************
  error function
  ***********************************************/
DLLEXPORT int mod_bps_errfunc()
{
  /* GET PARAMETERS */
  rebps();

  GetStructureParameters();
  GetLearningParameters();
  set_param_errfunc();
  
  /* SYSTEM INITAIALIZE */
  system_initialize();
  MakeNetwork();

  ReadWeight2(bps_cont.WgtHistoryFile, wgt_hist_no);

  clean_outcell();
  workspace_initialize();
  
  /* CALCULATE ERROR FUNCTION */
  cal_err();
  
  /* SYSTEM END */
  BreakNetwork();
  system_end();

  wrbps();
  return 0;
}

#ifdef __cplusplus
}
#endif