#@ File (label = "Input directory (leave blank to process front-most image)", style = "directory",  required=false) src_dir
#@ String (label = "Save location", choices={"Parent folder", "Same folder"}) save_location
#@ Boolean (label = "Z-Project", description="Preprocess images with a z-projection. If not selected, only the current slice will be used.") z_project
#@ String (label = "Channels to display", description="The channel scheme to use in the final montage. You can use a sequence of:\n    - A channel number to display this channel only in grayscale (e.g: \"2\")\n    - A combination of channel numbers to display a composite of those channels only (e.g. : \"(13)\")\n    - \"c\" to display all channels together, it is a shortcut for (123) on a 3-channel image.\n\nExample : 123(12)c will display:\n    - Channel 1 in grayscale\n    - Channel 2 in grayscale\n    - Channel 3 in grayscale\n    - A composite of channels 1 & 2\n    - A composite of all channels\n    \nIf you want several output images, you can separate your channels schemes with \",\", e.g : 12,123c") separated_channels
#@ Integer (label = "Length of scale bar in µm (0 to disable)", min = 0) scale_length
#@ Float (label = "Pixel size im µm (0 to use existing calibration)", min = 0) pixel_size
#@ Boolean (label = "Process as film") process_as_film

import os
import re

from ij import IJ, ImagePlus, ImageStack
from ij.gui import NonBlockingGenericDialog
from ij.plugin import HyperStackConverter, MontageMaker, ChannelArranger, ZProjector, Duplicator, RGBStackConverter
from ij.plugin.frame import RoiManager

# DESCRIPTION z_project
"Preprocess images with a z-projection. If not selected, only the current slice will be used."

# DESCRIPTION separated_channels
r"""
The channel scheme to use in the final montage. You can use a sequence of:
    - A channel number to display this channel only in grayscale (e.g: "2")
    - A combination of channel numbers to display a composite of those channels only (e.g. : "(13)")
    - "c" to display all channels together, it is a shortcut for (123) on a 3-channel image.

Example : 123(12)c will display:
    - Channel 1 in grayscale
    - Channel 2 in grayscale
    - Channel 3 in grayscale
    - A composite of channels 1 & 2
    - A composite of all channels
    
If you want several output images, you can separate your channels schemes with ",", e.g : 12,123c
"""

# For shutting the mouth up of the IDE
src_dir = src_dir
save_location = save_location
z_project = z_project
separated_channels = separated_channels
scale_length = scale_length
pixel_size = pixel_size
process_as_film = process_as_film

d = Duplicator()

def apply_mix_and_max(imp):
    # On applique éventuellement les valeurs d'affichage précédentes, ou à défaut les valeurs extrémales
    for i in range(imp.getDimensions()[2]):
        imp.setC(i + 1)
        if i in mins_and_maxs:
            imp.setDisplayRange(mins_and_maxs[i][0], mins_and_maxs[i][1])
        else:
            IJ.resetMinAndMax(imp)
    imp.setC(1)


# Fonction appelée pour chaque image trouvée dans le dossier
def process_image(imp, output_basename):

    apply_mix_and_max(imp)

    # Si l'image est un z-stack, on se place au milieu
    if imp.getDimensions()[3] > 1:
        imp.setZ((imp.getDimensions()[3] + 1) / 2)

    # Z-Project si demandé
    if z_project and imp.getDimensions()[3] > 1:
        imp2 = ZProjector.run(imp, "max all")
        imp.close()
        imp = imp2

    imp.show()
    IJ.selectWindow(imp.getID())

    # On demande à l'utilisateur de sélectionner des ROIs rectangulaires autour des cellules d'intérêt
    IJ.setTool("rectangle")
    rm = RoiManager.getRoiManager()
    rm.runCommand(imp, "Show All with labels")

    global roi_popup_location

    gui = NonBlockingGenericDialog("ROI selection")
    gui.addMessage(
        "Select region of interest (or add many regions with \"t\") \n Warning : Dont' forget to click before selecting, when using Shift-selection")
    gui.addCheckbox("Skip", False)
    if roi_popup_location is not None:
        gui.setLocation(roi_popup_location.x, roi_popup_location.y)
    gui.showDialog()

    roi_popup_location = gui.getLocation()

    # Si on ferme la fenêtre de dialogue, le retour de la fonction est "True" pour gérer l'annulation du script
    if not gui.wasOKed():
        return True

    # Si l'option "Skip" a été cochée, on passe à la prochaine image sans traiter l'image en cours
    if gui.getNextBoolean():
        return False

    rois = rm.getRoisAsArray()
    rm.reset()
    rm.close()

    # Si aucun ROI n'a été ajouté au RoiManager ...
    if not rois:
        if imp.getRoi() is not None:
            IJ.run(imp, "Crop", "")  # S'il y a une sélection, on rogne, sinon on utilise l'image entière
        process_image_cropped(imp, output_basename)

    # Sinon, on crée un tableau d'images en utilisant le RoiManager
    else:
        cropped_array = imp.crop(rois, "stack")
        imp.close()

        # Appel de process_image_cropped pour chaque ROI
        for i, cropped in enumerate(cropped_array):
            cropped.show()
            if cropped.getDimensions()[3] > 1:
                cropped.setZ((cropped.getDimensions()[3] + 1) / 2)
            interrupt = process_image_cropped(cropped, None if output_basename is None else output_basename + "_" + str(i + 1))
            if interrupt:
                return True

    return False


