from ij import WindowManager, IJ
from ij.gui import GenericDialog
from ij.plugin.frame import RoiManager, Fitter
from ij.measure import Measurements, CurveFitter


def trimStack(image, start):
    stack = image.getStack()
    for i in xrange(1, start):
        stack.deleteSlice(1)
    return stack

IJ.run('Split Channels')
imageList = list(WindowManager.getImageTitles())

g = GenericDialog('Select Green and Red images')
g.addChoice('Green:', imageList, imageList[0])
g.addChoice('Red:', imageList, imageList[1])
g.addStringField('Fit start image #:', '7')
g.addChoice('Fit type', ['Exponential', 'Linear'], 'Exponential')
g.centerDialog(True)
g.showDialog()

if g.wasOKed():

    greenName = g.getNextChoice()
    redName = g.getNextChoice()
    fitType = g.getNextChoice()
    start = int(g.getStringFields()[0].getText())
    green = WindowManager.getImage(greenName)
    red = WindowManager.getImage(redName)
    roi = RoiManager.getInstance().getRoisAsArray()[-1]
    gStack = trimStack(green, start)
    green.setStack(gStack)
    rStack = trimStack(red, start)
    red.setStack(rStack)
    green.setTitle('Green')
    red.setTitle('Red')
    green.updateAndDraw()
    red.updateAndDraw()
    green.setRoi(roi)
    red.setRoi(roi)

    IJ.runMacroFile("Auto-Fusion_macro.ijm")

    result = WindowManager.getImage("Result of Green")
    result.setRoi(roi)
    resultIP = result.getProcessor()
    resultIP.setThreshold(127, 255, 0)

    measurements = Measurements.AREA + Measurements.LIMIT
    if result.getNFrames() > 1:
        frames = result.getNFrames()
    else:
        frames = result.getNSlices()

    areas = []
    for i in xrange(1, frames + 1):
        result.setSliceWithoutUpdate(i)
        areas.append(result.getStatistics(measurements).area)

    interval = green.getCalibration().frameInterval
    if interval > 0:
        times = [interval * x for x in range(len(areas))]
    else:
        times = range(len(areas))
    fitter = CurveFitter(times, areas)
    if fitType == 'Exponential':
        fitType = 4
    elif fitType == 'Linear':
        fitType = 0
    fitter.doFit(fitType)
    # print fitter.getParams()
    print fitter.getFormula()
    print fitter.getParams()[0]
    print fitter.getParams()[1]
    print ''
    fitPlot = Fitter()
    fitPlot.plot(fitter)
    fitPlot.close()

else:
    print 'Cancelled'
