/*
 * Copyright (C) 2008 u6k.yu1@gmail.com, All Rights Reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 *    1. Redistributions of source code must retain the above copyright
 *       notice, this list of conditions and the following disclaimer.
 *
 *    2. Redistributions in binary form must reproduce the above copyright
 *       notice, this list of conditions and the following disclaimer in the
 *       documentation and/or other materials provided with the distribution.
 *
 *    3. Neither the name of Clarkware Consulting, Inc. nor the names of its
 *       contributors may be used to endorse or promote products derived
 *       from this software without prior written permission. For written
 *       permission, please contact clarkware@clarkware.com.
 *
 * THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESSED OR IMPLIED WARRANTIES,
 * INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
 * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL
 * CLARKWARE CONSULTING OR ITS CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
 * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA,
 * OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
 * LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
 * NEGLIGENCE OR OTHERWISE) ARISING IN  ANY WAY OUT OF THE USE OF THIS SOFTWARE,
 * EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 */

package jp.gr.java_conf.u6k.simplenn;

import java.applet.Applet;
import java.awt.Button;
import java.awt.Color;
import java.awt.Graphics;
import java.awt.event.ActionEvent;
import java.awt.event.ActionListener;
import java.awt.event.MouseEvent;
import java.awt.event.MouseListener;
import java.awt.event.MouseMotionListener;
import java.io.BufferedReader;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.io.ObjectOutputStream;

/**
 * <p>
 * バックプロパゲーション法で学習するニューラル・ネットワークのデモ・アプレットです。このソースコードは、<a href="http://codezine.jp/">CodeZine</a>の記事「<a href="http://codezine.jp/a/article/aid/372.aspx">ニューラルネットワークを用いたパターン認識</a>」を参考にしています。
 * </p>
 * 
 * @version $Id$
 * @see http://codezine.jp/a/article/aid/372.aspx
 */
@SuppressWarnings("serial")
public final class SimpleNNApplet extends Applet implements MouseListener, MouseMotionListener, ActionListener {

    /**
     * <p>
     * 文字列とかの表示基準座標。
     * </p>
     */
    private static final int X0           = 10;

    /**
     * <p>
     * 文字列とかの表示基準座標。
     * </p>
     */
    private static final int X1           = 125;

    /**
     * <p>
     * 文字列とかの表示基準座標。
     * </p>
     */
    private static final int Y0           = 55;

    /**
     * <p>
     * 文字列とかの表示基準座標。
     * </p>
     */
    private static final int Y1           = 70;

    /**
     * <p>
     * 文字列とかの表示基準座標。
     * </p>
     */
    private static final int Y2           = 160;

    /**
     * <p>
     * 文字列とかの表示基準座標。
     * </p>
     */
    private static final int Y3           = 240;

    /**
     * <p>
     * 文字列とかの表示基準座標。
     * </p>
     */
    private static final int Y4           = 305;

    /**
     * <p>
     * 文字列とかの表示基準座標。
     * </p>
     */
    private static final int RX0          = 30;

    /**
     * <p>
     * 文字列とかの表示基準座標。
     * </p>
     */
    private static final int RX1          = 60;

    /**
     * <p>
     * 文字列とかの表示基準座標。
     * </p>
     */
    private static final int RX2          = 210;

    /**
     * <p>
     * 文字列とかの表示基準座標。
     * </p>
     */
    private static final int RX3          = 260;

    /**
     * <p>
     * 文字列とかの表示基準座標。
     * </p>
     */
    private static final int RY0          = 225;

    /**
     * <p>
     * 文字列とかの表示基準座標。
     * </p>
     */
    private static final int RY1          = 240;

    /**
     * <p>
     * 入力データの幅。
     * </p>
     */
    private static final int WIDTH        = 7;

    /**
     * <p>
     * 入力データの高さ。
     * </p>
     */
    private static final int HEIGHT       = 11;

    /**
     * <p>
     * 入力層の数(入力データ数)。
     * </p>
     */
    private static final int INPUT        = SimpleNNApplet.WIDTH * SimpleNNApplet.HEIGHT;

    /**
     * <p>
     * 隠れ層の数。
     * </p>
     */
    private static final int HIDDEN       = 16;

    /**
     * <p>
     * パターンの種類。
     * </p>
     */
    private static final int PATTERN      = 10;

