/*
 * This software is distributed under following license based on modified BSD
 * style license.
 * ----------------------------------------------------------------------
 * 
 * Copyright 2009 The Nimbus2 Project. All rights reserved.
 * 
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions are met:
 * 
 * 1. Redistributions of source code must retain the above copyright notice,
 *    this list of conditions and the following disclaimer. 
 * 2. Redistributions in binary form must reproduce the above copyright notice,
 *    this list of conditions and the following disclaimer in the documentation
 *    and/or other materials provided with the distribution.
 * 
 * THIS SOFTWARE IS PROVIDED BY THE NIMBUS PROJECT ``AS IS'' AND ANY EXPRESS
 * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN
 * NO EVENT SHALL THE NIMBUS PROJECT OR CONTRIBUTORS BE LIABLE FOR ANY
 * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
 * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
 * 
 * The views and conclusions contained in the software and documentation are
 * those of the authors and should not be interpreted as representing official
 * policies, either expressed or implied, of the Nimbus2 Project.
 */
package jp.ossc.nimbus.net;

import java.util.*;
import java.io.File;
import java.io.InputStream;
import java.io.FileInputStream;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.UndeclaredThrowableException;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.security.KeyStore;
import java.security.SecureRandom;
import java.security.Principal;
import java.security.PrivateKey;
import javax.net.ssl.SSLServerSocket;
import javax.net.ssl.SSLContext;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.TrustManager;
import javax.net.ssl.TrustManagerFactory;
import javax.net.ssl.X509KeyManager;
import java.security.cert.X509Certificate;

import jp.ossc.nimbus.beans.*;

/**
 * SSLT[o\Pbgt@NgB<p>
 * java.security.KeyStorěƏؖgSSLʐMsSSLServerSocket𐶐t@NgB<br>
 * ̃t@Ng琶T[o\PbǵA{@link #setServerSocketProperty(String, Object)}ŁA\ߐݒ肳ꂽvpeBݒ肳B<br>
 *
 * @author M.Takata
 */
public class SSLServerSocketFactory extends javax.net.ssl.SSLServerSocketFactory{
    
    /**
     * gpZLA\PbgvgR̃ftHglB<p>
     */
    public static final String DEFAULT_PROTOCOL = "TLS";
    
    /**
     * L[XgA`̃ftHglB<p>
     */
    public static final String DEFAULT_KEYSTORE_TYPE = "JKS";
    
    /**
     * javax.net.ssl.KeyManagerFactoryɎw肷ASỸftHglB<p>
     */
    public static final String DEFAULT_ALGORITHM = "SunX509";
    
    protected javax.net.ssl.SSLServerSocketFactory serverSocketFactory;
    
    protected Map<Property, Object> serverSocketProperties;
    protected Map<String, Object> socketProperties;
    
    protected String protocol = DEFAULT_PROTOCOL;
    
    protected String keyAlias;
    protected String keyStoreType = DEFAULT_KEYSTORE_TYPE;
    protected String keyStoreAlgorithm = DEFAULT_ALGORITHM;
    protected String keyStoreFile = System.getProperty("user.home") + "/.keystore";
    protected String keyStorePassword = "changeit";
    protected String keyPassword = "";
    
    protected String trustKeyStoreType = DEFAULT_KEYSTORE_TYPE;
    protected String trustKeyStoreAlgorithm = DEFAULT_ALGORITHM;
    protected String trustKeyStoreFile;
    protected String trustKeyStorePassword;
    
    protected boolean initialized = false;
    
    /**
     * gpZLA\PbgvgRݒ肷B<p>
     * ftHǵA{@link #DEFAULT_PROTOCOL}B<br>
     *
     * @param protocol ZLA\PbgvgR
     */
    public void setProtocol(String protocol){
        this.protocol = protocol;
    }
    
    /**
     * gpZLA\PbgvgR擾B<p>
     *
     * @return ZLA\PbgvgR
     */
    public String getProtocol(){
        return protocol;
    }
    
    /**
     * L[XgA`ݒ肷B<p>
     * ftHǵA{@link #DEFAULT_KEYSTORE_TYPE}B<br>
     *
     * @param storeType L[XgA`
     */
    public void setKeyStoreType(String storeType){
        keyStoreType = storeType;
    }
    
    /**
     * L[XgA`擾B<p>
     *
     * @return L[XgA`
     */
    public String getKeyStoreType(){
        return keyStoreType;
    }
    
    /**
     * javax.net.ssl.KeyManagerFactoryɎw肷ASYݒ肷B<p>
     * ftHǵA{@link #DEFAULT_ALGORITHM}B<br>
     *
     * @param algorithm ASY
     */
    public void setKeyStoreAlgorithm(String algorithm){
        keyStoreAlgorithm = algorithm;
    }
    
