/*
 * Copyright (C) 2008 u6k.yu1@gmail.com, All Rights Reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 *    1. Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *
 *    2. Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *
 *    3. Neither the name of Clarkware Consulting, Inc. nor the names of its
 *       contributors may be used to endorse or promote products derived
 *       from this software without prior written permission. For written
 *       permission, please contact clarkware@clarkware.com.
 *
 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL
 * CLARKWARE CONSULTING OR ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
 * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN  ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 * EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

package jp.gr.java_conf.u6k.simplenn;

import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;

/**
 * <p>
 * 簡単に使用できるニューラル・ネットワークの実装クラスです。学習方法はバックプロパゲーション法です。
 * </p>
 * <p>
 * このソースコードは、<a href="http://codezine.jp/">CodeZine</a>の記事「<a href="http://codezine.jp/a/article/aid/372.aspx">ニューラルネットワークを用いたパターン認識</a>」を参考にしています。
 * </p>
 * 
 * @version $Id$
 * @see http://codezine.jp/a/article/aid/372.aspx
 */
public final class SimpleNN implements Externalizable {

    /**
     * <p>
     * ネットワークを流れる値の上限(1)の半分。
     * </p>
     */
    private static final double VALUE_HALF = 0.5;

    /**
     * <p>
     * 入力層のニューロン数。
     * </p>
     */
    private int                 inputNumber;

    /**
     * <p>
     * 隠れ層のニューロン数。
     * </p>
     */
    private int                 hiddenNumber;

    /**
     * <p>
     * 出力層のニューロン数。
     * </p>
     */
    private int                 outputNumber;

    /**
     * <p>
     * 学習係数(learning coefficient)。
     * </p>
     */
    private double              learningCoefficient;

    /**
     * <p>
     * 入力層と隠れ層の間の重み係数。
     * </p>
     */
    private double[]            weightInHidden;

    /**
     * <p>
     * 隠れ層の閾値。
     * </p>
     */
    private double[]            thresholdHidden;

    /**
     * <p>
     * 隠れ層と出力層の間の重み係数。
     * </p>
     */
    private double[]            weightHiddenOut;

    /**
     * <p>
     * 出力層の閾値。
     * </p>
     */
    private double[]            thresholdOut;

    /**
     * <p>
     * このコンストラクタは呼び出さないでください。デシリアライズのために定義しています。
     * </p>
     */
    public SimpleNN() {
    }

    /**
     * <p>
     * ニューラル・ネットワークの状態(閾値、重み)を初期化します。
     * </p>
     * 
     * @param inputNumber
     *            入力層のニューロン数。
     * @param hiddenNumber
     *            隠れ層のニューロン数。
     * @param outputNumber
     *            出力層のニューロン数。
     * @param learningCoefficient
     *            学習係数。
     * @throws IllegalArgumentException
     *             inputNumber引数、hiddenNumber引数、outputNumber引数、learningCoefficient引数が0以下の場合。
     */
    public SimpleNN(int inputNumber, int hiddenNumber, int outputNumber, double learningCoefficient) {
        /*
         * 引数を確認します。
         */
        if (inputNumber <= 0) {
            throw new IllegalArgumentException("inputNumber <= 0");
        }
        if (hiddenNumber <= 0) {
            throw new IllegalArgumentException("hiddenNumber <= 0");
        }
        if (outputNumber <= 0) {
            throw new IllegalArgumentException("outputNumber <= 0");
        }
        if (learningCoefficient <= 0) {
            throw new IllegalArgumentException("learningCoefficient <= 0");
        }

        /*
         * ニューラル・ネットワークの状態を初期化します。
         */
        this.thresholdHidden = new double[hiddenNumber];
        this.weightInHidden = new double[inputNumber * hiddenNumber];
        this.thresholdOut = new double[outputNumber];
        this.weightHiddenOut = new double[hiddenNumber * outputNumber];
        for (int i = 0; i < hiddenNumber; i++) {
            this.thresholdHidden[i] = Math.random() - SimpleNN.VALUE_HALF;
            for (int j = 0; j < inputNumber; j++) {
                this.weightInHidden[j * this.hiddenNumber + i] = Math.random() - SimpleNN.VALUE_HALF;
            }
        }
        for (int i = 0; i < outputNumber; i++) {
            this.thresholdOut[i] = Math.random() - SimpleNN.VALUE_HALF;
            for (int j = 0; j < hiddenNumber; j++) {
                this.weightHiddenOut[j * this.outputNumber + i] = Math.random() - SimpleNN.VALUE_HALF;
            }
        }

        this.inputNumber = inputNumber;
        this.hiddenNumber = hiddenNumber;
        this.outputNumber = outputNumber;
        this.learningCoefficient = learningCoefficient;
    }

