# MIT License
# Copyright (c) 2017 dcolam
from __future__ import with_statement, division
import sys, time, os, traceback, random, time, ConfigParser, csv, math, fnmatch, locale
from ij import IJ, ImagePlus, WindowManager, CompositeImage
from org.sqlite import SQLiteConfig
from java.lang import Class, System, Double
from java.awt import Color, Font
from loci.plugins.util import WindowTools as wt
from java.sql import DriverManager, SQLException, Types, Statement
from ij.gui import GenericDialog, WaitForUserDialog, Roi, ShapeRoi, Overlay, TextRoi
from ij.process import ImageProcessor, AutoThresholder
from ij.plugin import ChannelSplitter, ImageCalculator, RGBStackMerge, ZProjector, Duplicator, StackEditor, \
    Concatenator, RoiEnlarger, RoiRotator
from sc.fiji.analyzeSkeleton import AnalyzeSkeleton_
from ij.plugin import ZProjector as zp
from fiji.stacks import Hyperstack_rearranger as hyr
from ij.plugin.frame.RoiManager import multiMeasure
from ij.plugin.filter import EDM, ParticleAnalyzer, Calibrator, Filler, Analyzer, PlugInFilterRunner
from ij.measure import Measurements as ms
from loci.plugins import BF
from ij.plugin.filter import ThresholdToSelection as tts
from ij.measure import ResultsTable, Calibration
from ij.io import RoiDecoder
import org.scijava.command.Command
from org.scijava.util import ColorRGB

