# -*- coding: utf-8 -*-
#
#  Copyright (C) 2001, 2002 by Tamito KAJIYAMA
#  Copyright (C) 2002, 2003 by MATSUMURA Namihiko <nie@counterghost.net>
#  Copyright (C) 2002-2012 by Shyouzou Sugitani <shy@users.sourceforge.jp>
#
#  This program is free software; you can redistribute it and/or modify it
#  under the terms of the GNU General Public License (version 2) as
#  published by the Free Software Foundation.  It is distributed in the
#  hope that it will be useful, but WITHOUT ANY WARRANTY; without even the
#  implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
#  PURPOSE.  See the GNU General Public License for more details.
#

import codecs
import logging
import socket
from collections import OrderedDict

import ninix.entry_db
import ninix.script
import ninix.version

from ninix.sstplib import AsynchronousSSTPServer, BaseSSTPRequestHandler


class SSTPServer(AsynchronousSSTPServer):

    def __init__(self, address):
        self.request_parent = lambda *a: None # dummy
        AsynchronousSSTPServer.__init__(self, address, SSTPRequestHandler)
        self.request_handler = None

    def shutdown_request(self, request):
        if self.request_handler is not None:
            # XXX: send_* methods can be called from outside of the handler
            pass # keep alive
        else:
            AsynchronousSSTPServer.shutdown_request(self, request)

    def set_responsible(self, request_method):
        self.request_parent = request_method

    def send_response(self, code, data=None):
        try:
            self.request_handler.send_response(code)
            if data is not None:
                self.request_handler.wfile.write(data)
            self.request_handler.force_finish()
            self.socket.shutdown(socket.SHUT_WR) # XXX
        except IOError:
            pass
        self.request_handler = None

    def send_answer(self, value):
        charset = self.request_handler.headers.get('Charset', 'Shift_JIS')
        answer = b''.join((value.encode(charset, 'ignore'), b'\r\n\r\n'))
        self.send_response(200, answer) # OK

    def send_no_content(self):
        self.send_response(204) # No Content

    def send_sstp_break(self):
        self.send_response(210) # Break

    def send_timeout(self):
        self.send_response(408) # Request Timeout

    def close(self):
        self.socket.close()


