# Robert Bryson-Richardson and Albert Cardona 2010-10-08 at Estoril, Portugal
# EMBO Developmental Imaging course by Gabriel Martins
#
# Register time frames (stacks) to each other using Stitching_3D library
# to compute translations only, in all 3 spatial axes.
# Operates on a virtual stack.
# 23/1/13 -
# added user dialog to make use of virtual stack an option
# 10/01/16-
# Christian Tischer (tischitischer@gmail.com)
# major changes and additions:
# - it now also works for 2D time-series (used to be 3D only)
# - option: measure drift on multiple timescales (this allows to also find slow drift components of less than 1 pixel per frame)
# - option: correct sub-pixel drift computing the shifted images using TransformJ
# - option: if a ROI is put on the image, only this part of the image is considered for drift computation
#           the ROI is moved along with the detected drift thereby tracking the structure of interest
# - macro recording is compatible with previous version

from ij import VirtualStack, IJ, CompositeImage, ImageStack, ImagePlus
from ij.process import ColorProcessor
from ij.io import DirectoryChooser, FileSaver
from ij.gui import GenericDialog, YesNoCancelDialog, Roi
from mpicbg.imglib.image import ImagePlusAdapter
from mpicbg.imglib.algorithm.fft import PhaseCorrelation
from org.scijava.vecmath import Point3i  #from javax.vecmath import Point3i # java6
from org.scijava.vecmath import Point3f  #from javax.vecmath import Point3f # java6
from java.io import File, FilenameFilter
from java.lang import Integer
import math

# sub-pixel translation using imglib2
from net.imagej.axis import Axes
from net.imglib2.img.display.imagej import ImageJFunctions
from net.imglib2.realtransform import RealViews, Translation3D
from net.imglib2.view import Views
from net.imglib2.img.imageplus import ImagePlusImgs
from net.imglib2.converter import Converters
from net.imglib2.converter.readwrite import RealFloatSamplerConverter
from net.imglib2.interpolation.randomaccess import NLinearInterpolatorFactory

def translate_single_stack_using_imglib2(imp, dx, dy, dz):
  # wrap into a float imglib2 and translate
  #   conversion into float is necessary due to "overflow of n-linear interpolation due to accuracy limits of unsigned bytes"
  #   see: https://github.com/fiji/fiji/issues/136#issuecomment-173831951
  img = ImagePlusImgs.from(imp.duplicate())
  extended = Views.extendBorder(img)
  converted = Converters.convert(extended, RealFloatSamplerConverter())
  interpolant = Views.interpolate(converted, NLinearInterpolatorFactory())
  transformed = RealViews.affine(interpolant, Translation3D(dx, dy, dz))
  cropped = Views.interval(transformed, img)
  # wrap back into bit depth of input image and return
  bd = imp.getBitDepth()
  if bd==8:
    return(ImageJFunctions.wrapUnsignedByte(cropped,"imglib2"))
  elif bd == 16:
    return(ImageJFunctions.wrapUnsignedShort(cropped,"imglib2"))
  elif bd == 32:
    return(ImageJFunctions.wrapFloat(cropped,"imglib2"))
  else:
    return None

def compute_stitch(imp1, imp2):
  """ Compute a Point3i that expressed the translation of imp2 relative to imp1."""
  phc = PhaseCorrelation(ImagePlusAdapter.wrap(imp1), ImagePlusAdapter.wrap(imp2), 5, True)
  phc.process()
  p = phc.getShift().getPosition()
  if len(p)==3: # 3D data
    p3 = p
  elif len(p)==2: # 2D data: add zero shift
    p3 = [p[0],p[1],0]
  return Point3i(p3)

def extract_frame(imp, frame, channel):
  """ From a VirtualStack that is a hyperstack, contained in imp,
  extract the timepoint frame as an ImageStack, and return it.
  It will do so only for the given channel. """
  stack = imp.getStack() # multi-time point virtual stack
  vs = ImageStack(imp.width, imp.height, None)
  for s in range(1, imp.getNSlices()+1):
    i = imp.getStackIndex(channel, s, frame)
    vs.addSlice(str(s), stack.getProcessor(i))
  return vs