    /**
     * javax.net.ssl.KeyManagerFactoryɎw肷ASY擾B<p>
     *
     * @return ASY
     */
    public String getKeyStoreAlgorithm(){
        return keyStoreAlgorithm;
    }
    
    /**
     * L[XgAt@C̃pXݒ肷B<p>
     * ftHǵA[Uz[fBNg.keystoreB<br>
     *
     * @param path L[XgAt@C̃pX
     */
    public void setKeyStoreFile(String path){
        keyStoreFile = path;
    }
    
    /**
     * L[XgAt@C̃pX擾B<p>
     *
     * @return L[XgAt@C̃pX
     */
    public String getKeyStoreFile(){
        return keyStoreFile;
    }
    
    /**
     * L[XgÃpX[hݒ肷B<p>
     * ftHǵAchangeitB<br>
     *
     * @param password L[XgÃpX[h
     */
    public void setKeyStorePassword(String password){
        keyStorePassword = password;
    }
    
    /**
     * L[XgÃpX[h擾B<p>
     *
     * @return L[XgÃpX[h
     */
    public String getKeyStorePassword(){
        return keyStorePassword;
    }
    
    /**
     * T[o[̃ZLA\PbgF؂Ƃ̔閧̕ʖݒ肷B<p>
     * ̕ʖw肵Ȃꍇ́AJ̃^CvуsAɂĔFؖsǂ̃XgɊÂāA閧IB<br>
     *
     * @param alias 閧̕ʖ
     */
    public void setKeyAlias(String alias){
        this.keyAlias = alias;
    }
    
    /**
     * T[o[̃ZLA\PbgF؂Ƃ̔閧̕ʖ擾B<p>
     *
     * @return 閧̕ʖ
     */
    public String getKeyAlias(){
        return keyAlias;
    }
    
    /**
     * 閧L[XgAǂݏoɎgpA閧̃pX[hݒ肷B<p>
     *
     * @param password 閧̃pX[h
     */
    public void setKeyPassword(String password){
        keyPassword = password;
    }
    
    /**
     * 閧L[XgAǂݏoɎgpA閧̃pX[h擾B<p>
     *
     * @return 閧̃pX[h
     */
    public String getKeyPassword(){
        return keyPassword;
    }
    
    /**
     * ؖsǂƊ֘AMf[^̃\[XƂȂL[XgǍ`ݒ肷B<p>
     * ftHǵA{@link #DEFAULT_KEYSTORE_TYPE}B<br>
     *
     * @param storeType L[XgA`
     */
    public void setTrustKeyStoreType(String storeType){
        trustKeyStoreType = storeType;
    }
    
    /**
     * ؖsǂƊ֘AMf[^̃\[XƂȂL[XgǍ`擾B<p>
     *
     * @return L[XgA`
     */
    public String getTrustKeyStoreType(){
        return trustKeyStoreType;
    }
    
    /**
     * javax.net.ssl.TrustManagerFactoryɎw肷ASYݒ肷B<p>
     * ftHǵA{@link #DEFAULT_ALGORITHM}B<br>
     *
     * @param algorithm ASY
     */
    public void setTrustKeyStoreAlgorithm(String algorithm){
        trustKeyStoreAlgorithm = algorithm;
    }
    
    /**
     * javax.net.ssl.TrustManagerFactoryɎw肷ASY擾B<p>
     *
     * @return ASY
     */
    public String getTrustKeyStoreAlgorithm(){
        return trustKeyStoreAlgorithm;
    }
    
    /**
     * ؖsǂƊ֘AMf[^̃\[XƂȂL[XgAt@C̃pXݒ肷B<p>
     * ftHǵAVXevpeB"javax.net.ssl.trustStore"B<br>
     *
     * @param path L[XgAt@C̃pX
     */
    public void setTrustKeyStoreFile(String path){
        trustKeyStoreFile = path;
    }
    
    /**
     * ؖsǂƊ֘AMf[^̃\[XƂȂL[XgAt@C̃pX擾B<p>
     *
     * @return L[XgAt@C̃pX
     */
    public String getTrustKeyStoreFile(){
        return trustKeyStoreFile;
    }
    
    /**
     * ؖsǂƊ֘AMf[^̃\[XƂȂL[XgÃpX[hݒ肷B<p>
     * ftHǵAVXevpeB"javax.net.ssl.trustStorePassword"B<br>
     *
     * @param password L[XgÃpX[h
     */
    public void setTrustKeyStorePassword(String password){
        trustKeyStorePassword = password;
    }
    