# Main Dialog manager
class Dialoger(object):
    autoMethods = AutoThresholder.getMethods()
    allMethods = ["Manual"]
    allMethods += autoMethods

    def __init__(self):
        self.input_path_dir = ''
        self.output_path_dir = ''
        self.ext = ''
        self.delimiter = "_"
        self.filenames = []
        self.groupedFiles = {}
        self.zStack = False
        self.test = False
        self.c1 = Channel()
        self.c2 = Channel()
        self.c3 = Channel()
        self.c4 = Channel()
        self.output_path_dict = {}

        self.channels = [self.c1, self.c2, self.c3, self.c4]
        self.overwriteDB = False
        self.getOptions()
        self.loadfilenames()
        [self.getParticleAnalyzerOptions(i) for i, x in enumerate(self.channels) if self.channels[i].pa]

        for j in self.channels:
            if any(j.list_1whichChannel):
                [self.getParticleAnalyzerOptions(i, "coloc") for i, x in enumerate(j.list_1whichChannel) if
                 not self.channels[i].pa and x]

    def loadfilenames(self):
        filenames = []
        groupedfiles = {}
        if self.ext[0] != ".":
            self.ext = "." + self.ext
        for root, dirs, files in os.walk(self.input_path_dir):
            group = os.path.split(root)[1]
            if not group in groupedfiles:
                groupedfiles[group] = []
            for j in files:
                if os.path.splitext(os.path.join(root, j))[1] == self.ext:
                    groupedfiles[group].append(os.path.join(root, j))
                    filenames.append(os.path.join(root, j))

        if not filenames:
            WaitForUserDialog("No files have been found. Please, check for correct file-extension (file-type) or for presence of images in the folder").show()
            sys.exit("Analysis cancelled!")
        output_path_dir = os.path.join(self.input_path_dir, "Particle_Analysis")
        if not os.path.isdir(output_path_dir):
            os.makedirs(output_path_dir)

        self.output_path_dir = output_path_dir

        for k in groupedfiles:
            g_path = os.path.join(self.output_path_dir, k)
            if not os.path.isdir(g_path):
                os.makedirs(g_path)
            self.output_path_dict[k] = g_path

        output_table = os.path.join(self.output_path_dir, "Output_Table")
        if not os.path.isdir(output_table):
            os.makedirs(output_table)
        self.output_path_dict["output_table_path"] = output_table

        self.groupedFiles = dict((k, v) for k, v in groupedfiles.items() if v)
        self.filenames = filenames

    def getOptions(self):
        section = "ChannelOptions"
        ext = cp.cp.get(section, "ext")
        delimiter = cp.cp.get(section, "delimiter")
        zStackBool = cp.cp.getboolean(section, "zStackBool")
        c1Name = cp.cp.get(section, "c1Name")
        c1Opt_boolList = eval(cp.cp.get(section, "c1Opt_boolList"))

        backgroundRadc1 = cp.cp.getfloat(section, "backgroundRadc1")
        sigmaC1 = cp.cp.getfloat(section, "sigmaC1")
        c2Name = cp.cp.get(section, "c2Name")
        c2Opt_boolList = eval(cp.cp.get(section, "c2Opt_boolList"))

        backgroundRadc2 = cp.cp.getfloat(section, "backgroundRadc2")
        sigmaC2 = cp.cp.getfloat(section, "sigmaC2")
        c3Name = cp.cp.get(section, "c3Name")
        c3Opt_boolList = eval(cp.cp.get(section, "c3Opt_boolList"))

        backgroundRadc3 = cp.cp.getfloat(section, "backgroundRadc3")
        sigmaC3 = cp.cp.getfloat(section, "sigmaC3")
        c4Name = cp.cp.get(section, "c4Name")
        c4Opt_boolList = eval(cp.cp.get(section, "c4Opt_boolList"))

        backgroundRadc4 = cp.cp.getfloat(section, "backgroundRadc4")
        sigmaC4 = cp.cp.getfloat(section, "sigmaC4")
        testBool = cp.cp.getboolean(section, "testBool")

        if not headless:
            gd = GenericDialog("Options")
            gd.addMessage("Input Folder: %s" % expath)
            gd.addCheckboxGroup(1, 2, ["Z-project?", "Overwrite old database if it already exists?"],
                                [zStackBool, True])
            gd.addStringField("File extension", ext, 10)
            gd.addStringField("Title separator", delimiter, 10)
            gd.addMessage(
                "__________________________________________________________________________________________________________________________________________________")
            gd.addMessage("Set details for Channel 1")
            gd.addStringField("Channel 1", c1Name, 8)
            gd.addCheckboxGroup(1, 4, ["Background Substraction", "Adjust Brightness/Contrast automatically?",
                                       "Adjust Brightness/Contrast manually?", "Particle Analysis"],
                                c1Opt_boolList)
            gd.addNumericField("Background radius:", backgroundRadc1, 0)
            gd.addNumericField("Gaussian Blur (0 if not, otherwise state the radius)", sigmaC1, 2)
            gd.addMessage(
                "__________________________________________________________________________________________________________________________________________________")
            gd.addMessage("Set details for Channel 2")
            gd.addStringField("Channel 2", c2Name, 8)
            gd.addCheckboxGroup(1, 4, ["Background Substraction", "Adjust Brightness/Contrast automatically?",
                                       "Adjust Brightness/Contrast manually?", "Particle Analysis"],
                                c2Opt_boolList)
            gd.addNumericField("Background radius:", backgroundRadc2, 0)
            gd.addNumericField("Gaussian Blur (0 if not, otherwise state the radius)", sigmaC2, 2)
            gd.addMessage(
                "__________________________________________________________________________________________________________________________________________________")
            gd.addMessage("Set details for Channel 3")
            gd.addStringField("Channel 3", c3Name, 8)
            gd.addCheckboxGroup(1, 4, ["Background Substraction", "Adjust Brightness/Contrast automatically?",
                                       "Adjust Brightness/Contrast manually?", "Particle Analysis"],
                                c3Opt_boolList)
            gd.addNumericField("Background radius:", backgroundRadc3, 0)
            gd.addNumericField("Gaussian Blur (0 if not, otherwise state the radius)", sigmaC3, 2)
            gd.addMessage(
                "__________________________________________________________________________________________________________________________________________________")
            gd.addMessage("Set details for Channel 4")
            gd.addStringField("Channel 4", c4Name, 8)
            gd.addCheckboxGroup(1, 4, ["Background Substraction", "Adjust Brightness/Contrast automatically?",
                                       "Adjust Brightness/Contrast manually?", "Particle Analysis"],
                                c4Opt_boolList)
            gd.addNumericField("Background radius:", 50, 0)
            gd.addNumericField("Gaussian Blur (0 if not, otherwise state the radius)", sigmaC4, 2)
            gd.addMessage("_________________________________________________________________________________")
            gd.addCheckbox("Test parameters on random pictures?", testBool)
            wt.addScrollBars(gd)

            gd.showDialog()

            if gd.wasCanceled():
                print "User canceled dialog!"
                sys.exit("Analysis was cancelled")

            if isinstance(expath, str):
                input_path_dir = expath
            else:
                input_path_dir = expath.getAbsolutePath()

            zStack = zStackBool = gd.getNextBoolean()
            ext = gd.getNextString()
            delimiter = gd.getNextString()
            self.overwriteDB = gd.getNextBoolean()

            info_channels = []
            for i in range(0, 4):
                channelName = gd.getNextString()
                background = gd.getNextBoolean()
                brightness_auto = gd.getNextBoolean()
                brightness_man = gd.getNextBoolean()
                pa = gd.getNextBoolean()
                radius = gd.getNextNumber()
                gaussian = gd.getNextNumber()

                if brightness_auto:
                    brightness_man = False

                if i == 0:
                    c1Name = channelName
                    c1Opt_boolList = [background, brightness_auto, brightness_man, pa]
                    backgroundRadc1 = radius
                    sigmaC1 = gaussian
                if i == 1:
                    c2Name = channelName
                    c2Opt_boolList = [background, brightness_auto, brightness_man, pa]
                    backgroundRadc2 = radius
                    sigmaC2 = gaussian
                if i == 2:
                    c3Name = channelName
                    c3Opt_boolList = [background, brightness_auto, brightness_man, pa]
                    backgroundRadc3 = radius
                    sigmaC3 = gaussian
                if i == 3:
                    c4Name = channelName
                    c4Opt_boolList = [background, brightness_auto, brightness_man, pa]
                    backgroundRadc4 = radius
                    sigmaC4 = gaussian

                info_channels.append([channelName, background, radius, brightness_auto, brightness_man, pa, gaussian])
                self.channels[i].setInfo(channel_name=channelName, background_substraction=background,
                                         background_radius=radius, brightness_auto=brightness_auto,
                                         brightness_man=brightness_man, pa=pa, gaussian_blur=gaussian)

            self.test = testBool = gd.getNextBoolean()

            l = ["expath", "ext", "delimiter", "zStackBool", "c1Name", "c1Opt_boolList", "backgroundRadc1", "sigmaC1", "c2Name",
                 "c2Opt_boolList", "backgroundRadc2", "sigmaC2", "c3Name", "c3Opt_boolList", "backgroundRadc3",
                 "sigmaC3", "c4Name", "c4Opt_boolList","backgroundRadc4", "sigmaC4", "testBool"]

            n = [expath, ext, delimiter, zStackBool, c1Name, c1Opt_boolList, backgroundRadc1, sigmaC1, c2Name,c2Opt_boolList,
                 backgroundRadc2, sigmaC2, c3Name, c3Opt_boolList,backgroundRadc3, sigmaC3, c4Name, c4Opt_boolList,
                 backgroundRadc4, sigmaC4, testBool]

            cp.update(section, dict((na, str(n[i])) for i, na in enumerate(l)))
            self.input_path_dir = input_path_dir
            self.zStack = zStack
            self.ext = ext
            self.delimiter = delimiter
        else:
            self.input_path_dir = expath2
            self.zStack = zStackBool
            self.ext = ext
            self.overwriteDB = True
            self.delimiter = delimiter

            cnames = [c1Name, c2Name, c3Name, c3Name]
            backgrounds = [backgroundRadc1, backgroundRadc2, backgroundRadc3, backgroundRadc4]
            radiuss = [sigmaC1, sigmaC2, sigmaC3, sigmaC4]
            info_channels = []
            for i in range(0, 4):
                channelName = cnames[i]
                radius = backgrounds[i]

                if i == 0:
                    background = c1Opt_boolList[0]
                    brightness_auto = c1Opt_boolList[1]
                    brightness_man = c1Opt_boolList[2]
                    pa = c1Opt_boolList[3]
                    c1Name = channelName
                    backgroundRadc1 = radius
                    gaussian = sigmaC1

                if i == 1:
                    background = c2Opt_boolList[0]
                    brightness_auto = c2Opt_boolList[1]
                    brightness_man = c2Opt_boolList[2]
                    pa = c2Opt_boolList[3]
                    c2Name = channelName
                    backgroundRadc2 = radius
                    gaussian = sigmaC2
                if i == 2:
                    background = c3Opt_boolList[0]
                    brightness_auto = c3Opt_boolList[1]
                    brightness_man = c3Opt_boolList[2]
                    pa = c3Opt_boolList[3]
                    c3Name = channelName
                    backgroundRadc3 = radius
                    gaussian = sigmaC3

                if i == 3:
                    background = c4Opt_boolList[0]
                    brightness_auto = c4Opt_boolList[1]
                    brightness_man = c4Opt_boolList[2]
                    pa = c4Opt_boolList[3]
                    c4Name = channelName
                    backgroundRadc4 = radius
                    gaussian = sigmaC4

                info_channels.append([channelName, background, radius, brightness_auto, brightness_man, pa, gaussian])
                self.channels[i].setInfo(channel_name=channelName, background_substraction=background,
                                         background_radius=radius, brightness_auto=brightness_auto,
                                         brightness_man=brightness_man, pa=pa, gaussian_blur=gaussian)

            self.test = False

    def getParticleAnalyzerOptions(self, channel_number, coloc=''):
        section = "ParticleAnalysisOptions%s" % channel_number
        paInOutBool_list = eval(cp.cp.get(section, "paInOutBool_list"))
        paColocBool_list = eval(cp.cp.get(section, "paColocBool_list"))
        paEnlarge = cp.cp.getfloat(section, "paEnlarge")
        paSizeA1 = cp.cp.getfloat(section, "paSizeA1")
        paSizeB1 = cp.cp.getfloat(section, "paSizeB1")
        paSizeA2 = cp.cp.getfloat(section, "paSizeA2")
        paSizeB2 = cp.cp.getfloat(section, "paSizeB2")
        paCirc1 = cp.cp.getfloat(section, "paCirc1")
        paCirc2 = cp.cp.getfloat(section, "paCirc2")
        paMethod = cp.cp.get(section, "paMethod")
        addMeth1 = cp.cp.get(section, "addMeth1")
        watershed1 = cp.cp.getboolean(section, "watershed1")
        addMeth2 = cp.cp.get(section, "addMeth2")
        watershed2 = cp.cp.getboolean(section, "watershed2")

        if not headless:

            if coloc == "coloc":
                gd = GenericDialog("Options for Channel %s colocalized Particle Analysis" % (channel_number + 1))

            else:
                gd = GenericDialog("Options for Channel %s Particle Analysis" % (channel_number + 1))

            gd.addMessage("Set details for Channel %s" % (channel_number + 1))
            gd.addMessage("___________________________________________________________________________________")

            if not coloc == "coloc":
                gd.addMessage("Colocalisation Options")
                gd.addCheckboxGroup(1, 2, ["Inside mask?", "Or outside?"],
                                    paInOutBool_list)
                gd.addCheckboxGroup(1, 4, ["C1", "C2", "C3", "C4"],
                                    paColocBool_list)
                gd.addNumericField("Enlarge mask in [um]? (For shrinkage put negative numbers)", paEnlarge,
                                   2)
                gd.addMessage("___________________________________________________________________________________")
            gd.addMessage("Particle Analysis Options")

            if channel_number == 0:
                gd.addNumericField("Lower Particle Size:", paSizeA1, 3)
                gd.addNumericField("Higher Particle Size:", paSizeB1, 3)

            else:
                gd.addNumericField("Lower Particle Size:", paSizeA2, 3)
                gd.addNumericField("Higher Particle Size:", paSizeB2, 3)

            gd.addNumericField("Circularity bottom:", paCirc1, 1)
            gd.addNumericField("Circularity top:", paCirc2, 1)
            gd.addChoice("Binary Threshold Method", self.allMethods, paMethod)

            if channel_number == 0:
                gd.addStringField("Do you want to test additional thresholds? (Separate only by space)", addMeth1,
                                  8)
                gd.addCheckbox("Watershed?", watershed1)

            else:
                gd.addStringField("Do you want to test additional thresholds? (Separate only by space)", addMeth2,
                                  8)
                gd.addCheckbox("Watershed?", watershed2)

            gd.showDialog()
            if gd.wasCanceled():
                print "User canceled dialog!"
                sys.exit("Analysis was cancelled")

            if not coloc == "coloc":
                pa_inside = gd.getNextBoolean()
                pa_outside = gd.getNextBoolean()

                paInOutBool_list = [pa_inside, pa_outside]

                bool_c1 = gd.getNextBoolean()
                bool_c2 = gd.getNextBoolean()
                bool_c3 = gd.getNextBoolean()
                bool_c4 = gd.getNextBoolean()

                pa_enlarge_mask = paEnlarge = gd.getNextNumber()

                list_1whichChannel = paColocBool_list = [bool_c1, bool_c2, bool_c3, bool_c4]

            if channel_number == 0:
                lowerSize = paSizeA1 = gd.getNextNumber()
                higherSize = paSizeB1 = gd.getNextNumber()
            else:
                lowerSize = paSizeA2 = gd.getNextNumber()
                higherSize = paSizeB2 = gd.getNextNumber()

            circ1 = paCirc1 = gd.getNextNumber()
            circ2 = paCirc2 = gd.getNextNumber()
            pa_threshold_c1 = paMethod = gd.getNextChoice()

            if channel_number == 0:
                pa_addthreshold_c1 = addMeth1 = gd.getNextString()
                watershed = watershed1 = gd.getNextBoolean()
            else:
                pa_addthreshold_c1 = addMeth2 = gd.getNextString()
                watershed = watershed2 = gd.getNextBoolean()

            pa_thresholds_c1 = [pa_threshold_c1]

            if pa_addthreshold_c1:
                pa_addthreshold_c1 = pa_addthreshold_c1.split(" ")

                for i in pa_addthreshold_c1:
                    if i in self.allMethods:
                        pa_thresholds_c1.append(i)
                    else:
                        print i + " is not a Threshold!"
            if not coloc == "coloc":
                self.channels[channel_number].setInfo(lowerSize=lowerSize, higherSize=higherSize, circ1=circ1,
                                                      circ2=circ2,
                                                      method=pa_thresholds_c1, list_1whichChannel=list_1whichChannel,
                                                      watershed=watershed, pa_inside=pa_inside, pa_outside=pa_outside,
                                                      pa_enlarge_mask=pa_enlarge_mask)
                if channel_number == 0:
                    l = ["paInOutBool_list", "paEnlarge", "paColocBool_list", "paSizeA1", "paSizeB1", "paCirc1", "paCirc2","paMethod", "addMeth1", "watershed1"]
                    n = [paInOutBool_list, paEnlarge, paColocBool_list, paSizeA1, paSizeB1, paCirc1, paCirc2, paMethod, addMeth1, watershed1]
                else:
                    l = ["paInOutBool_list", "paEnlarge", "paColocBool_list", "paSizeA2", "paSizeB2", "paCirc1","paCirc2","paMethod", "addMeth2", "watershed2"]
                    n = [paInOutBool_list, paEnlarge, paColocBool_list, paSizeA2, paSizeB2, paCirc1, paCirc2, paMethod,
                         addMeth2, watershed2]
            else:
                self.channels[channel_number].setInfo(lowerSize=lowerSize, higherSize=higherSize, circ1=circ1,circ2=circ2, method=pa_thresholds_c1, watershed=watershed)
                if channel_number == 0:
                    l = ["paSizeA1", "paSizeB1", "paCirc1", "paCirc2", "paMethod", "addMeth1", "watershed1"]
                    n = [paSizeA1, paSizeB1, paCirc1, paCirc2, paMethod, addMeth1, watershed1]
                else:
                    l = ["paSizeA2", "paSizeB2", "paCirc1", "paCirc2", "paMethod", "addMeth2", "watershed2"]
                    n = [paSizeA2, paSizeB2, paCirc1, paCirc2, paMethod, addMeth2, watershed2]

            cp.update(section, dict((na, str(n[i])) for i, na in enumerate(l)))
        else:
            if not coloc == "coloc":
                paInOutBool_list = paInOutBool_list
                pa_enlarge_mask = paEnlarge
                list_1whichChannel = paColocBool_list
            if channel_number == 0:
                lowerSize = paSizeA1
                higherSize = paSizeB1
            else:
                lowerSize = paSizeA2
                higherSize = paSizeB2
            circ1 = paCirc1
            circ2 = paCirc2
            pa_threshold_c1 = paMethod

            if channel_number == 0:
                pa_addthreshold_c1 = addMeth1
                watershed = watershed1
            else:
                pa_addthreshold_c1 = addMeth2
                watershed = watershed2
            pa_thresholds_c1 = [pa_threshold_c1]
            if pa_addthreshold_c1:
                pa_addthreshold_c1 = pa_addthreshold_c1.split(" ")

                for i in pa_addthreshold_c1:
                    if i in self.allMethods:
                        pa_thresholds_c1.append(i)
                    else:
                        print i + " is not a Threshold!"
            if not coloc == "coloc":
                self.channels[channel_number].setInfo(lowerSize=lowerSize, higherSize=higherSize, circ1=circ1,
                                                      circ2=circ2,
                                                      method=pa_thresholds_c1, list_1whichChannel=list_1whichChannel,
                                                      watershed=watershed, pa_inside=paInOutBool_list[0],
                                                      pa_outside=paInOutBool_list[1],
                                                      pa_enlarge_mask=pa_enlarge_mask)
            else:
                self.channels[channel_number].setInfo(lowerSize=lowerSize, higherSize=higherSize, circ1=circ1,
                                                      circ2=circ2,
                                                      method=pa_thresholds_c1, watershed=watershed)