def extract_frame_process_roi(imp, frame, channel, process, roi):
  # extract frame and channel
  imp_frame = ImagePlus("", extract_frame(imp, frame, channel)).duplicate()
  # check for roi and crop
  if roi != None:
    #print roi.getBounds()
    imp_frame.setRoi(roi)
    IJ.run(imp_frame, "Crop", "")
  # process
  if process:
    IJ.run(imp_frame, "Mean 3D...", "x=1 y=1 z=0");
    IJ.run(imp_frame, "Find Edges", "stack");
  # return
  return imp_frame

def add_Point3f(p1, p2):
  p3 = Point3f(0,0,0)
  p3.x = p1.x + p2.x
  p3.y = p1.y + p2.y
  p3.z = p1.z + p2.z
  return p3

def subtract_Point3f(p1, p2):
  p3 = Point3f(0,0,0)
  p3.x = p1.x - p2.x
  p3.y = p1.y - p2.y
  p3.z = p1.z - p2.z
  return p3

def shift_between_rois(roi2, roi1):
  """ computes the relative xy shift between two rois
  """
  dr = Point3f(0,0,0)
  dr.x = roi2.getBounds().x - roi1.getBounds().x
  dr.y = roi2.getBounds().y - roi1.getBounds().y
  dr.z = 0
  return dr

def shift_roi(imp, roi, dr):
  """ shifts a roi in x,y by dr.x and dr.y
  if the shift would cause the roi to be outside the imp,
  it only shifts as much as possible maintaining the width and height
  of the input roi
  """
  if roi == None:
    return roi
  else:
    r = roi.getBounds()
    # init x,y coordinates of new shifted roi
    sx = 0
    sy = 0
    # x shift
    if (r.x + dr.x) < 0:
      sx = 0
    elif (r.x + dr.x + r.width) > imp.width:
      sx = int(imp.width-r.width)
    else:
      sx = r.x + int(dr.x)
    # y shift
    if (r.y + dr.y) < 0:
      sy = 0
    elif (r.y + dr.y + r.height) > imp.height:
      sy = int(imp.height-r.height)
    else:
      sy = r.y + int(dr.y)
    # return shifted roi
    shifted_roi = Roi(sx, sy, r.width, r.height)
    return shifted_roi

def compute_and_update_frame_translations_dt(imp, channel, dt, process, shifts = None):
  """ imp contains a hyper virtual stack, and we want to compute
  the X,Y,Z translation between every t and t+dt time points in it
  using the given preferred channel.
  if shifts were already determined at other (lower) dt
  they will be used and updated.
  """
  nt = imp.getNFrames()
  # get roi (could be None)
  roi = imp.getRoi()
  if roi:
    print "ROI is at", roi.getBounds()
  # init shifts
  if shifts == None:
    shifts = []
    for t in range(nt):
      shifts.append(Point3f(0,0,0))
  # compute shifts
  IJ.showProgress(0)
  for t in range(dt, nt+dt, dt):
    if t > nt-1: # together with above range till nt+dt this ensures that the last data points are not missed out
      t = nt-1 # nt-1 is the last shift (0-based)
    IJ.log("      between frames "+str(t-dt+1)+" and "+str(t+1))
    # get (cropped and processed) image at t-dt
    roi1 = shift_roi(imp, roi, shifts[t-dt])
    imp1 = extract_frame_process_roi(imp, t+1-dt, channel, process, roi1)
    # get (cropped and processed) image at t-dt
    roi2 = shift_roi(imp, roi, shifts[t])
    imp2 = extract_frame_process_roi(imp, t+1, channel, process, roi2)
    if roi:
      print "ROI at frame",t-dt+1,"is",roi1.getBounds()
      print "ROI at frame",t+1,"is",roi2.getBounds()
    # compute shift
    local_new_shift = compute_stitch(imp2, imp1)
    if roi: # total shift is shift of rois plus measured drift
      print "correcting measured drift of",local_new_shift,"for roi shift:",shift_between_rois(roi2, roi1)
      local_new_shift = add_Point3f(local_new_shift, shift_between_rois(roi2, roi1))
    # determine the shift that we knew alrady
    local_shift = subtract_Point3f(shifts[t],shifts[t-dt])
    # compute difference between new and old measurement (which come from different dt)
    add_shift = subtract_Point3f(local_new_shift,local_shift)
    print "++ old shift between %s and %s: dx=%s, dy=%s, dz=%s" % (int(t-dt+1),int(t+1),local_shift.x,local_shift.y,local_shift.z)
    print "++ add shift between %s and %s: dx=%s, dy=%s, dz=%s" % (int(t-dt+1),int(t+1),add_shift.x,add_shift.y,add_shift.z)
    # update shifts from t-dt to the end (assuming that the measured local shift will presist till the end)
    for i,tt in enumerate(range(t-dt,nt)):
      # for i>dt below expression basically is a linear drift predicition for the frames at tt>t
      # this is only important for predicting the best shift of the ROI
      # the drifts for i>dt will be corrected by the next measurements
      shifts[tt].x += 1.0*i/dt * add_shift.x
      shifts[tt].y += 1.0*i/dt * add_shift.y
      shifts[tt].z += 1.0*i/dt * add_shift.z
      print "updated shift till frame",tt+1,"is",shifts[tt].x,shifts[tt].y,shifts[tt].z
    IJ.showProgress(1.0*t/(nt+1))

  IJ.showProgress(1)
  return shifts

