# @ ImagePlus imp

from ij import *
from ij import IJ, ImagePlus
from ij.gui import Roi, PointRoi, GenericDialog, Overlay
from ij.plugin import Duplicator
from ij.plugin.frame import RoiManager
from ij.io import FileSaver
from ConfigParser import SafeConfigParser

from datetime import datetime
import subprocess
import shlex
import os
import glob
import csv
import math
import shutil

from ij.process import ImageStatistics


def error_dialog(title, message):
    """ Creates a dialog box with a message and throws an exception

    Args:
        title::str
            The title at the top of the dialog box
        message::str
            The message within the dialog box and in the exception

    Returns:
        Dialog box with message
    """
    gui = GenericDialog(title)
    gui.addMessage(message)
    gui.hideCancelButton()
    gui.showDialog()
    raise Exception(message)

def adjust_count(input_count, slope, intercept):
	""" Adjusts cell counts based on linear equation

	Args:
		input_count::int
			The number of cells in a given ROI
		slope::float
			The slope of the linear equation to adjust with
		intercept::float
			The intercept of the linear equation to adjust with
			
	Returns:
		Float of number of cells in an ROI
	"""
	
	return (input_count-intercept)/float(slope)

def path_convert(path):
    """ Converts Windows paths to Unix paths

    Args:
        path::str
            The Windows path to convert

    Returns:
        Unix path
    """
    if IJ.isWindows():
        return path.replace("\\", "/")
    else:
        return path
    
def config_setup(event=None):
    """ Opens dialog box and sets up the configuration file for the plugin

    Args:
        None

    Returns:
        None
    """
    if IJ.isWindows() == True:
        plugin_wd = IJ.getDirectory("plugins").replace("\\", "/") + "InVitro_Cellpose/scripts/"
    else:
        plugin_wd = IJ.getDirectory("plugins") + "InVitro_Cellpose/scripts/"

    if not os.path.exists(plugin_wd + 'plugin_config.ini'):
        open(plugin_wd + 'plugin_config.ini', 'w').close()
    
    config = SafeConfigParser()
    config.read(plugin_wd + 'plugin_config.ini')
    
    if config.has_section('default_settings'):
        existing_model_path = config.get('default_settings', 'model_path')
        existing_python_path = config.get('default_settings','python_path')
        existing_logging_bool = config.getboolean('default_settings','logging')
    else:
        config.add_section('default_settings')
        existing_model_path = plugin_wd + "models/Pass3"
        existing_python_path = ""
        existing_logging_bool = True
    
    config_dialog = GenericDialog("Configure Plugin Defaults")
    config_dialog.addFileField("Model Path", existing_model_path)
    config_dialog.addFileField("Python Path", existing_python_path)
    config_dialog.addCheckbox("Logging", existing_logging_bool)
    config_dialog.showDialog()
    if config_dialog.wasCanceled():
        return
    
    model_path_input = path_convert(config_dialog.getNextString())
    python_path_input = path_convert(config_dialog.getNextString())
    logging_bool_input = str(config_dialog.getNextBoolean())

    config.set('default_settings', 'model_path', model_path_input)
    config.set('default_settings', 'python_path', python_path_input)
    config.set('default_settings', 'logging', logging_bool_input)
    with open(plugin_wd + 'plugin_config.ini', 'w') as configfile:
        config.write(configfile)

    return
    

