from ij import IJ, ImagePlus, WindowManager
from ij.plugin import ImageCalculator
import os
from ij.process import FloatProcessor
from ij.io import OpenDialog


path = 'images\\loaner_calib\\'
calib_files = os.listdir(path)

print calib_files

for img in calib_files:
    if img[0] == 'a':
        a = IJ.openImage(path + img)
    elif img[0] == 'n':
        b = IJ.openImage(path + img)
    elif img[0] == 'y':
        y0 = IJ.openImage(path + img)
    elif img[0] == 'm':
        m = IJ.openImage(path + img)
    else:
        raise Exception('Calibration data not found')


def do_correction(image_path, chans='3'):
    img = IJ.openImage(image_path)
    # img.show()
    calc = ImageCalculator()

    img_stack = img.getStack().duplicate()
    img_stack.deleteLastSlice()
    x = ImagePlus('x', img_stack)
    print 'Doing m*x'
    mx = calc.run("Multiply create 32-bit stack", x, m)
    # mx.show()
    print 'Doing m*x + y0 (+ a)'
    calc.run("Add 32-bit stack", mx, y0)
    print 'Doing b*x'
    bx = calc.run("Multiply create 32-bit stack", x, b)
    x.close()
    IJ.freeMemory()
    # x.show()
    print 'Doing exp'
    WindowManager.setTempCurrentImage(bx)
    IJ.run("Exp", "stack")
    print 'Doing a*exp(b*x)'
    calc.run("Multiply 32-bit stack", bx, a)
    # bx.show()
    print 'Doing mx + y0 + a - a*exp(bx)'
    mx = calc.run("Subtract create 32-bit stack", mx, bx)
    bx.close()
    IJ.freeMemory()
    # mx.show()
    print 'Doing add blank'
    corr_stack = mx.getStack()
    blank = FloatProcessor(mx.getWidth(), mx.getHeight())
    corr_stack.addSlice('', blank, 0)
    mx.setStack(corr_stack)
    mx.show()
    print 'Doing correction'
    calc.run("Subtract stack", img, mx)
    mx.close()
    IJ.freeMemory()
    img.show()
    # IJ.run("Deinterleave", "how=" + chans)
    # imgs = WindowManager.getImageTitles()
    # for i in xrange(int(chans)):
    #     for title in imgs:
    #         if '#' + str(i + 1) in title:
    #             WindowManager.setTempCurrentImage(
    #                 WindowManager.getImage(title))
    #             IJ.saveAsTiff(WindowManager.getImage(title),
    #                           image_path.rstrip('.tif') + '_' +
    #                           str(i + 1) + '.tif')
    #             WindowManager.getTempCurrentImage().close()

    IJ.freeMemory()
    print '\nDone\n'


od = OpenDialog('Select Image')
img_path = od.getPath()
do_correction(img_path)
