#include "base.h"

#include "svd.h"

#define  BASELINE_FRACTION  0.1
#define  BASELINE_STRETCH  0.5
#define  CHISQ_MULTIPLIER  10
#define  CHISQ_SCALER  1.5
#define  SVD_CUTOFF  0.00001
#define  NLOOPS_MAX  10

static int npoints_alloc = 0;
static int order_alloc = 0;

static int *base = NULL;
static float *chisq = NULL;
static float **u1 = NULL;
static float **u2 = NULL;
static float *u3 = NULL;
static float *y = NULL;
static float *tmp = NULL;
static float *params = NULL;

/*  The baseline correction algorithm is based on ideas stated in:
    P. Guntert and K. Wuthrich, J. Mag. Reson. 96 (1992) 403 - 407.  */

Status alloc_baseline(int npoints, int order)
{
    int n, o;

    if ((npoints > npoints_alloc) || (order > order_alloc))
    {
	n = MAX(npoints, npoints_alloc);
	o = MAX(order, order_alloc);

	FREE(base, int);
	FREE(chisq, float);
	FREE2(u1, float, npoints_alloc);
	FREE2(u2, float, order_alloc);
	FREE(u3, float);
	FREE(y, float);
	FREE(tmp, float);
	FREE(params, float);

	npoints_alloc = 0;
	order_alloc = 0;

	MALLOC(base, int, n);
	MALLOC(chisq, float, n);
	MALLOC2(u1, float, n, o);
	MALLOC2(u2, float, o, o);
	MALLOC(u3, float, o);
	MALLOC(y, float, n);
	MALLOC(tmp, float, o);
	MALLOC(params, float, o);

	npoints_alloc = n;
	order_alloc = o;
    }

    return  OK;
}

void range_base_points(int first, int last, int nbaseline_in, int *baseline_in,
					int *nbaseline_out, int *baseline_out)
{
    int f, l;

    for (f = 0; (f < nbaseline_in) && (baseline_in[f] < first); f++)
	;

    for (l = f; (l < nbaseline_in) && (baseline_in[l] < last); l++)
	baseline_out[l-f] = baseline_in[l] - first;

    *nbaseline_out = l - f;
}

/*  routine below assumes that baseline allocated of size >= npoints  */

Status read_base_points(String file, int npoints, int *nbaseline,
					int *baseline, String error_msg)
{
    int i, j, b, nbase;
    FILE *fp;

    if (OPEN_FOR_READING(fp, file))
    {
	sprintf(error_msg, "opening '%s' for reading", file);
	return  ERROR;
    }

    sprintf(error_msg, "reading '%s': ", file);
    error_msg += strlen(error_msg);

    for (i = 0; i < npoints; i++)
	baseline[i] = 0;

    for (nbase = 1; fscanf(fp, "%d", &b) == 1; nbase++)
    {
	if ((b < 1) || (b > npoints))
	{
	    fclose(fp);

	    sprintf(error_msg, "point #%d is %d, must be in range (1, %d)",
							nbase, b, npoints);
	    return  ERROR;
	}

	if (baseline[b-1] != 0)
	{
	    fclose(fp);

	    sprintf(error_msg, "point #%d is %d, repeats point #%d",
						nbase, b, baseline[b-1]);
	    return  ERROR;
	}

	baseline[b-1] = nbase;
    }

    fclose(fp);

    nbase--;
    if (nbase == 0)
	RETURN_ERROR_MSG("no baseline points");

    for (i = j = 0; i < npoints; i++)  /* shuffle points down */
    {
	if (baseline[i] > 0)
	    baseline[j++] = i;
    }

    *nbaseline = nbase;

    return  OK;
}

/*  routine below assumes that base_values allocated of size >= npoints  */

Status read_base_values(String file, int npoints, float *base_values,
							String error_msg)
{
    int i;
    float v;
    FILE *fp;

    if (OPEN_FOR_READING(fp, file))
    {
	sprintf(error_msg, "opening '%s' for reading", file);
	return  ERROR;
    }

    sprintf(error_msg, "reading '%s': ", file);
    error_msg += strlen(error_msg);

    for (i = 0; i < npoints; i++)
    {
	if (fscanf(fp, "%f", &v) != 1)
	{
	    fclose(fp);

	    sprintf(error_msg, "missing value #%d", i+1);
	    return  ERROR;
	}

	base_values[i] = v;
    }

    fclose(fp);

    return  OK;
}

static void find_chisq(int n, int w, float *data)
{
    int i_old, i_new, i, j_old, j_new, j, w2;
    float t, scale1, scale2;
    double d, d_old, d_new, sum1, sum2, sum3;

    w2 = w / 2;

    scale1 = w;
    scale2 = w2 * (w2+1) * w;

    scale1 = 1 / scale1;
    scale2 = 3 / scale2;

    sum1 = sum2 = sum3 = 0;

    for (i = -w2-1; i < w2; i++)
    {
	j = (i+n) % n;
	d = data[j];

	sum1 += d;
	sum2 += i * d;
	sum3 += d * d;
    }

    for (i = 0, i_old = -w2-1, i_new = w2; i < n; i++, i_old++, i_new++)
    {
	j_old = (i_old+n) % n;
	j_new = i_new % n;

	d_old = data[j_old];
	d_new = data[j_new];

	sum1 += d_new - d_old;
	sum2 += i_new*d_new - i_old*d_old;
	sum3 += d_new*d_new - d_old*d_old;

	t = sum2 - i*sum1;
	chisq[i] = scale1 * (sum3 - scale1*sum1*sum1 - scale2*t*t);
    }
}

