package org.maachang.comet.net.ssl ;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.SocketChannel;

import javax.net.ssl.SSLEngine;
import javax.net.ssl.SSLEngineResult;
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
import javax.net.ssl.SSLEngineResult.Status;

import org.maachang.comet.net.nio.NioElement;
import org.maachang.comet.net.nio.ReceiveBuffer;

/**
 * SSLソケット要素.
 *  
 * @version 2008/06/01
 * @author  masahito suzuki
 * @since   MaachangComet 1.1D
 */
public class SslElement {
    
    /**
     * NioElement.
     */
    private NioElement element = null ;
    
    /**
     * SSLEngine.
     */
    private SSLEngine sslEngine = null ;
    
    /**
     * デフォルトバッファサイズ.
     */
    private int defaultSize = -1 ;
    
    /**
     * クライアント受け取りバッファ.
     */
    private ByteBuffer netInBuffer = null ;
    
    /**
     * クライアント出力バッファ.
     */
    private ByteBuffer netOutBuffer = null ;
    
    /**
     * 受信バッファ.
     */
    private ByteBuffer recvBuffer = null ;
    
    /**
     * ハンドシェイクステータス.
     */
    private HandshakeStatus handShakeState = null ;
    
    /**
     * ハンドシェイク終了フラグ.
     */
    private volatile boolean isHandShakeFinish = false ;
    
    /**
     * クローズフラグ.
     */
    private volatile boolean isClose = false ;
    
    /**
     * ハンドシェイクデータ送信完了フラグ.
     */
    private volatile boolean isExitHandShakeSend = false ;
    
    /**
     * コンストラクタ.
     */
    private SslElement() {
        
    }
    
    /**
     * コンストラクタ.
     * @param element 対象のNioElementを設定します.
     * @param sslEngine 対象のSSLエンジンを設定します.
     * @exception Exception 例外.
     */
    public SslElement( NioElement element,SSLEngine sslEngine ) throws Exception {
        if( sslEngine  == null ) {
            throw new IllegalArgumentException( "引数は不正です" ) ;
        }
        int recvBufferSize = sslEngine.getSession().getApplicationBufferSize() ;
        this.defaultSize = ( int )( sslEngine.getSession().getPacketBufferSize() * 1.5f ) ;
        this.recvBuffer = ByteBuffer.allocate( recvBufferSize ) ;
        this.netInBuffer = ByteBuffer.allocate( defaultSize ) ;
        this.netOutBuffer = ByteBuffer.allocate( defaultSize ) ;
        this.netOutBuffer.limit( 0 ) ;
        this.sslEngine = sslEngine ;
        this.element = element ;
        this.isHandShakeFinish = false ;
        this.isExitHandShakeSend = false ;
        this.isClose = false ;
        this.sslEngine.beginHandshake();
        
        String[] supportCiphers = sslEngine.getSupportedCipherSuites() ;
        if( supportCiphers != null ) {
            sslEngine.setEnabledCipherSuites( supportCiphers ) ;
        }
    }
    
    /**
     * デストラクタ.
     */
    protected void finalize() throws Exception {
        destroy() ;
    }
    
    /**
     * オブジェクト破棄.
     */
    public synchronized void destroy() {
        if( sslEngine != null ) {
            try {
                sslEngine.closeOutbound() ;
            } catch( Exception e ) {
            }
        }
        netInBuffer = null ;
        netOutBuffer = null ;
        sslEngine = null ;
        element = null ;
        recvBuffer = null ;
        isClose = true ;
    }
    
    /**
     * オブジェクトクローズ.
     */
    public synchronized void close()
        throws Exception {
        if( element != null && element.isUse() ) {
            if( sslEngine != null ) {
                element.restartWrite() ;
                sslEngine.closeOutbound() ;
                SSLEngineResult res = SslUtil.wrap( sslEngine,SslUtil.EMPTY_BUFFER,netOutBuffer ) ;
                if( res.getStatus() == SSLEngineResult.Status.CLOSED ) {
                    writeQueue( netOutBuffer ) ;
                }
            }
        }
    }
    
    /**
     * 既にクローズ状態であるかチェック.
     * @return boolean [true]の場合、クローズしています.
     */
    public synchronized boolean isClosed() {
        return isClose ;
    }
    
    /**
     * HandShake終了チェック.
     * @return boolean [true]の場合、HandShakeが終了しています.
     */
    public synchronized boolean isHandShake() {
        return isHandShakeFinish ;
    }
    
    /**
     * ハンドシェイク送信完了の場合の呼び出し.
     */
    public synchronized void sendEndHandShake() {
        if( isHandShakeFinish == true ) {
            isExitHandShakeSend = true ;
            netInBuffer = ByteBuffer.allocate( defaultSize ) ;
            netOutBuffer = ByteBuffer.allocate( defaultSize ) ;
            netOutBuffer.limit( 0 ) ;
        }
    }
    
    /**
     * ハンドシェイク送信完了チェック.
     * @return boolean [true]の場合、ハンドシェイク処理は終了しています.
     */
    public synchronized boolean isSendEndHandShake() {
        return isExitHandShakeSend ;
    }
    
