
from ij.plugin.frame import RoiManager
from ij.io import DirectoryChooser
from ij.gui import GenericDialog, PolygonRoi
from ij.text import TextWindow
from Mito_ER_Interface_release import Calcium_Interface, Radial_Interface
from math import exp
import os
import time


scale = '0.8'  # pixels/nm
oversample = '2'
bins = map(lambda x: str(x), range(10, 101, 10))  # Ascending in nm
inclusive = True   # set True for inclusive binning; False for exclusive
spline = False  # set True to convert ROIs to spline curves

score_funcs = {
    # (score function, oversample correction func)
    'inverse_sq': (lambda x: 100.0 / ((x + 1)**2), lambda x: x**-2),
    'inverse': (lambda x: 1.0 / (x + 1), lambda x: x**-1),
    'Neher': (lambda x: exp(x / -100000.0) / (x + 1), lambda x: exp(x) / x),
    'dummy': (lambda x: 1.0, lambda x: 1.0)
}


def is_valid_roiset(rois):
    # rois is list of rois
    # TODO: narrow possibilities
    count = 0
    for roi in rois:
        if roi.getType() in [1, 2, 3]:
            count += 1
    if count != 1 or len(rois) < 2:
        return False
    else:
        return True


def prepare_rois_batch(file_list, path):
    all_good = True
    good_rois = []
    roiM = RoiManager.getRoiManager()

    for roiset in file_list:
        roiM.reset()
        roiM.runCommand('Open', path + roiset)
        rois = roiM.getRoisAsArray()

        if is_valid_roiset(rois):
            good_rois.append((rois, roiset))  # rois and name tuples
        else:
            all_good = False
            print 'Problems with RoiSet: ' + roiset
    if all_good:
        return good_rois
    else:
        print ''
        yncd = GenericDialog('')
        yncd.addMessage(
            'Problem ROIs detected; continue?')
        yncd.enableYesNoCancel()
        yncd.centerDialog(True)
        yncd.showDialog()
        if yncd.wasOKed():
            return good_rois
        else:
            return []


def spline_fit(roi):
    if roi.getType() in [1, 3, 7]:
        print ('*** WARNING: Spline doesn\'t work for this ROI type: ' +
               roi.getTypeAsString() + ' ***')
    elif roi.isSplineFit():
        return roi
    else:
        new_roi = PolygonRoi(roi.getFloatPolygon(), roi.getType())
        new_roi.fitSpline()
        return new_roi


dialog = GenericDialog('Set options')
dialog.addChoice('Method', ['Calcium', 'Radial'], 'Calcium')
dialog.addCheckbox('Batch Process?', False)
dialog.addCheckbox('Spline', False)
dialog.addStringField('Scale in pixels/nm', scale)
dialog.addStringField('Oversampling', oversample)
dialog.addStringField('Bins', ','.join(bins))
dialog.addChoice('Binning method', ['inclusive', 'exclusive'], 'inclusive')
dialog.addChoice('Score method (Calcium only)',
                 score_funcs.keys(), 'inverse_sq')
dialog.centerDialog(True)
dialog.showDialog()

if dialog.wasOKed():
    print ''
    choices = map(lambda x: x.getSelectedItem(), dialog.getChoices())
    checks = map(lambda x: x.getState(), dialog.getCheckboxes())
    strings = map(lambda x: x.getText(), dialog.getStringFields())

    calcium_method = (choices[0] == 'Calcium')
    inclusive = (choices[1] == 'inclusive')
    score_func = score_funcs[choices[2]]
    # print 'include', inclusive
    batch, spline = checks[0], checks[1]
    scale, oversample = float(strings[0]), float(strings[1])
    bins = map(lambda x: float(x), strings[2].split(','))
    bins.sort()

    if batch:
        heads = 'File\t'
    else:
        heads = ''

    if calcium_method:
        heads += 'Mito_perim\tER_len\t#_ERs\tscore\tmin_dist'
    else:
        heads += 'Mito_perim\tER_len\t#_ERs\t#_segments\tmean_dist\tmin_dist'

    for binn in bins:
        heads += '\t<' + str(binn) + 'nm'

    if not inclusive:
        heads += '\ttotal'

    text = TextWindow('Interfaces', heads, '', 500, 350)

    if batch:
        dc = DirectoryChooser('Choose directory')
        path = dc.getDirectory()
        file_names = os.listdir(path)
        file_names = filter(lambda x: x.endswith('.zip'), file_names)
        file_names.sort()
        roisets = prepare_rois_batch(file_names, path)
    else:
        rois = RoiManager.getRoiManager().getRoisAsArray()
        if is_valid_roiset(rois):
            roisets = [(rois, '')]
        else:
            raise Exception('ROIs are invalid')

    for roiset in roisets:

        print 'Processing', roiset[1]

        if spline:
            rois = map(lambda x: spline_fit(x), roiset[0])
        else:
            rois = roiset[0]

        if calcium_method:
            interface = Calcium_Interface(
                rois, score_func, bins, scale, oversample, inclusive)
        else:
            interface = Radial_Interface(
                rois, bins, scale, oversample, inclusive)

        start = time.clock()
        interface.calc_min_dists()
        print round(time.clock() - start, 2), 'seconds to process\n'

        new_line = []
        if batch:
            new_line.append(roiset[1])
        new_line += [interface.mito_len(), interface.er_len(),
                     interface.er_num]

        if calcium_method:
            new_line.append(interface.get_norm_score())
        else:
            new_line.append(interface.segments)
            new_line.append(interface.mean_dist)

        new_line.append(interface.abs_min_dist)

        for binn in interface.get_calib_bin_counts():
            new_line.append(binn)

        new_line = map(lambda x: str(x), new_line)
        text.append('\t'.join(new_line))

print 'Done!'