    /**
     * <p>
     * 入力データと教師信号を用いて、ニューラル・ネットワークの状態を更新します(学習します)。
     * </p>
     * 
     * @param input
     *            入力データ。
     * @param teach
     *            教師信号。
     * @throws NullPointerException
     *             input引数、teach引数がnullの場合。
     * @throws IllegalArgumentException
     *             input引数の配列要素数が入力層のニューロン数と異なる場合。teach引数の配列要素数が出力層のニューロン数と異なる場合。
     */
    public void learn(double[] input, double[] teach) {
        double[] output = new double[this.outputNumber];
        double[] hiddenOutput = new double[this.hiddenNumber];

        this.calcForward(input, output, hiddenOutput);
        this.calcBackward(input, output, hiddenOutput, teach);
    }

    /**
     * <p>
     * 入力データをニューラル・ネットワークを用いて計算します。
     * </p>
     * 
     * @param input
     *            入力データ。
     * @return 計算結果の出力データ。
     * @throws NullPointerException
     *             input引数がnullの場合。
     * @throws IllegalArgumentException
     *             input引数の配列要素数が入力層のニューロン数と異なる場合。
     */
    public double[] calculate(double[] input) {
        double[] output = new double[this.outputNumber];

        this.calcForward(input, output, new double[this.hiddenNumber]);

        return output;
    }

    /**
     * <p>
     * 入力データから導き出される出力データと教師信号とのずれを表す、二乗誤差を算出します。
     * </p>
     * 
     * @param input
     *            入力データ。
     * @param teach
     *            教師信号。
     * @return 二乗誤差。
     * @throws NullPointerException
     *             input引数、teach引数がnullの場合。
     * @throws IllegalArgumentException
     *             input引数の配列要素数が入力層のニューロン数と異なる場合。teach引数の配列要素数が出力層のニューロン数と異なる場合。
     */
    public double reportError(double[] input, double[] teach) {
        double[] output = this.calculate(input);
        double err = this.calcError(output, teach);

        return err;
    }

    /**
     * <p>
     * 順方向演算を行います。
     * </p>
     * 
     * @param input
     *            入力データ。
     * @param output
     *            順方向演算の結果を格納する配列。
     * @param hiddenOutput
     *            順方向演算の過程の隠れ層出力を格納する配列。
     * @throws NullPointerException
     *             input引数、output引数、hiddenOutput引数がnullの場合。
     * @throws IllegalArgumentException
     *             input引数の配列要素数が入力層のニューロン数と異なる場合。output引数の配列要素数が出力層のニューロン数と異なる場合。hiddenOutput引数の配列要素数が隠れ層のニューロン数と異なる場合。
     */
    private void calcForward(double[] input, double[] output, double[] hiddenOutput) {
        /*
         * 引数を確認します。
         */
        if (input == null) {
            throw new NullPointerException("input == null");
        }
        if (output == null) {
            throw new NullPointerException("output == null");
        }
        if (hiddenOutput == null) {
            throw new NullPointerException("hiddenOutput == null");
        }
        if (input.length != this.inputNumber) {
            throw new IllegalArgumentException("input.length != inputNumber");
        }
        if (output.length != this.outputNumber) {
            throw new IllegalArgumentException("output.length != outputNumber");
        }
        if (hiddenOutput.length != this.hiddenNumber) {
            throw new IllegalArgumentException("hiddenOutput.length != hiddenNumber");
        }

        /*
         * 隠れ層の出力を計算します。
         */
        for (int i = 0; i < hiddenOutput.length; i++) {
            hiddenOutput[i] = -this.thresholdHidden[i];
            for (int j = 0; j < input.length; j++) {
                hiddenOutput[i] += input[j] * this.weightInHidden[j * this.hiddenNumber + i];
            }
            hiddenOutput[i] = this.sigmoid(hiddenOutput[i]);
        }

        /*
         * 出力層の出力を計算します。
         */
        for (int i = 0; i < output.length; i++) {
            output[i] = -this.thresholdOut[i];
            for (int j = 0; j < hiddenOutput.length; j++) {
                output[i] += hiddenOutput[j] * this.weightHiddenOut[j * this.outputNumber + i];
            }
            output[i] = this.sigmoid(output[i]);
        }
    }

