import sys
import numpy as np
import cv2

from .pipeline_preprocessing import renumber_masks

def filter_expanded_mask(mask, crop_info, edge_threshold):
    """
    Given a mask segmented on an expanded inset, apply the proportion threshold
    logic to remove masks that don't meet the specified criteria near truncated
    edges.

    Parameters:
    - mask: numpy array of integer masks (e.g., Cellpose masks) corresponding 
      to the expanded_inset.
    - crop_info: Dictionary returned by crop_template_to_expanded_inset 
      containing:
        {
            'template_crop_x': ...,
            'template_crop_y': ...,
            'effective_margin_top': ...,
            'effective_margin_bottom': ...,
            'effective_margin_left': ...,
            'effective_margin_right': ...,
            'inset_x': ...,
            'inset_y': ...,
            'inset_width': ...,
            'inset_height': ...
        }
    - edge_threshold: Proportion threshold for deciding whether to keep or 
      remove a mask at edges.

    Returns:
    - filtered_mask: Mask after removing cells that don't meet the proportion 
      threshold criteria.
    """

    # Extract info from crop_info
    template_crop_x = crop_info['template_crop_x']
    template_crop_y = crop_info['template_crop_y']
    effective_margin_top = crop_info['effective_margin_top']
    effective_margin_bottom = crop_info['effective_margin_bottom']
    effective_margin_left = crop_info['effective_margin_left']
    effective_margin_right = crop_info['effective_margin_right']
    inset_x = crop_info['inset_x']
    inset_y = crop_info['inset_y']
    inset_width = crop_info['inset_width']
    inset_height = crop_info['inset_height']

    # Determine where the original inset lies within the expanded inset
    inset_y_in_expanded = inset_y - template_crop_y
    inset_x_in_expanded = inset_x - template_crop_x

    # Identify the regions corresponding to truncated margins
    # Top margin region:    [0 : inset_y_in_expanded, :]
    # Left margin region:   [:, 0 : inset_x_in_expanded]
    # Bottom margin region: [inset_y_in_expanded+inset_height : end, :]
    # Right margin region:  [:, inset_x_in_expanded+inset_width : end]

    height, width = mask.shape
    top_region = (0, inset_y_in_expanded) if effective_margin_top > 0 else None
    left_region = (0, inset_x_in_expanded) if effective_margin_left > 0 else None
    bottom_region = (inset_y_in_expanded + inset_height, height) if effective_margin_bottom > 0 else None
    right_region = (inset_x_in_expanded + inset_width, width) if effective_margin_right > 0 else None

    filtered_mask = np.copy(mask)
    unique_masks = np.unique(filtered_mask)
    unique_masks = unique_masks[unique_masks != 0]

    for m in unique_masks:
        original_mask = (filtered_mask == m)
        original_area = np.sum(original_mask)
        if original_area == 0:
            # No area to consider
            continue

        # Area inside the original inset region
        inset_region_mask = original_mask[
            inset_y_in_expanded:inset_y_in_expanded + inset_height,
            inset_x_in_expanded:inset_x_in_expanded + inset_width
        ]
        cropped_area = np.sum(inset_region_mask)

        if cropped_area == 0:
            # Mask does not extend into the original inset region at all
            continue

        # Check each edge
        crosses_top = False
        if top_region is not None:
            y_start, y_end = top_region
            if y_end > 0 and np.any(original_mask[y_start:y_end, :]):
                crosses_top = True

        crosses_left = False
        if left_region is not None:
            x_start, x_end = left_region
            if x_end > 0 and np.any(original_mask[:, x_start:x_end]):
                crosses_left = True

        crosses_bottom = False
        if bottom_region is not None:
            y_start, y_end = bottom_region
            if y_start < height and np.any(original_mask[y_start:y_end, :]):
                crosses_bottom = True

        crosses_right = False
        if right_region is not None:
            x_start, x_end = right_region
            if x_start < width and np.any(original_mask[:, x_start:x_end]):
                crosses_right = True

        # Apply threshold logic
        if (crosses_top or crosses_left):
            # If mask crosses top/left edges, use edge_threshold
            if cropped_area / original_area < edge_threshold:
                filtered_mask[filtered_mask == m] = 0

        if (crosses_bottom or crosses_right):
            # If mask crosses bottom/right edges, use (1 - edge_threshold)
            if cropped_area / original_area < (1 - edge_threshold):
                filtered_mask[filtered_mask == m] = 0

    # Optionally renumber masks if desired
    filtered_mask = renumber_masks(filtered_mask)

    return filtered_mask


def find_inset_coordinates(inset, template):
    """
    Find the coordinates of the top-left corner of the inset within the template.
    
    Parameters:
    - inset: The original inset image (smaller image).
    - template: The larger image containing the inset.
    
    Returns:
    - (inset_x, inset_y): The top-left coordinates of the inset in the template.
    """
    result = cv2.matchTemplate(template, inset, cv2.TM_CCOEFF_NORMED)
    _, _, _, max_loc = cv2.minMaxLoc(result)
    # max_loc is (x, y)
    return max_loc