def convert_shifts_to_integer(shifts):
  int_shifts = []
  for shift in shifts:
    int_shifts.append(Point3i(int(round(shift.x)),int(round(shift.y)),int(round(shift.z))))
  return int_shifts

def compute_min_max(shifts):
  """ Find out the top left up corner, and the right bottom down corner,
  namely the bounds of the new virtual stack to create.
  Expects absolute shifts. """
  minx = Integer.MAX_VALUE
  miny = Integer.MAX_VALUE
  minz = Integer.MAX_VALUE
  maxx = -Integer.MAX_VALUE
  maxy = -Integer.MAX_VALUE
  maxz = -Integer.MAX_VALUE
  for shift in shifts:
    minx = min(minx, shift.x)
    miny = min(miny, shift.y)
    minz = min(minz, shift.z)
    maxx = max(maxx, shift.x)
    maxy = max(maxy, shift.y)
    maxz = max(maxz, shift.z)
  return minx, miny, minz, maxx, maxy, maxz

def zero_pad(num, digits):
  """ for 34, 4 --> '0034' """
  str_num = str(num)
  while (len(str_num) < digits):
    str_num = '0' + str_num
  return str_num

def invert_shifts(shifts):
  """ invert shifts such that they can be used for correction.
  """
  for shift in shifts:
    shift.x *= -1
    shift.y *= -1
    shift.z *= -1
  return shifts

