#include "fitter.h"

#include "nonlinear_model.h"

#define  CHISQ_STOP_CRITERION		(1.0e-1)
#define  MAX_MODEL_ITER			40
#define  MAX_CONDITION			4

#define  SMALL_VALUE			(1.0e-20)

#define  DO_PROCESS(z)		((*do_process)(code, z))

#define  CHECK_NONLINEAR_MODEL(stage) \
	 {   nonlinear_model(x, y, w, m, a_fit, covar, alpha, beta, da, ap, \
			dy_da, piv, row, col, tot_fitted_parameters, &chisq, \
			&lambda, model_fit, stage, &singular); \
	     if (singular)  return  FALSE;   }

static int nfitted;
static int *npts;
static Script **scripts;
static Bool *complex;
static int *npoints_fitted;
static int *base_fitted;
/*
static int *top_fitted;
*/
static int *cum_fitted_points;

static int point[MAX_NDIM];

static float *freq;
static float **fit;
static float **phase;
static float **decay;
static float **center;

static int fitted_parameters;
static int tot_fitted_parameters;
static int max_fitted_parameters;
static int grouped;

static Bool have_phase[MAX_NDIM];
static Bool have_decay[MAX_NDIM];

static float *a_all;
static float *a_fit;
static int *a_mapping;
static float **covar;
static float **alpha;
static float *beta;
static float *da;
static float *ap;
static float *dy_da;
static int *piv;
static int *row;
static int *col;

/*static float *yy; *//* test purpose only */

/*
static void print_float(float x)
{
    printf("%4.3f\t", x);
}

static void print_vector(float *y, int m)
{
    int i;

    for (i = 0; i < m; i++)
	print_float(y[i]);

    printf("\n");
}
*/

static Status alloc_fit_data_memory(int ndim, int nmax, int *n,
							String error_msg)
{
    int i, j, ntot = ndim * nmax;

    sprintf(error_msg, "allocating fit data memory");

    MALLOC(freq, float, ntot);

    MALLOC(fit, float *, ntot);
    MALLOC(phase, float *, ntot);
    MALLOC(decay, float *, ntot);
    MALLOC(center, float *, ntot);

    for (i = 0; i < ntot; i++)
    {
	j = i % ndim;

	MALLOC(fit[i], float, n[j]);
	MALLOC(phase[i], float, n[j]);
	MALLOC(decay[i], float, n[j]);
	MALLOC(center[i], float, n[j]);
    }

    return  OK;
}

static Status alloc_fit_coeff_memory(int ndim, int m, String error_msg)
{
    int n = GLOBAL_PARAMETERS + ndim*DIMENSION_PARAMETERS;

    sprintf(error_msg, "allocating fit coefficient memory");

    MALLOC(a_all, float, n);

    MALLOC(a_fit, float, m);
    MALLOC(a_mapping, int, m);
    MALLOC(beta, float, m);
    MALLOC(da, float, m);
    MALLOC(ap, float, m);
    MALLOC(dy_da, float, m);
    MALLOC2(covar, float, m, m);
    MALLOC2(alpha, float, m, m);

    MALLOC(piv, int, m);
    MALLOC(row, int, m);
    MALLOC(col, int, m);

    return  OK;
}

static Status alloc_fit_memory(int ndim, int nmax, int *npts_max,
							String error_msg)
{
    CHECK_STATUS(alloc_fit_data_memory(ndim, nmax, npts_max, error_msg));
    CHECK_STATUS(alloc_fit_coeff_memory(ndim, max_fitted_parameters,error_msg));

    return  OK;
}

Status initialize_fit(int ndim, int nparams, int *npts_max, int group_max,
							String error_msg)
{
    max_fitted_parameters = nparams * group_max;

    CHECK_STATUS(alloc_fit_memory(ndim, group_max, npts_max, error_msg));

    return  OK;
}

