# coding=<'utf-8'>

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cellpose
import os
from collections import Counter
from skimage.segmentation import find_boundaries
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import train_test_split
from sklearn.cluster import KMeans
from cellpose import core, utils, io, models, metrics, plot, train

from cellpose.io import imread

from PIL import Image
from roifile import roiread
from matplotlib.pyplot import tight_layout
from matplotlib import colors
from itertools import cycle
from pathlib import Path

import cv2
import tifffile as tiff

import torch

import argparse

from scipy.spatial import Voronoi, voronoi_plot_2d


def renumber_masks(masks):
    # Map existing mask values to consecutive integers 
    # (because some values were removed)
    unique_masks = np.unique(masks)
    # Remove 0 from unique_edges_removed
    unique_masks = unique_masks[unique_masks != 0]  
    value_map_dict = {}
    for i in range(len(unique_masks)):
        value_map_dict[unique_masks[i]] = i + 1  # Start at 1 instead of 0

    # Renumber mask values in array to be consecutive
    for i in range(len(masks)):
        for j in range(len(masks[i])):
            if masks[i][j] in value_map_dict:
                masks[i][j] = value_map_dict[masks[i][j]]

    return masks


def remove_edges(masks, left_and_bottom: float = 1.0, size_percentile: float = 0.0):
    """
    Removes the masks that are on the edges of the image based on the 
    proportion of pixels that are on the edge for each cell and the size 
    percentile threshold. (e.g. 0.0 means all pixels on edge are removed and 
    1.0 means no pixels on edge for each cell are removed). Cells smaller than 
    the specified percentile (of all cell sizes within the inset) will be 
    removed if they touch the left or bottom edges. Default values mean no 
    cells are removed.
    """
    
    # Get mask values for bottom and left edge of image
    column_vals = masks[:,0]  # First column of mask (left edge)
    row_vals = masks[-1,:]  # Last row of mask (bottom edge)
    edge_mask = np.concatenate((column_vals, row_vals))
    edge_mask = edge_mask[edge_mask != 0]  # Get all non-zero values
    edge_mask_counts = dict(Counter(edge_mask))

    # Create dictionary of counts and proportion for each edge_mask value in masks
    mask_counts = dict(Counter(masks.flatten()))
    edge_mask_proportion = {i: (edge_mask_counts.get(i, 0) / mask_counts[i]) for i in edge_mask}
    
    # Calculate the size percentile threshold
    cell_sizes = [count for cell, count in mask_counts.items() if cell != 0]
    size_threshold = np.percentile(cell_sizes, size_percentile)
    
    # Set values in mask to 0 if proportion of edge pixels in whole cell is greater than left_and_bottom
    # or if the cell size is smaller than the size threshold
    edges_removed_mask = np.copy(masks)
    for i in range(len(edges_removed_mask)):
        for j in range(len(edges_removed_mask[i])):
            cell = edges_removed_mask[i][j]
            if cell in edge_mask:
                if edge_mask_proportion[cell] > left_and_bottom or mask_counts[cell] < size_threshold:
                    edges_removed_mask[i][j] = 0

    # Map existing mask values to consecutive integers (because some values were removed)
    unique_edges_removed = np.unique(edges_removed_mask)
    unique_edges_removed = unique_edges_removed[unique_edges_removed != 0]  # Remove 0 from unique_edges_removed
    value_map_dict = {value: idx + 1 for idx, value in enumerate(unique_edges_removed)}
    
    # Renumber mask values in array to be consecutive
    for i in range(len(edges_removed_mask)):
        for j in range(len(edges_removed_mask[i])):
            if edges_removed_mask[i][j] in value_map_dict:
                edges_removed_mask[i][j] = value_map_dict[edges_removed_mask[i][j]]

    return edges_removed_mask


def remove_masks(masks, size: int):
    """
    Removes the masks that are smaller than the given size (in pixels)
    """
    
    masks_removed = np.copy(masks)

    unique_masks, counts = np.unique(masks_removed, return_counts=True)
    # print(unique_masks)
    # print(counts)

    for i in range(len(unique_masks)):
        if counts[i] < size:
            masks_removed[masks_removed == unique_masks[i]] = 0

    # Map existing mask values to consecutive integers (because some values were removed)
    unique_masks_removed = np.unique(masks_removed)
    unique_masks_removed = unique_masks_removed[unique_masks_removed != 0]  # Remove 0 from unique_edges_removed
    value_map_dict = {}
    for i in range(len(unique_masks_removed)):
        value_map_dict[unique_masks_removed[i]] = i + 1  # Start at 1 instead of 0
    
    # Renumber mask values in array to be consecutive
    for i in range(len(masks_removed)):
        for j in range(len(masks_removed[i])):
            if masks_removed[i][j] in value_map_dict:
                masks_removed[i][j] = value_map_dict[masks_removed[i][j]]

    return masks_removed