    /**
     * <p>
     * 逆方向演算を行います。
     * </p>
     * 
     * @param input
     *            順方向演算の入力データ。
     * @param output
     *            順方向演算の結果。
     * @param hiddenOutput
     *            順方向演算の過程の隠れ層出力を格納する配列。
     * @param teach
     *            教師信号。
     * @throws NullPointerException
     *             input引数、output引数、hiddenOutput引数、teach引数がnullの場合。
     * @throws IllegalArgumentException
     *             input引数の配列要素数が入力層のニューロン数と異なる場合。output引数の配列要素数が出力層のニューロン数と異なる場合。hiddenOutput引数の配列要素数が隠れ層のニューロン数と異なる場合。teach引数の配列要素数が出力層のニューロン数と異なる場合。
     */
    private void calcBackward(double[] input, double[] output, double[] hiddenOutput, double[] teach) {
        /*
         * 引数を確認します。
         */
        if (input == null) {
            throw new NullPointerException("input == null");
        }
        if (output == null) {
            throw new NullPointerException("output == null");
        }
        if (hiddenOutput == null) {
            throw new NullPointerException("hiddenOutput == null");
        }
        if (teach == null) {
            throw new NullPointerException("teach == null");
        }
        if (input.length != this.inputNumber) {
            throw new IllegalArgumentException("input.length != inputNumber");
        }
        if (output.length != this.outputNumber) {
            throw new IllegalArgumentException("output.length != outputNumber");
        }
        if (hiddenOutput.length != this.hiddenNumber) {
            throw new IllegalArgumentException("hiddenOutput.length != hiddenNumber");
        }
        if (teach.length != this.outputNumber) {
            throw new IllegalArgumentException("teach.length != outputNumber");
        }

        /*
         * 出力層の誤差を計算します。
         */
        double[] outputError = new double[output.length];
        for (int i = 0; i < outputError.length; i++) {
            outputError[i] = (teach[i] - output[i]) * output[i] * (1.0 - output[i]);
        }

        /*
         * 隠れ層の誤差を計算します。
         */
        double[] hiddenError = new double[hiddenOutput.length];
        for (int i = 0; i < hiddenError.length; i++) {
            double err = 0;
            for (int j = 0; j < output.length; j++) {
                err += outputError[j] * this.weightHiddenOut[i * this.outputNumber + j];
            }
            hiddenError[i] = hiddenOutput[i] * (1.0 - hiddenOutput[i]) * err;
        }

        /*
         * 重みを補正します。
         */
        for (int i = 0; i < outputError.length; i++) {
            for (int j = 0; j < hiddenOutput.length; j++) {
                this.weightHiddenOut[j * this.outputNumber + i] += this.learningCoefficient * outputError[i] * hiddenOutput[j];
            }
        }
        for (int i = 0; i < hiddenError.length; i++) {
            for (int j = 0; j < input.length; j++) {
                this.weightInHidden[j * this.hiddenNumber + i] += this.learningCoefficient * hiddenError[i] * input[j];
            }
        }

        /*
         * 閾値を補正します。
         */
        for (int i = 0; i < this.thresholdOut.length; i++) {
            this.thresholdOut[i] -= this.learningCoefficient * outputError[i];
        }
        for (int i = 0; i < this.thresholdHidden.length; i++) {
            this.thresholdHidden[i] -= this.learningCoefficient * hiddenError[i];
        }
    }