# At the beginning of the Script, this object sets up the SelectionManager and the Dialoger and gathers all parameters
class testParameters(object):
    def __init__(self):
        self.d = ""
        self.s = ""
        self.another = False
        self.newparams = False
        self.start = False

    def dialog(self):
        self.another = False
        self.newparams = False
        self.start = False

        gd = GenericDialog("Test parameter mode - Select just one option")
        gd.addCheckbox("Test another image?", False)
        gd.addCheckbox("Try new parameters?", False)
        gd.addCheckbox("Start Experiment", True)
        gd.showDialog()

        if gd.wasCanceled():
            print "User canceled dialog!"
            sys.exit("Analysis was cancelled")
        self.another = gd.getNextBoolean()
        self.newparams = gd.getNextBoolean()
        self.start = gd.getNextBoolean()

    def startScript(self):
        self.d = Dialoger()
        self.s = SelectionManager()
        cp.writeIni()
        cp.readIni()
        if self.d.test:
            filepath = random.choice(self.d.filenames)
            print "*****************************************************"
            print "Testing Parameters on image: %s \n" % os.path.split(filepath)[1]
            l = Image(filepath, self.d, self.s, True)
            self.stitch(filepath)
            self.dialog()
            while self.another:
                IJ.run("Close All")
                filepath = random.choice(self.d.filenames)
                print "*****************************************************"
                print "Testing Parameters on image: %s \n" % os.path.split(filepath)[1]
                l = Image(filepath, self.d, self.s, True)
                self.stitch(filepath)
                self.dialog()

            if self.newparams:
                self.startScript()
            if self.start:
                IJ.run("Close All")
                return self.d, self.s
        else:
            return self.d, self.s

    def stitch(self, filepath):
        imp = BF.openImagePlus(filepath)[0]
        if WindowManager.getImageCount() > 1:
            titles = WindowManager.getImageTitles()
            count = WindowManager.getImageCount()
            ids = [WindowManager.getNthImageID(i) for i in range(1, count + 1)]
            imps = [WindowManager.getImage(i) for i in ids]
            stack = Concatenator().concatenate(imps, False)
            stack.show()
            stack.setT(1)
            for i, t in enumerate(titles):
                stack.setT(i + 1)
                IJ.run("Set Label...", "label=]%s" % t)
            imp.show()
            WaitForUserDialog("Inspect results compared to original image and then proceed").show()
            stack.close()
            imp.close()
        else:
            WaitForUserDialog("Inspect results compared to original image and then proceed").show()
            IJ.getImage().close()
        return

