/*
 * Copyright (c) 2003-2005 RIKEN Japan, All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY RIKEN AND CONTRIBUTORS ``AS IS'' AND ANY
 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL RIKEN OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGE.
 */

/* $SATELLITE: satellite4/modules/bps/command/wgtrenew.cpp,v 1.4 2005/02/18 13:20:23 orrisroot Exp $ */

#ifdef HAVE_CONFIG_H
# include "config.h"
#endif
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "libbps.h"

/************************************************
 *                                               *
 *        Back Propergation Simurator(BPS)       *
 *              subroutine package               *
 *                Version 4.0                    *
 *          coded                in May.17 1989  *
 *          last modified in Jul.3  1990         *
 *          coded by        Y.Okamura            *
 *          modified by         K.Kuroda         *
 *                                               *
 *************************************************
 *                                               *
 *        filename wgtren.c                      *
 *            weight renew routine               *
 *                                               *
 ************************************************/

#ifdef __cplusplus
extern "C" {
#endif

/**********************************************
** modifeid by higashi for Structure Learning *
**********************************************/
double structlearn(int enableLay, int enableUnit, bps_ilin_t *enablePoint,
                   double *tempWeight)
{
  int lay,unit;
  bps_ilin_t *link_pt; 
  double Var = 0.0;

  switch (bps_cont.bps_sp.StrLrnMode) {
  case 0:
    return Var;
  case BPS_STRLRNMODE_PLAUT:
          Var = - enablePoint->Weight;
    Var = Var * bps_cont.bps_sp.Ramuda;
    return Var;
  case BPS_STRLRNMODE_ISHIKAWA:
    if(enablePoint->Weight < 0)
      Var = bps_cont.bps_sp.Ramuda;
    else
      Var = -bps_cont.bps_sp.Ramuda;
    return Var;
  case BPS_STRLRNMODE_YASUI:
    lay=enableLay;
    unit = enableUnit;
    link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
    while (link_pt != NULL) {
      if(link_pt->Weight >0)
        Var += *tempWeight;
      else
        Var -= *tempWeight;
      link_pt = Getinfwdlist(link_pt);
      tempWeight ++;
    }
    if(enablePoint->Weight <0)
      Var = (Var + enablePoint->Weight)  * bps_cont.bps_sp.Ramuda;
    else     
      Var = -(Var - enablePoint->Weight) * bps_cont.bps_sp.Ramuda;
    return Var;
  }
    return Var;
}

/***********************************************
  make temp for Yasui structure learning
  modified by higashi
  *********************************************/
static double *temp_wgt(int lay, int unit)
{
  double *startpoint;
  bps_ilin_t *link_pt;
  int    count=0;
  
  link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
  while (link_pt != NULL){
    count ++;
    link_pt = Getinfwdlist(link_pt);
  }
  startpoint = (double*)emalloc(sizeof(double)*count);
  count = 0;

  link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
  while (link_pt != NULL){
    startpoint[count] = link_pt->Weight;
    link_pt = Getinfwdlist(link_pt);
    count ++;
  }
  return startpoint;
}


/************************************************
  steep method
  ************************************************/

void steep()
{
  int      lay, unit;
  bps_ilin_t  *link_pt;

  for (lay = 0; lay < bps_cont.NumOfLayer; lay++) {
    for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
      link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
      while (link_pt != NULL) {
        link_pt->AdjWgt = bps_cont.LearnRate * link_pt->WgtWork;
        link_pt->Weight += link_pt->AdjWgt;
        link_pt = Getinfwdlist(link_pt);
      }
    }
  }
}


/************************************************
  momentum method
  ************************************************/
void momentum1()
{
  int      lay, unit;
  bps_ilin_t  *link_pt;
  double  *temp_yasui; /* modified by higashi*/

  for (lay = 0; lay < bps_cont.NumOfLayer; lay++) {
    for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
      link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
      temp_yasui =  temp_wgt(lay,unit);
      while (link_pt != NULL) {
        link_pt->AdjWgt = bps_cont.LearnRate * link_pt->WgtWork 
          + bps_cont.Momentum * link_pt->AdjWgt;
        /*link_pt->Weight += link_pt->AdjWgt;*/

        /*modified by higashi */
        link_pt->Weight += link_pt->AdjWgt 
                + structlearn(lay,unit,link_pt,temp_yasui);

        link_pt = Getinfwdlist(link_pt);
      }
      efree(temp_yasui);
    }
  }
}


/************************************************
  momentum rule
  ************************************************/
static void basis_renew()
{
  int      lay, unit;
  bps_ilin_t  *link_pt;

  for (lay = 0; lay < bps_cont.NumOfLayer; lay++) {
    for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
      link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
      while (link_pt != NULL) {
        link_pt->AdjWgt = bps_cont.LearnRate * link_pt->WgtWork 
          + bps_cont.Momentum * link_pt->AdjWgt;
        link_pt->Weight += link_pt->AdjWgt;
        link_pt = Getinfwdlist(link_pt);
      }
    }
  }
}


/************************************************
  rejection
  ************************************************/
static void reject()
{

  int      lay, unit;
  bps_ilin_t  *link_pt;

  for (lay = 0; lay < bps_cont.NumOfLayer; lay++) {
    for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
      link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
      while (link_pt != NULL) {
        link_pt->Weight -= link_pt->AdjWgt;
        link_pt = Getinfwdlist(link_pt);
      }
    }
  }
}


/************************************************
  Vogl method
  ************************************************/
void vogl()
{
  double  err1;

  basis_renew();
  workspace_initialize();
  err1 = set_learn(0);
  if ((err1 - bps_cont.SumOfErr) / bps_cont.SumOfErr > bps_cont.VoglThresh) {
    bps_cont.LearnRate *= bps_cont.ReductFact;
    bps_cont.Momentum = 0;
    reject();
  } else {
    bps_cont.LearnRate += bps_cont.IncreaseFact;
    bps_cont.Momentum = bps_cont.InitMoment;
  }
}