def run_script():
    # Tile Selection Portion
    # Creates a new ImagePlus
    session_stamp = datetime.now().strftime("%y%m%d%H%M%S")
    imp = IJ.getImage()
    imp_path = imp.getOriginalFileInfo().directory + imp.getOriginalFileInfo().fileName
    print(imp_path)
    stripped_imp = imp.duplicate().getProcessor()  # Creates a copy of the ImagePlus, stripped of any overlays
    stripped_imp = ImagePlus("duplicate", stripped_imp)
    width = imp.getDimensions()[0]
    height = imp.getDimensions()[1]
    
    imp_overlay = imp.getOverlay()
    number_tiles = imp_overlay.size()
    iterate = imp_overlay.iterator()

    # Checks if points have been selected, displays warning and raises exception if not
    if imp_overlay is None:
        error_dialog('Warning', 'Make a selection first and re-run the plugin.')
    # Checks if too many tiles are selected
    if number_tiles >= 6:
        warn_dialog = GenericDialog("Warning")
        warn_dialog.addMessage("There are currently {} tiles selected. \nConfirm the active window is correct. \nAre you sure you want to proceed?".format(number_tiles))
        warn_dialog.setOKLabel("Continue")
        warn_dialog.showDialog()
        if warn_dialog.wasCanceled():
            return
        if warn_dialog.wasOKed():
            pass

    # Plugin path setup
    if IJ.isWindows() == True:
        plugin_wd = IJ.getDirectory("plugins").replace("\\", "/") + "InVitro_Cellpose/scripts/"
    else:
        plugin_wd = IJ.getDirectory("plugins") + "InVitro_Cellpose/scripts/"
    temp_folder = plugin_wd + "temp/"
    log_folder = plugin_wd + "logs/"
    
    # Config setup
    if not os.path.exists(plugin_wd + 'plugin_config.ini'):
        config_setup()

    # Parameter Dialog
    parameter_dialog = GenericDialog("Set Parameters")
    parameter_dialog.addNumericField("Box Size (px)", 272)
    parameter_dialog.addSlider("Flow Threshold", 0.0, 3.0, 3.0, 0.1)
    parameter_dialog.addSlider("Cell Probability Threshold", -6.0, 6.0, -6.0, 0.1)
    parameter_dialog.addSlider("Edge Proportion", 0.0, 1.0, 0.25, 0.05)
    parameter_dialog.addChoice("Plot Option", ["Cell Outlines", "Cell Centroids", "Both Outlines and Centroids"], "Both Outlines and Centroids")
    parameter_dialog.addCheckbox("Plot Insets in Grayscale", False)
    parameter_dialog.addCheckbox("Expand Insets", True)
    parameter_dialog.addCheckbox("Remove Left & Bottom Edges", False)
    parameter_dialog.addButton("Configure Defaults", config_setup)

    parameter_dialog.showDialog() 
    if parameter_dialog.wasCanceled():
        return
    
    config = SafeConfigParser()
    config.read(plugin_wd + 'plugin_config.ini')
    model_path = config.get('default_settings', 'model_path')
    python_path = config.get('default_settings','python_path')
    logging_bool = config.getboolean('default_settings','logging')
    
    box_size = parameter_dialog.getNextNumber()
    flow_threshold = parameter_dialog.getSliders().get(0).getValue() / 10.0
    cellprob_threshold = parameter_dialog.getSliders().get(1).getValue() / 10.0
    edge_prop = parameter_dialog.getSliders().get(2).getValue() / 20.0
    plot_option = parameter_dialog.getNextChoice()
    grayscale_bool = parameter_dialog.getNextBoolean()
    expand_bool = parameter_dialog.getNextBoolean()
    remove_edges = parameter_dialog.getNextBoolean()
    
    print(flow_threshold, cellprob_threshold, edge_prop, expand_bool)
    
    # Set inset expand pixel size
    if expand_bool == True:
        expand = 20
        expand_bool_arg = " --expand_edges"
        print("expand_bool is true")
    else:
        expand = 0
        expand_bool_arg = " --no-expand_edges"
        print("expand_bool is false")

    # Set edge removal
    if remove_edges == True:
        remove_edges_arg = " --remove_edges"
    else:
        remove_edges_arg = " --no-remove_edges"
    
    safe_zone = box_size/2

    i = 1  # Counter
    all_counts = []
    
    while iterate.hasNext():
        IJ.showStatus("Processing region " + str(i) + "/" + str(number_tiles))

        p = iterate.next()

        # # Check edges are within safe zone
        # Get the centroid coordinates
        x_centroid, y_centroid = p.getContourCentroid()
        
        # Initialize flags for adjustments
        x_adjusted = False
        y_adjusted = False

        # For x-coordinate
        if safe_zone < x_centroid < (width - safe_zone):
            x_min = x_centroid - safe_zone
            # Ensure x_min is at least 'expand'
            if x_min < expand:
                x_min = expand
                x_adjusted = True
            # Ensure the box doesn't exceed the right edge minus 'expand'
            if x_min + box_size > width - expand:
                x_min = width - expand - box_size
                x_adjusted = True
        else:
            if x_centroid < safe_zone:
                x_min = expand  # Set x_min to 'expand' instead of 0
                x_adjusted = True
            else:
                x_min = width - box_size - expand
                x_adjusted = True

        # For y-coordinate
        if safe_zone < y_centroid < (height - safe_zone):
            y_min = y_centroid - safe_zone
            # Ensure y_min is at least 'expand'
            if y_min < expand:
                y_min = expand
                y_adjusted = True
            # Ensure the box doesn't exceed the bottom edge minus 'expand'
            if y_min + box_size > height - expand:
                y_min = height - expand - box_size
                y_adjusted = True
        else:
            if y_centroid < safe_zone:
                y_min = expand  # Set y_min to 'expand' instead of 0
                y_adjusted = True
            else:
                y_min = height - box_size - expand
                y_adjusted = True

        # Log a message if any adjustments to bounds were made
        if x_adjusted or y_adjusted:
            IJ.log("Selection area was adjusted to stay within edge safe zones defined by 'expand'.")

        # For output to Cellpose
        cellpose_bounds = Roi(x_min - expand, y_min - expand, box_size + 2*expand, box_size + 2*expand)  # Creating ROI with correct bounds
        cropped_imp = Duplicator().run(stripped_imp)
        cropped_imp.setRoi(cellpose_bounds)  # Setting cropping boundaries to bounds ROI
        cropped_imp = cropped_imp.crop()
        cropped_imp.setTitle("region_%s" % i)

        # Crop for display
        original_bounds = Roi(x_min, y_min, box_size, box_size)
        display_imp = Duplicator().run(stripped_imp)
        display_imp.setRoi(original_bounds)  # Setting cropping boundaries to bounds ROI
        display_imp = display_imp.crop()
        display_imp.setTitle("region_%s_inset%s%s" % (i, expand_bool_arg, remove_edges_arg))

        # Saves images as temporary files to be passed into algorithm subprocess
        output_file = FileSaver(cropped_imp)
        output_file.saveAsJpeg(temp_folder + cropped_imp.getTitle() + ".jpg")
        output_filename = (temp_folder + cropped_imp.getTitle() + ".jpg")
        roi_filename = (temp_folder + cropped_imp.getTitle() + "_rois.zip")  # Filename without extension (_rois.zip)
        centroids_filename = (temp_folder + cropped_imp.getTitle() + "_rois.csv")  # Filename without extension (_rois.csv)
        output_filename = '"{}"'.format(output_filename) 
        clahe_filename = (temp_folder + cropped_imp.getTitle() + "_clahe.tif")

        print(roi_filename)
        
        print(output_filename)

        IJ.showStatus("Identifying cells in region " + str(i) + "/" + str(number_tiles))

        shell_in = python_path + " " + plugin_wd + "run_invitro_cellpose.py " + " --input_image " + output_filename + " --model_path " + str(model_path) + " --flow_threshold " + str(flow_threshold) + " --cellprob_threshold " + str(cellprob_threshold) + " --edge_prop " + str(edge_prop) + str(expand_bool_arg) + str(remove_edges_arg)
        args = shlex.split(shell_in)
        print(args)

        # Calls cell detection algorithm and saves output as string
        try:
            process = subprocess.check_output(args, stderr=subprocess.STDOUT, shell=False,
                                              cwd=plugin_wd)
        except subprocess.CalledProcessError as e:
            raise RuntimeError("command '{}' return with error (code {}): {}".format(e.cmd, e.returncode, e.output))
        
        print("Plotting region " + str(i))
        print(roi_filename)

        IJ.showStatus("Plotting cell locations in region " + str(i) + "/" + str(number_tiles))

        # Plotting grayscale inset
        if grayscale_bool == True:
            display_imp = IJ.openImage(clahe_filename)
            crop_bounds = Roi(expand, expand, box_size, box_size)
            display_imp.setRoi(crop_bounds)
            display_imp = display_imp.crop()
            display_imp.setTitle("gray_region_%s_inset%s%s" % (i, expand_bool_arg, remove_edges_arg))

        # Plotting Outlines and Centroids - requires duplicate imp
        if plot_option == "Both Outlines and Centroids":
            duplicate_outline_imp = Duplicator().run(display_imp)
            duplicate_outline_imp.setTitle("duplicate_region_%s_inset%s%s" % (i, expand_bool_arg, remove_edges_arg))
 
        all_points = RoiManager(False)  # Creates RoiManager to store points and hides it

        # Plotting Centroids
        if plot_option == "Cell Centroids" or plot_option == "Both Outlines and Centroids":
            with open(centroids_filename, "r") as csvfile:
                csvreader = csv.reader(csvfile, delimiter=',')
                next(csvreader) # Skip header
                for row in csvreader:
                    x = float(row[0])
                    y = float(row[1])
                    p = PointRoi(x, y)
                    p.setPointType(0)
                    all_points.add(p, -1)

        # Plotting Outlines
        if plot_option == "Cell Outlines":
            all_points.runCommand("Open", roi_filename)

        all_points.moveRoisToOverlay(display_imp)
        display_imp.show()
        display_imp.getCanvas().zoom100Percent()
        display_imp.getCanvas().setScaleToFit(True)
        all_points.runCommand("Show All")

        # Plotting outlines for "Both" option
        if plot_option == "Both Outlines and Centroids":
            duplicate_roi_manager = RoiManager(False)
            duplicate_roi_manager.runCommand("Open", roi_filename)
            duplicate_roi_manager.moveRoisToOverlay(duplicate_outline_imp)
            duplicate_outline_imp.show()
            duplicate_outline_imp.getCanvas().zoom100Percent()
            duplicate_outline_imp.getCanvas().setScaleToFit(True)
            duplicate_roi_manager.runCommand("Show All")


        cell_count = all_points.getCount()
        all_counts.append(cell_count)
        IJ.log("Region {} cell count: {}".format(i, cell_count))

        # Logging
        dt_string = datetime.utcnow().strftime("%y%m%d%H%M%S")
        if logging_bool:
            cropped_imp.setProp("SessionID", session_stamp)
            cropped_imp.setProp("Inset_WholeImg", imp_path)
            cropped_imp.setProp("Inset_XY_TopLeft", "%s, %s" % (x_min, y_min))
            cropped_imp.setProp("Inset_Size", box_size + 2*expand)
            cropped_imp.setProp("Auto_Timestamp", dt_string)
            cropped_imp.setProp("Auto_Count", str(cell_count))
            log_img = FileSaver(cropped_imp)
            log_img.saveAsTiff("%s%s_auto_tocellpose_region_%s.tiff" % (log_folder, session_stamp, str(i)))
            
            display_imp.setProp("SessionID", session_stamp)
            display_imp.setProp("Inset_WholeImg", imp_path)
            display_imp.setProp("Inset_XY_TopLeft", "%s, %s" % (x_min, y_min))
            display_imp.setProp("Inset_Size", box_size + 2*expand)
            display_imp.setProp("Auto_Timestamp", dt_string)
            display_imp.setProp("Auto_Count", str(cell_count))
            log_orig = FileSaver(display_imp)
            log_orig.saveAsTiff("%s%s_auto_display_region_%s.tiff" % (log_folder, session_stamp, str(i)))

            shutil.copy(roi_filename, "%s%s_auto_region_%s_rois.zip" % (log_folder, session_stamp, str(i)))
            shutil.copy(centroids_filename, "%s%s_auto_region_%s_rois.csv" % (log_folder, session_stamp, str(i)))

        i += 1  # Increment counter

    if logging_bool:
        shutil.copy(imp_path, "%s%s_auto_wholeimg.tiff" % (log_folder, session_stamp))

    # Calculate average and standard deviation
    if len(all_counts) > 1:
        avg_count = sum(all_counts)/float(len(all_counts))
        sum_differences_sq = 0
        for value in all_counts:
            sum_differences_sq += (value - avg_count)**2
            stdev_val = math.sqrt(sum_differences_sq/(len(all_counts)-1.0))
        IJ.log("Average cell count: %.2f" % avg_count)
        IJ.log("Standard deviation (s): %.2f" % stdev_val)
        IJ.log("Session ID: %s \n" % session_stamp)
    else:
        IJ.log("Session ID: %s \n" % session_stamp)

    IJ.showStatus("Cleaning up temporary files")
    # Deleting tile images saved in temp folder
    IJ.showStatus("Cleaning up temporary files")
    temp_files = glob.glob(temp_folder + '*')
    for f in temp_files:
        os.remove(f)
    
    # Switch selection back to whole image
    imp.show()
    IJ.selectWindow(imp.getTitle())

    if plot_option == "Cell Centroids" or plot_option == "Both Outlines and Centroids":
        IJ.setTool("point tool")


if __name__ in ['__builtin__', '__main__']:
    run_script()