    /**
     * 受信バッファを取得.
     * @return ByteBuffer 受信バッファを取得します.
     */
    public synchronized ByteBuffer getRecvBuffer() {
        return recvBuffer ;
    }
    
    /**
     * HandShake処理.
     * @param recvBuf 受信されたByteBufferを設定します.
     * @exception Exception 例外.
     */
    public synchronized void handShake( ByteBuffer recvBuf ) throws Exception {
        if( isClose == true ) {
            throw new IOException( "オブジェクトは既にクローズしています" ) ;
        }
        // 最初の処理の場合、ハンドシェイクステータスは、unwrapにセットする.
        if( handShakeState == null ) {
            //handShakeState = HandshakeStatus.NEED_UNWRAP ;
            sslEngine.beginHandshake() ;
            this.handShakeState = sslEngine.getHandshakeStatus() ;
        }
        
        // 読み込みデータ内容が存在しない場合は処理しない.
        SSLEngineResult result = null ;
        if( isUseReadData( recvBuf ) ) {
            
            if( isHandShakeFinish == true ) {
                throw new IOException( "ハンドシェイク処理は完了しています" ) ;
            }
            
            // 読み込み処理.
            result = SslUtil.unwrap( sslEngine,recvBuf,netInBuffer ) ;
            this.handShakeState = result.getHandshakeStatus() ;
            // unwrapが正常終了するまでループ.
            while(true) {
                // ステータスOKの場合.
                if( result.getStatus() == Status.OK ) {
                    break ;
                }
                // バッファオーバフローの場合は、Inputバッファサイズを増やして
                // 受信処理を行う.
                else if( result.getStatus() == Status.BUFFER_OVERFLOW ) {
                    netInBuffer = SslUtil.reallocate( netInBuffer ) ;
                    result = SslUtil.unwrap( sslEngine,recvBuf,netInBuffer ) ;
                    this.handShakeState = result.getHandshakeStatus() ;
                }
                // その他.
                else {
                    throw new IOException( "不明なステータス[" +
                        this.handShakeState + "/"+ result.getStatus() + "]が返されました" ) ;
                }
            }
        }
        else {
            return ;
        }
        
        // ハンドシェイクイベント処理.
        boolean roopEnd = true ;
        while( roopEnd ) {
            switch( this.handShakeState ) {
                
                // wrap.
                case NEED_WRAP :
                    // 書き込み処理.
                    result = SslUtil.wrap( sslEngine,SslUtil.EMPTY_BUFFER,netOutBuffer ) ;
                    this.handShakeState = result.getHandshakeStatus() ;
                    // ステータスがOKの場合.
                    if( result.getStatus() == Status.OK ) {
                        writeQueue( netOutBuffer ) ;
                        // 書き込み後の読み込み条件の場合は、一端ハンドシェイクループを抜ける.
                        if( this.handShakeState == HandshakeStatus.NEED_UNWRAP ) {
                            roopEnd = false ;
                            break ;
                        }
                        break ;
                    }
                    throw new IOException( "不明なステータス[" +
                        this.handShakeState + "/"+ result.getStatus() + "]が返されました" ) ;
                
                // unwrap.
                case NEED_UNWRAP :
                    // 続けて読み込む場合.
                    if( isUseReadData( recvBuf ) ) {
                        for( ;; ) {
                            result = SslUtil.unwrap( sslEngine,recvBuf,netInBuffer ) ;
                            this.handShakeState = result.getHandshakeStatus() ;
                            if( result.getStatus() == Status.BUFFER_OVERFLOW ) {
                                netInBuffer = SslUtil.reallocate( netInBuffer ) ;
                                continue ;
                            }
                            break ;
                        }
                    }
                    else if( this.handShakeState == HandshakeStatus.NEED_UNWRAP ) {
                        if( isUseReadData( recvBuf ) == false ) {
                            roopEnd = false ;
                        }
                    }
                    break ;
                    
                // task.
                case NEED_TASK :
                    this.handShakeState = tasks() ;
                    break ;
                
                // finish.
                case FINISHED :
                    isHandShakeFinish = true ;
                    roopEnd = false ;
                    break ;
                
                // その他.
                default :
                    throw new IOException(
                        "不明なステータス[" + this.handShakeState + "]が返されました" ) ;
            }
        }
    }
    
    /**
     * 読み込み可能情報が存在する場合.
     */
    public synchronized boolean isUseReadData( ByteBuffer recvBuf ) {
        boolean ret = ( recvBuf.position() != 0 && recvBuf.hasRemaining() &&
            handShakeState == HandshakeStatus.NEED_UNWRAP ) ;
        if( ret == false ) {
            return ( recvBuf.position() == recvBuf.capacity() ) ;
        }
        return ret ;
    }
    
