from PySide2 import QtWidgets as qtw
from PySide2 import QtCore as qtc


class EnrolForm(qtw.QDialog):
    def __init__(self, parent, factors, treatment, edit=False, identifier=None):
        super().__init__()
        self.m_treatment = None
        self.p_treatment = None
        self.m_index = None
        self.p_index = None
        self.selected_probs = None
        self.parent = parent
        self.setGeometry(100, 100, 400, -1)
        self.factors = factors
        self.treatment = treatment
        form_layout = qtw.QFormLayout()
        if edit:
            self.setWindowTitle(self.tr('Editing subject "{}"').format(identifier))
            treatment_combo = qtw.QComboBox()
            treatment_titles = []
            treatment_ids = []
            selected_treatment_index = -1
            for i, t in enumerate(treatment['treatments']):
                treatment_titles.append(t['title'])
                treatment_ids.append(t['id'])
                if treatment['selected_treatment_id'] == t['id']:
                    selected_treatment_index = i
            treatment_combo.addItems(treatment_titles)
            treatment_combo.setCurrentIndex(selected_treatment_index)
            treatment_combo.currentIndexChanged.connect(self.get_treatment_combo_events(treatment_ids))
            form_layout.addRow(self.tr('Treatment'), treatment_combo)
        else:
            self.setWindowTitle(self.tr('Subject enrol'))

        for f, factor in enumerate(factors):
            factor_title = factor['title']
            levels_combo = qtw.QComboBox()
            level_titles = []
            level_ids = []
            selected_index = -1
            for i, level in enumerate(factor['levels']):
                level_titles.append(level['title'])
                level_ids.append(level['id'])
                if factor['selected_level_id'] == level['id']:
                    selected_index = i
            levels_combo.addItems(level_titles)
            levels_combo.setCurrentIndex(selected_index)
            levels_combo.currentIndexChanged.connect(self.get_levels_combo_event(f, level_ids))
            form_layout.addRow(factor_title, levels_combo)
        self.result_label = qtw.QLabel(self.tr('Select factors, click enrol'))
        self.enrol_button = qtw.QPushButton(self.tr('Enrol'))
        close_button = qtw.QPushButton(self.tr('Close'))
        if edit:
            self.enrol_button.setText(self.tr('Save'))
            form_layout.addRow(self.enrol_button)
            close_button.setText(self.tr('Cancel'))
            self.enrol_button.clicked.connect(self.save_subject)
        else:
            form_layout.addRow(self.enrol_button, self.result_label)
            self.enrol_button.clicked.connect(self.enrol_subject)
        clickable(self.result_label).connect(self.on_detail)
        close_button.clicked.connect(self.close)
        form_layout.addRow(close_button)
        self.setLayout(form_layout)

    def on_detail(self):
        trial_id = self.parent.settings.value('last_trial_id', 0, type=int)
        if trial_id == 0:
            return
        ns = self.parent.database.get_subject_count(trial_id)
        if ns == 1:
            qtw.QMessageBox.information(self.parent, self.tr('Minimisation detail'), self.tr('Enrolled by randomisation'))
            return
        final_group = self.tr('preferred')
        prob = self.selected_probs[self.p_index]
        if self.m_treatment != self.p_treatment:
            final_group = self.tr('non-preferred')
            prob = self.selected_probs[self.m_index]
        msg = self.tr('Preferred treatment = {}\n').format(self.p_treatment)
        msg += self.tr('Minimised treatment = {}\n').format(self.m_treatment)
        msg += self.tr('Subject assigned to {} treatment with a probability of {:4.2f}').format(final_group, prob)
        qtw.QMessageBox.information(self.parent, self.tr('Minimisation detail'), msg)

    def save_subject(self):
        self.close()
        self.setResult(qtw.QDialog.Accepted)

    def enrol_subject(self):
        selected_indices = []
        selected_ids = []
        for factor in self.factors:
            if factor['selected_level_id'] == -1:
                self.result_label.setText(self.tr('Select {} level').format(factor['title']))
                return
            for i, level in enumerate(factor['levels']):
                if factor['selected_level_id'] == level['id']:
                    selected_indices.append(i)
                    selected_ids.append((factor['id'], factor['selected_level_id']))
                    break
        mt, mt_index, pt, pt_index, identifier, probs = self.parent.enrol_one(selected_indices, selected_ids)
        self.m_index = mt_index
        self.p_index = pt_index
        self.m_treatment = mt
        self.p_treatment = pt
        self.selected_probs = probs
        style = 'color: blue;'
        if mt != pt:
            style = 'color: red;'
        self.result_label.setStyleSheet(style)
        self.result_label.setText(self.tr('Subject "{}", Enrolled to {}').format(identifier, mt))
        self.enrol_button.setEnabled(False)


    def get_treatment_combo_events(self, treatment_ids):
        def on_treatment_combo_index_changed(index):
            self.treatment['selected_treatment_id'] = treatment_ids[index]
        return on_treatment_combo_index_changed

    def get_levels_combo_event(self, f, level_ids):
        def on_levels_combo_index_changed(index):
            self.factors[f]['selected_level_id'] = level_ids[index]
        return on_levels_combo_index_changed

def clickable(widget):
    class Filter(qtc.QObject):
        clicked = qtc.Signal()
        def eventFilter(self, obj, event):
            if obj == widget:
                if event.type() == qtc.QEvent.MouseButtonRelease:
                    if obj.rect().contains(event.pos()):
                        self.clicked.emit()
                        # The developer can opt for .emit(obj) to get the object within the slot.
                        return True
            return False
    filter = Filter(widget)
    widget.installEventFilter(filter)
    return filter.clicked
