#include "convolve.h"

#include "command.h"
#include "conv.h"
#include "ref.h"
#include "script.h" /* for get_reference() */

static int ncodes = 0;
static int npoints[MAX_NCODES];
static int step[MAX_NCODES];
static int width[MAX_NCODES];
static float *filter[MAX_NCODES];
static float *sum_filter[MAX_NCODES];
static float *convolve[MAX_NCODES];
static float *cos_mult[MAX_NCODES];
static float *sin_mult[MAX_NCODES];

static void do_convolve(int code, float *data)
{
    int i;

    calculate_convolve(npoints[code], data, convolve[code],
		width[code], filter[code], sum_filter[code], step[code]);

/*
    for (i = 0; i < npoints[code]; i++)
	printf("%d\t%1.0f\t%1.0f\t%1.0f\n",
		i, data[i], convolve[code][i], data[i]-convolve[code][i]);
*/

    for (i = 0; i < npoints[code]; i++)
	data[i] -= convolve[code][i];
}

static void do_multiply(int dirn, int code, float *data)
{
    int i, j, k, n = npoints[code] / 2;
    float d;
    float *c = cos_mult[code];
    float *s = sin_mult[code];

    for (i = 0, j = 0, k = 1; i < n; i++, j += 2, k += 2)
    {
        d = c[i]*data[j] + dirn*s[i]*data[k];
        data[k] = - dirn*s[i]*data[j] + c[i]*data[k];
        data[j] = d;
    }
}

static void do_convolve2(int code, float *data)
{
    int i;

    do_multiply(1, code, data);

    calculate_convolve(npoints[code], data, convolve[code],
		width[code], filter[code], sum_filter[code], step[code]);

    for (i = 0; i < npoints[code]; i++)
	data[i] -= convolve[code][i];

    do_multiply(-1, code, data);
}

static Status check_allocation(int code)
{
    MALLOC(filter[code], float, width[code]);
    MALLOC(sum_filter[code], float, width[code]);

    MALLOC(convolve[code], float, npoints[code]);
	/* do not really need separate allocations for each code */

    return  OK;
}

static Status check_allocation2(int code)
{
    int n = npoints[code] / 2;

    MALLOC(cos_mult[code], float, n);
    MALLOC(sin_mult[code], float, n);

    return  OK;
}

#define  CONVOLVE_ERROR(cmd, msg) \
	 {   sprintf(error_msg, "'%s': %s", cmd, msg);  return  ERROR;   }

static Status check_convolve(Generic_ptr *param, String cmd, String error_msg)
{
    int type, npts, w2;
    Line msg;

    w2 = *((int *) param[0]);

    if (w2 < 1)
        CONVOLVE_ERROR(cmd, "half width must be >= 1");

    sprintf(msg, "%s %d", cmd, w2);
    if (setup_command(&type, &npts, ncodes, msg,
					do_convolve, error_msg) == ERROR)
        return  ERROR;

    if ((type == COMPLEX_DATA) && (npts % 2))
	CONVOLVE_ERROR(cmd, "complex data but odd number of points");

    width[ncodes] = 2*w2+1;
    npoints[ncodes] = npts;

    if (type == COMPLEX_DATA)
	step[ncodes] = 2;
    else
	step[ncodes] = 1;

    if ((npoints[ncodes]/step[ncodes]) < width[ncodes])
	CONVOLVE_ERROR(cmd, "half width too large for given number of points");

    if (check_allocation(ncodes) == ERROR)
	CONVOLVE_ERROR(cmd, "allocating memory");

    CHECK_STATUS(end_command(type, npts, cmd, error_msg));

    ncodes++;

    return  OK;
}

static void setup_multiplier(int code, float offset)
{
    int i, n = npoints[code] / 2;
    float angle;

    offset *= TWOPI;

    for (i = 0; i < n; i++)
    {
        angle = i * offset;
        cos_mult[code][i] = cos(angle);
        sin_mult[code][i] = sin(angle);
    }
}

