#!/usr/bin/env python
# -*- coding: utf-8 -*-
#
# This file is part of Karesansui.
#
# Copyright (C) 2009-2010 HDE, Inc.
#
# This program is free software; you can redistribute it and/or
# modify it under the terms of the GNU General Public License
# as published by the Free Software Foundation; either version 2
# of the License, or (at your option) any later version.
#

import os
import sys
import traceback

from sqlalchemy import MetaData

from karesansui.db import get_engine, get_metadata, get_session
from karesansui.lib.crypt import sha1encrypt
from karesansui.db.model.user import User
from karesansui.db.model.notebook import Notebook
from karesansui.db.model.tag import Tag
from karesansui.db.model.machine import Machine
from karesansui.lib.const import MACHINE_ATTRIBUTE, MACHINE_HYPERVISOR

import karesansui.db._2pysilhouette

import logging
logger = logging.getLogger("karesansui.initdb")

global logfile
env = os._Environ(os.environ)
logfile = env.get("LOGFILE")

from installer.const import DATABASE_MODULES

if os.path.basename(sys.argv[0]) == "karesansui-checkenv":
    from installer.checkenv import CheckEnvError as DatabaseScriptError
else:
    from installer.install import InstallError   as DatabaseScriptError

def import_database_module(type=None):

    try:
        DATABASE_MODULES[type]
    except:
        raise DatabaseScriptError("Unknown database type %s is specified" % type)

    flag = False
    error_module = []
    for module_name in DATABASE_MODULES[type]:
        try:
            exec("import %s" % module_name)
            flag = True
        except ImportError, e:
            error_module.append(module_name)

    if flag == False:
        raise DatabaseScriptError("ERROR: No module named %s" % " or ".join(error_module))

def karesansui_database_init(opts,drop=True):

    engine = get_engine()
    metadata = get_metadata()

    try:
        if drop is True:
            metadata.drop_all()
        else:
            metadata.tables['machine2jobgroup'].drop()
            metadata.tables['machine2jobgroup'].create()
        metadata.create_all()   
    except Exception, e:
        logger.info('Initializing/Updating a database error - %s' % ''.join(e.args))
        logger.info(traceback.format_exc())
        raise

    if opts.password1 != "":
        session = get_session()

        try:
            (password, salt) = sha1encrypt(u"%s" % opts.password1)
            user = session.query(User).filter(User.email == opts.mailaddr).first()

            if user is None:
                # User Table set.
                new_user  = User(u"%s" % opts.mailaddr,
                              unicode(password),
                              unicode(salt),
                              u"Administrator",
                              u"%s" % opts.lang,
                              )
                session.save(new_user)
                session.commit()
            else:
                user.password  = password
                user.salt      = salt
                user.languages = opts.lang
                session.update(user)
                session.commit()

            # Tag Table set.
            tag = Tag(u"default")
            session.save(tag)
            session.commit()
        
            # Machine Table set.
            user     = session.query(User).filter(User.email == opts.mailaddr).first()
            notebook = Notebook(u"", u"")
            machine  = Machine(user,
                               user,
                               u"%s" % opts.uuid,
                               u"%s" % opts.fqdn,
                               MACHINE_ATTRIBUTE['HOST'],
                               MACHINE_HYPERVISOR['REAL'],
                               notebook,
                               [tag],
                               u"%s" % opts.fqdn,
                               u'icon-guest1.png',
                               False,
                               None,
                               )
            session.save(machine)
            session.commit()

            session.close()
        except:
            logger.info(traceback.format_exc())
            raise

def pysilhouette_database_init(opts,drop=True):
    engine = karesansui.db._2pysilhouette.get_engine()
    metadata = karesansui.db._2pysilhouette.get_metadata()
    try:
        if drop is True:
            metadata.drop_all()
        metadata.create_all()   
    except Exception, e:
        logger.info('Initializing a database error - %s' % ''.join(e.args))
        logger.info(traceback.format_exc())
        raise

def initialize(opts,flag=True,drop=True):

    if flag is True:
        import StringIO
        stdout = sys.stdout
        stderr = sys.stderr
        sys.stdout = StringIO.StringIO()
        sys.stderr = StringIO.StringIO()

    logger.info("init database - karesansui [start]")
    try:
        karesansui_database_init(opts,drop=drop)
    except:
        pass
    logger.info("init database karesansui [end]")

    logger.info("init database pysilhouette [start]")
    try:
        pysilhouette_database_init(opts,drop=drop)
    except:
        pass
    logger.info("init database pysilhouette [end]")

    if flag is True:
        sys.stdout.seek(0)
        sys.stderr.seek(0)
        stdout_result = sys.stdout.read()
        stderr_result = sys.stderr.read()
        sys.stdout = stdout
        sys.stderr = stderr

    try:
        global logfile
        logf = open(logfile, "a")
        logf.write("database.start()\n")
        logf.write("### stdout_result ###: '%s'\n" % (stdout_result))
        logf.write("### stderr_result ###: '%s'\n" % (stderr_result))
        logf.close()
    except:
        pass

def update(opts,flag=True):

    if flag is True:
        import StringIO
        stdout = sys.stdout
        stderr = sys.stderr
        sys.stdout = StringIO.StringIO()
        sys.stderr = StringIO.StringIO()

    logger.info("update database - karesansui [start]")
    try:
        karesansui_database_init(opts,drop=False)
    except:
        pass
    logger.info("update database karesansui [end]")

    if flag is True:
        sys.stdout.seek(0)
        sys.stderr.seek(0)
        stdout_result = sys.stdout.read()
        stderr_result = sys.stderr.read()
        sys.stdout = stdout
        sys.stderr = stderr

    try:
        global logfile
        logf = open(logfile, "a")
        logf.write("database.start()\n")
        logf.write("### stdout_result ###: '%s'\n" % (stdout_result))
        logf.write("### stderr_result ###: '%s'\n" % (stderr_result))
        logf.close()
    except:
        pass

def is_connect(url):
    try:
        from sqlalchemy import create_engine
        from sqlalchemy.exc import OperationalError
        engine = create_engine(url, echo=False, convert_unicode=True)
        connection = engine.connect()
        connection.close()
        return True
    except OperationalError, e:
        raise DatabaseScriptError("%s" % ''.join(e.args))