    /**
     * アプリケーションデータを読み込み.
     * @exception Exception 例外.
     */
    public synchronized boolean read( ByteBuffer recvBuf ) throws Exception {
        if( isClose == true ) {
            throw new IOException( "オブジェクトは既にクローズしています" ) ;
        }
        if( isExitHandShakeSend == false ) {
            throw new IOException( "ハンドシェイク処理が終了していません" ) ;
        }
        ReceiveBuffer emtBuf = element.getBuffer() ;
        SocketChannel channel = element.getChannel() ;
        boolean flg = false ;
        for( ;; ) {
            int len = channel.read( recvBuf ) ;
            if( len <= 0 ) {
                if( len <= -1 ) {
                    try {
                        sslEngine.closeInbound();
                    } catch (IOException ex) {
                    }
                    return false ;
                }
                break ;
            }
            readToSslConvert( emtBuf,recvBuf ) ;
            flg = true ;
        }
        if( flg == true ) {
            element.update() ;
        }
        return flg ;
    }
    
    /**
     * ハンドシェイク完了後に対する受信データ残りが存在する場合に、受信処理を行う.
     */
    public synchronized boolean readTo( ByteBuffer recvBuf ) throws Exception {
        boolean ret = false ;
        if( recvBuf.position() != 0 && recvBuf.hasRemaining() ) {
            ReceiveBuffer emtBuf = element.getBuffer() ;
            while( recvBuf.position() != 0 && recvBuf.hasRemaining() ) {
                int ln = readToSslConvert( emtBuf,recvBuf ) ;
                ret = true ;
                if( ln <= 0 ) {
                    break ;
                }
            }
        }
        return ret ;
    }
    
    /**
     * アプリケーションデータ書き込み.
     * @param buffer 書き込み対象データを設定します.
     * @exception Exception 例外.
     */
    public synchronized void write( ByteBuffer buf ) throws Exception {
        if( isClose == true ) {
            throw new IOException( "オブジェクトは既にクローズしています" ) ;
        }
        if( isExitHandShakeSend == false ) {
            throw new IOException( "ハンドシェイク処理が終了していません" ) ;
        }
        int written = 0;
        SSLEngineResult result = SslUtil.wrap( sslEngine,buf,netOutBuffer ) ;
        written = result.bytesConsumed();
        if (result.getStatus() == Status.OK) {
            if (result.getHandshakeStatus() == HandshakeStatus.NEED_TASK) {
                tasks() ;
            }
        } else {
            throw new IOException("不正なステータス[" + result.getStatus() + "]を検知しました" ) ;
        }
        
        // ソケットにデータを書き込む.
        element.getChannel().write( netOutBuffer ) ;
        
        // バッファ長と比較して、残りデータが存在する場合は、
        // 次の書き込み処理にまわす.
        if( buf.limit() != written ) {
            buf.position( written ) ;
            buf.compact() ;
            element.getWriteBuffer().appendHead( buf ) ;
        }
    }
    
    /**
     * タスク処理.
     */
    private HandshakeStatus tasks() {
        Runnable r = null ;
        while ( ( r = sslEngine.getDelegatedTask() ) != null ) {
            r.run() ;
        }
        return sslEngine.getHandshakeStatus() ;
    }
    
    /**
     * キューに書き込みデータを出力.
     */
    private void writeQueue( ByteBuffer buf ) {
        if( element != null && element.isUse() ) {
            element.getWriteBuffer().append( SslUtil.copy( netOutBuffer ) ) ;
        }
    }
    
    /**
     * SSL処理を解析して受信バッファにセットする.
     */
    private int readToSslConvert( ReceiveBuffer emtBuf,ByteBuffer recvBuf )
        throws Exception {
        SSLEngineResult unwrap ;
        netInBuffer.clear() ;
        int recv = 0 ;
        do {
            // SSLデータ解析.
            unwrap = SslUtil.unwrap( sslEngine,recvBuf,netInBuffer ) ;
            // 処理結果を判別.
            if( unwrap.getStatus() == Status.OK || unwrap.getStatus() == Status.BUFFER_UNDERFLOW ) {
                recv += unwrap.bytesProduced() ;
                // OKの場合はデータを対象スレッド宛に送信.
                if( unwrap.getStatus() == Status.OK ) {
                    // 受信データを書き込む.
                    netInBuffer.flip() ;
                    emtBuf.put( netInBuffer ) ;
                    netInBuffer.clear() ;
                }
                // タスク要求の場合.
                if( unwrap.getHandshakeStatus() == HandshakeStatus.NEED_TASK ) {
                    tasks() ;
                }
                // データ解析失敗.
                if( unwrap.getStatus() == Status.BUFFER_UNDERFLOW ) {
                    break ;
                }
            } else if( ( unwrap.getStatus()==Status.BUFFER_OVERFLOW && recv > 0 ) || unwrap.getStatus()==Status.CLOSED ) {
                break ;
            } else if( unwrap.getStatus()==Status.CLOSED ) {
                break ;
            } else {
                throw new IOException("不正なステータス[" + unwrap.getStatus() + "]を検知しました" ) ;
            }
        } while( ( recv != 0 ) ) ;
        return recv ;
    }
}

