#@ UpdateService updateService

import os
from ij import IJ, ImagePlus, ImageStack
from ij.plugin import Duplicator, HyperStackConverter
from ij.gui import WaitForUserDialog

if not updateService.getUpdateSite("BIG-EPFL").isActive():
	error_msg = "This plugin uses the StackReg plugin from BIG-EPFL. Please activate the BIG-EPFL Update site !"
	WaitForUserDialog('Error', error_msg).show()
	raise Exception(error_msg)

imp = IJ.getImage()
dimX, dimY, dimC, dimZ, dimT = imp.getDimensions()
refC = imp.getC()  # The StackReg script is run on the currently selected channel
final_name = os.path.splitext(imp.getTitle())[0] + "-StackReg"

# If we have a z-stack, we process without modifications
if dimZ > 1 and dimT == 1:
	dimStack = dimZ
	refStack = imp.getZ()

# If we have a film, we reorder the image dimensions to convert it into a z-stack
elif dimZ == 1 and dimT > 1:
	dimStack = dimT
	refStack = imp.getT()
	imp = HyperStackConverter.toHyperStack(imp.duplicate(), dimC, dimT, dimZ)

# If we have a film, we reorder the image dimensions to convert it into a z-stack
else:
	error_msg = "Cannot process an image with both z-stack and frames !"
	WaitForUserDialog('Error', error_msg).show()
	raise Exception(error_msg)

# The original StackReg script is run on the reference channel
imp_ref = Duplicator().run(imp, refC, refC, 1, dimStack, 1, 1)
imp_ref.setZ(refStack)
imp_ref.show()
IJ.run(imp_ref, "StackReg", "transformation=Translation")

final_stack = ImageStack(dimX, dimY)

# For each slice of the image ...
for z in range(1, dimStack + 1):
	imp_ref.setZ(z)

	# ... we detect the width of the black margin on left / right of the aligned reference channel ...
	offset_x = 0
	middle_y = dimY // 2
	while imp_ref.getPixel(offset_x, middle_y)[0] == 0:
		offset_x += 1
	if offset_x == 0:
		while imp_ref.getPixel(dimX - 1 + offset_x, middle_y)[0] == 0:
			offset_x -= 1

	# ... we also detect the height of the black margin on top / bottom ...
	offset_y = 0
	middle_x = dimX // 2
	while imp_ref.getPixel(middle_x, offset_y)[0] == 0:
		offset_y += 1
	if offset_y == 0:
		while imp_ref.getPixel(middle_x, dimY - 1 + offset_y)[0] == 0:
			offset_y -= 1

	# ... and we apply those shifts to all channels.
	for c in range(1, dimC + 1):
		if c == refC:
			final_stack.addSlice(imp_ref.getProcessor())
		else:
			slice = Duplicator().run(imp, c, c, z, z, 1, 1)
			IJ.run(slice, "Translate...", "x=" + str(offset_x) + " y=" + str(offset_y) + " interpolation=None")
			final_stack.addSlice(slice.getProcessor())

imp_ref.close()

# We construct the final hyperstack, with correct dimensions no matter the potential z-t reorder.
final_imp = ImagePlus(final_name, final_stack)
final_imp = HyperStackConverter.toHyperStack(final_imp, dimC, dimZ, dimT)

# We duplicate the LUTs and display ranges from the initial images.
final_imp.setLuts(imp.getLuts())
for c in range(1, dimC + 1):
	imp.setC(c)
	final_imp.setC(c)
	final_imp.setDisplayRange(imp.getDisplayRangeMin(), imp.getDisplayRangeMax())

final_imp.show()
