package jp.sfjp.armadillo.io;

import java.io.*;

/**
 * A bit-by-bit writable OutputStream.
 */
public final class BitOutputStream extends FilterOutputStream {

    private static final int INT_BITSIZE = 32;
    private static final int OUTPUT_BITSIZE = 8;

    private boolean closed;
    private int buffer;
    private int buffered;

    public BitOutputStream(OutputStream out) {
        super(out);
        this.closed = false;
    }

    private void ensureOpen() throws IOException {
        if (closed)
            throw new IOException("stream closed");
    }

    /**
     * Writes the specified int bit by bit.
     * Reads bits of int from the lowest bit,
     * and writes it into the highest bits of this stream.
     * @param b int value
     * @param bitLength bit length to write
     * @throws IOException
     * @throws IllegalArgumentException
     */
    public void writeBits(int b, int bitLength) throws IOException {
        ensureOpen();
        if (1 <= bitLength && bitLength <= 16) {
            if (buffered + bitLength >= INT_BITSIZE)
                flushBuffer();
            int value = b << (INT_BITSIZE - bitLength);
            value >>>= buffered;
            buffer |= value;
            buffered += bitLength;
        }
        else if (17 <= bitLength && bitLength <= 32) {
            writeBits(b >>> 16, bitLength - 16);
            writeBits(b, 16);
        }
        else
            throw new IllegalArgumentException("value=" + b + ", bit length=" + bitLength);
    }

    @Override
    public void write(int b) throws IOException {
        ensureOpen();
        writeBits((byte)b, 8);
    }

    @Override
    public void write(byte[] b, int off, int len) throws IOException {
        ensureOpen();
        for (int i = off; i < off + len; i++)
            writeBits(b[i], 8);
    }

    @Override
    public void flush() throws IOException {
        ensureOpen();
        flushBuffer();
        super.flush();
    }

    private void flushBuffer() throws IOException {
        while (buffered >= OUTPUT_BITSIZE) {
            int value = buffer >>> (INT_BITSIZE - OUTPUT_BITSIZE);
            super.write(value);
            buffer <<= OUTPUT_BITSIZE;
            buffered -= OUTPUT_BITSIZE;
        }
    }

    @Override
    public void close() throws IOException {
        ensureOpen();
        try {
            if (buffered > 0)
                buffered += OUTPUT_BITSIZE - 1;
            flush();
        }
        finally {
            buffer = 0;
            buffered = 0;
            closed = true;
            super.close();
        }
    }

}