# Fonction appelée pour chaque ROI sélectionnée sur l'image source
def process_image_cropped(imp, output_basename):

    apply_mix_and_max(imp)

    # On demande à régler les niveaux de couleur à l'utilisateur
    imp.setDisplayMode(IJ.COLOR)
    IJ.selectWindow(imp.getID())
    IJ.run("Brightness/Contrast...")

    global threshold_popup_location

    gui = NonBlockingGenericDialog("Thresholding")
    gui.addMessage("Set your thresholds, then click OK")
    if threshold_popup_location is not None:
        gui.setLocation(threshold_popup_location.x, threshold_popup_location.y)
    gui.showDialog()

    threshold_popup_location = gui.getLocation()

    # Si on ferme la fenêtre de dialogue, le retour de la fonction est "True" pour gérer l'annulation du script
    if not gui.wasOKed():
        return True

    imp.hide()

    # On sauvegarde les nouvelles valeurs d'affichage :
    for i in range(imp.getDimensions()[2]):
        imp.setC(i + 1)
        mins_and_maxs[i] = (imp.getDisplayRangeMin(), imp.getDisplayRangeMax())
    imp.setC(1)

    # Si on a encore un stack (pas de Z-Project ni de film) on garde la slice actuelle uniquement
    if not process_as_film and imp.getDimensions()[3] > 1:
        IJ.run(imp, "Reduce Dimensionality...", "channels bkeep")
        imp.changes = False
        imp.close()
        imp = IJ.getImage()
        IJ.selectWindow(imp.getID())

    if separated_channels == '':
        _chan_groups = [['c'], [i for i in range(1, imp.getDimensions()[2] + 1)] + ['c']]
    else:
        global chan_groups
        _chan_groups = chan_groups

    dimC, dimZ, dimT = imp.getDimensions()[2:]

    for n, chan_group in enumerate(_chan_groups):

        output_path = output_basename + "_montage"
        if len(_chan_groups) > 1:
            output_path += "_" + str(n + 1)

        if process_as_film:
            for t in range(1, dimT + 1):
                montage_frame = generate_montage_frame(Duplicator().run(imp, 1, dimC, 1, 1, t, t), chan_group)
                if t == 1:
                    montage_stack = ImageStack(montage_frame.getWidth(), montage_frame.getHeight())
                montage_stack.addSlice(montage_frame.getProcessor())
                montage_frame.close()
            montage = ImagePlus("", montage_stack)
            if output_basename is not None:
                IJ.run(montage, "AVI... ", "compression=JPEG frame=15 save='" + output_path + ".avi'")

        else:
            montage = generate_montage_frame(imp, chan_group)
            if output_basename is not None:
                IJ.saveAs(montage, "PNG", output_path + ".png")

        if src_dir is None:
            montage.show()
        else:
            montage.close()

    return False


