# Cellpose Exploration

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import cellpose
import cv2
import tifffile as tiff

from collections import Counter
from skimage.segmentation import find_boundaries
from skimage.color import label2rgb
from sklearn.model_selection import GridSearchCV, train_test_split
from skimage.measure import label, regionprops
from sklearn.cluster import KMeans
from cellpose import core, utils, io, models, metrics, plot, train
from cellpose.io import imread
from PIL import Image
from roifile import roiread, ImagejRoi, ROI_TYPE, ROI_OPTIONS
from matplotlib.pyplot import tight_layout
from matplotlib import colors
from itertools import cycle
from pathlib import Path
from glob import glob
from natsort import natsorted
from scipy.spatial import Voronoi, voronoi_plot_2d, cKDTree


def outline_view(img0,maski,color=[1,0,0], mode='inner'):
    """
    Generates an outline overlay of a given color onto an image.
    
    Parameters:
    img0 (numpy.ndarray): The input image on which the outline will be 
                          overlaid. If the image is grayscale, it will be 
                          converted to RGB.
    maski (numpy.ndarray): The mask used to generate the outline. This should 
                           be a binary or labeled mask.
    color (list, optional): The color of the outline in RGB format. Default is
                            red ([1, 0, 0]).
    mode (str, optional): The mode used to find boundaries. Default is 'inner'.
    
    Returns:
    numpy.ndarray: The image with the outline overlay.
    """
    # img0 = utils.rescale(img0)
    if len(img0.shape)<3:
        # img0 = image_to_rgb(img0) broken, transposing some images...
        img0 = np.stack([img0]*3,axis=-1)
    
    outlines = find_boundaries(maski,mode=mode) #not using masks_to_outlines as that gives border 'outlines'

    outY, outX = np.nonzero(outlines)
    imgout = img0.copy()
    # imgout[outY, outX] = np.array([255,0,0]) #pure red
    imgout[outY, outX] = np.array(color)

    return imgout


def centroids_from_mask(mask: np.ndarray) -> tuple:
    """
    Find centroids of cells based on a given mask.
    
    Parameters:
    mask (numpy.ndarray): A 2D array where each unique value represents a 
                          different cell.
    
    Returns:
    tuple: Two lists containing the x and y coordinates of the centroids of the
           cells, respectively.
    """
    unique_vals = np.unique(mask)
    x_values = []
    y_values = []
    
    for value in unique_vals:
        if value == 0:
            continue  # Skip background if necessary
        indices = np.where(mask == value)
        if indices[0].size == 0:
            continue  # Skip if no indices found
        centroid_y = np.mean(indices[0])  # Row indices (y-coordinates)
        centroid_x = np.mean(indices[1])  # Column indices (x-coordinates)
        x_values.append(centroid_x)
        y_values.append(centroid_y)
        
    return x_values, y_values


def plot_mask(img, mask, x, y, ExpX, ExpY):
    """
    Plots the mask with the centroid and the manually labeled coordinates.

    Parameters:
    img (ndarray): The image to be displayed.
    mask (ndarray): The mask to be overlaid on the image.
    x (list or ndarray): X-coordinates of the centroids detected by Cellpose.
    y (list or ndarray): Y-coordinates of the centroids detected by Cellpose.
    ExpX (list or ndarray): X-coordinates of the manually labeled centroids.
    ExpY (list or ndarray): Y-coordinates of the manually labeled centroids.

    Returns:
    None
    """
    # plt.nipy_spectral()
    # plt.imshow(mask)
    # plt.scatter(x, y, s=2, c='b')
    # plt.scatter(ExpX, ExpY, s=2, c='black')
    # # plt.gca().invert_yaxis()  # not needed for these x and y values
    # plt.show()

    fig, axs = plt.subplots(1, 2, sharey=True, tight_layout=True)
    fig.suptitle('Cellpose vs. Grader Output', y=0.8)

    axs[0].imshow(img)
    axs[0].imshow(outline_view(img, mask, [0,0,255], 'inner'))
    axs[0].scatter(x, y, s=2, c='b')
    axs[0].scatter(ExpX, ExpY, s=2, c='r')

    # axs[1].imshow(masks)
    axs[1].imshow(cellpose.plot.mask_rgb(mask))
    axs[1].scatter(x, y, s=2, c='b')
    axs[1].scatter(ExpX, ExpY, s=2, c='r')