class SSTPRequestHandler(BaseSSTPRequestHandler):

    def handle(self):
        if not self.server.request_parent('GET', 'get_sakura_cantalk'):
            self.error = self.version = None
            if not self.parse_request(self.rfile.readline()):
                return
            self.send_error(512)
        else:
            BaseSSTPRequestHandler.handle(self)

    def force_finish(self):
        BaseSSTPRequestHandler.finish(self)

    def finish(self):
        if self.server.request_handler is None:
            BaseSSTPRequestHandler.finish(self)

    # SEND
    def do_SEND_1_0(self):
        self.handle_send(1.0)

    def do_SEND_1_1(self):
        self.handle_send(1.1)

    def do_SEND_1_2(self):
        self.handle_send(1.2)

    def do_SEND_1_3(self):
        self.handle_send(1.3)

    def do_SEND_1_4(self):
        self.handle_send(1.4)

    def handle_send(self, version):
        if not self.check_decoder():
            return
        sender = self.get_sender()
        if sender is None:
            return
        if version == 1.3:
            handle = self.get_handle()
            if handle is None:
                return
        else:
            handle = None
        script_odict = self.get_script_odict()
        if script_odict is None:
            return
        if version in [1.0, 1.1]:
            entry_db = None
        elif version in [1.2, 1.3, 1.4]:
            entry_db = self.get_entry_db()
            if entry_db is None:
                return
        self.enqueue_request(sender, None, handle, script_odict, entry_db)

    # NOTIFY
    def do_NOTIFY_1_0(self):
        self.handle_notify(1.0)

    def do_NOTIFY_1_1(self):
        self.handle_notify(1.1)

    def handle_notify(self, version):
        if not self.check_decoder():
            return
        sender = self.get_sender()
        if sender is None:
            return
        event = self.get_event()
        if event is None:
            return
        if version == 1.0:
            entry_db = None
        elif version == 1.1:
            script_odict = self.get_script_odict()
            if script_odict is None:
                return
            entry_db = self.get_entry_db()
            if entry_db is None:
                return
        self.enqueue_request(sender, event, None, script_odict, entry_db)

    def enqueue_request(self, sender, event, handle, script_odict, entry_db):
        try:
            address = self.client_address[0]
        except:
            address = self.client_address
        if entry_db is None or entry_db.is_empty():
            self.send_response(200) # OK
            show_sstp_marker, use_translator = self.get_options()
            self.server.request_parent(
                'NOTIFY', 'enqueue_request',
                event, script_odict, sender, handle,
                address, show_sstp_marker, use_translator,
                entry_db, None)
        elif self.server.request_handler:
            self.send_response(409) # Conflict
        else:
            show_sstp_marker, use_translator = self.get_options()
            self.server.request_parent(
                'NOTIFY', 'enqueue_request',
                event, script_odict, sender, handle,
                address, show_sstp_marker, use_translator,
                entry_db, self.server)
            self.server.request_handler = self # keep alive

    PROHIBITED_TAGS = [r'\j', r'\-', r'\+', r'\_+', r'\!']

    def check_script(self, script):
        if not self.local_request():
            parser = ninix.script.Parser()
            nodes = []
            while 1:
                try:
                    nodes.extend(parser.parse(script))
                except ninix.script.ParserError as e:
                    done, script = e
                    nodes.extend(done)
                else:
                    break
            for node in nodes:
                if node[0] == ninix.script.SCRIPT_TAG and \
                   node[1] in self.PROHIBITED_TAGS:
                    self.send_response(400) # Bad Request
                    self.log_error('Script: tag {0} not allowed'.format(node[1]))
                    return 1
        return 0

    def get_script(self):
        script = self.headers.get('Script', None)
        if script is None:
            self.send_response(400) # Bad Request
            self.log_error('Script: header field not found')
            return None
        return script

    def get_script_odict(self):
        script_odict = OrderedDict()
        if_ghost = None
        for name, value in self.headers.items():
            if name != 'Script':
                if name == ('IfGhost'):
                    if_ghost = value
                else:
                    if_ghost = None
                continue
            script = str(value)
            if self.check_script(script):
                return
            if if_ghost is None:
                script_odict[''] = script
            else:
                script_odict[if_ghost] = script
            if_ghost = None
        return script_odict

    def get_script_if_ghost(self, current=0):
        default = None
        if_ghost = None
        for header, value in self.headers.items():
            if header == 'IfGhost': 
                if_ghost = value.strip()
                continue
            elif header == 'Script':
                if if_ghost is None:
                    continue
                script = value.strip()
            else:
                if_ghost = None
                continue
            if current: # NOTIFY
                ghost = self.server.request_parent('GET', 'get_ghost_name')
                if ghost == if_ghost:
                    return script, if_ghost
            else: # SEND
                if self.server.request_parent('GET', 'if_ghost', if_ghost):
                    return script, if_ghost
            if default is None:
                default = script, if_ghost
        if default is None:
            script = self.headers.get('Script', None)
            default = script, None
        return default

    def get_entry_db(self):
        entry_db = ninix.entry_db.EntryDatabase()
        for value in self.headers.get_all('Entry', failobj=[]):
            entry = value.split(',', 1)
            if len(entry) != 2:
                self.send_response(400) # Bad Request
                return None
            entry_db.add(entry[0].strip(), entry[1].strip())
        return entry_db

    def get_event(self):
        event = self.headers.get('Event', None)
        if event is None:
            self.send_response(400) # Bad Request
            self.log_error('Event: header field not found')
            return None
        buf = [event]
        for i in range(8):
            value = self.headers.get(''.join(('Reference', str(i))), None)
            buf.append(value)
        return tuple(buf)

    def get_sender(self):
        sender = self.headers.get('Sender', None)
        if sender is None:
            self.send_response(400) # Bad Request
            self.log_error('Sender: header field not found')
            return None
        return sender

    def get_handle(self):
        path = self.headers.get('HWnd', None)
        if path is None:
            self.send_response(400) # Bad Request
            self.log_error('HWnd: header field not found')
            return None
        handle = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
        try:
            handle.connect(path)
        except socket.error:
            handle = None # discard socket object
            logging.error('cannot open Unix socket: {0}'.format(path))
        if handle is None:
            self.send_response(400) # Bad Request
            self.log_error('Invalid HWnd: header field')
            return None
        return handle

    def check_decoder(self):
        charset = str(self.headers.get('Charset', 'Shift_JIS'))
        try:
            codecs.lookup(charset)
        except:
            self.send_response(420, 'Refuse (unsupported charset)')
            self.log_error('Unsupported charset {0}'.format(repr(charset)))
        else:
            return 1
        return 0

    def get_options(self):
        show_sstp_marker = use_translator = 1
        for option in self.headers.get('Option', '').split(','):
            option = option.strip()
            if option == 'nodescript' and self.local_request():
                show_sstp_marker = 0
            elif option == 'notranslate':
                use_translator = 0
        return show_sstp_marker, use_translator

    def local_request(self):
        result = 0
        try:
            path = self.client_address
            result = 1
        except:
            host, port = self.client_address
            result = host ==  '127.0.0.1'
        return result

    # EXECUTE
    def do_EXECUTE_1_0(self):
        self.handle_command()

    def do_EXECUTE_1_2(self):
        self.handle_command()

    def do_EXECUTE_1_3(self):
        if not self.local_request():
            host, port = self.client_address
            self.send_response(420)
            self.log_error(
                'Unauthorized EXECUTE/1.3 request from {0}'.format(host))
            return
        self.handle_command()

    def handle_command(self):
        if not self.check_decoder():
            return
        sender = self.get_sender()
        if sender is None:
            return
        command = self.get_command()
        charset = str(self.headers.get('Charset', 'Shift_JIS'))
        if command is None:
            return
        elif command == 'getname':
            self.send_response(200)
            name = self.server.request_parent('GET', 'get_ghost_name')
            self.wfile.write(b''.join((name.encode(charset, 'ignore'), b'\r\n')))
            self.wfile.write(b'\r\n')
        elif command == 'getversion':
            self.send_response(200)
            self.wfile.write(b''.join((b'ninix-aya ',
                                       ninix.version.VERSION.encode(charset),
                                       b'\r\n')))
            self.wfile.write(b'\r\n')
        elif command == 'quiet':
            self.send_response(200)
            self.server.request_parent('NOTIFY', 'keep_silence', True)
        elif command == 'restore':
            self.send_response(200)
            self.server.request_parent('NOTIFY', 'keep_silence', False)
        elif command == 'getnames':
            self.send_response(200)
            for name in self.server.request_parent('GET', 'get_ghost_names'):
                self.wfile.write(
                    b''.join((name.encode(charset, 'ignore'), b'\r\n')))
            self.wfile.write(b'\r\n')
        elif command == 'checkqueue':
            self.send_response(200)
            count, total = self.server.request_parent(
                'GET', 'check_request_queue', sender)
            self.wfile.write(b''.join((str(count).encode(charset), b'\r\n')))
            self.wfile.write(b''.join((str(total).encode(charset), b'\r\n')))
            self.wfile.write(b'\r\n')
        else:
            self.send_response(501) # Not Implemented
            self.log_error('Not Implemented ({0})'.format(command))

    def get_command(self):
        command = self.headers.get('Command', None)
        if command is None:
            self.send_response(400) # Bad Request
            self.log_error('Command: header field not found')
            return None
        return command.lower()

    def do_COMMUNICATE_1_1(self):
        if not self.check_decoder():
            return
        sender = self.get_sender()
        if sender is None:
            return
        sentence = self.get_sentence()
        if sentence is None:
            return
        self.send_response(200) # OK
        self.server.request_parent(
            'NOTIFY', 'enqueue_event', 'OnCommunicate', sender, sentence)
        return

    def get_sentence(self):
        sentence = self.headers.get('Sentence', None)
        if sentence is None:
            self.send_response(400) # Bad Request
            self.log_error('Sentence: header field not found')
            return None
        return sentence