def generate_montage_frame(imp, chan_group):

    stack = ImageStack(imp.getWidth(), imp.getHeight())

    def composite_frame(chan_string):
        duplicate = imp.duplicate()
        duplicate.setDisplayMode(IJ.COMPOSITE)
        duplicate.setActiveChannels(chan_string)
        RGBStackConverter.convertToRGB(duplicate)
        stack.addSlice(duplicate.getProcessor())
        duplicate.close()

    for chan_element in chan_group:
        if type(chan_element) is int and chan_element <= imp.getDimensions()[2]:
            duplicate = imp.duplicate()
            duplicate.setDisplayMode(IJ.GRAYSCALE)
            duplicate.setC(chan_element)
            RGBStackConverter.convertToRGB(duplicate)
            stack.addSlice(duplicate.getProcessor())
            duplicate.close()

        elif type(chan_element) is list:
            composite_frame(''.join(['1' if i in chan_element else '0' for i in range(1, imp.getDimensions()[2] + 1)]))

        elif chan_element == 'c':
            composite_frame(''.join(['1' for _ in range(1, imp.getDimensions()[2] + 1)]))

    # On crée le montage à partir du stack
    stack_imp = ImagePlus("", stack)
    montage = MontageMaker().makeMontage2(stack_imp, stack.getSize(), 1, 1, 1, stack.getSize(), 1, 0, False)
    stack_imp.close()

    # On ajoute la barre d'échelle si nécessaire
    if scale_length != 0:
        if pixel_size != 0:
            IJ.run(montage, "Set Scale...", "distance=1 known=" + str(pixel_size) + " unit=µm")
        IJ.run(montage, "Scale Bar...",
               "width=" + str(
                   scale_length) + " height=4 font=6 color=White background=None location=[Lower Right] hide overlay")

    return montage


## MAIN ##

threshold_popup_location = None
roi_popup_location = None
mins_and_maxs = {}

# PARSING THE CHANNEL INPUT #
chan_groups = []
# The final channel list, with an element for each image to generate.
# These elements are lists whose elements can be either ints, lists of ints, or 'c' (empty list for default)

for channel_group_substring in separated_channels.split(','):

    chan_group = []
    # Finding all parenthesis groups
    matches = list(re.finditer(r'\([1-9]+\)', channel_group_substring))
    # Building a correspondance table between groups and input string
    match_table = [0] * len(channel_group_substring)
    for m, match in enumerate(matches):
        for i in range(match.start(), match.end()):
            match_table[i] = m + 1

    i = 0
    while i < len(channel_group_substring):
        # Adding an int or 'c'
        if match_table[i] == 0:
            char = channel_group_substring[i]
            if char.isdigit():
                chan_group += [int(char)]
            elif char == 'c':
                chan_group += ['c']
            else:
                print('Error in channels input')
                exit(1)
            i += 1

        # Adding a list of ints
        else:
            match = matches[match_table[i] - 1]
            group_without_parenthesis = match.group(0)[1:-1]  # Removing parentheses

            if not group_without_parenthesis.isdigit():
                print('Error in channels input')
                exit(1)

            chan_group.append([int(char) for char in group_without_parenthesis])
            i = match.end()

    chan_groups.append(chan_group)


# Si l'image à traiter est déjà ouverte, on récupère son dossier
if src_dir is None:
    imp_original = IJ.getImage()
    image_dir = IJ.getDir("image")
    imp_original.hide()

    if image_dir:
        output_dir = os.path.dirname(image_dir) if save_location == "Parent folder" else image_dir
        output_basename = os.path.join(output_dir, os.path.splitext(imp_original.getTitle())[0])
    else:
        output_basename = None

    imp = imp_original.duplicate()
    process_image(imp, output_basename)
    imp_original.show()

# Sinon, on ouvre les images dans le dossier indiqué
else:
    interrupt = False  # Utilisé pour arrêter le script avant sa fin

    file_paths = []
    # Itération dans les dossiers / sous-dossiers
    for root, dirnames, filenames in os.walk(src_dir.getAbsolutePath()):

        # Itération sur les fichiers
        for filename in filenames:
            if any([filename.endswith(ext) for ext in [".tif", ".tiff"]]):
                file_paths.append(os.path.join(root, filename))

    dre = re.compile(r'(\d+)')
    file_paths.sort(key=lambda l: [int(s) if s.isdigit() else s.lower() for s in re.split(dre, l)])  # Sugar for natural sorting of filenames

    for file_path in file_paths:
        # Saving either to parent dir or same dir
        if save_location == "Parent folder":
            output_basename = os.path.split(file_path)[0]
        else:
            output_basename = os.path.splitext(file_path)[0]

        # Checking if there is already some output, if yes skipping this file
        _, _, filenames = next(os.walk(os.path.split(output_basename)[0]))
        if not any([f.startswith(os.path.split(output_basename)[1]) and not (f == os.path.split(file_path)[1]) for f in filenames]):

            imp = IJ.openImage(file_path)
            interrupt = process_image(imp, output_basename)
            imp.close()

            if interrupt:
                break
        if interrupt:
            break