def roi_reader(img_filepath, crop_margin=(0, 0)):
    """
    Reads the ROIs from the TIFF image and adjusts for cropping margins.
    
    Parameters:
    img_filepath (str): Path to the TIFF image.
    crop_margin (tuple): Margins (crop_bottom, crop_left) to adjust the ROIs. 
                         Default is (0, 0).
    
    Returns:
    tuple: Lists of x coordinates, y coordinates, and reversed y coordinates.
    """
    crop_bottom, crop_left = crop_margin
    
    # Image size
    im = Image.open(img_filepath)
    width, height = im.size

    # Reading Image Metadata
    ExpX = []
    ExpY = []
    Reverse_ExpY = []
    
    for roi in roiread(img_filepath):
        coords = roi.subpixel_coordinates  # roi.integer_coordinates previously used for Heyex annotations, fixed to match roi.subpixel_coordinates for ImageJ annotations
        coords = np.hsplit(coords, 2)
        
        x_coords = np.ndarray.flatten(coords[0])
        y_coords = np.ndarray.flatten(coords[1])
        
        # Adjust x-coordinates based on the crop margin
        adjusted_x_coords = x_coords - crop_left
        
        # Remove ROIs within the cropped region
        valid_indices = (adjusted_x_coords >= 0) & (adjusted_x_coords < (width - crop_left)) & \
                        (y_coords >= 0) & (y_coords < (height - crop_bottom))
        
        adjusted_x_coords = adjusted_x_coords[valid_indices]
        valid_y_coords = y_coords[valid_indices]
        
        ExpX.extend(adjusted_x_coords)
        ExpY.extend(valid_y_coords)
        Reverse_ExpY.extend([abs(height - y) for y in valid_y_coords])
    
    return ExpX, ExpY, Reverse_ExpY


def save_image_with_rois(image, x_coords, y_coords, output_path):
    """
    Save an image with ImageJ Point ROIs.

    Parameters:
    image (numpy.ndarray): The image read using cv2.
    x_coords (list of floats): List of x coordinates.
    y_coords (list of floats): List of y coordinates.
    output_path (str): Path to save the TIFF image with ROIs.
    """
    # Convert the separate x and y coordinates to the required format for ImagejRoi
    points = np.array(list(zip(x_coords, y_coords)), dtype=np.float32)
    
    # Create an ImagejRoi instance for points
    roi = ImagejRoi.frompoints(points)
    roi.roitype = ROI_TYPE.POINT
    roi.options |= ROI_OPTIONS.SHOW_LABELS

    tiff.imwrite(output_path, image, imagej=True, metadata={'ROI': roi.tobytes()})


def renumber_masks(masks):
    """
    Renumber the values in a 2D array of masks to be consecutive integers 
    starting from 1, excluding the background value of 0.

    Parameters:
        masks (numpy.ndarray): A 2D array of integer mask values.

    Returns:
        numpy.ndarray: A 2D array with the mask values renumbered to be consecutive integers.
    """
    # Find unique mask labels excluding background
    unique_labels = np.unique(masks)
    unique_labels = unique_labels[unique_labels != 0]

    # Create a mapping array
    max_label = masks.max()
    mapping_array = np.zeros(max_label + 1, dtype=int)
    mapping_array[unique_labels] = np.arange(1, len(unique_labels) + 1)

    # Apply the mapping to renumber the masks
    renumbered_masks = mapping_array[masks]

    return renumbered_masks


