# coding=<'utf-8'>

import os
import sys
sys.path.append(os.path.join(os.path.dirname(__file__), 'img_processing'))
from img_processing.pipeline_preprocessing import *
from img_processing.pipeline_templatematch import *

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cellpose
from skimage.segmentation import find_boundaries
from skimage import measure
from cellpose import io, models

from pathlib import Path
import cv2
import torch
import argparse
import time


def sliding_window_postprocess_masks(full_mask,
                                     window_size,
                                     step,
                                     margin,
                                     edge_tl,
                                     edge_br):
    H, W = full_mask.shape
    win_w, win_h = window_size

    # compute x/y origins, always at least `margin` from edges
    x_starts = list(range(margin, W - margin - win_w + 1, step))
    if x_starts[-1] != (W - margin - win_w):
        x_starts.append(W - margin - win_w)
    y_starts = list(range(margin, H - margin - win_h + 1, step))
    if y_starts[-1] != (H - margin - win_h):
        y_starts.append(H - margin - win_h)

    for y in y_starts:
        for x in x_starts:
            # 1) pull out the expanded tile (with full 'margin' of context)
            exp = full_mask[
                y - margin : y + win_h + margin,
                x - margin : x + win_w + margin
            ]
            if exp.size == 0:
                continue

            # 2) build a simple crop_info (all margins are full 'margin')
            crop_info = {
                'template_crop_x':  x - margin,
                'template_crop_y':  y - margin,
                'effective_margin_left':   margin,
                'effective_margin_top':    margin,
                'effective_margin_right':  margin,
                'effective_margin_bottom': margin,
                'inset_x':   x,
                'inset_y':   y,
                'inset_width':  win_w,
                'inset_height': win_h
            }

            # 3) run existing edge‐filter
            filtered_expanded = filter_expanded_mask_updated(
                exp, crop_info,
                edge_br,  # bottom/right threshold
                edge_tl   # top/left threshold
            )

            # 4) carve back out the exact 272×272 window
            cleaned = revert_expanded_inset_to_inset(filtered_expanded, crop_info)

            yield cleaned, (x, y)

if __name__ == '__main__':
    # Set up argument parser
    parser = argparse.ArgumentParser(description='Run Cellpose on a set of images')
    parser.add_argument('--input_image', dest='input_image', 
                        type=str, help='Path to the image file')
    parser.add_argument('--model_path', dest='model_path', 
                        type=str, help='Path to the model file')
    parser.add_argument('--box_size', dest='box_size',
                        type=int, help='Size of inset in pixels')
    parser.add_argument('--margin_size', dest='margin_size',
                        type=int, help='Size of margin in pixels')
    parser.add_argument('--flow_threshold', dest='flow_threshold', 
                        type=float, help='Flow threshold')
    parser.add_argument('--cellprob_threshold', dest='cellprob_threshold', 
                        type=float, help='Cell probability threshold')
    parser.add_argument('--lt_edge_threshold', dest='lt_edge_threshold', type=float, 
                        help='Left/top Edge proportion for cutoff used for edge expansion')
    parser.add_argument('--rb_edge_threshold', dest='rb_edge_threshold', type=float, 
                    help='Right/bottom edge proportion for cutoff used for edge expansion')
    parser.add_argument('--percentile', dest='percentile', type=int,
                        help='Percentile to use for thresholding')
    

    args = parser.parse_args()

    start = time.time()

    # Use GPU/MPS if available
    if torch.backends.mps.is_available():
        device = torch.device('mps')
        gpu = True
    elif torch.cuda.is_available():
        device = torch.device('cuda')
        gpu = True
    else:
        device = torch.device('cpu')
        gpu = False
    print("Device: ", device)
    cellpose.models.device = device

    input_image = Path(args.input_image)
    output_roi_stem = input_image.with_name(input_image.stem)
    img = cv2.imread(args.input_image, cv2.IMREAD_UNCHANGED)
    img = img_gray_clahe(img)
    gray_2d_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)

    box_size = int(args.box_size)
    margin_size = int(args.margin_size)
    step_size = int(box_size/2)

    cv2.imwrite(str(output_roi_stem) + "_clahe.tif", img)
    
    model_path = args.model_path
    pretrained_model = models.CellposeModel(pretrained_model=model_path, gpu=gpu, device=device)
    flow_threshold = args.flow_threshold
    cellprob_threshold = args.cellprob_threshold
    lt_edge_threshold = args.lt_edge_threshold
    rb_edge_threshold = args.rb_edge_threshold
    percentile_arg = int(args.percentile)

    # Initialize matrices to store the sum of evaluations and counts.
    heatmap_sum = np.zeros(gray_2d_img.shape, dtype=np.float32)
    count_matrix = np.zeros_like(heatmap_sum, dtype=np.float32)
    single_count_matrix = np.zeros_like(heatmap_sum, dtype=np.float32)

    # windows = list(sliding_window(img, window_size=(box_size, box_size), step=step_size))
    # windows_2d_gray = list(sliding_window(gray_2d_img, window_size=(box_size, box_size), step=step_size))
    # print(windows[0])
    # window_images = [win for win, _ in windows]

    window_time = time.time()
    print("Windows generated, elapsed: ", window_time - start)

