#!/usr/bin/env python
# -*- encoding: utf-8 -*-

# Copyright (c) 2012, tamanegi (tamanegi@users.sourceforge.jp)
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to
# deal in the Software without restriction, including without limitation the
# rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
# sell copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
# 
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
# 
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
# IN THE SOFTWARE.

import sys
import getopt
import numpy as np # tradition?

from rdgmsout import *
from rdmacmolplt import *
from basis import *
from voxel import *

class Grid:
  def __init__( self, i = 0, j = 0, k = 0 ):
    self.i = i
    self.j = j
    self.k = k

def calcGridPos( basepos, grid, i, j, k ):
  ret = Vector3D()
  ret.x = basepos.x + grid.x * i
  ret.y = basepos.y + grid.y * j
  ret.z = basepos.z + grid.z * k
  return ret

def calcMinMaxPos( orbitals ):
  minpos = None
  maxpos = None
  for orbital in orbitals:
    if not minpos:
      minpos = orbital.getPosition()
      maxpos = orbital.getPosition()
    else:
      pos = orbital.getPosition()
      minpos = getMinVector3D( minpos, pos )
      maxpos = getMaxVector3D( maxpos, pos )
  return minpos, maxpos

def createVoxelData( lattice, i, j, k ):
  ret = []
  ret.append( lattice[i][j][k] )
  ret.append( lattice[i+1][j][k] )
  ret.append( lattice[i][j][k+1] )
  ret.append( lattice[i+1][j][k+1] )
  ret.append( lattice[i][j+1][k] )
  ret.append( lattice[i+1][j+1][k] )
  ret.append( lattice[i][j+1][k+1] )
  ret.append( lattice[i+1][j+1][k+1] )
  return ret

def usage():
  print >> sys.stderr, "Generate Isosurface of Electron Density"
  print >> sys.stderr, "usage: mo2xml.py -t [(string) filetype]"
  print >> sys.stderr, "                 -i [(integer) MO index]"
  print >> sys.stderr, "                 -v [(float) isosurface criterion]"
  print >> sys.stderr, "                 -a [(float) alpha buffer of MO]"
  print >> sys.stderr, "                 -c [(string) color of MO with positive value]"
  print >> sys.stderr, "                 -C [(string) color of MO with negative value]"
  print >> sys.stderr, "                 -g [(float) grid size in bohr]"
  print >> sys.stderr, "                 -m [(float) margin/cutoff]"
  print >> sys.stderr, "                 [data file]"
  print >> sys.stderr, ""
  print >> sys.stderr, "avilable format names are: gamess, macmolplt"
  print >> sys.stderr, "default values are:"
  print >> sys.stderr, "  filetype: \"gamess\", MO index: 1, isosurface value: 0.3"
  print >> sys.stderr, "  color(positive): red, color(negative): blue"
  print >> sys.stderr, "  alpha: 0.6 (valid range is 0.0(transparent)-1.0(opaque))"
  print >> sys.stderr, "  grid: 0.5 (bohr), margin: 3.0 (bohr)"
  print >> sys.stderr, "NOTE: format names are NOT case sensitive."

