/*
 * Copyright 2011 BitMeister Inc.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package jp.bitmeister.asn1.codec.ber;

import java.io.IOException;
import java.io.InputStream;
import java.math.BigInteger;
import java.util.ArrayList;
import java.util.List;

import jp.bitmeister.asn1.codec.ASN1Decoder;
import jp.bitmeister.asn1.exception.ASN1DecodingException;
import jp.bitmeister.asn1.processor.ASN1Visitor;
import jp.bitmeister.asn1.type.ASN1Module;
import jp.bitmeister.asn1.type.ASN1ModuleManager;
import jp.bitmeister.asn1.type.ASN1TagClass;
import jp.bitmeister.asn1.type.ASN1TagMode;
import jp.bitmeister.asn1.type.ASN1TagValue;
import jp.bitmeister.asn1.type.ASN1Type;
import jp.bitmeister.asn1.type.CollectionType;
import jp.bitmeister.asn1.type.Concatenatable;
import jp.bitmeister.asn1.type.ElementSpecification;
import jp.bitmeister.asn1.type.NamedTypeSpecification;
import jp.bitmeister.asn1.type.StringType;
import jp.bitmeister.asn1.type.StructuredType;
import jp.bitmeister.asn1.type.TimeType;
import jp.bitmeister.asn1.type.TypeSpecification;
import jp.bitmeister.asn1.type.UnknownType;
import jp.bitmeister.asn1.type.builtin.ANY;
import jp.bitmeister.asn1.type.builtin.BIT_STRING;
import jp.bitmeister.asn1.type.builtin.BOOLEAN;
import jp.bitmeister.asn1.type.builtin.CHOICE;
import jp.bitmeister.asn1.type.builtin.ENUMERATED;
import jp.bitmeister.asn1.type.builtin.INTEGER;
import jp.bitmeister.asn1.type.builtin.NULL;
import jp.bitmeister.asn1.type.builtin.OBJECT_IDENTIFIER;
import jp.bitmeister.asn1.type.builtin.OCTET_STRING;
import jp.bitmeister.asn1.type.builtin.REAL;
import jp.bitmeister.asn1.type.builtin.RELATIVE_OID;
import jp.bitmeister.asn1.type.builtin.SEQUENCE;
import jp.bitmeister.asn1.type.builtin.SEQUENCE_OF;
import jp.bitmeister.asn1.type.builtin.SET;
import jp.bitmeister.asn1.type.builtin.SET_OF;

/**
 * BER (Basic Encoding Rules) decoder.
 * 
 * <p>
 * {@code BerDecoder} is an implementation of {@code ASN1Decoder}. It reads a
 * number of bytes from an {@code InputStream} that is specified when a decoder
 * is instantiated, and decodes them to an ASN.1 data with Basic Encoding Rules
 * (BER).
 * </p>
 * 
 * @author WATANABE, Jun. <jwat at bitmeister.jp>
 * 
 * @see ASN1Decoder
 * @see DerEncoder
 */
