#include "baseline.h"

#include "base.h"
#include "command.h"

#define  RETURN_BASE_ERROR(msg) \
	 {   sprintf(error_msg, "'%s': %s", base_msg, msg);  return  ERROR;   }

#define  FIND_BASELINE_POINTS \
	 {   if (!(have_base_points[code])) \
	     {   if (!find_baseline(npoints[code], width[code], \
			avg_min_chisq+code, nchisq+code, data+first[code], \
					nbaseline+code, baseline+code)) \
	     return;   }   }

static int ncodes = 0;

static int npoints[MAX_NCODES];
static int npoints_orig[MAX_NCODES];
static int width[MAX_NCODES];
static int order[MAX_NCODES];
static int first[MAX_NCODES];
static float avg_min_chisq[MAX_NCODES];
static int nchisq[MAX_NCODES];

static int base_points_code = -1;
static int nbase_points;
static int nbaseline[MAX_NCODES];
static int *baseline[MAX_NCODES];
static Bool have_base_points[MAX_NCODES];

static float *base_values[MAX_NCODES];

static void do_base_const(int code, float *data)
{
    static int n = 0;

    FIND_BASELINE_POINTS;

    fit_const_baseline(npoints[code],
			data+first[code], nbaseline[code], baseline[code]);

    n++;
}

static void do_base_poly(int code, float *data)
{
    FIND_BASELINE_POINTS;

    fit_poly_baseline(npoints[code], order[code],
			data+first[code], nbaseline[code], baseline[code]);
}

static void do_base_trig(int code, float *data)
{
    FIND_BASELINE_POINTS;

    fit_trig_baseline(npoints[code], npoints_orig[code], order[code],
			data+first[code], nbaseline[code], baseline[code]);
}

static Status alloc_base_pnts_memory(int npts)
{
    MALLOC(baseline[ncodes], int, npts);

    return  OK;
}

static Status init_base(String base_msg, Do_process do_base_proc,
	Bool have_range, int w2, int o, int f, int l, String error_msg)
{
    int type, npts;

    if ((base_points_code < 0) && (w2 < 1))
	RETURN_BASE_ERROR("half width must be >= 1");

    if (setup_command(&type, &npts, ncodes, base_msg,
					do_base_proc, error_msg) == ERROR)
        return  ERROR;

    if (type == COMPLEX_DATA)
	RETURN_BASE_ERROR("data must be real");

    if (o < 1)
	RETURN_BASE_ERROR("order must be >= 1");

    if (have_range)
    {
	if (f < 1)
	    RETURN_BASE_ERROR("first point must be >= 1");

	if (l < f)
	    RETURN_BASE_ERROR("last point must be >= first point");

	if (l > npts)
	    RETURN_BASE_ERROR("last point must be <= #points");

	npoints[ncodes] = l - f + 1;
	first[ncodes] = f - 1;
    }
    else
    {
	npoints[ncodes] = npts;
	first[ncodes] = 0;
    }

    if (base_points_code >= 0)
    {
	if (have_range)
	{
	    if (alloc_base_pnts_memory(nbaseline[base_points_code]) == ERROR)
		RETURN_BASE_ERROR("allocating memory");/* above is max needed */

	    range_base_points(first[ncodes], l, nbaseline[base_points_code],
					baseline[base_points_code],
					nbaseline+ncodes, baseline[ncodes]);
	}
	else
	{
	    nbaseline[ncodes] = nbaseline[base_points_code];
	    baseline[ncodes] = baseline[base_points_code];
	}

	if (nbaseline[ncodes] < o)
	{
	    sprintf(error_msg, 
	      "'%s': from file have %d baseline points, must have at least %d",
					base_msg, nbaseline[ncodes], o);
	    return  ERROR;
	}

	if (npts != nbase_points)
	    RETURN_BASE_ERROR("inconsistent #points compared to 'base_points'");
    }

    npoints_orig[ncodes] = npts;
    width[ncodes] = 2*w2 + 1;
    order[ncodes] = o;

    avg_min_chisq[ncodes] = 0;
    nchisq[ncodes] = 0;

    if (base_points_code < 0)
	have_base_points[ncodes] = FALSE;
    else
	have_base_points[ncodes] = TRUE;

    if ((base_points_code < 0) && (npoints[ncodes] < width[ncodes]))
	RETURN_BASE_ERROR("width too large for given #points");

    if (npoints[ncodes] < order[ncodes])
	RETURN_BASE_ERROR("order too large for given #points");

    if (alloc_baseline(npoints[ncodes], order[ncodes]) == ERROR)
	RETURN_BASE_ERROR("allocating memory");

    CHECK_STATUS(end_command(type, npts, base_msg, error_msg));

    ncodes++;

    return  OK;
}

Status init_base_const(Generic_ptr *param, String error_msg)
{
    int w2;
    Line msg;

    w2 = *((int *) param[0]);

    sprintf(msg, "base_const %d", w2);

    return  init_base(msg, do_base_const, FALSE, w2, 1, 0, 0, error_msg);
}

Status init_base_const2(Generic_ptr *param, String error_msg)
{
    int w2, f, l;
    Line msg;

    w2 = *((int *) param[0]);
    f = *((int *) param[1]);
    l = *((int *) param[2]);

    sprintf(msg, "base_const2 %d %d %d", w2, f, l);

    return  init_base(msg, do_base_const, TRUE, w2, 1, f, l, error_msg);
}