    /**
     * ؖsǂƊ֘AMf[^̃\[XƂȂL[XgÃpX[h擾B<p>
     *
     * @return L[XgÃpX[h
     */
    public String getTrustKeyStorePassword(){
        return trustKeyStorePassword;
    }
    
    /**
     * {@link SSLServerSocket}ɃvpeBݒ肷B<p>
     *
     * @param props vpeB}bv
     */
    public void setServerSocketProperties(Map<String, Object> props){
        if(props == null || props.size() == 0){
            if(serverSocketProperties != null){
                serverSocketProperties = null;
            }
            return;
        }
        for(Map.Entry<String, Object> entry : props.entrySet()){
            setServerSocketProperty(entry.getKey(), entry.getValue());
        }
    }
    
    /**
     * {@link SSLServerSocket}ɃvpeBݒ肷B<p>
     *
     * @param name vpeB
     * @param value l
     */
    public void setServerSocketProperty(String name, Object value){
        if(serverSocketProperties == null){
            serverSocketProperties = new LinkedHashMap<Property, Object>();
        }
        final Property prop = PropertyFactory.createProperty(name);
        serverSocketProperties.put(prop, value);
    }
    
    /**
     * {@link SSLServerSocket}̃vpeB擾B<p>
     *
     * @param name vpeB
     * @return l
     */
    public Object getServerSocketProperty(String name){
        if(serverSocketProperties == null){
            return null;
        }
        for(Map.Entry<Property, Object> entry : serverSocketProperties.entrySet()){
            if(entry.getKey().getPropertyName().equals(name)){
                return entry.getValue();
            }
        }
        return null;
    }
    
    protected synchronized void init() throws IOException{
        if(initialized){
            return;
        }
        try{
            SSLContext context = SSLContext.getInstance(protocol); 
            context.init(
                getKeyManagers(),
                getTrustManagers(),
                new SecureRandom()
            );
            serverSocketFactory = context.getServerSocketFactory();
        }catch(RuntimeException e){
            throw e;
        }catch(Exception e){
            if(e instanceof IOException){
                throw (IOException)e;
            }
            e.printStackTrace();
            throw new IOException(e.toString());
        }
        initialized = true;
    }
    
    protected KeyManager[] getKeyManagers() throws Exception {
        
        KeyManager[] keyManager = null;
        
        KeyStore store = getKeyStore();
        
        if(keyAlias != null && !store.isKeyEntry(keyAlias)) {
            throw new IOException("KeyAlias is not entried. " + keyAlias);
        }
        
        KeyManagerFactory keyManagerFactory = KeyManagerFactory.getInstance(keyStoreAlgorithm);
        keyManagerFactory.init(store, keyPassword.toCharArray());
        
        keyManager = keyManagerFactory.getKeyManagers();
        if(keyAlias != null){
            if(DEFAULT_KEYSTORE_TYPE.equals(keyStoreType)) {
                keyAlias = keyAlias.toLowerCase();
            }
            for(int i = 0; i < keyManager.length; i++) {
                keyManager[i] = new X509KeyManagerWrapper((X509KeyManager)keyManager[i], keyAlias);
            }
        }
        
        return keyManager;
    }
    
    protected TrustManager[] getTrustManagers() throws Exception{
        TrustManager[] trustManager = null;
        
        KeyStore trustStore = getTrustStore();
        if(trustStore != null) {
            TrustManagerFactory trustManagerFactory = TrustManagerFactory.getInstance(trustKeyStoreAlgorithm);
            trustManagerFactory.init(trustStore);
            trustManager = trustManagerFactory.getTrustManagers();
        }
        
        return trustManager;
    }
    
    protected KeyStore getKeyStore() throws IOException{
        return getStore(keyStoreType, keyStoreFile, keyStorePassword);
    }
    
    protected KeyStore getTrustStore() throws IOException{
        KeyStore trustStore = null;
        
        if(trustKeyStoreFile == null){
            trustKeyStoreFile = System.getProperty("javax.net.ssl.trustStore");
        }
        if(trustKeyStorePassword == null){
            trustKeyStorePassword = System.getProperty("javax.net.ssl.trustStorePassword");
        }
        if(trustKeyStorePassword == null){
            trustKeyStorePassword = keyStorePassword;
        }
        if(trustKeyStoreFile != null && trustKeyStorePassword != null){
            trustStore = getStore(
                trustKeyStoreType,
                trustKeyStoreFile,
                trustKeyStorePassword
            );
        }
        return trustStore;
    }
    