def crop_template_to_expanded_inset(inset, template, margin):
    """
    Given an inset and its parent template, compute an expanded_inset by adding
    a margin around the inset. If the margin extends beyond the template 
    boundaries, it is truncated.
    
    Parameters:
    - inset: The original inset image.
    - template: The larger template image.
    - margin: The desired margin in pixels around the inset.
    
    Returns:
    - expanded_inset: The cropped portion of the template containing the inset 
      plus margins.
    - crop_info: A dictionary with information needed to revert back to the 
      original inset.
        {
            'template_crop_x': ...,
            'template_crop_y': ...,
            'effective_margin_top': ...,
            'effective_margin_bottom': ...,
            'effective_margin_left': ...,
            'effective_margin_right': ...,
            'inset_x': ...,
            'inset_y': ...,
            'inset_width': ...,
            'inset_height': ...
        }
    """
    inset_height, inset_width = inset.shape[:2]
    template_height, template_width = template.shape[:2]

    # Locate the inset in the template
    inset_x, inset_y = find_inset_coordinates(inset, template)
    
    # Intended coordinates of the expanded inset
    intended_top = inset_y - margin
    intended_bottom = inset_y + inset_height + margin
    intended_left = inset_x - margin
    intended_right = inset_x + inset_width + margin

    # Clamp to template boundaries
    crop_top = max(intended_top, 0)
    crop_bottom = min(intended_bottom, template_height)
    crop_left = max(intended_left, 0)
    crop_right = min(intended_right, template_width)

    # Calculate how much margin was effectively applied on each side
    effective_margin_top = inset_y - crop_top
    effective_margin_left = inset_x - crop_left
    effective_margin_bottom = (crop_bottom - (inset_y + inset_height))
    effective_margin_right = (crop_right - (inset_x + inset_width))

    # Extract the expanded inset from the template
    expanded_inset = template[crop_top:crop_bottom, crop_left:crop_right]

    crop_info = {
        'template_crop_x': crop_left,
        'template_crop_y': crop_top,
        'effective_margin_top': effective_margin_top,
        'effective_margin_bottom': effective_margin_bottom,
        'effective_margin_left': effective_margin_left,
        'effective_margin_right': effective_margin_right,
        'inset_x': inset_x,
        'inset_y': inset_y,
        'inset_width': inset_width,
        'inset_height': inset_height
    }

    return expanded_inset, crop_info


def revert_expanded_inset_to_inset(expanded_inset, crop_info):
    """
    Given the expanded_inset and crop_info, revert back to the original inset 
    size and position.
    
    Parameters:
    - expanded_inset: The expanded portion of the template that includes the 
      inset and margin.
    - crop_info: The dictionary returned by crop_template_to_expanded_inset 
      containing offsets and effective margins.
    
    Returns:
    - original_inset_region: The exact original inset region extracted from the
      expanded_inset.
    """
    # Determine the inset's position within the expanded_inset
    inset_y_in_expanded = crop_info['inset_y'] - crop_info['template_crop_y']
    inset_x_in_expanded = crop_info['inset_x'] - crop_info['template_crop_x']

    inset_height = crop_info['inset_height']
    inset_width = crop_info['inset_width']

    # Extract the original inset from the expanded_inset
    original_inset_region = expanded_inset[
        inset_y_in_expanded:inset_y_in_expanded + inset_height,
        inset_x_in_expanded:inset_x_in_expanded + inset_width
    ]

    original_inset_region = renumber_masks(original_inset_region)

    return original_inset_region


def sliding_window(image, window_size=(272, 272), step=136):
    """
    Generate sliding windows from the image with irregular tiling at the right and bottom edges.

    Parameters:
        image (np.array): Grayscale image as a 2D NumPy array
        window_size (tuple): (width, height) of the window.
        step (int): Step size in pixels.

    Yields:
        tuple: (window, (x, y)) where window is the cropped sub-image and (x, y) is its top-left coordinate.
    """
    win_w, win_h = window_size
    print(win_w, win_h)
    image_h, image_w = image.shape[:2]
    print(image_h, image_w)

    print(type(image_h), type(image_w), type(win_h), type(win_w))
    print(type(step))

    # Compute x positions and adjust for right edge if necessary.
    x_positions = list(range(0, image_w - win_w + 1, step))
    if x_positions[-1] != image_w - win_w:
        x_positions.append(image_w - win_w)

    # Compute y positions and adjust for bottom edge if necessary.
    y_positions = list(range(0, image_h - win_h + 1, step))
    if y_positions[-1] != image_h - win_h:
        y_positions.append(image_h - win_h)

    for y in y_positions:
        for x in x_positions:
            yield image[y:y+win_h, x:x+win_w], (x, y)

# # Example usage

# # Expand inset dimensions
# expanded_inset, crop_info = crop_template_to_expanded_inset(inset, template, margin)

# # Example: Running Cellpose (hypothetical function call)
# mask = run_cellpose(expanded_inset)  
# # mask is now aligned with expanded_inset dimensions

# # Filter expanded mask (remove cells near edges that don't meet threshold)
# filtered_mask = filter_expanded_mask(mask, crop_info, edge_threshold)

# # Crop expanded mask to original inset dimensions
# final_inset_mask = revert_expanded_inset_to_inset(filtered_mask, crop_info)