void initialize_fixed_params(int ndim, Bool *fixed_phase, float *phase,
					Bool *fixed_decay, float *decay)
{
    int i, j, k;

/*  this routine assumes given order of parameters  */

    a_mapping[0] = 0;  /* magnitude */

    j = k = 1;
    for (i = 0; i < ndim; i++)
    {
	if (fixed_phase[i])
	    a_all[k++] = phase[i];
	else
	    a_mapping[j++] = k++;

	if (fixed_decay[i])
	    a_all[k++] = decay[i];
	else
	    a_mapping[j++] = k++;

	a_mapping[j++] = k++;  /* center */

	have_phase[i] = !fixed_phase[i];
	have_decay[i] = !fixed_decay[i];
    }
}

static void setup_oscillator(float *z, float w, int n)
{
    int i;
    double t;

/*
    printf("oscillator: w = %4.3f\n", w);
*/

    for (i = 0; i < n; i++)
    {
        t = i;
        z[2*i] = cos(w*t);
        z[2*i+1] = - sin(w*t);
    }
}

static void setup_decay(float *z, float d, int n)
{
    int i;
    float r, s;

/*
    printf("decay: d = %4.3f\n", d);
*/

    r = 1;  s = exp(-(double) d);
    for (i = 0; i < n; i++)
    {
	z[2*i] *= r;
	z[2*i+1] *= r;
	r *= s;
    }
}

static void setup_phase(float *z, float phi, int n)
{
    int i;
    float c, s, x, y;

/*
    printf("phase: phi = %4.3f\n", phi);
*/

    c = cos((double) phi);
    s = sin((double) phi);

    for (i = 0; i < n; i++)
    {
	x = z[2*i];
	y = z[2*i+1];

	z[2*i] = c*x + s*y;
	z[2*i+1] = - s*x + c*y;
    }
}

static void setup_data(float freq, float *a, float *y, int n)
{
    setup_oscillator(y, freq+a[CENTER_PARAMETER], n);
    setup_decay(y, a[DECAY_PARAMETER], n);
    setup_phase(y, a[PHASE_PARAMETER], n);
}

static void run_real_script(float *a, Script *script, int npts, float freq,
			float *fit, float *phase, float *decay, float *center,
			Bool all, Bool do_phase, Bool do_decay)
{
    int i, j, k, code;
    float z1, z2;
    Do_process do_process;

    setup_data(freq, a, fit, npts);

    if (all)
    {
	for (i = 0, j = 0, k = 1; i < npts; i++, j += 2, k += 2)
	{
	    z1 = fit[j];  z2 = fit[k];

	    phase[j] = z2;  phase[k] = - z1;
	    decay[j] = - i * z1;  decay[k] = - i * z2;
	    center[j] = - decay[k];  center[k] = decay[j];
	}
    }

    for (i = 0; i < script->ncommands; i++)
    {
	do_process = script->commands[i].do_process;
	code = script->commands[i].code;

	DO_PROCESS(fit);

	if (all)
	{
	    if (do_phase)
		DO_PROCESS(phase);

	    if (do_decay)
		DO_PROCESS(decay);

	    DO_PROCESS(center);
	}
    }
}

static void run_complex_script(float *a, Script *script, int npts, float freq,
		float *fit, float *phase, float *decay, float *center, Bool all)
{
    int i, j, k, code;
    Do_process do_process;

/*
   printf("a[0] = %f, a[1] = %f, a[2] = %f, npts = %d, freq = %4.3f\n",
						a[0], a[1], a[2], npts, freq);
*/

    setup_data(freq, a, fit, npts);

/*
    for (i = 0; i < npts; i++)
	printf("%5.4f\n", fit[2*i]);
*/

    if (all)
    {
	for (i = 0; i < npts; i++)
	{
	    decay[2*i] = - i * fit[2*i];
	    decay[2*i+1] = - i * fit[2*i+1];
	}
    }

    for (i = 0; i < script->ncommands; i++)
    {
	do_process = script->commands[i].do_process;
	code = script->commands[i].code;

	DO_PROCESS(fit);

	if (all)
	    DO_PROCESS(decay);
    }

    for (i = 0, j = 0, k = 1; i < script->npts[0]/2; i++, j += 2, k += 2)
    {
	fit[i] = fit[j];

	if (all)
	{
	    phase[i] = fit[k];
	    decay[i] = decay[j];
	    center[i] = - decay[k];
	}
    }

/*
    if (all)
	for (i = 0; i < script->npts[0]/2; i++)
printf("i = %d, fit = %4.3f, phase = %4.3f, decay = %4.3f, center = %4.3f\n",
				i, fit[i], phase[i], decay[i], center[i]);
    else
	for (i = 0; i < script->npts[0]/2; i++)
	    printf("i = %d, fit = %4.3f\n", i, fit[i]);
*/
}