def remove_edges(masks, right_and_bottom: float = 1.0, size_percentile: float = 0.0):
    """
    Removes the masks that are on the edges of the image based on the 
    proportion of pixels that are on the edge for each cell and the size 
    percentile threshold. (e.g. 0.0 means all pixels on edge are removed and 
    1.0 means no pixels on edge for each cell are removed). Cells smaller than 
    the specified percentile will be removed if they touch the right or bottom 
    edges. Default values mean no cells are removed.
    
    Parameters:
    masks (numpy.ndarray): The input mask array where each cell is represented 
                           by a unique integer.
    right_and_bottom (float, optional): Proportion threshold for removing cells 
                                        touching the right or bottom edges.
                                        Default is 1.0, meaning no cells are 
                                        removed based on edge proportion.
    size_percentile (float, optional): Percentile threshold for removing small 
                                       cells touching the right or bottom edges.
                                       Default is 0.0, meaning no cells are 
                                       removed based on size.
    Returns:
    numpy.ndarray: The mask array with cells removed based on the specified 
    thresholds.
    """

    # Handle completely empty mask
    if np.all(masks == 0):
        return masks

    # Get mask values for bottom and right edge of image
    column_vals = masks[:,-1]  # Last column of mask (right edge)
    row_vals = masks[-1,:]  # Last row of mask (bottom edge)
    edge_mask = np.concatenate((column_vals, row_vals))
    edge_mask = edge_mask[edge_mask != 0]  # Get all non-zero values
    edge_mask_counts = dict(Counter(edge_mask))

    # Create dictionary of counts and proportion for each edge_mask value in masks
    mask_counts = dict(Counter(masks.flatten()))
    edge_mask_proportion = {i: (edge_mask_counts.get(i, 0) / mask_counts[i]) for i in edge_mask}
    
    # Calculate the size percentile threshold
    cell_sizes = [count for cell, count in mask_counts.items() if cell != 0]
    size_threshold = np.percentile(cell_sizes, size_percentile)
    
    # Set values in mask to 0 if proportion of edge pixels in whole cell is greater than right_and_bottom
    # or if the cell size is smaller than the size threshold
    edges_removed_mask = np.copy(masks)
    for i in range(len(edges_removed_mask)):
        for j in range(len(edges_removed_mask[i])):
            cell = edges_removed_mask[i][j]
            if cell in edge_mask:
                if edge_mask_proportion[cell] > right_and_bottom or mask_counts[cell] < size_threshold:
                    edges_removed_mask[i][j] = 0

    edges_removed_mask = renumber_masks(edges_removed_mask)

    return edges_removed_mask


