#include "match.h"

#include "output.h"
#include "shift.h"

#define  STEPS1			200
#define  STEPS2			50

#define  UNKNOWN_SHIFT		-99

#define  MIN_MATCH_ALLOC	1000  /* should be plenty! */

typedef struct
{
    Shift *shift;
    float low[MAX_COLUMNS];
    float high[MAX_COLUMNS];
}   Shift_data;

typedef struct
{
    int nshifts;
    Shift_data *data;
}   Shift_box;

typedef struct
{
    Columns *columns;
    Shift_box **box;
    int nsteps[MAX_COLUMNS];
}   Box_data;

typedef Status (*Box_func)(int ncolumns, Shift_box *box,
					Shift *shift, String error_msg);

static Box_data box_data[NCOLUMNS];
static int low_box[MAX_COLUMNS];
static int high_box[MAX_COLUMNS];
static float low[MAX_COLUMNS];
static float high[MAX_COLUMNS];
 
static int nacceptable;

static int nmatches[NCOLUMNS];
static int nmatches_alloc[] = { 0, 0 }; /* danger */
static Atom **matches[NCOLUMNS];

static int excluded_match;
static int no_match;
static int one_match;
static int total_match;

static void free_match_memory(int c)
{
    if (nmatches_alloc[c] > 0)
    {
	FREE(matches[c], Atom *);

	nmatches_alloc[c] = 0;
    }
}

static Status alloc_match_memory(int c, int n, String error_msg)
{
    if (n > nmatches_alloc[c])
    {
	free_match_memory(c);

	n = MAX(n, MIN_MATCH_ALLOC); /* do not really need this now */

	MALLOC(matches[c], Atom *, n);

	nmatches_alloc[c] = n;
    }

    return  OK;
}

static Status alloc_box_memory(Box_data *box_data, String error_msg)
{
    int i, j, n;

    sprintf(error_msg, "allocating box memory");

    for (i = 0; i < box_data->nsteps[0]; i++)
    {
	for (j = 0; j < box_data->nsteps[1]; j++)
	{
	    n = box_data->box[i][j].nshifts;

	    if (n > 0)
	    {
		MALLOC(box_data->box[i][j].data, Shift_data, n);
		box_data->box[i][j].nshifts = 0;
	    }
	}
    }

    return  OK;
}

static Status init_boxes(Box_data *box_data, Columns *columns, String error_msg)
{
    int i, j, n1, n2;

    box_data->columns = columns;

    if (columns->ncolumns == 1)
    {
	n1 = box_data->nsteps[0] = 1;
	n2 = box_data->nsteps[1] = STEPS1;

    }
    else /* (box_data->ncolumns == 2) */
    {
	n1 = box_data->nsteps[0] = STEPS2;
	n2 = box_data->nsteps[1] = STEPS2;
    }

    sprintf(error_msg, "allocating box memory");

    MALLOC2(box_data->box, Shift_box, n1, n2);

    for (i = 0; i < n1; i++)
    {
	for (j = 0; j < n2; j++)
	{
	    box_data->box[i][j].nshifts = 0;
	    box_data->box[i][j].data = (Shift_data *) NULL;
	}
    }

    return  OK;
}

static void alias_shift(float *x, Column_data *data, float v)
{
    int n;

    if (v >= data->low)
    {
	n = (v - data->low) / (data->high - data->low);
	v -= n * (data->high - data->low);
    }
    else
    {
	n = (data->high - v) / (data->high - data->low);
	v += n * (data->high - data->low);
    }

    *x = v;
}

static void box_number(int *b, float x, Column_data *data, int nsteps)
{
    int n;

    n = nsteps * (x - data->low) / (data->high - data->low);

/*  protection  */
    n = MAX(n, 0);
    n = MIN(n, nsteps-1);

    *b = n;
}

