# coding=<'utf-8'>

import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), 'img_processing'))
from img_processing.pipeline_preprocessing import *
from img_processing.pipeline_templatematch import *

import numpy as np
import pandas as pd
import cellpose
from cellpose import io, models

from pathlib import Path
import cv2
import torch
import argparse


def crop_expanded_mask(mask, crop_px, edge_threshold):
    """
    Given an integer mask, crop the mask by the specified number of pixels from
    each edge, and for masks along the interior cropped edges, set the masks to
    0 according to the following rules:
    
    1. If the mask crosses the top or right edges, set the mask to 0 if the 
        proportion of the mask within the cropped region compared to outside 
        the cropped region is less than the edge_threshold.
    2. If the mask crosses the bottom or left edges, set the mask to 0 if the 
        proportion of the mask within the cropped region compared to outside 
        the cropped region is less than 1 - edge_threshold.
    
    Parameters:
    - mask: numpy array of integer masks (e.g., Cellpose masks).
    - crop_px: Number of pixels to crop from each edge.
    - edge_threshold: Proportion threshold for determining whether to keep or 
        remove a mask near an edge.
    
    Returns:
    - cropped_mask: Cropped mask array with specified masks set to 0.
    """
    
    # Get mask dimensions
    height, width = mask.shape
    print(mask.shape)

    # Create the cropped mask by removing crop_px from each edge
    cropped_mask = mask[crop_px:height-crop_px, crop_px:width-crop_px]
    
    # Get the unique mask labels (excluding 0)
    unique_masks = np.unique(mask)
    unique_masks = unique_masks[unique_masks != 0]
    
    # Iterate over each unique mask label
    for m in unique_masks:
        original_mask = mask == m
        cropped_region_mask = cropped_mask == m

        # Calculate the area of the mask inside and outside the cropped region
        original_area = np.sum(original_mask)
        cropped_area = np.sum(cropped_region_mask)
        
        # Skip masks that do not cross the cropped edges
        if cropped_area == 0 or original_area == 0:
            continue
        
        # Check if the mask crosses the top or right edges
        if (np.any(original_mask[:crop_px, :]) or np.any(original_mask[:, -crop_px:])):
            if cropped_area / original_area < edge_threshold:
                cropped_mask[cropped_mask == m] = 0
        
        # Check if the mask crosses the bottom or left edges
        if (np.any(original_mask[-crop_px:, :]) or np.any(original_mask[:, :crop_px])):
            if cropped_area / original_area < (1 - edge_threshold):
                cropped_mask[cropped_mask == m] = 0

    cropped_mask = renumber_masks(cropped_mask)
    
    return cropped_mask

def crop_expanded_mask_updated(mask, crop_px, lt_edge_threshold, rb_edge_threshold):
    """
    Given an integer mask, crop the mask by the specified number of pixels from
    each edge, and for masks along the interior cropped edges, set the masks to
    0 according to the following rules:
    
    1. If the mask crosses the top or left edges, set the mask to 0 if the 
        proportion of the mask within the cropped region compared to outside 
        the cropped region is less than the lt_edge_threshold.
    2. If the mask crosses the bottom or right edges, set the mask to 0 if the 
        proportion of the mask within the cropped region compared to outside 
        the cropped region is less than rb_edge_threshold.
    
    Parameters:
    - mask: numpy array of integer masks (e.g., Cellpose masks).
    - crop_px: Number of pixels to crop from each edge.
    - edge_threshold: Proportion threshold for determining whether to keep or 
        remove a mask near an edge.
    
    Returns:
    - cropped_mask: Cropped mask array with specified masks set to 0.
    """
    
    # Get mask dimensions
    height, width = mask.shape
    print(mask.shape)

    # Create the cropped mask by removing crop_px from each edge
    cropped_mask = mask[crop_px:height-crop_px, crop_px:width-crop_px]
    
    # Get the unique mask labels (excluding 0)
    unique_masks = np.unique(mask)
    unique_masks = unique_masks[unique_masks != 0]
    
    # Iterate over each unique mask label
    for m in unique_masks:
        original_mask = mask == m
        cropped_region_mask = cropped_mask == m

        # Calculate the area of the mask inside and outside the cropped region
        original_area = np.sum(original_mask)
        cropped_area = np.sum(cropped_region_mask)
        
        # Skip masks that do not cross the cropped edges
        if cropped_area == 0 or original_area == 0:
            continue
        
        # Check if the mask crosses the top or left edges
        if (np.any(original_mask[:crop_px, :]) or np.any(original_mask[:, :crop_px])):
            if cropped_area / original_area < lt_edge_threshold:
                cropped_mask[cropped_mask == m] = 0

        # Check if the mask crosses the bottom or right edges
        if (np.any(original_mask[-crop_px:, :]) or np.any(original_mask[:, -crop_px:])):
            if cropped_area / original_area < rb_edge_threshold:
                cropped_mask[cropped_mask == m] = 0

    cropped_mask = renumber_masks(cropped_mask)
    
    return cropped_mask

