#! /usr/bin/env python
# -*- coding: utf-8 -*-

import sys
import math
from model import Model

class Minim(object):
    """
Main Minimization Class
-------------------------------------------------------------------------------
Instanciation: Minim(allocations=None, groups=None, new_cae=None)

    'allocations' is a list of cases.
        Each 'case' is a dictionairy of 'levels' and 'allocation' and optianlly other info.
        Levels element is a list, with ith element is the level of ith var
        allocation is the allocation of the case: 0 to n-1, -1 = unalocated

    new_case is one case, with allocation = -1
    """

    def __init__(self, random, model=None, new_cae=None):
        self.pref_group = None
        self.selected_probs = None
        self.random = random
        self._allocations = model.allocations
        self._prob_method = model.prob_method
        self._distance_measure = model.distance_measure
        self._groups = model.groups
        self._variables = model.variables
        self._variables_weight = model.variables_weight
        self._allocation_ratio = model.allocation_ratio
        self._high_prob = model.high_prob
        self._min_group = model.min_group
        self._arms_weight = model.arms_weight

    @property
    def allocations(self):
        return self._allocations

    @allocations.setter
    def allocations(self, allocations):
        self._allocations = allocations

    @property
    def groups(self):
        return self._groups

    @groups.setter
    def groups(self, groups):
        self._groups = groups

    def build_probs(self, min_high_prob=None, min_group=0):
        """
This is only for bcm method of probability assignment
    min_high_prob: The high probability value for the group with
        lowest allocation ratio
    min_group: index of the group with the lowest probability ratio
for Naive method group rows are similar
        """
        if min_high_prob == None:
            min_high_prob = self._high_prob
        self._probs = [[0] * len(self._groups) for g in range(len(self._groups))]
        self._probs[min_group][min_group] = min_high_prob
        for r in self._groups:
            if r == min_group:
                continue
            numerator = sum([self._allocation_ratio[row] for row in range(len(self._groups)) if row != r])
            denominator = sum([self._allocation_ratio[row] for row in range(len(self._groups)) if row != min_group])
            if self._prob_method == Model.BCM:
                self._probs[r][r] = 1.0 - (1.0 * numerator / denominator) * (1.0 - min_high_prob)
            elif self._prob_method == Model.NM:
                self._probs[r][r] = min_high_prob
        for r in self._groups:
            for c in self._groups:
                if r == c:
                    continue
                H = self._probs[r][r]
                numerator = self._allocation_ratio[c]
                denominator = sum([self._allocation_ratio[col] for col in range(len(self._groups)) if col != r])
                if self._prob_method == Model.BCM:
                    self._probs[r][c] = (1.0 * numerator / denominator) * (1.0 - H)
                elif self._prob_method == Model.NM:
                    self._probs[r][c] = 1.0 * (1.0 - min_high_prob) / (len(self._groups) - 1.0)

    def enroll(self, case, freq_table=None):
        if freq_table:
            self.freq_table = freq_table
        else:
            self.build_freq_table()
        new_levels = case['levels']
        level_count = [[] for v in self._variables]
        # add the freq_table to level count
        for row in self.freq_table:
            for variable, level in enumerate(new_levels):
                level_count[variable].append(row[variable][level])
        scores = []
        adj_count = [sum(self.freq_table[i][0]) / self._allocation_ratio[i] for i in self._groups]
        max_list = []
        mx = max(adj_count)
        for i, cnt in enumerate(adj_count):
            if cnt == mx:
                max_list.append(i)
        for g in self._groups:
            w_arms = self._arms_weight
            func = min
            if g in max_list:
                func = max
            scores.append(sum([1.0 * func(self._variables_weight[v], w_arms) * self.get_imbalance_score(level_count[v], g) for v in range(len(self._variables))]))
        # indices of minimum score values
        min_indices = self.get_min_ties_index(scores)
        if len(min_indices) == len(self._groups):
            # all treatment have same score
            # so build a probs based on allocation ratio
            probs = self._allocation_ratio
            self.pref_group = None
        else:
            # indices of prefered treatment(s)
            # randomly selecting an index
            pt = self.random.choice(min_indices)
            self.pref_group = pt
            probs = self._probs[pt]
        self.selected_probs = probs
        case['allocation'] = self.get_rand_biased(probs)
        self.build_freq_table()
        return case['allocation']

    def get_rand_biased(self, probs):
        p = self.random.uniform(0, sum(probs))
        for i in range(len(probs)):
            if p < sum(probs[:i+1]):
                return i

    def get_min_ties_index(self, lst):
        """get indices of min values of the input list"""
        L, ret = min(lst), []
        for idx, item in enumerate(lst):
            # this pair is equal
            if abs(item - L) < sys.float_info.epsilon:
                # so take it
                ret.append(idx)
        return ret

    def get_marginal_balance(self, count):
        numerator = sum([abs(count[i] - count[j]) for i in range(len(count)-1) for j in range(i+1, len(count))])
        denominator = (len(count)-1) * sum(count)
        if denominator == 0: return 0.0
        return (1.0 * numerator) / denominator

    def get_imbalance_score(self, count, group, enroll=True):
        if enroll: count[group] += 1
        adj_count = [(1.0 * count[i]) / self._allocation_ratio[i] for i in range(len(count))]
        if self._distance_measure == Model.MB:
            ret = self.get_marginal_balance(adj_count)
        elif self._distance_measure == Model.rng:
            ret = max(adj_count) - min(adj_count)
        elif self._distance_measure == Model.var:
            ret = self.get_variance(adj_count)
        elif self._distance_measure == Model.SD:
            ret = self.get_standard_deviation(adj_count)
        if enroll: count[group] -= 1
        return ret

    def get_standard_deviation(self, count):
        return math.sqrt(self.get_variance(count))

    def get_variance(self, count):
        mean = 1.0 * sum(count) / len(count)
        sq_terms = sum([(i - mean)**2 for i in count])
        return 1.0 * sq_terms / (len(count) - 1.0)

    def build_freq_table(self):
        table = [[[0 for l in v] for v in self._variables] for g in self._groups]
        self.freq_table = table
        if not self._allocations: return
        for case in self._allocations:
            if not 'allocation' in case: return
            group = case['allocation']
            for variable, level in enumerate(case['levels']):
                table[group][variable][level] += 1
        self.freq_table = table
