package jp.sf.amateras.mirage.scala

import collection.JavaConversions._
import jp.sf.amateras.mirage.parser.{SqlContext, Node, SqlParserImpl}
import java.lang.reflect.Field
import jp.sf.amateras.mirage.{IterationCallback, SqlManagerImpl, SqlExecutor}

/**
 * SqlManager wrapper for Scala.
 */
class SqlManager private (sqlManager: jp.sf.amateras.mirage.SqlManagerImpl) {

  def getSingleResult[T](clazz: Class[T], sql: SqlProvider, param: AnyRef = null): Option[T] = {
    val node: Node = new SqlParserImpl(sql.getSql()).parse()
    val context: SqlContext = prepareSqlContext(convertParam(param))
    node.accept(context)

    Option(getSqlExecutor.getSingleResult(clazz, context.getSql, context.getBindVariables))
  }

  def getResultList[T](clazz: Class[T], sql: SqlProvider, param: AnyRef = null): List[T] = {
    val node: Node = new SqlParserImpl(sql.getSql()).parse()
    val context: SqlContext = prepareSqlContext(convertParam(param))
    node.accept(context)

    if(clazz == classOf[Map[String, _]]){
      // convert java.util.Map to scala.Map
      getSqlExecutor.getResultList(classOf[java.util.Map[String, _]], context.getSql,
        context.getBindVariables).toList.map { entry => entry.toMap }.asInstanceOf[List[T]]
    } else {
      getSqlExecutor.getResultList(clazz, context.getSql, context.getBindVariables).toList
    }
  }

  def iterate[T, R](clazz: Class[T], sql: SqlProvider, param: AnyRef = null) (callback: (T) => R): R = {
    val node: Node = new SqlParserImpl(sql.getSql()).parse()
    val context: SqlContext = prepareSqlContext(convertParam(param))
    node.accept(context)

    getSqlExecutor.iterate(clazz, new IterationCallbackAdapter(callback), context.getSql, context.getBindVariables)
  }

  def iterate[T, R](clazz: Class[T], sql: SqlProvider, context: R) (callback: (T, R) => R): R = iterate(clazz, sql, null, context) (callback)

  def iterate[T, R](clazz: Class[T], sql: SqlProvider, param: AnyRef, context: R) (callback: (T, R) => R): R = {
    var result = context
    iterate(clazz, sql, param){t =>
      result = callback(t, result)
      result
    }
  }

  def executeUpdate(sql: SqlProvider, param: AnyRef = null): Int = {
    val node: Node = new SqlParserImpl(sql.getSql()).parse()
    val context: SqlContext = prepareSqlContext(convertParam(param))
    node.accept(context)

    getSqlExecutor.executeUpdateSql(context.getSql, context.getBindVariables, null)
  }

  def getCount(sql: SqlProvider, param: AnyRef = null): Int = {
    val node: Node = new SqlParserImpl(sql.getSql()).parse()
    val context: SqlContext = prepareSqlContext(convertParam(param))
    node.accept(context)

    val countSql: String = sqlManager.getDialect.getCountSql(context.getSql);
    getSqlExecutor.getSingleResult(classOf[java.lang.Integer], countSql, context.getBindVariables).intValue
  }

  /**
   * Finds the entity by the given primary key.
   * @param clazz the type of entity
   * @param id primary keys
   * @return the entity. If the entity which corresponds to the given primary key is not found, this method returns None.
   */
  def findEntity[T](clazz: Class[T], id: Any*): Option[T] = {
    Option(sqlManager.findEntity(clazz, id.map {_.asInstanceOf[AnyRef]}: _*))
  }

  /**
   * Inserts the given entity.
   * @param entity the entity to insert
   * @return updated row count
   */
  def insertEntity(entity: AnyRef): Int = sqlManager.insertEntity(entity)

  /**
   * Inserts given entities with batch mode.
   * @param entities entities to insert
   * @return updated row count
   */
  def insertBatch(entities: AnyRef*): Int = sqlManager.insertBatch(entities: _*)

  /**
   * Updates the given entity.
   * @param entity the entity to update
   * @return updated row count
   */
  def updateEntity(entity: AnyRef): Int = sqlManager.updateEntity(entity)

  /**
   * Updates given entities with batch mode.
   * @param entities entities to update
   * @return updated row count
   */
  def updateBatch(entities: AnyRef*): Int = sqlManager.updateBatch(entities: _*)

  /**
   * Deletes the given entity.
   * @param entity the entity to delete
   * @return updated row count
   */
  def deleteEntity(entity: AnyRef): Int = sqlManager.deleteEntity(entity)

  /**
   * Deletes given entities with batch mode.
   * @param entities entities to insert
   * @return updated row count
   */
  def deleteBatch(entities: AnyRef*): Int = sqlManager.deleteBatch(entities: _*)

  /**
   * Converts to java.util.Map if param is scala.Map.
   */
  private def convertParam(param: AnyRef): AnyRef = param match {
    case map: Map[String, _] => (map:java.util.Map[String, _])
    case params => params
  }

  // TODO reflection is not so good...
  private def getSqlExecutor: SqlExecutor = {
    val field: Field = classOf[SqlManagerImpl].getDeclaredField("sqlExecutor")
    field.setAccessible(true)
    field.get(sqlManager).asInstanceOf[SqlExecutor]
  }

  // TODO reflection is not so good but it can't extend SqlManagerImpl, so it can't call prepareSqlContext() directly.
  private def prepareSqlContext(param: AnyRef): SqlContext = {
    val method = classOf[SqlManagerImpl].getDeclaredMethod("prepareSqlContext", classOf[Object])
    method.setAccessible(true)
    method.invoke(sqlManager, param).asInstanceOf[SqlContext]
  }

  /**
   * Adapter for that callback function that would be given to iterate().
   */
  private class IterationCallbackAdapter[T, R](val callback: (T) => R) extends IterationCallback[T, R] {
    def iterate(entity: T): R = callback(entity)
  }

}

object SqlManager {
  BeanDescFactoryInitializer.initialize() //It is called only once.
  def apply(sqlManager: SqlManagerImpl): SqlManager = new SqlManager(sqlManager)
}