    /**
     * <p>
     * 出力層の数(出力データ数)。
     * </p>
     */
    private static final int OUTPUT       = SimpleNNApplet.PATTERN;

    /**
     * <p>
     * 外部サイクル(一連のパターンの繰返し学習)の回数。
     * </p>
     */
    private static final int OUTER_CYCLES = 100;

    /**
     * <p>
     * 内部サイクル(同一パターンの繰返し学習)の回数。
     * </p>
     */
    private static final int INNER_CYCLES = 100;

    /**
     * <p>
     * 「再学習」ボタン。
     * </p>
     */
    private Button           button1;

    /**
     * <p>
     * 「学習終了」ボタン。
     * </p>
     */
    private Button           button2;

    /**
     * <p>
     * 「入力クリア」ボタン。
     * </p>
     */
    private Button           button3;

    /**
     * <p>
     * 「認識」ボタン。
     * </p>
     */
    private Button           button4;

    /**
     * <p>
     * 「状態出力」ボタン。
     * </p>
     */
    private Button           button5;

    /**
     * <p>
     * 学習用入力。
     * </p>
     */
    private double[]         sampleIn     = new double[SimpleNNApplet.INPUT];

    /**
     * <p>
     * 認識用手書き入力。
     * </p>
     */
    private double[]         writtenIn    = new double[SimpleNNApplet.INPUT];

    /**
     * <p>
     * 認識出力(出力層の出力)。
     * </p>
     */
    private double[]         recognizeOut = new double[SimpleNNApplet.OUTPUT];

    /**
     * <p>
     * 教師信号。
     * </p>
     */
    private double[]         teach        = new double[SimpleNNApplet.PATTERN];

    /**
     * <p>
     * 「学習モード」フラグ。
     * </p>
     */
    private boolean          learningFlag;

    /**
     * <p>
     * 学習用入力データの基となるパターン。
     * </p>
     */
    private double[][]       sampleArray;

    /**
     * <p>
     * パターンと出力すべき教師信号の比較表。
     * </p>
     */
    private double[][]       teachArray   = new double[SimpleNNApplet.PATTERN][SimpleNNApplet.OUTPUT];

    /**
     * <p>
     * 手書き文字入力用座標。
     * </p>
     */
    private int              xNew;

    /**
     * <p>
     * 手書き文字入力用座標。
     * </p>
     */
    private int              yNew;

    /**
     * <p>
     * 手書き文字入力用座標。
     * </p>
     */
    private int              xOld;

    /**
     * <p>
     * 手書き文字入力用座標。
     * </p>
     */
    private int              yOld;

    /**
     * <p>
     * ニューラル・ネットワーク。
     * </p>
     */
    private SimpleNN         simpleNN     = new SimpleNN(SimpleNNApplet.INPUT, SimpleNNApplet.HIDDEN, SimpleNNApplet.OUTPUT, 1.2);

    /**
     * <p>
     * 新しいインスタンスを初期化します。
     * </p>
     */
    public SimpleNNApplet() {
    }

    /**
     * <p>
     * アプレットを初期化します。
     * </p>
     */
    public void init() {
        // 学習用入力データの元となるパターンの読み込み
        this.sampleArray = new double[SimpleNNApplet.PATTERN][SimpleNNApplet.WIDTH * SimpleNNApplet.HEIGHT];
        for (int i = 0; i < SimpleNNApplet.PATTERN; i++) {
            BufferedReader r = new BufferedReader(new InputStreamReader(this.getClass().getClassLoader().getResourceAsStream(i + ".txt")));
            try {
                try {
                    int j = 0;
                    String line;
                    while ((line = r.readLine()) != null) {
                        for (char c : line.toCharArray()) {
                            this.sampleArray[i][j] = Integer.parseInt(Character.toString(c));
                            j++;
                        }
                    }
                } finally {
                    r.close();
                }
            } catch (IOException e) {
                throw new RuntimeException(e);
            }
        }

        this.setBackground(Color.gray);

        // ボタンの設定
        this.button1 = new Button("  再学習  ");
        this.button2 = new Button(" 学習終了 ");
        this.button3 = new Button("入力クリヤ");
        this.button4 = new Button("  認  識  ");
        this.button5 = new Button(" 状態出力 ");
        this.add(this.button1);
        this.add(this.button2);
        this.add(this.button3);
        this.add(this.button4);
        this.add(this.button5);
        this.button1.addActionListener(this);
        this.button2.addActionListener(this);
        this.button3.addActionListener(this);
        this.button4.addActionListener(this);
        this.button5.addActionListener(this);

        // マウスの設定
        this.addMouseListener(this);
        this.addMouseMotionListener(this);

        // 教師信号の設定
        for (int q = 0; q < SimpleNNApplet.PATTERN; q++) {
            for (int k = 0; k < SimpleNNApplet.OUTPUT; k++) {
                if (q == k) {
                    this.teachArray[q][k] = 1;
                } else {
                    this.teachArray[q][k] = 0;
                }
            }
        }

        // モードの初期設定
        this.learningFlag = true;
    }