def crop_expanded_mask(mask, crop_px, edge_threshold):
    """
    Given an integer mask, crop the mask by the specified number of pixels from
    each edge, and for masks along the interior cropped edges, set the masks to
    0 according to the following rules:
    
    1. If the mask crosses the top or right edges, set the mask to 0 if the 
        proportion of the mask within the cropped region compared to outside 
        the cropped region is less than the edge_threshold.
    2. If the mask crosses the bottom or left edges, set the mask to 0 if the 
        proportion of the mask within the cropped region compared to outside 
        the cropped region is less than 1 - edge_threshold.
    
    Parameters:
    - mask: numpy array of integer masks (e.g., Cellpose masks).
    - crop_px: Number of pixels to crop from each edge.
    - edge_threshold: Proportion threshold for determining whether to keep or 
        remove a mask near an edge.
    
    Returns:
    - cropped_mask: Cropped mask array with specified masks set to 0.
    """
    
    # Get mask dimensions
    height, width = mask.shape
    print(mask.shape)

    # Create the cropped mask by removing crop_px from each edge
    cropped_mask = mask[crop_px:height-crop_px, crop_px:width-crop_px]
    
    # Get the unique mask labels (excluding 0)
    unique_masks = np.unique(mask)
    unique_masks = unique_masks[unique_masks != 0]
    
    # Iterate over each unique mask label
    for m in unique_masks:
        original_mask = mask == m
        cropped_region_mask = cropped_mask == m

        # Calculate the area of the mask inside and outside the cropped region
        original_area = np.sum(original_mask)
        cropped_area = np.sum(cropped_region_mask)
        
        # Skip masks that do not cross the cropped edges
        if cropped_area == 0 or original_area == 0:
            continue
        
        # Check if the mask crosses the top or right edges
        if (np.any(original_mask[:crop_px, :]) or np.any(original_mask[:, -crop_px:])):
            if cropped_area / original_area < edge_threshold:
                cropped_mask[cropped_mask == m] = 0
        
        # Check if the mask crosses the bottom or left edges
        if (np.any(original_mask[-crop_px:, :]) or np.any(original_mask[:, :crop_px])):
            if cropped_area / original_area < (1 - edge_threshold):
                cropped_mask[cropped_mask == m] = 0

    cropped_mask = renumber_masks(cropped_mask)
    
    return cropped_mask

def img_gray_clahe(image):
    # Converts images to grayscale and applies CLAHE

    # Handle images separated to 3 channels
    if image.shape[0] == 3:
        image = image.transpose(1, 2, 0)
    # If image is grayscale, convert to 3-channel
    if len(image.shape) == 2 or (len(image.shape) == 3 and image.shape[2] == 1):
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    cl1 = clahe.apply(image)
    image = np.expand_dims(cl1, axis=2)
    image = np.repeat(image, 3, axis=2)

    return image

def centroids_from_mask(mask):
    """
    Find centroids of cells based on mask
    """
    flat_mask = mask.flatten()
    split = np.unique(np.sort(flat_mask), return_index=True)[1]  # Split mask into individual cells
    points = []  
    for inds in np.split(flat_mask.argsort(), split)[2:]:  
        points.append(np.array(np.unravel_index(inds, mask.shape)).mean(axis=1))  
    y, x = zip(*points)  # Transpose points (y comes before x for np)
    
    return (x, y)


if __name__ == '__main__':
    # Set up argument parser
    parser = argparse.ArgumentParser(description='Run Cellpose on a set of images')
    parser.add_argument('--input_image', dest='input_image', 
                        type=str, help='Path to the image file')
    parser.add_argument('--model_path', dest='model_path', 
                        type=str, help='Path to the model file')
    parser.add_argument('--flow_threshold', dest='flow_threshold', 
                        type=float, help='Flow threshold')
    parser.add_argument('--cellprob_threshold', dest='cellprob_threshold', 
                        type=float, help='Cell probability threshold')
    parser.add_argument('--edge_prop', dest='edge_prop', type=float, 
                        help='Edge proportion for cutoff used for edge expansion')
    parser.add_argument('--expand_edges', dest='expand_edges', 
                        action=argparse.BooleanOptionalAction,
                        help='Whether or not insets are expanded')
    parser.add_argument('--remove_edges', dest='remove_edges', 
                        action=argparse.BooleanOptionalAction,
                        help='Whether or not left and bottom edges are removed')

    args = parser.parse_args()

    input_image = Path(args.input_image)
    output_roi_stem = input_image.with_name(input_image.stem)
    img = cv2.imread(args.input_image, cv2.IMREAD_UNCHANGED)
    img = img_gray_clahe(img)

    cv2.imwrite(str(output_roi_stem) + "_clahe.tif", img)
    
    model_path = args.model_path
    pretrained_model = models.CellposeModel(pretrained_model=model_path)
    flow_threshold = args.flow_threshold
    cellprob_threshold = args.cellprob_threshold
    edge_prop = args.edge_prop

    # Run Cellpose
    masks, flows, styles = pretrained_model.eval(img, diameter=None, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, channels=[0, 0])

    # Perform expanded edge & crop down to original size
    if args.expand_edges == True:
        masks = crop_expanded_mask(masks, 20, edge_prop)

    # Remove masks touching left and bottom edges
    if args.remove_edges == True:
        masks = remove_edges(masks, 0.0, 0.0)

    # # For Debugging Masks
    # np_masks = np.asarray(masks)
    # np.savetxt(str(output_roi_stem) + "_masks.csv", np_masks, delimiter=",")

    # Output point ROIs as CSV
    x, y = centroids_from_mask(masks)
    df = pd.DataFrame({'X': x, 'Y': y})
    df.to_csv(str(output_roi_stem) + "_rois.csv", index=False)

    io.save_rois(masks, output_roi_stem)
    