def register_hyperstack(imp, channel, shifts, target_folder, virtual):
  """ Takes the imp, determines the x,y,z drift for each pair of time points, using the preferred given channel,
  and outputs as a hyperstack."""
  # Compute bounds of the new volume,
  # which accounts for all translations:
  minx, miny, minz, maxx, maxy, maxz = compute_min_max(shifts)
  # Make shifts relative to new canvas dimensions
  # so that the min values become 0,0,0
  for shift in shifts:
    shift.x -= minx
    shift.y -= miny
    shift.z -= minz
  #print "shifts relative to new dimensions:"
  #for s in shifts:
  #  print s.x, s.y, s.z
  # new canvas dimensions:r
  width = imp.width + maxx - minx
  height = maxy - miny + imp.height
  slices = maxz - minz + imp.getNSlices()

  print "New dimensions:", width, height, slices
  # Prepare empty slice to pad in Z when necessary
  empty = imp.getProcessor().createProcessor(width, height)

  # if it's RGB, fill the empty slice with blackness
  if isinstance(empty, ColorProcessor):
    empty.setValue(0)
    empty.fill()
  # Write all slices to files:
  stack = imp.getStack()

  if virtual is False:
    registeredstack = ImageStack(width, height, imp.getProcessor().getColorModel())
  names = []

  for frame in range(1, imp.getNFrames()+1):

    shift = shifts[frame-1]

    print "frame",frame,"correcting drift",-shift.x-minx,-shift.y-miny,-shift.z-minz
    IJ.log("    frame "+str(frame)+" correcting drift "+str(-shift.x-minx)+","+str(-shift.y-miny)+","+str(-shift.z-minz))

    fr = "t" + zero_pad(frame, len(str(imp.getNFrames())))
    # Pad with empty slices before reaching the first slice
    for s in range(shift.z):
      ss = "_z" + zero_pad(s + 1, len(str(slices))) # slices start at 1
      for ch in range(1, imp.getNChannels()+1):
        name = fr + ss + "_c" + zero_pad(ch, len(str(imp.getNChannels()))) +".tif"
        names.append(name)

        if virtual is True:
          currentslice = ImagePlus("", empty)
          currentslice.setCalibration(imp.getCalibration().copy())
          currentslice.setProperty("Info", imp.getProperty("Info"))
          FileSaver(currentslice).saveAsTiff(target_folder + "/" + name)
        else:
          empty = imp.getProcessor().createProcessor(width, height)
          registeredstack.addSlice(str(name), empty)

    # Add all proper slices
    stack = imp.getStack()
    for s in range(1, imp.getNSlices()+1):
      ss = "_z" + zero_pad(s + shift.z, len(str(slices)))
      for ch in range(1, imp.getNChannels()+1):
         ip = stack.getProcessor(imp.getStackIndex(ch, s, frame))
         ip2 = ip.createProcessor(width, height) # potentially larger
         ip2.insert(ip, shift.x, shift.y)
         name = fr + ss + "_c" + zero_pad(ch, len(str(imp.getNChannels()))) +".tif"
         names.append(name)

         if virtual is True:
           currentslice = ImagePlus("", ip2)
           currentslice.setCalibration(imp.getCalibration().copy())
           currentslice.setProperty("Info", imp.getProperty("Info"));
           FileSaver(currentslice).saveAsTiff(target_folder + "/" + name)
         else:
           registeredstack.addSlice(str(name), ip2)

    # Pad the end
    for s in range(shift.z + imp.getNSlices(), slices):
      ss = "_z" + zero_pad(s + 1, len(str(slices)))
      for ch in range(1, imp.getNChannels()+1):
        name = fr + ss + "_c" + zero_pad(ch, len(str(imp.getNChannels()))) +".tif"
        names.append(name)

        if virtual is True:
          currentslice = ImagePlus("", empty)
          currentslice.setCalibration(imp.getCalibration().copy())
          currentslice.setProperty("Info", imp.getProperty("Info"))
          FileSaver(currentslice).saveAsTiff(target_folder + "/" + name)
        else:
          registeredstack.addSlice(str(name), empty)

  if virtual is True:
      # Create virtual hyper stack with the result
      registeredstack = VirtualStack(width, height, None, target_folder)
      for name in names:
        registeredstack.addSlice(name)
      registeredstack_imp = ImagePlus("registered time points", registeredstack)
      registeredstack_imp.setDimensions(imp.getNChannels(), len(names) / (imp.getNChannels() * imp.getNFrames()), imp.getNFrames())
      registeredstack_imp.setCalibration(imp.getCalibration().copy())
      registeredstack_imp.setOpenAsHyperStack(True)
  else:
    registeredstack_imp = ImagePlus("registered time points", registeredstack)
    registeredstack_imp.setCalibration(imp.getCalibration().copy())
    registeredstack_imp.setProperty("Info", imp.getProperty("Info"))
    registeredstack_imp.setDimensions(imp.getNChannels(), len(names) / (imp.getNChannels() * imp.getNFrames()), imp.getNFrames())
    registeredstack_imp.setOpenAsHyperStack(True)
    if 1 == registeredstack_imp.getNChannels():
      return registeredstack_imp
  #IJ.log("\nHyperstack dimensions: time frames:" + str(registeredstack_imp.getNFrames()) + ", slices: " + str(registeredstack_imp.getNSlices()) + ", channels: " + str(registeredstack_imp.getNChannels()))

  # Else, as composite
  mode = CompositeImage.COLOR;
  if isinstance(imp, CompositeImage):
    mode = imp.getMode()
  else:
    return registeredstack_imp
  return CompositeImage(registeredstack_imp, mode)