static Status check_convolve2(Generic_ptr *param, String cmd, int ref_type,
							String error_msg)
{
    int type, npts, w2;
    float offset;
    Line msg;
    Ref_info *ref = get_reference();

    w2 = *((int *) param[0]);
    offset = *((float *) param[1]);

    if (w2 < 1)
        CONVOLVE_ERROR(cmd, "half width must be >= 1");

    sprintf(msg, "%s %d %3.2f", cmd, w2, offset);
    if (setup_command(&type, &npts, ncodes, msg,
					do_convolve2, error_msg) == ERROR)
        return  ERROR;

    if (type != COMPLEX_DATA)
	CONVOLVE_ERROR(cmd, "must be complex data");

    width[ncodes] = 2*w2+1;
    npoints[ncodes] = npts;
    step[ncodes] = 2;

    if ((npoints[ncodes]/step[ncodes]) < width[ncodes])
	CONVOLVE_ERROR(cmd, "half width too large for given number of points");

    if (check_allocation(ncodes) == ERROR)
	CONVOLVE_ERROR(cmd, "allocating memory");

    if (check_allocation2(ncodes) == ERROR)
	CONVOLVE_ERROR(cmd, "allocating memory");

    offset = fractional_ref_offset(ref_type, ref, offset);
    setup_multiplier(ncodes, offset);

    CHECK_STATUS(end_command(type, npts, cmd, error_msg));

    ncodes++;

    return  OK;
}

Status init_convolve(Generic_ptr *param, String error_msg)
{
    CHECK_STATUS(check_convolve(param, "convolve", error_msg));

    calculate_sine_filter(width[ncodes-1], filter[ncodes-1],
							sum_filter[ncodes-1]);

    return  OK;
}

Status init_conv_sine(Generic_ptr *param, String error_msg)
{
    CHECK_STATUS(check_convolve(param, "conv_sine", error_msg));

    calculate_sine_filter(width[ncodes-1], filter[ncodes-1],
							sum_filter[ncodes-1]);

    return  OK;
}

Status init_conv_box(Generic_ptr *param, String error_msg)
{
    CHECK_STATUS(check_convolve(param, "conv_box", error_msg));

    calculate_box_filter(width[ncodes-1], filter[ncodes-1],
							sum_filter[ncodes-1]);

    return  OK;
}

Status init_conv_triangle(Generic_ptr *param, String error_msg)
{
    CHECK_STATUS(check_convolve(param, "conv_triangle", error_msg));

    calculate_triangle_filter(width[ncodes-1], filter[ncodes-1],
							sum_filter[ncodes-1]);

    return  OK;
}

Status init_conv_gaussian(Generic_ptr *param, String error_msg)
{
    Line msg;
    float end = *((float *) param[1]);

    CHECK_STATUS(check_convolve(param, "conv_gaussian", error_msg));

    if ((end <= 0) || (end >= 1))
    {
	sprintf(msg, "end value = %f, must be in range (0, 1)", end);
	CONVOLVE_ERROR("conv_gaussian", msg);
    }

    calculate_gaussian_filter(width[ncodes-1], end, filter[ncodes-1],
							sum_filter[ncodes-1]);

    return  OK;
}

Status init_conv_file(Generic_ptr *param, String error_msg)
{
    String file = (char *) param[0];
    Line cmd;

    sprintf(cmd, "conv_file %s", file);

    CHECK_STATUS(check_convolve(param+1, cmd, error_msg)); /* note +1 */

    return  read_and_sum_filter(file, width[ncodes-1], filter[ncodes-1],
					sum_filter[ncodes-1], error_msg);
}

Status init_conv_sine_ppm(Generic_ptr *param, String error_msg)
{
    CHECK_STATUS(check_convolve2(param, "conv_sine_ppm", REF_PPM, error_msg));

    calculate_sine_filter(width[ncodes-1], filter[ncodes-1],
							sum_filter[ncodes-1]);

    return  OK;
}