static void find_baseline_points(int n, float value, int *nbaseline)
{
    int i, nbase;

    nbase = 0;
    for (i = 0; i < n; i++)
    {
	if (chisq[i] < value)
	    base[nbase++] = i;
    }

    *nbaseline = nbase;
}

static Bool find_more_baseline(int nbase_needed, int nbase, int stretch_needed)
{
    if (nbase < nbase_needed)
	return  TRUE;

/*  automatically true because of code in find_baseline, but be careful
    if (nbase == 0)
	return  TRUE;
*/

    if ((base[nbase-1] - base[0]) < stretch_needed)
	return  TRUE;

    return  FALSE;
}

Bool find_baseline(int n, int w, float *avg_min_chisq, int *nchisq,
				float *data, int *nbaseline, int **baseline)
{
    int i, m, nloops, nchi, nbase, stretch;
    float avg, value, min_chisq;

    avg = *avg_min_chisq;
    nchi = *nchisq;

    find_chisq(n, w, data);

    min_chisq = chisq[0];
    for (i = 1; i < n; i++)
	min_chisq = MIN(min_chisq, chisq[i]);

    value = CHISQ_MULTIPLIER * MAX(avg, min_chisq);
    m = BASELINE_FRACTION * n;
    m = MAX(m, 1);
    stretch = BASELINE_STRETCH * n;

    nloops = 0;
    do 
    {
	find_baseline_points(n, value, &nbase);
	value *= CHISQ_SCALER;
    }   while (find_more_baseline(m, nbase, stretch)
					&& (++nloops < NLOOPS_MAX));

    if (nloops == NLOOPS_MAX)
	return  FALSE;

    avg = nchi*avg + min_chisq;
    nchi++;

/*  comment out for now since peak uses this (makes derivative wrong)
    *avg_min_chisq = avg / ((float) nchi);
*/
    *nchisq = nchi;

    *nbaseline = nbase;
    *baseline = base;

    return  TRUE;
}

void fit_const_baseline(int npoints, float *data, int nbaseline, int *baseline)
{
    int i, j, s;
    float a, d, sy;

    if (nbaseline < 1)
	return;

    s = 0;
    sy = 0;

    for (i = 0; i < nbaseline; i++)
    {
	j = baseline[i];
	d = data[j];

	s++;
	sy += d;
    }

    a = sy / s;

    for (i = 0; i < npoints; i++)
	data[i] -= a;
}

void fit_poly_baseline(int npoints, int order, float *data,
						int nbaseline, int *baseline)
{
    int i, j, m1, m2;
    float x, v, s;
    Bool converged;

    if (nbaseline < 2)
	return;

    m1 = nbaseline;
    m2 = MIN(m1, order);

    for (i = 0; i < m1; i++)
    {
	v = 1;
	j = baseline[i];
	x = j;
	y[i] = data[j];

	for (j = 0; j < m2; j++)
	{
	    u1[i][j] = v;
	    v *= x;
	}
    }

    svd(u1, u2, u3, tmp, m1, m2, &converged);

    if (!converged)
	return;

    svd_fit(u1, u2, u3, params, y, tmp, SVD_CUTOFF, m1, m2);

    for (i = 0; i < npoints; i++)
    {
	x = i;
	s = 0;
	v = 1;

	for (j = 0; j < m2; j++)
	{
	    s += params[j] * v;
	    v *= x;
	}

	data[i] -= s;
    }
}

void fit_trig_baseline(int npoints, int npoints_orig, int order,
				float *data, int nbaseline, int *baseline)
{
    int i, j, k, m1, m2;
    float x, s;
    double v;
    Bool converged;

    if (nbaseline < 2)
	return;

    m1 = nbaseline;
    m2 = MIN(m1, order);

    if (!(m2 % 2))  /* force m2 to be odd */
	m2--;

    for (i = 0; i < m1; i++)
    {
	j = baseline[i];
	x = j;
	y[i] = data[j];
	v = (PI * x) / npoints_orig;

	u1[i][0] = 1;
	for (j = 1; j < m2; j += 2)
	{
	    k = j / 2;
	    u1[i][j] = cos(k*v);
	    u1[i][j+1] = sin(k*v);
	}
    }

    svd(u1, u2, u3, tmp, m1, m2, &converged);

    if (!converged)
	return;

    svd_fit(u1, u2, u3, params, y, tmp, SVD_CUTOFF, m1, m2);

    for (i = 0; i < npoints; i++)
    {
	x = i;
	s = params[0];
	v = (PI * x) / npoints_orig;

	for (j = 1; j < m2; j += 2)
	{
	    k = j / 2;
	    s += params[j] * cos(k*v);
	    s += params[j+1] * sin(k*v);
	}

	data[i] -= s;
    }
}