    /**
     * <p>
     * クリックしたボタンごとの処理を行います。
     * </p>
     * 
     * @param ae
     *            イベント情報。
     */
    public void actionPerformed(ActionEvent ae) {
        if (ae.getSource() == this.button1) {
            // 「再学習」
            this.learningFlag = true;
            this.repaint();
        }
        if (ae.getSource() == this.button2) {
            // 「学習終了」
            this.learningFlag = false;
            this.repaint();
        }
        if (ae.getSource() == this.button3) {
            // 「入力クリヤ」
            if (!this.learningFlag) {
                this.repaint();
            }
        }
        if (ae.getSource() == this.button4) {
            // 「認識」
            if (!this.learningFlag) {
                this.recognizeCharacter();
            }
        }
        if (ae.getSource() == this.button5) {
            // 「状態出力」
            System.out.println("状態出力開始。");
            try {
                ObjectOutputStream oout = new ObjectOutputStream(new FileOutputStream("C:/simplenn.dmp"));
                try {
                    oout.writeObject(this.simpleNN);
                } finally {
                    oout.close();
                }
            } catch (IOException e) {
                e.printStackTrace();
            }
            System.out.println("状態出力完了。");
        }
    }

    /**
     * <p>
     * 手書き入力を開始したときの処理を行います。
     * </p>
     * 
     * @param me
     *            イベント情報。
     */
    public void mousePressed(MouseEvent me) {
        int x = me.getX();
        int y = me.getY();
        if (!this.learningFlag && x >= SimpleNNApplet.RX1 && x <= SimpleNNApplet.RX1 + SimpleNNApplet.WIDTH * SimpleNNApplet.PATTERN && y >= SimpleNNApplet.RY1 && y <= SimpleNNApplet.RY1 + SimpleNNApplet.HEIGHT * SimpleNNApplet.PATTERN) {
            this.xOld = me.getX();
            this.yOld = me.getY();
            this.writtenIn[(this.yOld - SimpleNNApplet.RY1) / SimpleNNApplet.PATTERN * SimpleNNApplet.WIDTH + (this.xOld - SimpleNNApplet.RX1) / SimpleNNApplet.PATTERN] = 1;
        }
    }

    /**
     * <p>
     * 何もしません。
     * </p>
     * 
     * @param me
     *            イベント情報。
     */
    public void mouseClicked(MouseEvent me) {
    }

    /**
     * <p>
     * 何もしません。
     * </p>
     * 
     * @param me
     *            イベント情報。
     */
    public void mouseEntered(MouseEvent me) {
    }

    /**
     * <p>
     * 何もしません。
     * </p>
     * 
     * @param me
     *            イベント情報。
     */
    public void mouseExited(MouseEvent me) {
    }

    /**
     * <p>
     * 何もしません。
     * </p>
     * 
     * @param me
     *            イベント情報。
     */
    public void mouseReleased(MouseEvent me) {
    }

    /**
     * <p>
     * 手書き入力中にマウスを動かしたときの処理を行います。
     * </p>
     * 
     * @param me
     *            イベント情報。
     */
    public void mouseDragged(MouseEvent me) {
        int x = me.getX();
        int y = me.getY();
        if (!this.learningFlag && x >= SimpleNNApplet.RX1 && x <= SimpleNNApplet.RX1 + SimpleNNApplet.WIDTH * SimpleNNApplet.PATTERN && y >= SimpleNNApplet.RY1 && y <= SimpleNNApplet.RY1 + SimpleNNApplet.HEIGHT * SimpleNNApplet.PATTERN) {
            Graphics g = this.getGraphics();
            this.xNew = me.getX();
            this.yNew = me.getY();
            g.drawLine(this.xOld, this.yOld, this.xNew, this.yNew);
            this.xOld = this.xNew;
            this.yOld = this.yNew;
            this.writtenIn[(this.yOld - SimpleNNApplet.RY1) / SimpleNNApplet.PATTERN * SimpleNNApplet.WIDTH + (this.xOld - SimpleNNApplet.RX1) / SimpleNNApplet.PATTERN] = 1;
        }
    }

