import os
from ij import IJ, WindowManager
from ij.io import OpenDialog
from ij.gui import GenericDialog


dialog = OpenDialog("Choose a file in the directory")
dialog.setDefaultDirectory(dialog.getLastDirectory())
path = dialog.getDirectory()
if ' ' in path:
    print '\n***Please remove spaces from the file path:***\n***' \
        + path + '***'
else:

    # print path

    try:
        os.mkdir(path + 'stitched')
    except OSError:
        pass

    gd = GenericDialog('Dimensions')
    gd.addStringField('Rows:', '4')
    gd.addStringField('Columns:', '4')
    gd.addStringField('Channels', '2')
    gd.showDialog()

    if gd.wasOKed():
        dimensions = []
        for i in xrange(len(gd.getStringFields())):
            dimensions.append(gd.getNextString())
        dimensions = map(lambda x: int(x), dimensions)
        rows, cols, channels = dimensions[0], dimensions[1], dimensions[2]
        area = rows * cols
        digits = len(str(area))
        timepoints = []
        # print rows, cols, channels
        # print area, digits

        filenames = os.listdir(path)
        base_name = None
        # print filenames
        for i in xrange(len(filenames)):
            if filenames[i].endswith('.scan'):
                if not base_name:
                    base_name = filenames[i].rstrip('.scan')
                    j = len(base_name) - 1
                    while base_name[j].isdigit():
                        j -= 1
                    base_name = base_name[:j + 1]
                timepoints.append(filenames[i].rstrip(
                    '.scan')[len(base_name):])
            if '_s' in filenames[i] and filenames[i].endswith('TIF'):
                # parse filenames from MetaMorph and add 0s as necs
                split_name = filenames[i].split('_s')
                file_digits = split_name[1].find('_')
                if file_digits < digits:
                    new_name = (split_name[0] + '_s' +
                                '0' * (digits - file_digits) +
                                split_name[1])
                    # print new_name
                    os.rename(path + filenames[i], path + new_name)
                elif split_name[1].startswith(str(area + 1)):
                    if 'Thumb' not in filenames[i]:
                        os.rename(path + filenames[i],
                                  path + 'Thumb_' + filenames[i])
        timepoints = map(lambda x: int(x), timepoints)
        timepoints.sort()
        timepoints = map(lambda x: str(x), timepoints)
        print 'Timepoints:', timepoints
        print 'Base name:', base_name

        for channel in xrange(1, channels + 1):
            stack = None
            stack_IP = None
            for time in timepoints:
                # for time in xrange(1, 2):
                #     for channel in xrange(1, 2):

                base_time_chan = (base_name + time + '_w' + str(channel) +
                                  '_s{' + 'i' * digits + '}_t1.TIF')

                macro_str = ['type=[Grid: column-by-column] ',
                             'order=[Down & Right                ] ',
                             'grid_size_x=', str(cols),
                             ' grid_size_y=', str(rows), ' tile_overlap=10 ',
                             'first_file_index_i=1 directory=', path[:-1],
                             ' file_names=', base_time_chan,
                             ' output_textfile_name=stitched\\TileConfig.txt ',
                             'fusion_method=[Linear Blending] ',
                             'regression_threshold=0.30 ',
                             'max/avg_displacement_threshold=2.50 ',
                             'absolute_displacement_threshold=3.50 ',
                             'computation_parameters=[Save computation time ',
                             '(but use more RAM)] ',
                             'image_output=[Fuse and display]'
                             ]

                macro_str = ''.join(macro_str)
                # print macro_str
                IJ.run("Grid/Collection stitching", macro_str)
                image = WindowManager.getCurrentImage()
                if stack:
                    stack.addSlice(image.getProcessor())
                    image.getWindow().close()
                else:
                    stack_IP = image
                    stack = image.getStack()

            stack_IP.setStack(stack)