void fill_in_parameters(float *a)
{
    int i, j;

    for (i = 0; i < fitted_parameters; i++)
    {
	j = a_mapping[i];
	a_all[j] = a[i];
    }
}

static void run_scripts(float *a, Bool all)
{
    int i, j, k, g;

    for (g = 0; g < grouped; g++)
    {
/*
	printf("g = %d, a:\t", g);  print_vector(a, fitted_parameters);
*/

	fill_in_parameters(a);

	for (i = 0; i < nfitted; i++)
	{
	    j = GLOBAL_PARAMETERS + i * DIMENSION_PARAMETERS;
	    k = i + g*nfitted;

	    if (complex[i])
		run_complex_script(a_all+j, scripts[i], npts[i], freq[k],
				fit[k], phase[k], decay[k], center[k], all);
	    else
		run_real_script(a_all+j, scripts[i], npts[i], freq[k],
				fit[k], phase[k], decay[k], center[k],
				all, have_phase[i], have_decay[i]);
	}

	a += fitted_parameters;
    }
}

static void model_fit(float x, float *a, float *y, float *dy_da)
{
    int i, j, k, g, ind, p, xx;
    float v, amp, sum;
    static Bool all = TRUE;

/*
    printf("a[0]=%3.2e, a[1]=%3.2e, a[2]=%3.2e\n", a[0], a[1], a[2]);
*/

/*  this routine assumes given order of parameters  */

    xx = NEAREST_INTEGER(x);

    if (xx == 0)
	run_scripts(a, all);

    find_point(nfitted, xx, point, cum_fitted_points, base_fitted,
							npoints_fitted, TRUE);

    sum = 0;
    for (g = 0; g < grouped; g++)
    {
	amp = a[AMPLITUDE_PARAMETER];
	dy_da[AMPLITUDE_PARAMETER] = 1;

	ind = 1;
	for (i = 0; i < nfitted; i++)
	{
	    k = i + g*nfitted;
	    p = point[i];

	    dy_da[AMPLITUDE_PARAMETER] *= fit[k][p];

	    v = 1;
	    for (j = 0; j < nfitted; j++)
	    {
		if (i == j)
		    continue;

		v *= fit[j+g*nfitted][point[j]];
	    }

	    if (have_phase[i])
		dy_da[ind++] = amp * v * phase[k][p];

	    if (have_decay[i])
		dy_da[ind++] = amp * v * decay[k][p];

	    dy_da[ind++] = amp * v * center[k][p];
	}

	sum += amp * dy_da[AMPLITUDE_PARAMETER];

/*
printf("pnt=(%d,%d,%d): ind=%d, x=%1.0f,yy=%3.2e,y=%3.2e,a=%2.1e,dy_da=%3.2e,dy_dd=%3.2e,dy_dc=%3.2e\n",
	point[0], point[1], point[2], ind, x, yy[xx], sum, amp, dy_da[0], dy_da[1], dy_da[2]);
*/

	a += fitted_parameters;
	dy_da += fitted_parameters;
    }

    *y = sum;

/*
    print_float(*y);
    printf("x=%1.0f, yy=%3.2e, y=%3.2e, pnt=(%d, %d, %d), d=%3.2e, amp=%4.3e\n",
		x, yy[k], *y, point[0], point[1], point[2], *y - yy[k], amp);
    for (i = 0; i < nfitted; i++)
    {
	j = GLOBAL_PARAMETERS + i * DIMENSION_PARAMETERS;
	printf("\ti=%d,j=%d,dphase=, ddecay=%3.2e, dcenter=%3.2e\n", i, j,
			dy_da[j], dy_da[j+1]);
*/
/*
			dy_da[j+PHASE_PARAMETER],
			dy_da[j+DECAY_PARAMETER],
			dy_da[j+CENTER_PARAMETER]);
*/
/*
    }
*/
}