if __name__ == "__main__":
  # data file format
  datatype = "gamess"
  mo_index = 1
  isosurface = 0.04
  grid = Vector3D( 0.5, 0.5, 0.5 )
  margin = 3.0
  col_positive = "red"
  col_negative = "blue"
  alpha = "0.6"
  if len( sys.argv ) <= 1:
    usage()
    sys.exit(0)

  try:
    opts, args = getopt.getopt( sys.argv[1:], "a:c:g:C:i:m:t:v:" )
  except getopt.GetoptError:
    print >> sys.stderr, "Error: failed to parse options."
    print >> sys.stderr, "Error: check your command line arguments."
    usage()
    sys.exit(1)

  for o, a in opts:
    if o in ( "-t" ):
      datatype = str( a )
    elif o in ( "-a" ):
      alpha = a
    elif o in ( "-c" ):
      col_positive = a
    elif o in ( "-C" ):
      col_negative = a
    elif o in ( "-g" ):
      grid = Vector3D( float( a ), float( a ), float( a ) )
    elif o in ( "-i" ):
      mo_index = int( a )
    elif o in ( "-m" ):
      margin = float( a )
    elif o in ( "-v" ):
      isosurface = float( a )

  if len( args ) == 0:
    print >> sys.stderr, "You should specify filename."
    usage()
    sys.exit(1)

  # read data file
  orbs = []
  eigenvecs = []
  if datatype.lower() == "gamess":
    calculateGridData = True
    mydata = GamessOut()
    orbs, eigenvecs = mydata.read( args[0], mo_index )
  elif datatype.lower() == "macmolplt":
    calculateGridData = False
    mydata = MacMolPltOut( args[0] )
    val_container = mydata.getGridData()
    gridnum = mydata.getGridNum()
    grid = mydata.getGridSize()
    minpos = mydata.getCellOrigin()
  else:
    # not implemented
    print >> sys.stderr, "Error: requested file type,", datatype.lowercase, "is not implemented."
    usage()
    sys.exit(4)

  if datatype.lower() == "gamess":
    print >> sys.stderr, "Info: # of atoms & eigenvectors are " + str(len(orbs)) + ", " + str(len(eigenvecs))
    ## debug: show orbitals
    #for orb in orbs:
    #  print orb
    print >> sys.stderr, "Info: check orbital and eigenvec nums"
    no = 0
    for orb in orbs:
      no += orb.size()
    if no != len(eigenvecs):
      print >> sys.stderr, "Error: numbers of orbitals and eigenvecs mistamtch."
      print >> sys.stderr, "Error: oribtal =", no, "eigenvectors =", len(eigenvecs)
      sys.exit(8)

  if calculateGridData:
    print >> sys.stderr, "Info: create grid data"
    # create grid data (voxel data)
    minpos, maxpos = calcMinMaxPos( orbs )
    ###print minpos, maxpos
    maxpos = maxpos + Vector3D( margin, margin, margin )
    minpos = minpos - Vector3D( margin, margin, margin )
    margin = Vector3D( int( margin / grid.x ), int( margin / grid.y ), int( margin / grid.z ) )
    # grid is homogeneous for all the directions
    gridnum = maxpos - minpos
    mlength = Vector3D( 2 * margin.x + 1, 2 * margin.y + 1, 2 * margin.z + 1 )
    gridnum.x = int( gridnum.x / grid.x ) + mlength.x
    gridnum.y = int( gridnum.y / grid.y ) + mlength.y
    gridnum.z = int( gridnum.z / grid.z ) + mlength.z
    # value container using numpy array
    val_container = np.array( [], dtype=float )
    val_container.resize( ( gridnum.x, gridnum.y, gridnum.z ) )
    val_container.fill( 0.0 )
    counter = 0
    for orb in orbs:
      pos = orb.getPosition()
      g = pos - minpos
      g.x = int( g.x / grid.x )
      g.y = int( g.y / grid.y )
      g.z = int( g.z / grid.z )
      mypos = orb.getPosition()
      ###print g, gridnum, margin
      for o in orb.getOrbitals():
        eigenvec = float( eigenvecs[counter] )
        for i in range( g.x - margin.x, g.x + margin.x + 1 ):
          for j in range( g.y - margin.y, g.y + margin.y + 1 ):
            for k in range( g.z - margin.z, g.z + margin.z + 1 ):
              ###print i, j, k
              tpos = calcGridPos( minpos, grid, i, j, k )
              val_container[i][j][k] += eigenvec * o.getValueAt( mypos, tpos )
        counter += 1

  print >> sys.stderr, "Info: prepare contour data"
  # create base data for voxel
  ad = np.array( [], dtype=int )
  ad.resize( ( gridnum.x, gridnum.y, gridnum.z ) )
  ad.fill( 0 )
  for i in range( 0, gridnum.x ):
    for j in range( 0, gridnum.y ):
      for k in range( 0, gridnum.z ):
        if val_container[i][j][k] > isosurface:
          ad[i][j][k] = 2
        elif val_container[i][j][k] < -isosurface:
          ad[i][j][k] = 1
        else:
          ad[i][j][k] = 0

  # val_container is no longer useful; discard it here?

  print >> sys.stderr, "Info: write isosurface"
  # create voxels
  bohr2angstrom = 0.529
  grid_ang = grid * bohr2angstrom
  flags = [ True, False ]
  for f in flags:
    if f :
      print "    <SHAPE color=\"" + col_positive + "\" alpha=\"" + alpha + "\" spec=\"0.0\" diff=\"0.1\" am=\"0.6\" flat=\"1\">"
    else:
      print "    <SHAPE color=\"" + col_negative + "\" alpha=\"" + alpha + "\" spec=\"0.0\" diff=\"0.1\" am=\"0.6\" flat=\"1\">"
    for i in range( 0, gridnum.x - 1 ):
      for j in range( 0, gridnum.y - 1 ):
        for k in range( 0, gridnum.z - 1 ):
          voxel = Voxel()
          voxel.setValueFromList( createVoxelData( ad, i, j, k ) )
          tpos = calcGridPos( minpos, grid, i, j, k ) * bohr2angstrom
          voxel.setOrigin( tpos )
          voxel.setGridLength( Vector3D( grid_ang.x, grid_ang.y, grid_ang.z ) )
          voxel.generateXml( sys.stdout, f )
    print "    </SHAPE>"
