#include "fit.h"

#include "fitter.h"
#include "data.h"

#define  SMALL_VALUE		(1.0e-20)

static int ndim;
static int *npoints;

static float chisq;
static Bool subtract;
static int nfitted;
static int *dim_fitted;
static int *width;
static Bool *periodic;

static int npoints2[MAX_NDIM];
static Script *script[MAX_NDIM];
static int npts[MAX_NDIM];
static Bool complex[MAX_NDIM];
static float freq_a[MAX_NDIM];
static float freq_b[MAX_NDIM];
static int min_posn[MAX_NDIM];
static int max_posn[MAX_NDIM];

static Bool fixed_phase[MAX_NDIM];
static float phase[MAX_NDIM];
static Bool fixed_decay[MAX_NDIM];
static float decay[MAX_NDIM];

static float *fitted_pnts;
static float *fitted_data;
static float *fitted_wgt;

static int **peak_posns;
static Fit_peak **peaks;

static int npeaks;

static int fitted_points;
static int nparams;

static int point[MAX_NDIM];
static int cum_fitted_points[MAX_NDIM];
static int cum_fitted_points2[MAX_NDIM];
static int base_fitted[MAX_NDIM];
static int base_fitted2[MAX_NDIM];
static int top_fitted[MAX_NDIM];
static int top_fitted2[MAX_NDIM];
static int points_fitted[MAX_NDIM];
static int points_fitted2[MAX_NDIM];

static Status alloc_position_memory(int nmax, String error_msg)
{
    MALLOC2(peak_posns, int, nmax, nfitted);
    MALLOC(peaks, Fit_peak *, nmax);

    return  OK;
}

static Status alloc_fit_memory(int n, String error_msg)
{
    int i;
    static int nalloc = 0;

    if (n <= nalloc)
	return  OK;

    if (nalloc > 0)
    {
	FREE(fitted_pnts, float);
	FREE(fitted_data, float);
	FREE(fitted_wgt, float);
    }

    sprintf(error_msg, "allocating fit memory");

    MALLOC(fitted_pnts, float, n);
    MALLOC(fitted_data, float, n);
    MALLOC(fitted_wgt, float, n);

    for (i = 0; i < n; i++)
	fitted_pnts[i] = i;

    return  OK;
}

static void free_peak_memory(Fit_info *fit_info)
{
    int i;

    for (i = 0; i < nfitted; i++)
	FREE(fit_info->peak->fit[i], float);

    FREE(fit_info->peak->phase, float);
    FREE(fit_info->peak->decay, float);
    FREE(fit_info->peak->center, float);

    FREE(fit_info->peak->fit, float *);

    FREE(fit_info->peak, Fit_peak);
}

static Status alloc_peak_memory(Fit_info *fit_info, String error_msg)
{
    int i, j;
    Fit_peak *peak;

    sprintf(error_msg, "allocating peak memory");

    MALLOC(peak, Fit_peak, 1);

    MALLOC(peak->phase, float, nfitted);
    MALLOC(peak->decay, float, nfitted);
    MALLOC(peak->center, float, nfitted);

    MALLOC(peak->fit, float *, nfitted);

    for (i = 0; i < nfitted; i++)
    {
	j = dim_fitted[i];
	MALLOC(peak->fit[i], float, npoints[j]);
    }

    fit_info->peak = peak;

    return  OK;
}

static void subtract_previous_peaks(int n, Fit_info **fit_info)
/*  only called if ndim = nfitted, so assume this below  */
{
    int i, j;

    for (i = 0; i < fitted_points; i++)
    {
	find_point(ndim, i, point, cum_fitted_points,
					base_fitted, npoints, TRUE);

	for (j = 0; j < n; j++)
	{
	    if (fit_info[j]->peak)
	    {
		fitted_data[i] -= peak_contribution(nfitted, point,
						fit_info[j]->peak);
	    }
	}
    }
}

static void copy_position(Fit_info *fit_info)
{
    COPY_VECTOR(fit_info->after.position, fit_info->before.position, ndim);
}

static void new_position_and_value(Fit_info *fit_info)
/*  only called if ndim = nfitted, so assume this below  */
{
    int i;
    float v;

/*  should do more than below, in case extremum has moved  */

    SUBTRACT_VECTORS(point, fit_info->before.position, base_fitted, ndim);
    INDEX_OF_ARRAY(i, point, cum_fitted_points, ndim);

    v = fitted_data[i];

    fit_info->after.value = v;
/*  should calculate after position, for now leave equal to before position  */
    copy_position(fit_info);
}