def img_gray_clahe(image):
    # Converts images to grayscale and applies CLAHE

    # Handle images separated to 3 channels
    if image.shape[0] == 3:
        image = image.transpose(1, 2, 0)
    # If image is grayscale, convert to 3-channel
    if len(image.shape) == 2 or (len(image.shape) == 3 and image.shape[2] == 1):
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    cl1 = clahe.apply(image)
    image = np.expand_dims(cl1, axis=2)
    image = np.repeat(image, 3, axis=2)

    return image


if __name__ == '__main__':
    # Set up argument parser
    parser = argparse.ArgumentParser(description='Run Cellpose on a set of images')
    parser.add_argument('--input_image', dest='input_image', 
                        type=str, help='Path to the image file')
    parser.add_argument('--model_path', dest='model_path', 
                        type=str, help='Path to the model file')
    parser.add_argument('--flow_threshold', dest='flow_threshold', 
                        type=float, help='Flow threshold')
    parser.add_argument('--cellprob_threshold', dest='cellprob_threshold', 
                        type=float, help='Cell probability threshold')
    parser.add_argument('--expand_size', dest='expand_size',
                        type=float, help='Size of expanded edge in pixels')
    parser.add_argument('--lt_edge_threshold', dest='lt_edge_threshold',
                        type=float, help='Left and top edge threshold for removing masks')
    parser.add_argument('--rb_edge_threshold', dest='rb_edge_threshold',
                        type=float, help='Right and bottom edge threshold for removing masks')

    args = parser.parse_args()

    # Use GPU/MPS if available
    if torch.backends.mps.is_available():
        device = torch.device('mps')
        gpu = True
    elif torch.cuda.is_available():
        device = torch.device('cuda')
        gpu = True
    else:
        device = torch.device('cpu')
        gpu = False
    print("Device: ", device)
    cellpose.models.device = device

    input_image = Path(args.input_image)
    output_roi_stem = input_image.with_name(input_image.stem)
    img = cv2.imread(args.input_image, cv2.IMREAD_UNCHANGED)
    img = img_gray_clahe(img)

    cv2.imwrite(str(output_roi_stem) + "_clahe.tif", img)
    
    model_path = args.model_path
    pretrained_model = models.CellposeModel(pretrained_model=model_path, gpu=gpu, device=device)
    flow_threshold = args.flow_threshold
    cellprob_threshold = args.cellprob_threshold
    expand_size = int(args.expand_size)
    lt_edge_threshold = args.lt_edge_threshold
    rb_edge_threshold = args.rb_edge_threshold

    # Run Cellpose
    masks, flows, styles = pretrained_model.eval(img, diameter=None, flow_threshold=flow_threshold, cellprob_threshold=cellprob_threshold, channels=[0, 0])

    # Filter expanded masks
    masks = crop_expanded_mask_updated(masks, expand_size, lt_edge_threshold=lt_edge_threshold, rb_edge_threshold=rb_edge_threshold)

    # # For Debugging Masks
    # np_masks = np.asarray(masks)
    # np.savetxt(str(output_roi_stem) + "_masks.csv", np_masks, delimiter=",")

    # Output point ROIs as CSV
    x, y = centroids_from_mask(masks)
    df = pd.DataFrame({'X': x, 'Y': y})
    df.to_csv(str(output_roi_stem) + "_rois.csv", index=False)

    io.save_rois(masks, output_roi_stem)
    