public class BerDecoder implements ASN1Decoder,
		ASN1Visitor<Void, ASN1DecodingException> {
	
	private Class<? extends ASN1Module> module;

	private InputStream in;

	private int tagNumber;

	private ASN1TagClass tagClass;

	private boolean isConstructed;

	private int count;

	/**
	 * Instantiates a {@code BerDecoder}.
	 * 
	 * @param in
	 *            The {@code InputStream} to be read.
	 */
	public BerDecoder(InputStream in) {
		this.in = in;
	}
	
	/**
	 * Instantiates a {@code BerDecoder}.
	 * 
	 * @param module
	 *            The ASN.1 module used for decoding.
	 * @param in
	 *            The {@code InputStream} to be read.
	 */
	public BerDecoder(Class<? extends ASN1Module> module, InputStream in) {
		this(in);
		this.module = module;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see jp.bitmeister.asn1.codec.ASN1Decoder#decode()
	 */
	public ASN1Type decode() throws ASN1DecodingException {
		return decodeImpl(null);
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see jp.bitmeister.asn1.codec.ASN1Decoder#decode(java.lang.Class)
	 */
	public <T extends ASN1Type> T decode(Class<T> type)
			throws ASN1DecodingException {
		T data = ASN1Type.instantiate(type);
		if (module == null) {
			module = data.specification().module();
		}
		return decodeImpl(data);
	}

	/**
	 * Returns how many bytes were read from the {@code InputStream}.
	 * 
	 * @return size of read bytes.
	 */
	public int count() {
		return count;
	}

	/**
	 * Decodes the source data read from the {@code InputStream} to the ASN.1
	 * data.
	 * 
	 * @param data
	 *            The empty ASN.1 data.
	 * @return A decoded ASN.1 data.
	 * @throws ASN1DecodingException
	 *             When an error occurred while the decoding process.
	 */
	@SuppressWarnings("unchecked")
	private <T extends ASN1Type> T decodeImpl(T data)
			throws ASN1DecodingException {
		readTag();
		if (data == null) {
			data = (T) ASN1ModuleManager.instantiate(module, tagClass, tagNumber);
		}
		TypeSpecification specification = data.specification();
		do {
			if (specification.tag() != null) {
				ASN1TagValue tag = specification.tag();
				if (tag.tagClass() != tagClass || tag.tagNumber() != tagNumber) {
					ASN1DecodingException ex = new ASN1DecodingException();
					ex.setMessage("Decoded tag '" + tagClass + " " + tagNumber
							+ "' does not match given type.", null,
							data.getClass(), null, null);
					throw ex;
				}
				if (tag.tagMode() == ASN1TagMode.IMPLICIT) {
					break;
				}
				readLength();
				readTag();
			}
			specification = specification.reference();
		} while (specification != null);
		try {
			data.accept(this);
		} catch (ASN1DecodingException ex) {
			throw ex;
		} catch (Exception e) {
			ASN1DecodingException ex = new ASN1DecodingException();
			ex.setMessage("Exception thrown while decoding process.", e,
					data.getClass(), null, data);
			throw ex;
		}
		data.validate();
		return data;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.builtintype
	 * .BOOLEAN)
	 */
	public Void visit(BOOLEAN data) throws ASN1DecodingException {
		data.set(readStream(readLength())[0] != 0x00);
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .IntegerType)
	 */
	public Void visit(INTEGER data) throws ASN1DecodingException {
		data.set(new BigInteger(readStream(readLength())));
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .EnumeratedType)
	 */
	public Void visit(ENUMERATED data) throws ASN1DecodingException {
		visit((INTEGER) data);
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .RealType)
	 */
	public Void visit(REAL data) throws ASN1DecodingException {
		byte[] octets = readStream(readLength());
		if (octets.length == 0) {
			// no contents octets.
			data.set(0.0);
		} else if ((octets[0] & 0x80) != 0) {
			// binary encoding
			int sign = (octets[0] & 0x40) == 0 ? 1 : -1;
			int base;
			switch (octets[0] & 0x30) {
			case 0x00:
				base = 2;
				break;
			case 0x10:
				base = 8;
				break;
			case 0x20:
				base = 16;
				break;
			default:
				ASN1DecodingException ex = new ASN1DecodingException();
				ex.setMessage("'0x" + Integer.toHexString(octets[0] & 0xff)
						+ "' Invalid base B' value.", null, data.getClass(),
						null, null);
				throw ex;
			}
			int scaling = (octets[0] & 0x0c) >> 2;
			short exponent = octets[1];
			int index = 2;
			if ((octets[0] & 0x03) == 0x01) {
				exponent <<= 8;
				exponent |= octets[2] & 0xff;
				index++;
			}
			byte[] tmp = new byte[octets.length - index];
			System.arraycopy(octets, index, tmp, 0, octets.length - index);
			long mantissa = new BigInteger(tmp).longValue();
			data.set(mantissa * Math.pow(2, scaling) * Math.pow(base, exponent)
					* sign);
		} else if ((octets[0] & 0x40) == 0) {
			// ISO6093
			data.set(Double.valueOf(new String(octets, 1, octets.length - 1)));
		} else {
			// special value.
			switch (octets[0]) {
			case 0x40:
				data.set(Double.POSITIVE_INFINITY);
				break;
			case 0x41:
				data.set(Double.NEGATIVE_INFINITY);
				break;
			default:
				ASN1DecodingException ex = new ASN1DecodingException();
				ex.setMessage("'0x" + Integer.toHexString(octets[0] & 0xff)
						+ "' Invalid special value.", null, data.getClass(),
						null, null);
				throw ex;
			}
		}
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .BitStringType)
	 */
	public Void visit(BIT_STRING data) throws ASN1DecodingException {
		if (isConstructed) {
			processConcatenatable(data);
			return null;
		}
		byte[] octets = readStream(readLength());
		if (octets.length == 1) {
			if (octets[0] == 0x00) {
				data.set(new boolean[0]);
				return null;
			}
			ASN1DecodingException ex = new ASN1DecodingException();
			ex.setMessage("'0x" + Integer.toHexString(octets[0] & 0xff)
					+ "'The initial octet of empty BitString shall be zero.",
					null, data.getClass(), null, null);
			throw ex;
		}
		boolean[] value = new boolean[(octets.length - 1) * 8 - octets[0]];
		int mask = 0x80;
		int index = 1;
		for (int i = 0; i < value.length; i++) {
			value[i] = (octets[index] & mask) > 0;
			mask >>>= 1;
			if (mask == 0) {
				mask = 0x80;
				index++;
			}
		}
		data.set(value);
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .OctetStringType)
	 */
	public Void visit(OCTET_STRING data) throws ASN1DecodingException {
		if (isConstructed) {
			processConcatenatable(data);
		} else {
			data.set(readStream(readLength()));
		}
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .NullType)
	 */
	public Void visit(NULL data) throws ASN1DecodingException {
		if (readLength() != 0) {
			ASN1DecodingException ex = new ASN1DecodingException();
			ex.setMessage("The contents octets shall not contain any octets.",
					null, data.getClass(), null, null);
			throw ex;
		}
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .SequenceType)
	 */
	public Void visit(final SEQUENCE data) throws ASN1DecodingException {
		final ElementSpecification[] elements = data.getElementTypeList();
		processMultipleOctets(readLength(), new ElementProcessor() {

			private int index = 0;

			public void process() throws ASN1DecodingException {
				readTag();
				for (; index < elements.length; index++) {
					if (elements[index].matches(BerDecoder.this.tagClass,
							BerDecoder.this.tagNumber)) {
						processComponent(data, elements[index]);
						break;
					}
				}
				index++;
			}

		});
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .SequenceOfType)
	 */
	public Void visit(final SEQUENCE_OF<? extends ASN1Type> data)
			throws ASN1DecodingException {
		processCollection(data);
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .SetType)
	 */
	public Void visit(final SET data) throws ASN1DecodingException {
		processMultipleOctets(readLength(), new ElementProcessor() {

			public void process() throws ASN1DecodingException {
				readTag();
				for (ElementSpecification e : data.getElementTypeList()) {
					if (e.matches(BerDecoder.this.tagClass,
							BerDecoder.this.tagNumber)) {
						processComponent(data, e);
						break;
					}
				}
			}

		});
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .SetOfType)
	 */
	public Void visit(SET_OF<? extends ASN1Type> data)
			throws ASN1DecodingException {
		processCollection(data);
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .ChoiceType)
	 */
	public Void visit(CHOICE data) throws ASN1DecodingException {
		processComponent(data, data.alternative(tagClass, tagNumber));
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .ObjectIdentifierType)
	 */
	public Void visit(OBJECT_IDENTIFIER data) throws ASN1DecodingException {
		int length = readLength();
		final List<Integer> oids = new ArrayList<Integer>();
		byte[] octets = readStream(1);
		oids.add(octets[0] / 40);
		oids.add(octets[0] % 40);
		decodeOids(oids, length - 1);
		data.set(oids);
		return null;
	}
	
	/* (non-Javadoc)
	 * @see jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type.builtin.RELATIVE_OID)
	 */
	public Void visit(RELATIVE_OID data) throws ASN1DecodingException {
		List<Integer> oids = new ArrayList<Integer>();
		decodeOids(oids, readLength());
		data.set(oids);
		return null;
	}
	
	private void decodeOids(final List<Integer> oids, int length) throws ASN1DecodingException {
		processMultipleOctets(length, new ElementProcessor() {
			public void process() throws ASN1DecodingException {
				oids.add(readMultipleOctets());
			}
		});
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .base.StringType)
	 */
	public Void visit(StringType data) throws ASN1DecodingException {
		visit((OCTET_STRING) data);
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .TimeType)
	 */
	public Void visit(TimeType data) throws ASN1DecodingException {
		data.set(new String(readStream(readLength())));
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.builtintype
	 * .ANY)
	 */
	public Void visit(ANY data) throws ASN1DecodingException {
		if (data.value() == null) {
			data.set(ASN1ModuleManager.instantiate(module, tagClass, tagNumber));
		}
		data.value().accept(this);
		return null;
	}

	/*
	 * (non-Javadoc)
	 * 
	 * @see
	 * jp.bitmeister.asn1.processor.ASN1Visitor#visit(jp.bitmeister.asn1.type
	 * .UnknownType)
	 */
	public Void visit(UnknownType data) throws ASN1DecodingException {
		data.set(readStream(readLength()));
		return null;
	}

	/**
	 * Processes the {@code CollectionType} data.
	 * 
	 * @param data
	 *            The data to be processed.
	 * @throws ASN1DecodingException
	 *             When an error occurred while the decoding process.
	 */
	private <T extends ASN1Type> void processCollection(
			final CollectionType<T> data) throws ASN1DecodingException {
		if (data.size() > 0) {
			data.clear();
		}
		processMultipleOctets(readLength(), new ElementProcessor() {

			public void process() throws ASN1DecodingException {
				data.collection().add(
						decodeImpl(ASN1Type.instantiate(data.componentType())));
			}

		});
	}

	/**
	 * Processes the data which are encoded in constructed form.
	 * 
	 * @param data
	 *            The {@code Concatenatable} data.
	 * @throws ASN1DecodingException
	 *             When an error occurred while the decoding process.
	 */
	private <T extends ASN1Type & Concatenatable<T>> void processConcatenatable(
			final T data) throws ASN1DecodingException {

		processMultipleOctets(readLength(), new ElementProcessor() {

			public void process() throws ASN1DecodingException {
				@SuppressWarnings("unchecked")
				T component = (T) ASN1Type.instantiate(data.getClass());
				data.concatenate(decodeImpl(component));
			}

		});
	}

	/**
	 * Decodes an element of {@code StructuredType}.
	 * 
	 * @param enclosure
	 *            The instance of enclosing type.
	 * @param namedType
	 *            The specification of the element.
	 * @throws ASN1DecodingException
	 *             When an error occurred while the process.
	 */
	private void processComponent(StructuredType enclosure,
			NamedTypeSpecification namedType) throws ASN1DecodingException {
		ASN1Type component = namedType.instantiate();
		if (namedType.tag() != null
				&& namedType.tag().tagMode() != ASN1TagMode.IMPLICIT) {
			readLength();
			decodeImpl(component);
		} else {
			component.accept(this);
		}
		enclosure.set(namedType, component);
	}

	/**
	 * The interface used for processing multiple elements.
	 * 
	 * @author WATANABE, Jun. <jwat at bitmeister.jp>
	 */
	private interface ElementProcessor {

		/**
		 * The implimentation of process.
		 * 
		 * @throws ASN1DecodingException
		 *             When an error occurred while the process.
		 */
		public void process() throws ASN1DecodingException;

	}

	/**
	 * Processes an octets encoded in multiple form.
	 * 
	 * @param length
	 *            The length of octets.
	 * @param processor
	 *            The processer that process each elements.
	 * @throws ASN1DecodingException
	 *             When the {@code ElementProcessor} throws exception.
	 */
	private void processMultipleOctets(int length, ElementProcessor processor)
			throws ASN1DecodingException {
		int counter = this.count;
		while (this.count - counter < length) {
			processor.process();
		}
	}

	/**
	 * Decodes an ASN.1 tag octets and set the result to instance field.
	 * 
	 * @throws ASN1DecodingException
	 *             When an {@code IOException} thrown form the stream.
	 */
	private void readTag() throws ASN1DecodingException {
		byte[] octets = readStream(1);
		switch (octets[0] & 0xc0) {
		case 0x00:
			tagClass = ASN1TagClass.UNIVERSAL;
			break;
		case 0x40:
			tagClass = ASN1TagClass.APPLICATION;
			break;
		case 0x80:
			tagClass = ASN1TagClass.CONTEXT_SPECIFIC;
			break;
		case 0xc0:
			tagClass = ASN1TagClass.PRIVATE;
			break;
		}
		isConstructed = (octets[0] & 0x20) > 0;
		if ((octets[0] & 0x1f) == 0x1f) {
			tagNumber = readMultipleOctets();
		} else {
			tagNumber = octets[0] & 0x1f;
		}
	}

	/**
	 * Decodes a length octets.
	 * 
	 * @return A length read from octets.
	 * @throws ASN1DecodingException
	 *             When an {@code IOException} thrown form the stream.
	 */
	private int readLength() throws ASN1DecodingException {
		byte[] octets = readStream(1);
		if ((octets[0] & (byte) 0x80) == 0) {
			return octets[0];
		} else {
			octets = readStream(octets[0] & 0x7f);
			int length = 0;
			for (byte b : octets) {
				length <<= 8;
				length |= (int) b & 0xff;
			}
			return length;
		}
	}

	/**
	 * Decodes a number encoded in Long form. Tag number and length may be
	 * encoded in this form.
	 * 
	 * @return A number read from octets.
	 * @throws ASN1DecodingException
	 *             When an {@code IOException} thrown form the stream.
	 */
	private int readMultipleOctets() throws ASN1DecodingException {
		int result = 0;
		while (true) {
			byte[] octets = readStream(1);
			result |= (octets[0] & 0x7f);
			if ((octets[0] & 0x80) == 0) {
				return result;
			}
			result <<= 7;
		}
	}

	/**
	 * Reads specified size of bytes from the {@code InputStream}.
	 * 
	 * @param length
	 *            The size to read.
	 * @return An array of byte read from the stream.
	 * @throws ASN1DecodingException
	 *             When an {@code IOException} thrown form the stream.
	 */
	private byte[] readStream(int length) throws ASN1DecodingException {
		if (length == 0) {
			return new byte[0];
		}
		byte[] octets = new byte[length];
		try {
			if (in.read(octets) != octets.length) {
				ASN1DecodingException ex = new ASN1DecodingException();
				ex.setMessage("length = '" + length
						+ "' Incorrect length octets.", null, null, null, null);
				throw ex;
			}
		} catch (IOException e) {
			ASN1DecodingException ex = new ASN1DecodingException();
			ex.setMessage("IOException thrown while decoding process.", e,
					null, null, null);
			throw ex;
		}
		count += octets.length;
		return octets;
	}

}