static int find_fitted_point(int n)
{
    int i, j, p, m;

    find_point(ndim, n, point, cum_fitted_points, base_fitted, npoints, TRUE);

    p = 0;
    for (i = 0; i < nfitted; i++)
    {
	j = dim_fitted[i];
	m = (point[j] - base_fitted[j]) % npoints[j];
	p += m * cum_fitted_points2[i];
    }

    return  p;
}

static Status extract_data_to_fit(String error_msg)
{
    int i, j, n, r, m;
    float v;

    for (i = 0; i < ndim; i++)
    {
	r = width[i];

	base_fitted[i] = min_posn[i] - r;

	if (!periodic[i])
	    base_fitted[i] = MAX(0, base_fitted[i]);

	top_fitted[i] = max_posn[i] + r + 1;

	if (!periodic[i])
	    top_fitted[i] = MIN(npoints[i], top_fitted[i]);

	points_fitted[i] = top_fitted[i] - base_fitted[i];
    }

    fitted_points = 1;
    for (i = 0; i < ndim; i++)
    {
	cum_fitted_points[i] = fitted_points;
	fitted_points *= points_fitted[i];
    }

    m = 1;
    for (i = 0; i < nfitted; i++)
    {
	j = dim_fitted[i];

	base_fitted2[i] = base_fitted[j];
	top_fitted2[i] = top_fitted[j];
	points_fitted2[i] = points_fitted[j];

	cum_fitted_points2[i] = m;
	m *= points_fitted2[i];
    }

    CHECK_STATUS(alloc_fit_memory(m, error_msg));

    ZERO_VECTOR(fitted_data, m);

    for (i = 0; i < fitted_points; i++)
    {
	n = find_fitted_point(i);

	CHECK_STATUS(data_value(&v, point, error_msg));

	fitted_data[n] += v;

/*
	printf("i = %d, n = %d, v = %2.1e, data = %2.1e\n",
						i, n, v, fitted_data[n]);
*/
    }

    return  OK;
}

static float distance2(int n, int *pnt, int *posn)
{
    int i;
    float d, d2;

    d2 = 0;
    for (i = 0; i < n; i++)
    {
	d = posn[i] - pnt[i];
	d2 += d * d;
    }

    return  d2;
}

static float smallest_distance2(int n, int *pnt, int ngroup, int **posns)
{
    int i;
    float d2, d2_min;

    d2_min = distance2(n, pnt, posns[0]);

    for (i = 1; i < ngroup; i++)
    {
	d2 = distance2(n, pnt, posns[i]);
	d2_min = MIN(d2, d2_min);
    }

    return  d2_min;
}

static Bool have_good_peaks(int n, int ngroup, Fit_info **fit_info)
{
    int i, j, m, g;
    float scale;
    double d2, wgt_power;

    if (subtract)
    {
	subtract_previous_peaks(n, fit_info);

	for (g = 0; g < ngroup; g++)
	    new_position_and_value(fit_info[n+g]);
    }
    else
    {
	for (g = 0; g < ngroup; g++)
	{
	    fit_info[n+g]->after.value = fit_info[n+g]->before.value;
	    copy_position(fit_info[n+g]);
	}
    }

/*  still need to check that have enough points in fitted to do fit  */
/*  do it here rather than before subtraction in case fitted has moved  */

    m = 1;
    for (i = 0; i < nfitted; i++)
    {
	j = dim_fitted[i];

	m *= top_fitted[j] - base_fitted[j];

	for (g = 0; g < ngroup; g++)
	    peak_posns[g][i] = fit_info[n+g]->after.position[j];
    }

    if (m < (ngroup*nparams))  /* might want to loosen this some day */
	return  FALSE;

    for (g = 0; g < ngroup; g++)
	peaks[g] = fit_info[n+g]->peak;

    scale = fit_info[n]->after.value;
    for (g = 1; g < ngroup; g++)
	scale = MAX(scale, fit_info[n+g]->after.value);

    if (ABS(scale) < SMALL_VALUE)
	scale = SMALL_VALUE;
    else
	scale = 1.0 / scale;

    SCALE_VECTOR(fitted_data, fitted_data, scale, m);
				/* should help numerical stability */

    wgt_power = - nfitted;

    for (i = 0; i < m; i++)
    {
	find_point(nfitted, i, point, cum_fitted_points2,
					base_fitted2, npoints2, TRUE);

	d2 = smallest_distance2(nfitted, point, ngroup, peak_posns);
	d2 = MAX(d2, 0.5);  /* floor on d2 */

	fitted_wgt[i] = pow(d2, wgt_power);

/*
	printf("pnt %1.0f, data %2.1e, d2 %2.1lf, wgt %2.1e\n",
			fitted_pnts[i], fitted_data[i], d2, fitted_wgt[i]);
*/
    }

/*  now try to do the actual fit  */

    printf("working on fitting %d peak%s:", ngroup, (ngroup == 1) ? "" : "s");

    for (g = 0; g < ngroup; g++)
        printf(" %d", fit_info[n+g]->n);

    printf("\n");

    if (!have_good_fit(fitted_pnts, fitted_data, fitted_wgt,
			m, nfitted, nparams, ngroup,
			chisq, base_fitted2, top_fitted2, cum_fitted_points2,
			script, npoints2, npts, complex, freq_a, freq_b,
			peak_posns, scale, peaks))
	    return  FALSE;

    return  TRUE;
}