    /**
     * <p>
     * 順方向演算の結果と教師信号とのずれを表す二乗誤差を計算します。
     * </p>
     * 
     * @param output
     *            順方向演算の結果。
     * @param teach
     *            教師信号。
     * @return 二乗誤差。
     * @throws NullPointerException
     *             output引数、teach引数がnullの場合。
     */
    private double calcError(double[] output, double[] teach) {
        /*
         * 引数を確認します。
         */
        if (output == null) {
            throw new NullPointerException("output == null");
        }
        if (teach == null) {
            throw new NullPointerException("teach == null");
        }

        /*
         * 二乗誤差を計算します。
         */
        double error = 0;
        for (int i = 0; i < output.length; i++) {
            error += (teach[i] - output[i]) * (teach[i] - output[i]);
        }

        return error;
    }

    /**
     * <p>
     * シグモイド関数です。
     * </p>
     * 
     * @param x
     *            引数。
     * @return 計算結果。
     */
    private double sigmoid(double x) {
        return 1.0 / (1.0 + Math.exp(-x));
    }

    /**
     * {@inheritDoc}
     */
    public void readExternal(ObjectInput in) throws IOException, ClassNotFoundException {
        this.inputNumber = in.readInt();
        this.hiddenNumber = in.readInt();
        this.outputNumber = in.readInt();
        this.learningCoefficient = in.readDouble();

        this.weightInHidden = new double[this.inputNumber * this.hiddenNumber];
        for (int i = 0; i < this.weightInHidden.length; i++) {
            this.weightInHidden[i] = in.readDouble();
        }

        this.thresholdHidden = new double[this.hiddenNumber];
        for (int i = 0; i < this.thresholdHidden.length; i++) {
            this.thresholdHidden[i] = in.readDouble();
        }

        this.weightHiddenOut = new double[this.hiddenNumber * this.outputNumber];
        for (int i = 0; i < this.weightHiddenOut.length; i++) {
            this.weightHiddenOut[i] = in.readDouble();
        }

        this.thresholdOut = new double[this.outputNumber];
        for (int i = 0; i < this.thresholdOut.length; i++) {
            this.thresholdOut[i] = in.readDouble();
        }
    }

    /**
     * {@inheritDoc}
     */
    public void writeExternal(ObjectOutput out) throws IOException {
        out.writeInt(this.inputNumber);
        out.writeInt(this.hiddenNumber);
        out.writeInt(this.outputNumber);
        out.writeDouble(this.learningCoefficient);

        for (int i = 0; i < this.weightInHidden.length; i++) {
            out.writeDouble(this.weightInHidden[i]);
        }

        for (int i = 0; i < this.thresholdHidden.length; i++) {
            out.writeDouble(this.thresholdHidden[i]);
        }

        for (int i = 0; i < this.weightHiddenOut.length; i++) {
            out.writeDouble(this.weightHiddenOut[i]);
        }

        for (int i = 0; i < this.thresholdOut.length; i++) {
            out.writeDouble(this.thresholdOut[i]);
        }
    }

    /**
     * {@inheritDoc}
     */
    public String toString() {
        StringBuffer s = new StringBuffer();

        s.append(super.toString() + "[\r\n");

        s.append("\tinputNumber=" + this.inputNumber + "\r\n");
        s.append("\thiddenNumber=" + this.hiddenNumber + "\r\n");
        s.append("\toutputNumber=" + this.outputNumber + "\r\n");
        s.append("\tlearningCoefficient=" + this.learningCoefficient + "\r\n");

        for (int i = 0; i < this.weightInHidden.length; i++) {
            s.append("\tweightInHidden[" + i + "]=" + this.weightInHidden[i] + "\r\n");
        }

        for (int i = 0; i < this.thresholdHidden.length; i++) {
            s.append("\tthresholdHidden[" + i + "]=" + this.thresholdHidden[i] + "\r\n");
        }

        for (int i = 0; i < this.weightHiddenOut.length; i++) {
            s.append("\tweightHiddenOut[" + i + "]=" + this.weightHiddenOut[i] + "\r\n");
        }

        for (int i = 0; i < this.thresholdOut.length; i++) {
            s.append("\tthresholdOut[" + i + "]=" + this.thresholdOut[i] + "\r\n");
        }

        s.append("]");

        return s.toString();
    }

}