static Bool acceptable_shift(Box_data *box_data, Shift *shift)
{
    int i, ncolumns = box_data->columns->ncolumns, *nsteps = box_data->nsteps;
    int nshift_columns = box_data->columns->nshift_columns;
    float d;

    for (i = 0; i < nshift_columns; i++)
    {
/* changed 3 Apr 2002 from == to <= (because some people use -9999.9) */
/*
	if ((i < ncolumns) && (shift->atom[i].shift == UNKNOWN_SHIFT))
*/
	if ((i < ncolumns) && (shift->atom[i].shift <= UNKNOWN_SHIFT))
	    return  FALSE;

	if (box_data->columns->data[i].atom_type != shift->atom[i].atom_type)
	    return  FALSE;
    }

    for (i = 0; i < ncolumns; i++)
    {
	d = 2*shift->atom[i].delta - box_data->columns->data[i].high
					+ box_data->columns->data[i].low;

	if (d >= 0) /* extremely unlikely, means uncertainty >= sw */
	{
	    low[i] = box_data->columns->data[i].low;
	    high[i] = box_data->columns->data[i].high;
	}
	else
	{
	    alias_shift(low+i, box_data->columns->data+i,
				shift->atom[i].shift - shift->atom[i].delta);

	    alias_shift(high+i, box_data->columns->data+i,
				shift->atom[i].shift + shift->atom[i].delta);
	}
    }

    for (i = 0; i < ncolumns; i++)
    {
	box_number(low_box+i, low[i], box_data->columns->data+i, nsteps[i]);
	box_number(high_box+i, high[i], box_data->columns->data+i, nsteps[i]);

	if (low_box[i] > high_box[i])
	    high_box[i] += nsteps[i];
    }

    nacceptable++;

    return  TRUE;
}

static Status do_box_func(Box_data *box_data, Shift *shift,
					Box_func func, String error_msg)
{
    int i, j, k, l, *nsteps = box_data->nsteps;

    if (box_data->columns->ncolumns == 1)
    {
	for (i = low_box[0]; i <= high_box[0]; i++)
	{
	    k = i % nsteps[0];
	    CHECK_STATUS((*func)(1, &(box_data->box[0][k]), shift, error_msg));
	}
    }
    else /* (box_data->ncolumns == 2) */
    {
	for (i = low_box[0]; i <= high_box[0]; i++)
	{
	    k = i % nsteps[0];

	    for (j = low_box[1]; j <= high_box[1]; j++)
	    {
		l = j % nsteps[1];

		CHECK_STATUS((*func)(2, &(box_data->box[k][l]), shift, error_msg));
	    }
	}
    }

    return  OK;
}

static Status do_boxes(Box_data *box_data, int nshifts, Shift **shifts,
					Box_func func, String error_msg)
{
    int i;

    for (i = 0; i < nshifts; i++)
    {
	if (acceptable_shift(box_data, shifts[i]))
	    CHECK_STATUS(do_box_func(box_data, shifts[i], func, error_msg));
    }

    return  OK;
}

static Status do_count(int ncolumns, Shift_box *box, Shift *shift,
							String error_msg)
{
    box->nshifts++;

    return  OK;
}

static Status do_fill(int ncolumns, Shift_box *box, Shift *shift,
							String error_msg)
{
    int i, n = box->nshifts;

    box->data[n].shift = shift;

    for (i = 0; i < ncolumns; i++)
    {
	box->data[n].low[i] = low[i];
	box->data[n].high[i] = high[i];
    }

    box->nshifts++;

    return  OK;
}

static Status setup_matches(int nshifts, Shift **shifts, Columns *columns,
							String error_msg)
{
    int i;

    printf("setting up matches\n");

    for (i = 0; i < NCOLUMNS; i++)
    {
	CHECK_STATUS(alloc_match_memory(i, MIN_MATCH_ALLOC, error_msg));

	nacceptable = 0;

	CHECK_STATUS(init_boxes(box_data+i, columns+i, error_msg));
	CHECK_STATUS(do_boxes(box_data+i, nshifts, shifts, do_count,error_msg));

	printf("\t%s column(s): acceptable shifts found = %d (%1.0f%%)\n",
			(i == 0) ? "first" : "second", nacceptable,
			(nshifts == 0) ? 0.0 : (100.0*nacceptable)/nshifts);

	CHECK_STATUS(alloc_box_memory(box_data+i, error_msg));
	CHECK_STATUS(do_boxes(box_data+i, nshifts, shifts, do_fill, error_msg));
    }

    return  OK;
}