static void initialize_amplitude(float *a, float **f,
					/*float peak_value,*/ int *peak_posn)
{
    int i;
    float v;

    v = 1;
    for (i = 0; i < nfitted; i++)
	v *= f[i][peak_posn[i]];

/*
    if (ABS(v) > SMALL_VALUE)
	a[AMPLITUDE_PARAMETER] = peak_value / v;
*/
    if (ABS(v) > SMALL_VALUE)
	a[AMPLITUDE_PARAMETER] = 1.0 / v;
}

/*
#define  DA  0.0001
#define  MAX_M  10000

static void test_model(float *x, int m)
{
    int i, j;
    float dyda[MAX_M], y[MAX_M], yp, dydap;

    if (m > MAX_M)
	exit(1);

    for (i = 0; i < fitted_parameters; i++)
    {
	printf("working on fitted parameter %d\n", i);

	ZERO_VECTOR(a, fitted_parameters);
	a[AMPLITUDE_PARAMETER] = 1.0;

	for (j = 0; j < m; j++)
	{
	    model_fit(x[j], a, &y[j], dy_da);
	    dyda[j] = dy_da[i];
	}

	a[i] += DA;

	for (j = 0; j < m; j++)
	{
	    model_fit(x[j], a, &yp, dy_da);
	    printf("%d: dyp = %5.4f, dy = %5.4f, dyda = %5.4f, dydap = %5.4f\n",
				j, yp-y[j], dyda[j]*DA, dyda[j], dy_da[i]);
	}
    }

    if (m <= MAX_M)
	exit(1);
}
*/

Bool have_good_fit(float *x, float *y, float *w,
			int m, int ndim, int nparams, int ngroup,
			float max_chisq, int *base, int *top, int *cum_points,
			Script **s, int *n1, int *n2,
			Bool *cplx, float *freq_a, float *freq_b,
			int **peak_posns, float scale, Fit_peak **peaks)
{
    int i, j, cond, iter, g;
    float lambda, chisq, old_chisq, tot_w, chisq_stop_criterion, chisq_scale;
    float *a;
    Bool singular, all = FALSE;

/*
    yy = y;
*/

    if (m == 0)
	return  FALSE;

    nfitted = ndim;
    base_fitted = base;
/*
    top_fitted = top;
*/
    cum_fitted_points = cum_points;

    scripts = s;
    npoints_fitted = n1;
    npts = n2;
    complex = cplx;

    fitted_parameters = nparams;
    tot_fitted_parameters = nparams * ngroup;
    grouped = ngroup;

/*  done in have_good_peaks now
    scale = 1.0 / peak_value;
    SCALE_VECTOR(y, y, scale, m); */  /* should help numerical stability */

/*
    print_vector(y, m);
*/

    a = a_fit;
    for (g = 0; g < grouped; g++)
    {
	for (i = 0; i < nfitted; i++)
	    freq[i+g*nfitted] = freq_b[i] * (peak_posns[g][i] - freq_a[i]);

	ZERO_VECTOR(a, fitted_parameters);
	a[AMPLITUDE_PARAMETER] = 1.0;

	a += fitted_parameters;
    }

    a = a_fit;
    run_scripts(a, all);

    for (g = 0; g < grouped; g++)
    {
	initialize_amplitude(a, fit+g*nfitted, /*peak_values[g],*/ peak_posns[g]);

	a += fitted_parameters;
    }

/*
    printf("iteration -1\n");
*/

    tot_w = 0;
    for (i = 0; i < m; i++)
	tot_w += w[i];

/*  must guarantee (!) that tot_w > 0, which is true from other code  */
    chisq_scale = 100 / tot_w;
    chisq_stop_criterion = chisq_scale * CHISQ_STOP_CRITERION;

    CHECK_NONLINEAR_MODEL(INITIAL_STAGE);

    for (iter = cond = 0; (iter < MAX_MODEL_ITER) &&
						(cond < MAX_CONDITION); iter++)
    {
/*
	printf("iteration %d\n", iter);
*/

	old_chisq = chisq;

	CHECK_NONLINEAR_MODEL(GENERAL_STAGE);

	if (chisq > old_chisq)
	    cond = 0;
/*
	else if ((old_chisq - chisq) < CHISQ_STOP_CRITERION)
	else if ((old_chisq - chisq) < (CHISQ_STOP_CRITERION*old_chisq))
*/
	else if ((old_chisq - chisq) < chisq_stop_criterion)
	    cond++;

	printf("iter = %d, cond = %d, chisq = %f, old_chisq = %f\n",
			iter, cond, chisq_scale*chisq, chisq_scale*old_chisq);
    }

    if (iter == MAX_MODEL_ITER)
	return  FALSE;

/*  below calculates covar and alpha, so don't need with current algorithm
    CHECK_NONLINEAR_MODEL(FINAL_STAGE);
*/

/*    if (m > tot_fitted_parameters) */
			/* should not be here otherwise, but be safe */
/*
	chisq /= m - tot_fitted_parameters;
*/

    chisq *= chisq_scale;  /* try to scale sensibly */

    printf("final chisq = %f\n", chisq);

    if (chisq > max_chisq)
	return  FALSE;

    a = a_fit;
    for (g = 0; g < grouped; g++)
    {
	fill_in_parameters(a);

	peaks[g]->chisq = chisq;
/*
	peaks[g]->magnitude = a_all[AMPLITUDE_PARAMETER];
*/
	peaks[g]->magnitude = a_all[AMPLITUDE_PARAMETER] / scale;

/*
printf("g=%d, a_all[0]=%3.2e, mag=%3.2e\n", g, a_all[0],  peaks[g]->magnitude);
*/

	for (i = 0; i < nfitted; i++)
	{
	    j = GLOBAL_PARAMETERS + i * DIMENSION_PARAMETERS;

/*
printf("i=%d, j=%d, a_all[j]=%3.2e, a_all[j+1]=%3.2e, a_all[j+2]=%3.2e\n",
				i, j, a_all[j], a_all[j+1], a_all[j+2]);
*/

	    peaks[g]->phase[i] = a_all[j+PHASE_PARAMETER];
	    peaks[g]->decay[i] = a_all[j+DECAY_PARAMETER];
	    peaks[g]->center[i] = peak_posns[g][i] +
					a_all[j+CENTER_PARAMETER]/freq_b[i];
	}

	a += fitted_parameters;
    }

/*
    scale = 1.0 / scale;
    SCALE_VECTOR(y, y, scale, m);
*/
/* undo previous scaling */

    return  TRUE;
}