def remove_edges_updated(masks, rb_threshold: float = 1.0, rb_size_percentile: float = 0.0, 
                         lt_threshold: float = 1.0, lt_size_percentile: float = 0.0):
    """
    Removes cells based on their edge contact. Cells touching the right or bottom edges are removed 
    if the proportion of their pixels on those edges is greater than rb_threshold or if their total 
    size is below the rb_size_percentile threshold. Similarly, cells touching the left or top edges 
    are removed if the proportion of their pixels on those edges is greater than lt_threshold or if 
    their size is below the lt_size_percentile threshold.

    Parameters:
    masks (numpy.ndarray): The input mask array where each cell is represented by a unique integer.
    rb_threshold (float, optional): Proportion threshold for removing cells touching the right or bottom edges.
                                    Default is 1.0 (no removal based on edge proportion).
    rb_size_percentile (float, optional): Percentile threshold for removing small cells touching the right or bottom edges.
                                          Default is 0.0 (no removal based on size).
    lt_threshold (float, optional): Proportion threshold for removing cells touching the left or top edges.
                                    Default is 1.0 (no removal based on edge proportion).
    lt_size_percentile (float, optional): Percentile threshold for removing small cells touching the left or top edges.
                                          Default is 0.0 (no removal based on size).

    Returns:
    numpy.ndarray: The mask array with cells removed based on the specified thresholds.
    """
    from collections import Counter

    # Handle completely empty mask
    if np.all(masks == 0):
        return masks

    # Right/Bottom edges
    rb_column_vals = masks[:, -1]  # Right edge (last column)
    rb_row_vals = masks[-1, :]      # Bottom edge (last row)
    edge_mask_rb = np.concatenate((rb_column_vals, rb_row_vals))
    edge_mask_rb = edge_mask_rb[edge_mask_rb != 0]  # Only non-zero cell labels
    edge_mask_counts_rb = dict(Counter(edge_mask_rb))

    # Left/Top edges
    lt_column_vals = masks[:, 0]   # Left edge (first column)
    lt_row_vals = masks[0, :]       # Top edge (first row)
    edge_mask_lt = np.concatenate((lt_column_vals, lt_row_vals))
    edge_mask_lt = edge_mask_lt[edge_mask_lt != 0]
    edge_mask_counts_lt = dict(Counter(edge_mask_lt))

    # Count all pixels for each cell
    mask_counts = dict(Counter(masks.flatten()))

    # Compute proportion of edge pixels per cell for each group
    edge_mask_proportion_rb = {i: (edge_mask_counts_rb.get(i, 0) / mask_counts[i]) for i in set(edge_mask_rb)}
    edge_mask_proportion_lt = {i: (edge_mask_counts_lt.get(i, 0) / mask_counts[i]) for i in set(edge_mask_lt)}

    # Calculate cell size threshold based on the specified percentiles
    cell_sizes = [count for cell, count in mask_counts.items() if cell != 0]
    size_threshold_rb = np.percentile(cell_sizes, rb_size_percentile)
    size_threshold_lt = np.percentile(cell_sizes, lt_size_percentile)

    # Determine which cells to remove based on the respective edge criteria
    cells_rb = set(edge_mask_rb)
    cells_lt = set(edge_mask_lt)
    cells_to_remove = set()

    # Check right/bottom edges criteria
    for cell in cells_rb:
        if edge_mask_proportion_rb.get(cell, 0) > rb_threshold or mask_counts[cell] < size_threshold_rb:
            cells_to_remove.add(cell)

    # Check left/top edges criteria
    for cell in cells_lt:
        if edge_mask_proportion_lt.get(cell, 0) > lt_threshold or mask_counts[cell] < size_threshold_lt:
            cells_to_remove.add(cell)

    # Remove cells that meet removal criteria by setting them to 0
    edges_removed_mask = np.copy(masks)
    for cell in cells_to_remove:
        edges_removed_mask[masks == cell] = 0

    # Renumber the masks so the remaining cells are consecutive integers
    edges_removed_mask = renumber_masks(edges_removed_mask)

    return edges_removed_mask