static Bool shift_in_range(float low, float high, float shift)
{
    if (low <= high) /* normal case */
    {
	if ((shift > low) && (shift < high))
	    return  TRUE;
	else
	    return  FALSE;
    }
    else /* (low > high) */ /* wrap around */
    {
	if ((shift >= high) && (shift <= low))
	    return  FALSE;
	else
	    return  TRUE;
    }
}

static Bool shifts_in_range(int ncolumns, Shift_data *data, float *shifts)
{
    int i;

    for (i = 0; i < ncolumns; i++)
    {
	if (!shift_in_range(data->low[i], data->high[i], shifts[i]))
	    return  FALSE;
    }

    return  TRUE;
}

static Bool residue_in_range(int residue, int nresidue_ranges,
					int *residue1, int *residue2)
{
    int i;

    for (i = 0; i < nresidue_ranges; i++)
    {
	if ((residue >= residue1[i]) && (residue <= residue2[i]))
	    return  TRUE;
    }

    return FALSE;
}

static Status check_matches(int m, int n, int c, float *shifts,
							String error_msg)
{
    Columns *cols = box_data[c].columns;
    int i, knt, ncolumns = cols->ncolumns;
    Shift_box *box = &(box_data[c].box[m][n]);
    int nshifts = box->nshifts;
    int nresidue_ranges = cols->nresidue_ranges;
    int *residue1 = cols->residue1, *residue2 = cols->residue2;

    knt = 0;

    if (nshifts > 0)
    {
	CHECK_STATUS(alloc_match_memory(c, nshifts, error_msg));

	for (i = 0; i < nshifts; i++)
	{
/* have to check whether nresidue_ranges > 0 because if 0 should
   accept shift but residue_in_range() returns FALSE in that case */
	    if ((nresidue_ranges > 0) &&
		!residue_in_range(box->data[i].shift->atom->residue,
				nresidue_ranges, residue1, residue2))
		continue;

	    if (shifts_in_range(ncolumns, box->data+i, shifts))
	    {
		matches[c][knt] = box->data[i].shift->atom+LIGHT_ATOM;
		knt++;
	    }
	}
    }

    nmatches[c] = knt;

    return  OK;
}

static Bool exclude_crosspeak(float shift, Column_data *data)
{
    float d;

    if (!(data->exclusion))
	return  FALSE;

    d = data->shift - shift;
    d = ABS(d);

    if (d < data->delta)
	return  TRUE;
    else
	return  FALSE;
}

static Status get_matches(int c, Crosspeak *crosspeak, Bool *excluded,
							String error_msg)
{
    int i, j, m, n, ncolumns = box_data[c].columns->ncolumns;
    int *nsteps = box_data[c].nsteps, box[MAX_COLUMNS];
    float shifts[MAX_COLUMNS];

    nmatches[c] = 0;

    for (i = 0; i < ncolumns; i++)
    {
	j = box_data[c].columns->data[i].column;

	if (exclude_crosspeak(crosspeak->atom[j].shift,
						box_data[c].columns->data+i))
	{
	    *excluded = TRUE;

	    return  OK;
	}

	alias_shift(shifts+i, box_data[c].columns->data+i,
						crosspeak->atom[j].shift);
	box_number(box+i, shifts[i], box_data[c].columns->data+i, nsteps[i]);
    }

    if (ncolumns == 1)
    {
	m = 0;
	n = box[0];
    }
    else /* (ncolumns == 2) */
    {
	m = box[0];
	n = box[1];
    }

    CHECK_STATUS(check_matches(m, n, c, shifts, error_msg));

    return  OK;
}