void calculate_fit(int ndim, Bool *complex, Script **scripts, int *npts,
			float *freq_a, float *freq_b,
			int *peak_posn, Fit_peak *peak)
{
    int i, j, m;
    static Bool all = FALSE;

    for (i = 0; i < ndim; i++)
    {
	freq[i] = freq_b[i] * (peak_posn[i] - freq_a[i]);

	j = GLOBAL_PARAMETERS + i * DIMENSION_PARAMETERS;

	a_all[j+PHASE_PARAMETER] = peak->phase[i];
	a_all[j+DECAY_PARAMETER] = peak->decay[i];
	a_all[j+CENTER_PARAMETER] = freq_b[i]*(peak->center[i]-peak_posn[i]);

/*
printf("cf: i=%d, j=%d, a_all[j]=%3.2e, a_all[j+1]=%3.2e, a_all[j+2]=%3.2e\n",
				i, j, a_all[j], a_all[j+1], a_all[j+2]);
*/

	if (complex[i])
	    run_complex_script(a_all+j, scripts[i], npts[i], freq[i], fit[i],
					phase[i], decay[i], center[i], all);
	else
	    run_real_script(a_all+j, scripts[i], npts[i], freq[i],
					fit[i], phase[i], decay[i], center[i],
					all, have_phase[i], have_decay[i]);

	m = scripts[i]->npts[0];

	if (complex[i])
	    m /= 2;

	COPY_VECTOR(peak->fit[i], fit[i], m);
    }
}

float peak_contribution(int ndim, int *point, Fit_peak *peak)
{
    int i, p;
    float v;

    if (peak)
    {
	v = peak->magnitude;
/*
	v = 1;
*/
	for (i = 0; i < ndim; i++)
	{
	    p = point[i];
	    v *= peak->fit[i][p];
	}
    }
    else
    {
	v = 0;
    }

    return  v;
}