/************************************************
  Jacob's method
************************************************/
void jacobs()
{
  int      lay, unit;
  double   nxt_adj;
  bps_ilin_t  *link_pt;
  
  for (lay = 0; lay < bps_cont.NumOfLayer; lay++) {
    for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
      link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
      while (link_pt != NULL) {
        nxt_adj = (1.0 - bps_cont.InitMoment) * 
          link_pt->WgtWork + bps_cont.InitMoment * link_pt->AdjWgt;
        if (nxt_adj * link_pt->AdjWgt > 0)
          link_pt->CoefLearn += bps_cont.IncreaseFact;
        else {
          if (nxt_adj * link_pt->AdjWgt < 0)
            link_pt->CoefLearn *= bps_cont.ReductFact;
        }
        link_pt->AdjWgt = nxt_adj;
        link_pt->Weight = link_pt->CoefLearn * link_pt->AdjWgt;
        link_pt = Getinfwdlist(link_pt);
      }
    }
  }
}


/************************************************
  Vogl's coefficient
  ************************************************/
void vgl2_coe()
{
  bps_cel_t  *cur_cel;
  int     lay, n_sig_cnt, n_sig;
  
  n_sig = 0;
  bps_cont.VglCoef[bps_cont.NumOfLayer] = 1;
  for (lay = bps_cont.NumOfLayer - 1; lay >= 0; lay--) {
    cur_cel = bps_cont.BPNet[lay][1].CellNode;
    if (cur_cel->CharFunc == BPS_BIAS_SIGMOID) {
      n_sig++;
      bps_cont.VglCoef[lay] = 1;
      for (n_sig_cnt = 1; n_sig_cnt <= n_sig; n_sig_cnt++) {
        bps_cont.VglCoef[lay] *= n_sig + n_sig_cnt;
        bps_cont.VglCoef[lay] /= n_sig_cnt;
      }
      bps_cont.VglCoef[lay] *= 2 * n_sig + 1;
    } else
      bps_cont.VglCoef[lay] = bps_cont.VglCoef[lay + 1];
  }
}


/************************************************
  momentum Vogl's coefficient method
************************************************/
void momentum2()
{
  int      lay, unit;
  bps_ilin_t  *link_pt;

  for (lay = 0; lay < bps_cont.NumOfLayer; lay++) {
    for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
      link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
      while (link_pt != NULL) {
        link_pt->AdjWgt = (bps_cont.VglCoef[lay]*
                           bps_cont.LearnRate*(link_pt->WgtWork)
                           + bps_cont.Momentum * (link_pt->AdjWgt));
        link_pt->Weight += link_pt->AdjWgt;
        link_pt = Getinfwdlist(link_pt);
      }
    }
  }
}


/*********************************************************************
 * Ochiai's method
 *********************************************************************/
void Ochi()
{
  int      lay, unit;
  double   delta, cof, norm;
  bps_ilin_t  *link_pt;
  double  *temp_yasui;
  for (lay = 0; lay < bps_cont.NumOfLayer; lay++) {
    for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
      link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
      while (link_pt != NULL) {
        if (link_pt->dltold * (-link_pt->WgtWork) > 0.0)
          link_pt->CoefLearn += bps_cont.IncreaseFact;
        else {
          if (link_pt->dltold * (-link_pt->WgtWork) < 0.0)
            link_pt->CoefLearn *= bps_cont.ReductFact;
        }
        link_pt->dltold = (1.0 - bps_cont.Theta) * (-link_pt->WgtWork)
          + bps_cont.Theta * link_pt->dltold;
        link_pt = Getinfwdlist(link_pt);
      }
    }
  }

  cof = 0.0;
  for (lay = 0; lay < bps_cont.NumOfLayer; lay++) {
    for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
      /* modified by higashi for Yasui structure learning */
      temp_yasui = temp_wgt(lay,unit);
      /*   end of modified     oct 1999                   */
      link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
      while (link_pt != NULL) {
        link_pt->dltwgtold = link_pt->AdjWgt;

        link_pt->Weight += link_pt->AdjWgt
          = link_pt->CoefLearn * link_pt->WgtWork
          + bps_cont.Momentum * link_pt->dltwgtold;
        
        /**** modified by higshi *******/
        link_pt->Weight += structlearn(lay,unit,link_pt,temp_yasui);
        /*******************************/

        link_pt->dltwork = -link_pt->WgtWork - link_pt->wgtworkold;
        link_pt->wgtworkold = -link_pt->WgtWork;
        cof += link_pt->dltworkold * link_pt->dltwork;
        link_pt->dltworkold = link_pt->dltwork;
        link_pt = Getinfwdlist(link_pt);
      }
      efree(temp_yasui);
    }
  }

  if (cof < 0.0) {
    cof = norm = 0.0;
    for (lay = 0; lay < bps_cont.NumOfLayer; lay++) {
      for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
        link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
        while (link_pt != NULL) {
          cof -= link_pt->dltwork * link_pt->AdjWgt;
          norm += link_pt->dltwork * link_pt->dltwork;
          link_pt = Getinfwdlist(link_pt);
        }
      }
    }

    cof = cof / (2.0 * norm);

    for (lay = 0; lay < bps_cont.NumOfLayer; lay++) {
      for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
        link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
        while (link_pt != NULL) {
          link_pt->AdjWgt += delta
            = cof * link_pt->dltwork;
          link_pt->Weight += delta;
          link_pt = Getinfwdlist(link_pt);
        }
      }
    }
  }
}

#ifdef __cplusplus
}
#endif