def register_hyperstack_subpixel(imp, channel, shifts, target_folder, virtual):
  """ Takes the imp, determines the x,y,z drift for each pair of time points, using the preferred given channel,
  and outputs as a hyperstack.
  The shifted image is computed using TransformJ allowing for sub-pixel shifts using interpolation.
  This is quite a bit slower than just shifting the image by full pixels as done in above function register_hyperstack().
  However it significantly improves the result by removing pixel jitter.
  """
  # Compute bounds of the new volume,
  # which accounts for all translations:
  minx, miny, minz, maxx, maxy, maxz = compute_min_max(shifts)
  # Make shifts relative to new canvas dimensions
  # so that the min values become 0,0,0
  for shift in shifts:
    shift.x -= minx
    shift.y -= miny
    shift.z -= minz
  # new canvas dimensions:
  width = int(imp.width + maxx - minx)
  height = int(maxy - miny + imp.height)
  slices = int(maxz - minz + imp.getNSlices())

  print "New dimensions:", width, height, slices

  # prepare stack for final results
  stack = imp.getStack()
  if virtual is True:
    names = []
  else:
    registeredstack = ImageStack(width, height, imp.getProcessor().getColorModel())

  # prepare empty slice for padding
  empty = imp.getProcessor().createProcessor(width, height)

  IJ.showProgress(0)

  # get raw data as stack
  stack = imp.getStack()

  # loop across frames
  for frame in range(1, imp.getNFrames()+1):

    IJ.showProgress(frame / float(imp.getNFrames()+1))
    fr = "t" + zero_pad(frame, len(str(imp.getNFrames()))) # for saving files in a virtual stack

    # get and report current shift
    shift = shifts[frame-1]
    print "frame",frame,"correcting drift",-shift.x-minx,-shift.y-miny,-shift.z-minz
    IJ.log("    frame "+str(frame)+" correcting drift "+str(round(-shift.x-minx,2))+","+str(round(-shift.y-miny,2))+","+str(round(-shift.z-minz,2)))

    # loop across channels
    for ch in range(1, imp.getNChannels()+1):

      tmpstack = ImageStack(width, height, imp.getProcessor().getColorModel())

      # get all slices of this channel and frame
      for s in range(1, imp.getNSlices()+1):
        ip = stack.getProcessor(imp.getStackIndex(ch, s, frame))
        ip2 = ip.createProcessor(width, height) # potentially larger
        ip2.insert(ip, 0, 0)
        tmpstack.addSlice("", ip2)

      # Pad the end (in z) of this channel and frame
      for s in range(imp.getNSlices(), slices):
        tmpstack.addSlice("", empty)

      # subpixel translation
      imp_tmpstack = ImagePlus("", tmpstack)
      imp_translated = translate_single_stack_using_imglib2(imp_tmpstack, shift.x, shift.y, shift.z)

      # Add translated frame to final time-series
      translated_stack = imp_translated.getStack()
      for s in range(1, translated_stack.getSize()+1):
        ss = "_z" + zero_pad(s, len(str(slices)))
        ip = translated_stack.getProcessor(s).duplicate() # duplicate is important as otherwise it will only be a reference that can change its content
        if virtual is True:
          name = fr + ss + "_c" + zero_pad(ch, len(str(imp.getNChannels()))) +".tif"
          names.append(name)
          currentslice = ImagePlus("", ip)
          currentslice.setCalibration(imp.getCalibration().copy())
          currentslice.setProperty("Info", imp.getProperty("Info"));
          FileSaver(currentslice).saveAsTiff(target_folder + "/" + name)
        else:
          registeredstack.addSlice("", ip)

  IJ.showProgress(1)

  if virtual is True:
    # Create virtual hyper stack with the result
    registeredstack = VirtualStack(width, height, None, target_folder)
    for name in names:
      registeredstack.addSlice(name)
    registeredstack_imp = ImagePlus("registered time points", registeredstack)
    registeredstack_imp.setDimensions(imp.getNChannels(), slices, imp.getNFrames())
    registeredstack_imp.setCalibration(imp.getCalibration().copy())
    registeredstack_imp.setOpenAsHyperStack(True)
  else:
    registeredstack_imp = ImagePlus("registered time points", registeredstack)
    registeredstack_imp.setCalibration(imp.getCalibration().copy())
    registeredstack_imp.setProperty("Info", imp.getProperty("Info"))
    registeredstack_imp.setDimensions(imp.getNChannels(), slices, imp.getNFrames())
    registeredstack_imp.setOpenAsHyperStack(True)
    if 1 == registeredstack_imp.getNChannels():
      return registeredstack_imp

  #IJ.log("\nHyperstack dimensions: time frames:" + str(registeredstack_imp.getNFrames()) + ", slices: " + str(registeredstack_imp.getNSlices()) + ", channels: " + str(registeredstack_imp.getNChannels()))

  # Else, as composite
  mode = CompositeImage.COLOR;
  if isinstance(imp, CompositeImage):
    mode = imp.getMode()
  else:
    return registeredstack_imp
  return CompositeImage(registeredstack_imp, mode)

class Filter(FilenameFilter):
  def accept(self, folder, name):
    return not File(folder.getAbsolutePath() + "/" + name).isHidden()