# Run Cellpose
    masks, flows, styles = pretrained_model.eval(
        # window_images,  # Changed to do sliding window on the whole image
        img,
        channels=[0, 0], 
        diameter=None, 
        flow_threshold=flow_threshold, 
        cellprob_threshold=cellprob_threshold
    )

    inference_time = time.time()
    print("Inference complete, elapsed: ", inference_time - start)

    windows_masks = list(sliding_window_postprocess_masks(masks, window_size=(box_size, box_size), step=step_size, margin=margin_size, edge_tl=lt_edge_threshold, edge_br=rb_edge_threshold))

    window_dict = {(x, y): win for win, (x, y) in windows_masks}
    window_count_dict = {}

    for idx, (win, (x, y)) in enumerate(windows_masks):
        # result = masks[idx].max()  # Maximum value from the mask of the current window.
        unique_labels, counts = np.unique(win, return_counts=True)
        label_counts = dict(zip(unique_labels, counts))
        if 0 in label_counts:
            del label_counts[0]
        
        result = len(label_counts)  # Number of unique labels in the current window.
        window_count_dict[(x, y)] = result
        heatmap_sum[y:y + box_size, x:x + box_size] += result
        count_matrix[y:y + box_size, x:x + box_size] += 1
        single_count_matrix[y, x] = result

    percentiles = [25, 50, 75]
    percentiles.append(percentile_arg)
    nonzero_values = single_count_matrix[single_count_matrix != 0]
    percentile_values = np.percentile(nonzero_values, percentiles)

    percentile_coordinates = []

    tolerance = 1
    for percentile, percentile_value in zip(percentiles, percentile_values):
        coords = np.argwhere(np.isclose(single_count_matrix, percentile_value, atol=tolerance))
        if coords.size == 0:
            print(f"No coordinate found for the {percentile}th percentile value {percentile_value}")
            continue
        # y_coord, x_coord = coords[0]
        # Pick the median coordinate if multiple coordinates are found
        if len(coords) > 1:
            y_coord, x_coord = np.median(coords, axis=0).astype(int)
        else:
            y_coord, x_coord = coords[0]

        # # Ensure the patch fits within the image boundaries.
        # if x_coord + box_size > gray_2d_img.shape[1]:
        #     x_coord = gray_2d_img.shape[1] - box_size
        #     print("Adjusted x value to fit within image")
        # if y_coord + box_size > gray_2d_img.shape[0]:
        #     y_coord = gray_2d_img.shape[0] - box_size
        #     print("Adjusted y value to fit within image")

        print(f"Coordinate for {percentile}th percentile:", y_coord, x_coord)
        percentile_coordinates.append((x_coord, y_coord))

    int_percentile_val = [int(val) for val in percentile_values]

    # Separate the x and y coordinates for each percentile
    x_coords = [coord[0] for coord in percentile_coordinates]
    y_coords = [coord[1] for coord in percentile_coordinates]

    output_coordinates = zip(percentiles, x_coords, y_coords)
    # Save percentile_coordinates as CSV
    df = pd.DataFrame(output_coordinates, columns=['Percentile','XCoord', 'YCoord'])
    df.to_csv(str(output_roi_stem) + "_percentile_coordinates.csv", index=False)

    complete_time = time.time()
    print("Complete, elapsed: ", complete_time - start)

    # Save the masks of the selected percentile
    for percentile, x, y in df.itertuples(index=False):
        print(f"Processing {percentile}th percentile at coordinates: ({x}, {y})")
        try:
            percentile_mask = window_dict[(x, y)]
        except KeyError:
            # fallback to raw slice if something went wrong
            percentile_mask = masks[y:y + box_size, x:x + box_size]
        # percentile_mask = renumber_masks(percentile_mask)
        percentile_mask = np.uint8(percentile_mask)
        # io.save_rois(percentile_mask, str(output_roi_stem))

        x, y = centroids_from_mask(percentile_mask)
        roi_df = pd.DataFrame({'X': x, 'Y': y})
        roi_df.to_csv(str(output_roi_stem) + f"_{percentile}_percentile_rois.csv", index=False)

        io.save_rois(percentile_mask, str(output_roi_stem) + f"_{percentile}_percentile_mask")

    # Make a matplotlib subplot comparing the 25, 50, and 75th percentile masks
    fig, axs = plt.subplots(1, len(percentiles), figsize=(15, 5))
    for i, (percentile, x, y) in enumerate(df.itertuples(index=False)):
        try:
            percentile_mask = window_dict[(x, y)]
        except KeyError:
            # fallback to raw slice if something went wrong
            percentile_mask = masks[y:y + box_size, x:x + box_size]
        cropped_image = img[y:y + box_size, x:x + box_size]
        cv2.imwrite(str(output_roi_stem) + f"_{percentile}_cropped.tif", cropped_image)

        axs[i].imshow(percentile_mask, cmap='gray')
        # Outline the masks
        boundaries = find_boundaries(percentile_mask, mode='thick')
        # Show the cropped image 
        for cells in range(1, percentile_mask.max() + 1):
            # Draw contours around the masks
            contours = measure.find_contours(percentile_mask, cells)
            for contour in contours:
                axs[i].plot(contour[:, 1], contour[:, 0], linewidth=2, color='red')
        axs[i].imshow(cropped_image, cmap='gray')
        # axs[i].imshow(boundaries, cmap='jet', alpha=0.25)
        axs[i].set_title(f"{percentile}th Percentile\n{int_percentile_val[i]} cells")
        axs[i].axis('off')
    plt.tight_layout()
    plt.savefig(str(output_roi_stem) + "_percentile_plot.png")

    # # Remove masks touching left and bottom edges - TODO: Set left/bottom proportion to 0, use new function
    # if args.remove_edges == True:
    #     masks = remove_edges(masks, 0.0, 0.0)


    # # TODO: Fix later, masks for selected inset?
    # # For Debugging Masks
    # np_masks = np.asarray(masks)
    # np.savetxt(str(output_roi_stem) + "_masks.csv", np_masks, delimiter=",")

    # # Output point ROIs as CSV
    # x, y = centroids_from_mask(masks)
    # df = pd.DataFrame({'X': x, 'Y': y})
    # df.to_csv(str(output_roi_stem) + "_rois.csv", index=False)

    # io.save_rois(masks, output_roi_stem)
    

