#!/usr/bin/env python
# -*- coding: utf-8 -*-
import sys
import os
import tensorflow as tf
import time

# 入力画像の幅・高さ・チャネル数
INPUT_WIDTH = 100
INPUT_HEIGHT = 100
INPUT_CHANNELS = 3
INPUT_SIZE = INPUT_WIDTH * INPUT_HEIGHT * INPUT_CHANNELS

# 1つめの畳み込み-プーリング層のパラメータ
CONV1_SIZE = 5              # 畳み込みフィルタのサイズ
CONV1_STRIDE = [1, 1, 1, 1] # 畳み込みフィルタのストライド
CONV1_CHANNELS = 32         # 畳み込み層の出力チャネル数
POOL1_SIZE = [1, 2, 2, 1]   # プーリング層のウィンドウサイズ
POOL1_STRIDE = [1, 2, 2, 1] # プーリングのストライド

# 2つめの畳み込み-プーリング層のパラメータ
CONV2_SIZE = 5              # 畳み込みフィルタのサイズ
CONV2_STRIDE = [1, 1, 1, 1] # 畳み込みフィルタのストライド
CONV2_CHANNELS = 32         # 畳み込み層の出力チャネル数
POOL2_SIZE = [1, 2, 2, 1]   # プーリング層のウィンドウサイズ
POOL2_STRIDE = [1, 2, 2, 1] # プーリングのストライド

# 全結合層のサイズ
W5_SIZE=25 * 25 * CONV2_CHANNELS

# 出力サイズ
OUTPUT_SIZE = 3
LABEL_SIZE = OUTPUT_SIZE

TEACH_FILES = ["../data2/teach_cat.tfrecord",
               "../data2/teach_dog.tfrecord",
               "../data2/teach_monkey.tfrecord"]
TEST_FILES = ["../data2/test_cat.tfrecord",
              "../data2/test_dog.tfrecord",
              "../data2/test_monkey.tfrecord"]
MODEL_FILE = "./cnn_model"


# 結果をそろえるために乱数の種を指定
tf.set_random_seed(1111)

## 入力と計算グラフを定義
with tf.variable_scope('model') as scope:

    # 入力（＝第1層）および正答を入力するプレースホルダを定義
    x1 = tf.placeholder(dtype=tf.float32, name="x1")
    y = tf.placeholder(dtype=tf.float32, name="y")

    # ドロップアウト設定用のプレースホルダ
    enable_dropout = tf.placeholder_with_default(0.0, [], name="enable_dropout")

    # ドロップアウト確率
    prob_one = tf.constant(1.0, dtype=tf.float32)

    # enable_dropoutが0の場合、キープ確率は1。そうでない場合、一定の確率に設定する
    x5_keep_prob = prob_one - enable_dropout * 0.5

    # 第2層（畳み込み処理）
    W1 = tf.get_variable("W1",
                         shape=[CONV1_SIZE, CONV1_SIZE, INPUT_CHANNELS, CONV1_CHANNELS],
                         dtype=tf.float32,
                         initializer=tf.random_normal_initializer(stddev=0.01))

    b1 = tf.get_variable("b1",
                         shape=[CONV1_CHANNELS],
                         dtype=tf.float32,
                         initializer=tf.random_normal_initializer(stddev=0.01))
    u1 = tf.nn.bias_add(tf.nn.conv2d(x1, W1, CONV1_STRIDE, "SAME"), b1, name="u1")
    x2 = tf.nn.relu(u1, name="x2")

    # 第3層（プーリング層）
    x3 = tf.nn.max_pool(x2, POOL1_SIZE, POOL1_STRIDE, "SAME", name="x3")

    # 第4層（畳み込み処理）
    W3 = tf.get_variable("W3",
                         shape=[CONV2_SIZE, CONV2_SIZE, CONV1_CHANNELS, CONV2_CHANNELS],
                         dtype=tf.float32,
                         initializer=tf.random_normal_initializer(stddev=0.01))

    b3 = tf.get_variable("b3",
                         shape=[CONV2_CHANNELS],
                         dtype=tf.float32,
                         initializer=tf.random_normal_initializer(stddev=0.01))
    u3 = tf.nn.bias_add(tf.nn.conv2d(x3, W3, CONV1_STRIDE, "SAME"), b3, name="u3")
    x4 = tf.nn.relu(u3, name="x4")

    # 第5層（プーリング層）
    x5 = tf.nn.max_pool(x4, POOL2_SIZE, POOL2_STRIDE, "SAME", name="x5")
    x5_ = tf.reshape(x5, [-1, W5_SIZE], name="x5_")
    x5_drop = tf.nn.dropout(x5_, x5_keep_prob, name="x5_drop")

    # 第6層（出力層）
    W5 = tf.get_variable("W5",
                         shape=[W5_SIZE, OUTPUT_SIZE],
                         dtype=tf.float32,
                         initializer=tf.random_normal_initializer(stddev=0.01))
    b5 = tf.get_variable("b5",
                         shape=[OUTPUT_SIZE],
                         dtype=tf.float32,
                         initializer=tf.random_normal_initializer(stddev=0.01))
    x6 = tf.nn.softmax(tf.matmul(x5_drop, W5) + b5, name="x6")

    # コスト関数
    cross_entropy = -tf.reduce_sum(y * tf.log(x6), name="cross_entropy")
    tf.summary.scalar('cross_entropy', cross_entropy)

    # 正答率
    correct = tf.equal(tf.argmax(x6,1), tf.argmax(y, 1), name="correct")
    accuracy = tf.reduce_mean(tf.cast(correct, "float"), name="accuracy")
    tf.summary.scalar('accuracy', accuracy)


    # 最適化アルゴリズムを定義
    global_step = tf.Variable(0, name='global_step', trainable=False)
    optimizer = tf.train.AdamOptimizer(1e-4, name="optimizer")
    minimize = optimizer.minimize(cross_entropy, global_step=global_step, name="minimize")

    # 学習結果を保存するためのオブジェクトを用意
    saver = tf.train.Saver()