Status init_base_poly(Generic_ptr *param, String error_msg)
{
    int w2, o;
    Line msg;

    w2 = *((int *) param[0]);
    o = *((int *) param[1]) + 1;

    sprintf(msg, "base_poly %d %d", w2, o-1);

    if (o > 1)
	return  init_base(msg, do_base_poly, FALSE, w2, o, 0, 0, error_msg);
    else
	return  init_base(msg, do_base_const, FALSE, w2, o, 0, 0, error_msg);
}

Status init_base_poly2(Generic_ptr *param, String error_msg)
{
    int w2, o, f, l;
    Line msg;

    w2 = *((int *) param[0]);
    o = *((int *) param[1]) + 1;
    f = *((int *) param[2]);
    l = *((int *) param[3]);

    sprintf(msg, "base_poly2 %d %d %d %d", w2, o-1, f, l);

    if (o > 1)
	return  init_base(msg, do_base_poly, TRUE, w2, o, f, l, error_msg);
    else
	return  init_base(msg, do_base_const, TRUE, w2, o, f, l, error_msg);
}

Status init_base_trig(Generic_ptr *param, String error_msg)
{
    int w2, o;
    Line msg;

    w2 = *((int *) param[0]);
    o = *((int *) param[1]);
    o = 2*o + 1;

    sprintf(msg, "base_trig %d %d", w2, (o-1)/2);

    if (o > 1)
	return  init_base(msg, do_base_trig, FALSE, w2, o, 0, 0, error_msg);
    else
	return  init_base(msg, do_base_const, FALSE, w2, o, 0, 0, error_msg);
}

Status init_base_trig2(Generic_ptr *param, String error_msg)
{
    int w2, o, f, l;
    Line msg;

    w2 = *((int *) param[0]);
    o = *((int *) param[1]);
    o = 2*o + 1;
    f = *((int *) param[2]);
    l = *((int *) param[3]);

    sprintf(msg, "base_trig2 %d %d %d %d", w2, (o-1)/2, f, l);

    if (o > 1)
	return  init_base(msg, do_base_trig, TRUE, w2, o, f, l, error_msg);
    else
	return  init_base(msg, do_base_const, TRUE, w2, o, f, l, error_msg);
}

static Status init_base_pnts(String file, String base_msg, String error_msg)
{
    int type, npts;
    Line msg;

    if (base_points_code >= 0)
	RETURN_BASE_ERROR("missing previous 'end_base_points'");

    if (setup_command(&type, &npts, ncodes, base_msg,
					do_nothing, error_msg) == ERROR)
        return  ERROR;

    if (type == COMPLEX_DATA)
	RETURN_BASE_ERROR("data must be real");

    if (alloc_base_pnts_memory(npts) == ERROR)
	RETURN_BASE_ERROR("allocating memory");

    if (read_base_points(file, npts, nbaseline+ncodes, baseline[ncodes],
								msg) == ERROR)
	RETURN_BASE_ERROR(msg);

    CHECK_STATUS(end_command(type, npts, base_msg, error_msg));

    base_points_code = ncodes;
    nbase_points = npts;

    ncodes++;

    return  OK;
}

Status init_base_points(Generic_ptr *param, String error_msg)
{
    String file = (char *) param[0];
    Line msg;

    sprintf(msg, "base_points %s", file);

    return  init_base_pnts(file, msg, error_msg);
}

Status init_end_base_points(Generic_ptr *param, String error_msg)
{
    if (base_points_code == -1)
	RETURN_ERROR_MSG("'end_base_points': missing previous 'base_points'");

    base_points_code = -1;

    return  OK;
}

static Status alloc_base_sub_memory()
{
    MALLOC(base_values[ncodes], float, npoints[ncodes]);

    return  OK;
}

static void do_base_sub(int code, float *data)
{
    int n = npoints[code];
    float *d = data + first[code];

    SUBTRACT_VECTORS(d, d, base_values[code], n);
}

static Status init_base_sub(String file, String base_msg, Bool have_range,
						int f, int l, String error_msg)
{
    int type, npts;
    Line msg;

    if (base_points_code >= 0)
	RETURN_BASE_ERROR("cannot use this with base_point file");

    if (setup_command(&type, &npts, ncodes, base_msg,
					do_base_sub, error_msg) == ERROR)
        return  ERROR;

    if (type == COMPLEX_DATA)
	RETURN_BASE_ERROR("data must be real");

    if (have_range)
    {
	if (f < 1)
	    RETURN_BASE_ERROR("first point must be >= 1");

	if (l < f)
	    RETURN_BASE_ERROR("last point must be >= first point");

	if (l > npts)
	    RETURN_BASE_ERROR("last point must be <= #points");

	npoints[ncodes] = l - f + 1;
	first[ncodes] = f - 1;
    }
    else
    {
	npoints[ncodes] = npts;
	first[ncodes] = 0;
    }

    if (alloc_base_sub_memory() == ERROR)
	RETURN_BASE_ERROR("allocating memory");

    if (read_base_values(file, npoints[ncodes], base_values[ncodes],
								msg) == ERROR)
	RETURN_BASE_ERROR(msg);

    CHECK_STATUS(end_command(type, npts, base_msg, error_msg));

    return  OK;
}

Status init_base_subtract(Generic_ptr *param, String error_msg)
{
    String file = (char *) param[0];
    Line msg;

    sprintf(msg, "base_subtract %s", file);

    return  init_base_sub(file, msg, FALSE, 0, 0, error_msg);
}

Status init_base_subtract2(Generic_ptr *param, String error_msg)
{
    String file = (char *) param[0];
    int f = *((int *) param[1]);
    int l = *((int *) param[2]);
    Line msg;

    sprintf(msg, "base_subtract2 %s %d %d", file, f, l);

    return  init_base_sub(file, msg, TRUE, f, l, error_msg);
}