def filter_expanded_mask_updated(mask, crop_info, edge_threshold_bottom_right, edge_threshold_top_left):
    """
    Given a mask segmented on an expanded inset, apply the proportion threshold
    logic to remove masks that don't meet the specified criteria near truncated
    edges, and remove any cells completely outside the inset region.

    Parameters:
    - mask: numpy array of integer masks (e.g., Cellpose masks) corresponding 
      to the expanded_inset.
    - crop_info: Dictionary returned by crop_template_to_expanded_inset 
      containing:
        {
            'template_crop_x': ...,
            'template_crop_y': ...,
            'effective_margin_top': ...,
            'effective_margin_bottom': ...,
            'effective_margin_left': ...,
            'effective_margin_right': ...,
            'inset_x': ...,
            'inset_y': ...,
            'inset_width': ...,
            'inset_height': ...
        }
    - edge_threshold_top_left: Proportion threshold for keeping a mask that touches
      the top or left edge.
    - edge_threshold_bottom_right: Proportion threshold for keeping a mask that touches
      the bottom or right edge.

    Returns:
    - filtered_mask: Mask after removing cells that don't meet the proportion 
      threshold criteria, including those completely outside the inset.
    """

    # Extract info from crop_info
    template_crop_x = crop_info['template_crop_x']
    template_crop_y = crop_info['template_crop_y']
    effective_margin_top = crop_info['effective_margin_top']
    effective_margin_bottom = crop_info['effective_margin_bottom']
    effective_margin_left = crop_info['effective_margin_left']
    effective_margin_right = crop_info['effective_margin_right']
    inset_x = crop_info['inset_x']
    inset_y = crop_info['inset_y']
    inset_width = crop_info['inset_width']
    inset_height = crop_info['inset_height']

    # Determine where the original inset lies within the expanded inset
    inset_y_in_expanded = inset_y - template_crop_y
    inset_x_in_expanded = inset_x - template_crop_x

    # Identify the regions corresponding to truncated margins
    # Top margin region:    [0 : inset_y_in_expanded, :]
    # Left margin region:   [:, 0 : inset_x_in_expanded]
    # Bottom margin region: [inset_y_in_expanded+inset_height : end, :]
    # Right margin region:  [:, inset_x_in_expanded+inset_width : end]

    height, width = mask.shape
    top_region = (0, inset_y_in_expanded) if effective_margin_top > 0 else None
    left_region = (0, inset_x_in_expanded) if effective_margin_left > 0 else None
    bottom_region = (inset_y_in_expanded + inset_height, height) if effective_margin_bottom > 0 else None
    right_region = (inset_x_in_expanded + inset_width, width) if effective_margin_right > 0 else None

    filtered_mask = np.copy(mask)
    unique_masks = np.unique(filtered_mask)
    unique_masks = unique_masks[unique_masks != 0]

    for m in unique_masks:
        original_mask = (filtered_mask == m)
        original_area = np.sum(original_mask)
        if original_area == 0:
            continue

        # Area inside the original inset region
        inset_region_mask = original_mask[
            inset_y_in_expanded:inset_y_in_expanded + inset_height,
            inset_x_in_expanded:inset_x_in_expanded + inset_width
        ]
        cropped_area = np.sum(inset_region_mask)

        # Remove masks completely outside the inset region
        if cropped_area == 0:
            filtered_mask[filtered_mask == m] = 0
            continue

        # Check each edge
        crosses_top = False
        if top_region is not None:
            y_start, y_end = top_region
            if y_end > 0 and np.any(original_mask[y_start:y_end, :]):
                crosses_top = True

        crosses_left = False
        if left_region is not None:
            x_start, x_end = left_region
            if x_end > 0 and np.any(original_mask[:, x_start:x_end]):
                crosses_left = True

        crosses_bottom = False
        if bottom_region is not None:
            y_start, y_end = bottom_region
            if y_start < height and np.any(original_mask[y_start:y_end, :]):
                crosses_bottom = True

        crosses_right = False
        if right_region is not None:
            x_start, x_end = right_region
            if x_start < width and np.any(original_mask[:, x_start:x_end]):
                crosses_right = True

        # Apply threshold logic with independent parameters:
        if (crosses_top or crosses_left):
            if cropped_area / original_area < edge_threshold_top_left:
                filtered_mask[filtered_mask == m] = 0

        if (crosses_bottom or crosses_right):
            if cropped_area / original_area < edge_threshold_bottom_right:
                filtered_mask[filtered_mask == m] = 0

    # Optionally renumber masks if desired
    filtered_mask = renumber_masks(filtered_mask)

    return filtered_mask


def remove_masks(masks, size: int):
    """
    Remove small masks from the input array and renumber the remaining masks to
    have consecutive integer values.
    
    Parameters:
    masks (numpy.ndarray): A 2D array where each unique integer represents a 
                           different mask.
    size (int): The minimum size (number of pixeks) a mask must have to be 
                retained.
    
    Returns:
    numpy.ndarray: A 2D array with small masks removed and remaining masks 
    renumbered to have consecutive integer values.
    """
    masks_removed = np.copy(masks)

    unique_masks, counts = np.unique(masks_removed, return_counts=True)

    for i in range(len(unique_masks)):
        if counts[i] < size:
            masks_removed[masks_removed == unique_masks[i]] = 0

    masks_removed = renumber_masks(masks_removed)

    return masks_removed