    /**
     * <p>
     * 何もしません。
     * </p>
     * 
     * @param me
     *            イベント情報。
     */
    public void mouseMoved(MouseEvent me) {
    }

    /**
     * <p>
     * 画面に表示します。
     * </p>
     * 
     * @param g
     *            表示対象。
     */
    public void paint(Graphics g) {
        int i;
        int j;
        int k;
        int p;
        int q;
        int r;
        int x;

        String string;

        // 外部サイクルエラー累計
        double outerError;
        // 内部サイクルエラー累計
        double innerError;

        if (this.learningFlag) {
            // 学習モードの背景
            g.setColor(new Color(255, 255, 192));
            g.fillRect(5, 35, 590, 460);
            g.setColor(Color.black);
            g.drawString("学習モード", 500, 55);
        } else {
            // 認識モードの背景
            g.setColor(new Color(192, 255, 255));
            g.fillRect(5, 35, 590, 460);
            g.setColor(Color.black);
            g.drawString("認識モード", 500, 55);
        }

        // 学習用パターンの表示
        g.drawString("使用している学習用パターン", SimpleNNApplet.X0, SimpleNNApplet.Y0);
        for (q = 0; q < SimpleNNApplet.PATTERN; q++) {
            x = 56 * q;
            for (j = 0; j < SimpleNNApplet.HEIGHT; j++) {
                for (i = 0; i < SimpleNNApplet.WIDTH; i++) {
                    if (this.sampleArray[q][SimpleNNApplet.WIDTH * j + i] == 1) {
                        g.setColor(Color.red);
                    } else {
                        g.setColor(Color.cyan);
                    }
                    g.fillRect(SimpleNNApplet.X0 + x + 6 * i, SimpleNNApplet.Y1 + 6 * j, 5, 5);
                }
            }
        }
        g.setColor(Color.black);

        if (this.learningFlag) {
            // 学習モード

            // 閾値と重みの乱数設定
            this.simpleNN = new SimpleNN(SimpleNNApplet.INPUT, SimpleNNApplet.HIDDEN, SimpleNNApplet.OUTPUT, 1.2);

            // -------------------------- 学習 --------------------------
            for (p = 0; p < SimpleNNApplet.OUTER_CYCLES; p++) {
                // 外部サイクル

                // 外部二乗誤差のクリヤー
                outerError = 0;

                for (q = 0; q < SimpleNNApplet.PATTERN; q++) {
                    // パターンの切り替え

                    // パターンに対応した入力と教師信号の設定
                    this.sampleIn = this.sampleArray[q];
                    this.teach = this.teachArray[q];

                    for (r = 0; r < SimpleNNApplet.INNER_CYCLES; r++) {
                        // 内部サイクル

                        this.simpleNN.learn(this.sampleIn, this.teach);
                    }

                    // 内部二乗誤差の計算
                    innerError = this.simpleNN.reportError(this.sampleIn, this.teach);

                    // 外部二乗誤差への累加算
                    outerError += innerError;

                }

                // 外部サイクルの回数と外部二乗誤差の表示
                g.drawString("実行中の外部サイクルの回数と二乗誤差", SimpleNNApplet.X0, SimpleNNApplet.Y2);
                g.setColor(new Color(255, 255, 192));
                g.fillRect(SimpleNNApplet.X0 + 5, SimpleNNApplet.Y2 + 10, 200, 50);
                g.setColor(Color.black);
                g.drawString("OuterCycles=" + String.valueOf(p), SimpleNNApplet.X0 + 10, SimpleNNApplet.Y2 + 25);
                g.drawString("TotalSquaredError=" + String.valueOf(outerError), SimpleNNApplet.X0 + 10, SimpleNNApplet.Y2 + 45);

            }

            // --------------------- 学習結果の確認 ---------------------
            g.drawString("学習結果の確認", SimpleNNApplet.X0, SimpleNNApplet.Y3);
            for (k = 0; k < SimpleNNApplet.OUTPUT; k++) {
                g.drawString("Output", SimpleNNApplet.X1 + 45 * k, SimpleNNApplet.Y3 + 25);
                g.drawString("  [" + String.valueOf(k) + "]", SimpleNNApplet.X1 + 5 + 45 * k, SimpleNNApplet.Y3 + 40);
            }

            for (q = 0; q < SimpleNNApplet.PATTERN; q++) {

                // 入力パターンの設定
                this.sampleIn = this.sampleArray[q];

                // 順方向演算
                this.recognizeOut = this.simpleNN.calculate(this.sampleIn);

                // 結果の表示
                g.setColor(Color.black);
                g.drawString("TestPattern[" + String.valueOf(q) + "]", SimpleNNApplet.X0 + 10, SimpleNNApplet.Y4 + 20 * q);
                for (k = 0; k < SimpleNNApplet.OUTPUT; k++) {
                    if (this.recognizeOut[k] > 0.99) {
                        // 99% より大は、赤で YES と表示
                        g.setColor(Color.red);
                        string = "YES";
                    } else if (this.recognizeOut[k] < 0.01) {
                        // 1% より小は、青で NO と表示
                        g.setColor(Color.blue);
                        string = "NO ";
                    } else {
                        // 1% 以上 99% 以下は、黒で ? と表示
                        g.setColor(Color.black);
                        string = " ? ";
                    }
                    g.drawString(string, SimpleNNApplet.X1 + 10 + 45 * k, SimpleNNApplet.Y4 + 20 * q);
                }

            }
        } else {
            // 認識モード
            g.setColor(Color.black);
            g.drawString("マウスで数字を描いて下さい", SimpleNNApplet.RX0, SimpleNNApplet.RY0);
            // 外枠
            g.drawRect(SimpleNNApplet.RX1 - 1, SimpleNNApplet.RY1 - 1, SimpleNNApplet.WIDTH * 10 + 2, SimpleNNApplet.HEIGHT * 10 + 2);
            g.setColor(Color.gray);
            for (j = 1; j < SimpleNNApplet.HEIGHT; j++) {
                // 横方向区切り
                g.drawLine(SimpleNNApplet.RX1, SimpleNNApplet.RY1 + 10 * j, SimpleNNApplet.RX1 + SimpleNNApplet.WIDTH * 10, SimpleNNApplet.RY1 + 10 * j);
            }
            for (i = 1; i < SimpleNNApplet.WIDTH; i++) {
                // 縦方向区切り
                g.drawLine(SimpleNNApplet.RX1 + 10 * i, SimpleNNApplet.RY1, SimpleNNApplet.RX1 + 10 * i, SimpleNNApplet.RY1 + SimpleNNApplet.HEIGHT * 10);
            }
            for (i = 0; i < SimpleNNApplet.INPUT; i++) {
                // 手書き入力データのクリヤ
                this.writtenIn[i] = 0;
            }
        }
    }

    /**
     * <p>
     * 手書き入力された文字を認識します。
     * </p>
     */
    public void recognizeCharacter() {
        Graphics g = this.getGraphics();

        // 順方向演算
        this.recognizeOut = this.simpleNN.calculate(this.writtenIn);

        // 結果の表示
        for (int k = 0; k < SimpleNNApplet.OUTPUT; k++) {
            g.setColor(Color.black);
            g.drawString(String.valueOf(k) + "である", SimpleNNApplet.RX2, SimpleNNApplet.RY1 + 20 * k);
            if (this.recognizeOut[k] > 0.8) {
                g.setColor(Color.red);
            } else {
                g.setColor(Color.black);
            }

            g.fillRect(SimpleNNApplet.RX3, SimpleNNApplet.RY1 - 10 + 20 * k, (int) (200 * this.recognizeOut[k]), 10);
            g.drawString(String.valueOf((int) (100 * this.recognizeOut[k] + 0.5)) + "%", SimpleNNApplet.RX3 + (int) (200 * this.recognizeOut[k]) + 10, SimpleNNApplet.RY1 + 20 * k);
        }
    }

}