    private KeyStore getStore(
        String type,
        String path,
        String password
    ) throws IOException{
        
        KeyStore keyStore = null;
        InputStream is = null;
        try{
            keyStore = KeyStore.getInstance(type);
            File keyStoreFile = new File(path);
            is = new FileInputStream(keyStoreFile);
            
            keyStore.load(is, password.toCharArray());
            is.close();
            is = null;
        }catch(IOException e){
            throw e;
        }catch(Exception e){
            throw new IOException(
                "Exception trying to load keystore " + path
                    + " : " + e.toString()
            );
        }finally{
            if(is != null){
                try{
                    is.close();
                }catch(IOException e){}
            }
        }
        return keyStore;
    }
    
    @Override
    public ServerSocket createServerSocket() throws IOException{
        if(!initialized){
            init();
        }
        return applyServerSocketProperties(
            new SSLServerSocketWrapper(
                (SSLServerSocket)serverSocketFactory.createServerSocket()
            )
        );
    }
    
    @Override
    public ServerSocket createServerSocket(int port) throws IOException{
        if(!initialized){
            init();
        }
        return applyServerSocketProperties(
            new SSLServerSocketWrapper(
                (SSLServerSocket)serverSocketFactory.createServerSocket(port)
            )
        );
    }
    
    @Override
    public ServerSocket createServerSocket(int port, int backlog) throws IOException{
        if(!initialized){
            init();
        }
        return applyServerSocketProperties(
            new SSLServerSocketWrapper(
                (SSLServerSocket)serverSocketFactory.createServerSocket(port, backlog)
            )
        );
    }
    
    @Override
    public ServerSocket createServerSocket(int port, int backlog, InetAddress bindAddr) throws IOException{
        if(!initialized){
            init();
        }
        return applyServerSocketProperties(
            new SSLServerSocketWrapper(
                (SSLServerSocket)serverSocketFactory.createServerSocket(port, backlog, bindAddr)
            )
        );
    }
    
    @Override
    public String[] getDefaultCipherSuites(){
        if(!initialized){
            try{
                init();
            }catch(IOException e){
                return new String[0];
            }
        }
        return serverSocketFactory.getDefaultCipherSuites();
    }
    
    @Override
    public String[] getSupportedCipherSuites(){
        if(!initialized){
            try{
                init();
            }catch(IOException e){
                return new String[0];
            }
        }
        return serverSocketFactory.getSupportedCipherSuites();
    }
    
    protected ServerSocket applyServerSocketProperties(
        SSLServerSocketWrapper serverSocket
    ) throws IOException{
        try{
            if(socketProperties != null && socketProperties.size() != 0){
                for(Map.Entry<String, Object> entry : socketProperties.entrySet()){
                    serverSocket.setSocketProperty(
                        entry.getKey(),
                        entry.getValue()
                    );
                }
            }
            if(serverSocketProperties != null && serverSocketProperties.size() != 0){
                for(Map.Entry<Property, Object> entry : serverSocketProperties.entrySet()){
                    entry.getKey().setProperty(serverSocket, entry.getValue());
                }
            }
        }catch(InvocationTargetException e){
            Throwable target = e.getTargetException();
            if(target instanceof IOException){
                throw (IOException)target;
            }else if(target instanceof RuntimeException){
                throw (RuntimeException)target;
            }else if(target instanceof Error){
                throw (Error)target;
            }else{
                throw new UndeclaredThrowableException(target);
            }
        }catch(NoSuchPropertyException e){
            throw new UndeclaredThrowableException(e);
        }
        return serverSocket;
    }
    
    private static class X509KeyManagerWrapper implements X509KeyManager{
        
        private X509KeyManager keyManager;
        private String serverKeyAlias;
        
        public X509KeyManagerWrapper(X509KeyManager mgr, String serverKeyAlias){
            keyManager = mgr;
            this.serverKeyAlias = serverKeyAlias;
        }
        
        @Override
        public String chooseClientAlias(
            String[] keyType,
            Principal[] issuers,
            Socket socket
        ){
            return keyManager.chooseClientAlias(keyType, issuers, socket);
        }
        
        @Override
        public String chooseServerAlias(
            String keyType,
            Principal[] issuers,
            Socket socket
        ){
            return serverKeyAlias;
        }
        
        @Override
        public X509Certificate[] getCertificateChain(String alias){
            return keyManager.getCertificateChain(alias);
        }
        
        @Override
        public String[] getClientAliases(String keyType, Principal[] issuers){
            return keyManager.getClientAliases(keyType, issuers);
        }
        
        @Override
        public String[] getServerAliases(String keyType, Principal[] issuers){
            return keyManager.getServerAliases(keyType, issuers);
        }
        
        @Override
        public PrivateKey getPrivateKey(String alias) {
            return keyManager.getPrivateKey(alias);
        }
    }
}