def remove_unmatched_masks(img, masks):
    """
    Removes Cellpose masks that do not intersect with marked ROIs in the image.
    
    Parameters:
        img (str): Path to the input image from which manual ROIs are extracted.
        masks (numpy.ndarray): The Cellpose masks to be filtered.
    
    Returns:
        tuple:
            - numpy.ndarray: The filtered masks with unmatched masks removed and renumbered consecutively.
            - numpy.ndarray: An array of points that correspond to the remaining masks.
    """
    # Get unique mask labels
    unique_masks = np.unique(masks)

    # Extract manual labels from the image
    ExpX, ExpY, Reverse_ExpY = roi_reader(img)
    x = np.array(ExpX)
    y = np.array(ExpY)
    points = np.column_stack((x, y))  # Points are (x, y)

    to_keep = set()  # Masks to keep
    corresponding_points = []  # Points corresponding to kept masks

    for n in unique_masks[unique_masks > 0]:
        # Create a binary mask for the current label
        mask_n = (masks == n).astype(np.uint8)
        contours, _ = cv2.findContours(mask_n, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        
        if contours:
            contour = contours[0]
            # Check if any point is inside the contour
            for point in points:
                x_pt, y_pt = point
                result = cv2.pointPolygonTest(contour, (int(x_pt), int(y_pt)), False)
                if result >= 0:  # Point is inside or on the edge of the contour
                    to_keep.add(n)
                    corresponding_points.append(point)
                    break  # No need to check other points for this mask
        else:
            print(f"No contours found for mask {n}")

    # Remove masks not in 'to_keep' using boolean indexing
    masks_to_keep = np.isin(masks, list(to_keep))
    masks[~masks_to_keep] = 0  # Set unmatched masks to zero

    # Renumber the remaining masks to have consecutive integer values
    unique_labels = np.array(sorted(to_keep))
    unique_labels = unique_labels.astype(int)  # Ensure integer type

    new_labels = np.arange(1, len(unique_labels) + 1)
    label_mapping = dict(zip(unique_labels, new_labels))

    # Create a mapping array for efficient relabeling
    max_label = int(masks.max())
    label_mapping_array = np.zeros(max_label + 1, dtype=int)
    label_mapping_array[unique_labels] = new_labels
    masks_renumbered = label_mapping_array[masks]

    return masks_renumbered, np.array(corresponding_points)


def fill_array_voronoi(coordinates, shape):
    """
    Fills an array using the Voronoi diagram approach.
    
    Parameters:
    coordinates (np.ndarray): An array of shape (n, 2) containing the 
                              coordinates of the points.
    shape (tuple): A tuple (height, width) representing the shape of the output
                   array.
    
    Returns:
    np.ndarray: A 2D array of the given shape where each element is the index 
                of the nearest coordinate.
    """

    filled_array = np.empty(shape, dtype=int)

    for i in range(shape[0]):
        for j in range(shape[1]):
            filled_array[i][j] = np.argmin(np.linalg.norm(coordinates - np.array([j, i]), axis=1))
    
    return filled_array


def fill_gaps(mask, min_size):
    """
    Fill gaps (regions of 0s) within a mask that are at least `min_size` pixels in area,
    provided they do not touch the border of the mask. Each filled gap is assigned a unique integer value,
    starting from the next unused integer in the mask.
    
    Parameters:
        mask (np.ndarray): 2D array of integers where 0 indicates background.
        min_size (int): Minimum number of pixels a gap must have to be filled.
        
    Returns:
        np.ndarray: The modified mask with gaps filled.
    """
    # Determine the next unused label value.
    next_label = mask.max() + 1
    
    # Create a binary mask of the background (where mask == 0).
    background = (mask == 0)
    
    # Label connected regions in the background.
    # Using connectivity=1 gives 4-connected regions.
    labeled_bg = label(background, connectivity=1)
    
    # Iterate over each connected component (gap) in the background.
    for region in regionprops(labeled_bg):
        # The bounding box of the region: (min_row, min_col, max_row, max_col)
        min_row, min_col, max_row, max_col = region.bbox
        
        # Check if the region touches any border.
        if min_row == 0 or min_col == 0 or max_row == mask.shape[0] or max_col == mask.shape[1]:
            continue  # Skip regions touching the border.
        
        # Fill the gap if its area meets or exceeds the minimum size.
        if region.area >= min_size:
            mask[labeled_bg == region.label] = next_label
            next_label += 1  # Increment for the next gap.
    
    return mask


def create_masks_from_labels(image, masks, corresponding_points, radius=10):
    """
    Adds masks for unmatched manual ROIs using a Voronoi diagram approach.

    Parameters:
        image (str): Path to the TIFF image containing manual labels.
        masks (np.ndarray): Existing masks array.
        corresponding_points (np.ndarray): Points that correspond to existing masks.
        radius (int): Radius of the circular masks to be drawn.

    Returns:
        np.ndarray: Updated masks array with new masks added for unmatched points.
    """
    # Extract manual labels from the image
    ExpX, ExpY, _ = roi_reader(image)
    x = np.array(ExpX)
    y = np.array(ExpY)
    all_points = np.column_stack((x, y))

    # Convert points to integers to avoid floating-point precision issues
    all_points_int = np.round(all_points).astype(int)
    corresponding_points_int = np.round(corresponding_points).astype(int)

    # Identify unmatched points
    all_points_set = set(map(tuple, all_points_int))
    corresponding_points_set = set(map(tuple, corresponding_points_int))
    unmatched_points_set = all_points_set - corresponding_points_set
    unmatched_points = np.array(list(unmatched_points_set))

    if len(unmatched_points) == 0:
        return masks

    # Start mask labels from the next available integer
    mask_value_start = masks.max() + 1

    # Create circular masks for unmatched points
    circle_masks = np.zeros_like(masks, dtype=np.int32)
    for i, point in enumerate(unmatched_points):
        label = mask_value_start + i
        cv2.circle(circle_masks, (int(point[0]), int(point[1])), int(radius), int(label), -1)

    # Generate Voronoi labels for the entire image
    voronoi_array = fill_array_voronoi(unmatched_points, masks.shape)

    # Adjust Voronoi labels to match mask labels
    voronoi_array = mask_value_start + voronoi_array

    # Keep only Voronoi regions within the circular masks
    new_masks = np.where(circle_masks > 0, voronoi_array, 0)

    # Remove overlaps with existing masks
    overlap = (new_masks > 0) & (masks > 0)
    new_masks[overlap] = 0

    # Combine new masks with existing masks
    masks = np.maximum(masks, new_masks)

    # Renumber masks to ensure consecutive labels
    masks = renumber_masks(masks)

    return masks


def split_masks(img, masks):
    """
    Splits masks with multiple regions of interest (ROIs) into separate masks.

    The function performs the following steps:
    1. Extracts unique masks from the input mask image.
    2. Reads manual ROI labels from the input image.
    3. Identifies masks that have matching manual ROIs.
    4. For masks with multiple ROIs, splits them into separate masks using 
       nearest neighbor clustering.
    5. Renumbers the masks to ensure consecutive integer values.

    Args:
        img (str): Path to input image (TIFF) from which manual labels are extracted.
        masks (numpy.ndarray): The mask image where each unique value represents a different mask.

    Returns:
        numpy.ndarray: The modified mask image with split masks.
    """

    unique_masks = np.unique(masks)

    # Extract manual labels from TIFF image
    ExpX, ExpY, Reverse_ExpY = roi_reader(img)
    x = np.array(ExpX)
    y = np.array(ExpY)
    points = np.column_stack((x, y))  # Points are (x, y)

    to_keep = set()  # Set of masks with matching manual ROI
    mask_to_point = {i: [] for i in unique_masks[unique_masks > 0]}

    # For each mask, find if any manual points are inside its contour
    for n in unique_masks[unique_masks > 0]:
        mask_n = np.uint8(masks == n)
        contours, _ = cv2.findContours(mask_n, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        if contours:
            contour = contours[0]
            for point in points:
                x_pt, y_pt = point
                result = cv2.pointPolygonTest(contour, (int(x_pt), int(y_pt)), False)
                if result == 1:  # Point is inside the contour
                    to_keep.add(n)
                    mask_to_point[n].append((x_pt, y_pt))
        else:
            print(f"No contours found for mask {n}")

    # Identify masks that have multiple manual ROIs
    multiple_roi = [mask for mask in mask_to_point if len(mask_to_point[mask]) > 1]

    # Collect all points corresponding to masks with multiple ROIs
    flat_duplicate_roi = [pt for mask in multiple_roi for pt in mask_to_point[mask]]

    # Ensure that new labels do not overlap with existing labels
    mask_start_val = masks.max() + 1

    # Create mapping from points to new mask labels
    unique_points = list(set(flat_duplicate_roi))
    mapping_dict = {tuple(point): (mask_start_val + i) for i, point in enumerate(unique_points)}

    # Split masks with multiple ROIs
    for individual_mask in multiple_roi:
        points = mask_to_point[individual_mask]  # List of (x, y) points
        # Create a binary mask for the individual mask
        separate_mask = (masks == individual_mask).astype(np.uint8)

        # Get coordinates of all pixels in the mask
        indices = np.column_stack(np.where(separate_mask > 0))  # (row, col)
        if len(indices) == 0:
            continue

        # Build KD-Tree for the manual points
        tree = cKDTree(points)  # Points are (x, y)

        # Get pixel coordinates in (x, y) order
        pixel_coords = indices[:, [1, 0]]  # Swap columns to get (col, row) -> (x, y)

        # Query the nearest manual point for each pixel
        _, idx = tree.query(pixel_coords)
        closest_points = np.array(points)[idx]  # Get the corresponding points

        # Map each closest point to a new label
        new_labels = np.array([mapping_dict[tuple(p)] for p in closest_points])

        # Assign new labels to the pixels
        masks[indices[:, 0], indices[:, 1]] = new_labels

    # Renumber the masks to ensure consecutive integer values
    masks = renumber_masks(masks)
    return masks


def mask_rgba(masks, alpha=128):
    """
    Sets 0 masks to fully transparent and adjustable transparency for non-zero 
    masks.

    Args:
        masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels.
        alpha (int, optional): Transparency level for non-zero mask values, in 
                               the range 0-255.

    Returns:
        RGBA (uint8, 3D array): Array with transparency applied.
    """
    # Initialize RGBA array with the same width and height as masks, and 4 channels (RGBA)
    RGBA = np.zeros((masks.shape[0], masks.shape[1], 4), np.uint8)
    
    # Set the alpha channel
    RGBA[:, :, 3] = (masks == 0) * alpha
    
    return RGBA


def grayscale_clahe(img):
    # Make 3 channel grayscale w/ CLAHE
    image = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    cl1 = clahe.apply(image)
    image = np.expand_dims(cl1, axis=2)
    image = np.repeat(image, 3, axis=2)

    return image


def img_gray_clahe(image):
    # Converts images to grayscale and applies CLAHE

    # Handle images separated to 3 channels
    if image.shape[0] == 3:
        image = image.transpose(1, 2, 0)
    # If image is grayscale, convert to 3-channel
    if len(image.shape) == 2 or (len(image.shape) == 3 and image.shape[2] == 1):
        image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)

    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
    cl1 = clahe.apply(image)
    image = np.expand_dims(cl1, axis=2)
    image = np.repeat(image, 3, axis=2)

    return image