static Status record_fit(Fit_info *fit_info, int *posn, String error_msg)
{
    calculate_fit(nfitted, complex, script, npts, freq_a, freq_b,
							posn, fit_info->peak);

    npeaks++;

    return  OK;
}

static void min_max_posn(int ngroup, Fit_info **fit_info)
{
    int i, g;

    COPY_VECTOR(min_posn, fit_info[0]->before.position, nfitted);
    COPY_VECTOR(max_posn, fit_info[0]->before.position, nfitted);

    for (g = 1; g < ngroup; g++)
    {
	for (i = 0; i < nfitted; i++)
	{
	    min_posn[i] = MIN(min_posn[i], fit_info[g]->before.position[i]);
	    max_posn[i] = MAX(max_posn[i], fit_info[g]->before.position[i]);
	}
    }
}

static Status calculate_peaks(int nfit, Fit_info **fit_info, int *upper,
							String error_msg)
{
    int i, j, g, p, ngroup;

    npeaks = 0;

    p = nfit / 20;
    p = MAX(p, 1);

    for (i = j = 0; i < nfit; i = upper[j++])
    {
	if (!(i % p))
	    printf("\t...fitting peaks (%1.0f%% done)\n", (100.0*i)/nfit);

	ngroup = upper[j] - i;

	for (g = 0; g < ngroup; g++)
	    CHECK_STATUS(alloc_peak_memory(fit_info[i+g], error_msg));

	min_max_posn(ngroup, fit_info+i);

	CHECK_STATUS(extract_data_to_fit(error_msg));

	if (have_good_peaks(i, ngroup, fit_info))
	{
	    for (g = 0; g < ngroup; g++)
		CHECK_STATUS(record_fit(fit_info[i+g], peak_posns[g], error_msg));
	}
	else
	{
	    for (g = 0; g < ngroup; g++)
		free_peak_memory(fit_info[i+g]);
	}
    }

    printf("number of peaks fitted = %d (%1.0f%%)\n",
						npeaks, (100.0*npeaks)/nfit);

    return  OK;
}

Status fit_peaks(Size_info *size_info, Fit_param *fit_param, int nfit,
		Fit_info **fit_info, Group_info *group_info, String error_msg)
{
    int i, j, m;

    ndim = size_info->ndim;
    npoints = size_info->npoints;

    chisq = fit_param->chisq;
    subtract = fit_param->subtract;
    nfitted = fit_param->nfitted;
    dim_fitted = fit_param->dim_fitted;
    width = fit_param->width;
    periodic = fit_param->periodic;
    nparams = fit_param->nparams;

    m = 1;
    for (i = 0; i < nfitted; i++)
    {
	j = dim_fitted[i];

	npoints2[i] = npoints[j];
	script[i] = &(fit_param->script[j]);
	npts[i] = fit_param->npts[j];
	complex[i] = fit_param->complex[j];
	freq_a[i] = fit_param->freq_a[j];
	freq_b[i] = fit_param->freq_b[j];

	fixed_phase[i] = fit_param->fixed_phase[j];
	phase[i] = fit_param->phase[j];
	fixed_decay[i] = fit_param->fixed_decay[j];
	decay[i] = fit_param->decay[j];

	m *= 2*width[j] + 1;
    }

    if (m < nparams)
    {
	sprintf(error_msg, "product of widths = %d, must be at least %d",
								m, nparams);
	return  ERROR;
    }

    if (subtract && (nfitted != ndim))
	RETURN_ERROR_MSG("to subtract must fit peaks in all dimensions");

    initialize_fixed_params(nfitted, fixed_phase, phase, fixed_decay, decay);

    CHECK_STATUS(alloc_position_memory(group_info->group_max, error_msg));

    CHECK_STATUS(calculate_peaks(nfit, fit_info, group_info->upper, error_msg));

    return  OK;
}