static Bool need_to_find_matches(Crosspeak *crosspeak, Columns *columns)
{
    int i, c;

    for (i = 0; i < columns->ncolumns; i++)
    {
	c = columns->data[i].column;

	if (!atom_assigned(crosspeak->atom+c))
	    return  TRUE;
    }

    return  FALSE;
}

static Status find_matches(Crosspeak *crosspeak, Columns *columns,
							String error_msg)
{
    int i, c, n;
    Bool excluded = FALSE;
    Atom *atom;
    Columns *cols;

    n = 1;
    for (i = 0; (i < NCOLUMNS) && !excluded && (n > 0); i++)
    {
	if (need_to_find_matches(crosspeak, columns+i))
	{
	    CHECK_STATUS(get_matches(i, crosspeak, &excluded, error_msg));
	}
	else
	{
	    cols = columns + i;

/*  line below assumes that first column entry is LIGHT_ATOM, dangerous  */
/*  (currently this is enforced in script.c) */
	    c = cols->data[LIGHT_ATOM].column;
	    atom = crosspeak->atom + c;

	    if ((cols->nresidue_ranges > 0) &&
		!residue_in_range(atom->residue, cols->nresidue_ranges,
					cols->residue1, cols->residue2))
	    {
		nmatches[i] = 0;
	    }
	    else
	    {
		nmatches[i] = 1;
		matches[i][0] = atom;
	    }
	}

	n *= nmatches[i];
    }

    if (excluded)
    {
	excluded_match++;
    }
    else if (n == 0)
    {
	CHECK_STATUS(output_null_match(crosspeak, error_msg));
	no_match++;
    }
    else if (n == 1)
    {
	one_match++;
    }

    total_match += n;

    if (!excluded)
	CHECK_STATUS(output_matches(nmatches, matches, crosspeak, error_msg));

    return  OK;
}

Status find_all_matches(int nshifts, Shift **shifts, int ncrosspeaks,
		Crosspeak **crosspeaks, Columns *columns, String error_msg)
{
    int i;
    float p;

    CHECK_STATUS(setup_matches(nshifts, shifts, columns, error_msg));

    no_match = one_match = total_match = 0;

    printf("finding matches\n");

    for (i = 0; i < ncrosspeaks; i++)
    {
	if (!(i % 100))
	    printf("\t...crosspeak %d\n", i);

	CHECK_STATUS(find_matches(crosspeaks[i], columns, error_msg));
    }

    printf("\texcluded crosspeaks = %d (%1.0f%%), non-excluded crosspeaks = %d (%1.0f%%)\n",
	excluded_match,
		(ncrosspeaks == 0) ? 0.0 : (100.0*excluded_match)/ncrosspeaks,
	ncrosspeaks-excluded_match,
		(ncrosspeaks == excluded_match) ? 0.0 :
			(100.0*(ncrosspeaks-excluded_match))/ncrosspeaks);

    ncrosspeaks -= excluded_match;
    printf("\tstatistics for non-excluded crosspeaks:\n");

    printf("\tno matches = %d (%1.0f%%), one match = %d (%1.0f%%), some match = %d (%1.0f%%)\n",
	no_match, (ncrosspeaks == 0) ? 0.0 : (100.0*no_match)/ncrosspeaks,
	one_match, (ncrosspeaks == 0) ? 0.0 : (100.0*one_match)/ncrosspeaks,
	ncrosspeaks - no_match,
	(ncrosspeaks == 0) ? 0.0 : (100.0*(ncrosspeaks-no_match))/ncrosspeaks);

    p = (ncrosspeaks == no_match) ? 0.0 :
		((float) total_match)/((float) (ncrosspeaks-no_match));

    if (p > 9.5)
	printf("\ttotal number of matches = %d (average per matched = %1.0f)\n",
							total_match, p);
    else
	printf("\ttotal number of matches = %d (average per matched = %2.1f)\n",
							total_match, p);

    return  OK;
}