# 読み込んだデータの変換用関数
def map_dataset(serialized):
    features = {
        'label':     tf.FixedLenFeature([], tf.int64),
        'height':    tf.FixedLenFeature([], tf.int64),
        'width':     tf.FixedLenFeature([], tf.int64),
        'raw_image': tf.FixedLenFeature([INPUT_SIZE], tf.float32),
    }
    parsed = tf.parse_single_example(serialized, features)

    # 読み込んだデータを変換する
    raw_label = tf.cast(parsed['label'], tf.int32)
    label = tf.reshape(tf.slice(tf.eye(LABEL_SIZE),
                                [raw_label, 0],
                                [1, LABEL_SIZE]),
                       [LABEL_SIZE])

    image = tf.reshape(parsed['raw_image'], tf.stack([parsed['height'], parsed['width'], 3]))
    return (image, label, raw_label)

## データセットの読み込み
# 読み出すデータは各データ200件ずつ×3で計600件
dataset_size = tf.placeholder(shape=[], dtype=tf.int64)
dataset = tf.data.TFRecordDataset(TEACH_FILES)\
                 .map(map_dataset)\
                 .repeat()\
                 .shuffle(600)\
                 .batch(dataset_size)

# データにアクセスするためのイテレータを作成
iterator = dataset.make_initializable_iterator()
next_dataset = iterator.get_next()

# セッションの作成
sess = tf.Session()

# 変数の初期化を実行する
sess.run(tf.global_variables_initializer())

 # 学習結果を保存したファイルが存在するかを確認し、
 # 存在していればそれを読み出す
latest_filename = tf.train.latest_checkpoint("./")
if latest_filename:
    print("load saved model {}".format(latest_filename))
    saver.restore(sess, latest_filename)

# サマリを取得するための処理
summary_op = tf.summary.merge_all()
summary_writer = tf.summary.FileWriter('data', graph=sess.graph)

# 教師データに対する正答率の取得用に学習用データを読み出しておく
sess.run(iterator.initializer, {dataset_size: 600})
(dataset_all_x, dataset_all_y, values_all_y) = sess.run(next_dataset)


## テスト用データセットを読み出す
# テストデータは50×3＝150件
dataset2 = tf.data.TFRecordDataset(TEST_FILES)\
                  .map(map_dataset)\
                  .batch(150)
iterator2 = dataset2.make_one_shot_iterator()
item2 = iterator2.get_next()
(testdataset_x, testdataset_y, testvalues_y) = sess.run(item2)

test_summary = tf.summary.scalar('test_result', accuracy)

steps = tf.train.global_step(sess, global_step)
if steps == 0:
    # 初期状態を記録
    xe, acc, summary = sess.run([cross_entropy, accuracy, summary_op],
                                { x1: dataset_all_x, y: dataset_all_y })
    print("CROSS ENTROPY({}): {}".format(0, xe))
    print("     ACCURACY({}): {}".format(0, acc))
    summary_writer.add_summary(summary, global_step=0)

# 学習を開始
start_time = time.time()
sess.run(iterator.initializer, {dataset_size: 100})
for i in range(90):
    for j in range(10):
        (dataset_x, dataset_y, values_y) = sess.run(next_dataset)
        sess.run(minimize, {x1: dataset_x, y: dataset_y, enable_dropout: 1.0})
    # 途中経過を取得・保存
    xe, acc, summary = sess.run([cross_entropy, accuracy, summary_op],
                                {x1: dataset_all_x, y: dataset_all_y})
    acc2, summary2 = sess.run([accuracy, test_summary],
                              {x1: testdataset_x, y: testdataset_y})
    print("CROSS ENTROPY({}): {}".format(steps + 10 * (i+1), xe))
    print("     ACCURACY({}): {}".format(steps + 10 * (i+1), acc))
    print("  TEST RESULT({}): {}".format(steps + 10 * (i+1), acc2))
    summary_writer.add_summary(summary, global_step=tf.train.global_step(sess, global_step))
    summary_writer.add_summary(summary2, global_step=tf.train.global_step(sess, global_step))

# 学習終了
print ("time: {} sec".format(time.time() - start_time))

save_path = saver.save(sess, MODEL_FILE)
print("Model saved to {}".format(save_path))

## 結果の出力
 
# 学習に使用したデータを入力した場合の
# 正答率を計算する
print("----result with teaching data----")
 
print("assumed label:")
print(sess.run(tf.argmax(x6, 1), {x1: dataset_all_x}))
print("real label:")
print(sess.run(tf.argmax(y, 1), {y: dataset_all_y}))
print("accuracy:", sess.run(accuracy, {x1: dataset_all_x, y: dataset_all_y}))
 
 
# テスト用データを入力した場合の
# 正答率を計算する
print("----result with test data----")
 
 
# 正答率を出力
print("assumed label:")
print(sess.run(tf.argmax(x6, 1), {x1: testdataset_x}))
print("real label:")
print(sess.run(tf.argmax(y, 1), feed_dict={y: testdataset_y}))
print("accuracy:", sess.run(accuracy, {x1: testdataset_x, y: testdataset_y}))