def validate(target_folder):
  f = File(target_folder)
  if len(File(target_folder).list(Filter())) > 0:
    yn = YesNoCancelDialog(IJ.getInstance(), "Warning!", "Target folder is not empty! May overwrite files! Continue?")
    if yn.yesPressed():
      return True
    else:
      return False
  return True

def getOptions(imp):
  gd = GenericDialog("Correct 3D Drift Options")
  channels = []
  for ch in range(1, imp.getNChannels()+1 ):
    channels.append(str(ch))
  gd.addChoice("Channel for registration:", channels, channels[0])
  gd.addCheckbox("Multi_time_scale computation for enhanced detection of slow drifts?", False)
  gd.addCheckbox("Sub_pixel drift correction (possibly needed for slow drifts)?", False)
  gd.addCheckbox("Edge_enhance images for possibly improved drift detection?", False)
  gd.addCheckbox("Use virtualstack for saving the results to disk to save RAM?", False)
  gd.addMessage("If you put a ROI, drift will only be computed in this region;\n the ROI will be moved along with the drift to follow your structure of interest.")
  gd.showDialog()
  if gd.wasCanceled():
    return
  channel = gd.getNextChoiceIndex() + 1  # zero-based
  multi_time_scale = gd.getNextBoolean()
  subpixel = gd.getNextBoolean()
  process = gd.getNextBoolean()
  virtual = gd.getNextBoolean()
  dt = gd.getNextNumber()
  return channel, virtual, multi_time_scale, subpixel, process

# Need function to get colors for each channel. Loop channels extracting color model and then apply to registered

def run():

  IJ.log("Correct_3D_Drift")

  imp = IJ.getImage()
  if imp is None:
    return
  #if not imp.isHyperStack():
  #  print "Not a hyper stack!"
  #  return
  if 1 == imp.getNFrames():
    print "There is only one time frame!"
    return
  #if 1 == imp.getNSlices():
  #  print "To register slices of a stack, use 'Register Virtual Stack Slices'"
  #  return

  options = getOptions(imp)
  if options is not None:
    channel, virtual, multi_time_scale, subpixel, process = options
    print "channel="+str(channel)
    print "multi_time_scale="+str(multi_time_scale)
    print "virtual="+str(virtual)
    print "process="+str(process)

  if virtual is True:
    dc = DirectoryChooser("Choose target folder to save image sequence")
    target_folder = dc.getDirectory()
    if target_folder is None:
      return # user canceled the dialog
    if not validate(target_folder):
      return
  else:
    target_folder = None

  # compute shifts
  IJ.log("  computing drifts..."); print("\nCOMPUTING SHIFTS:")

  IJ.log("    at frame shifts of 1");
  dt = 1; shifts = compute_and_update_frame_translations_dt(imp, channel, dt, process)

  # multi-time-scale computation
  if multi_time_scale is True:
    dt_max = imp.getNFrames()-1
    # computing drifts on exponentially increasing time scales 3^i up to 3^6
    # ..one could also do this with 2^i or 4^i
    # ..maybe make this a user choice? did not do this to keep it simple.
    dts = [3,9,27,81,243,729,dt_max]
    for dt in dts:
      if dt < dt_max:
        IJ.log("    at frame shifts of "+str(dt))
        shifts = compute_and_update_frame_translations_dt(imp, channel, dt, process, shifts)
      else:
        IJ.log("    at frame shifts of "+str(dt_max));
        shifts = compute_and_update_frame_translations_dt(imp, channel, dt_max, process, shifts)
        break

  # invert measured shifts to make them the correction
  shifts = invert_shifts(shifts)

  # apply shifts
  IJ.log("  applying shifts..."); print("\nAPPLYING SHIFTS:")
  if subpixel:
    registered_imp = register_hyperstack_subpixel(imp, channel, shifts, target_folder, virtual)
  else:
    shifts = convert_shifts_to_integer(shifts)
    registered_imp = register_hyperstack(imp, channel, shifts, target_folder, virtual)

  if virtual is True:
    if 1 == imp.getNChannels():
      ip=imp.getProcessor()
      ip2=registered_imp.getProcessor()
      ip2.setColorModel(ip.getCurrentColorModel())
      registered_imp.show()
    else:
      registered_imp.copyLuts(imp)
      registered_imp.show()
  else:
    if 1 ==imp.getNChannels():
      registered_imp.show()
    else:
      registered_imp.copyLuts(imp)
      registered_imp.show()

  registered_imp.show()

run()
