#@INPUT  File(label="Working Directory", style="directory", description="Select the directory for the experiment to analyze. It should contain red and/or green channel and cell mask images in a subdirectory for each plate.") expdir
#@INPUT  String(label="Channels", choices={"Red Channel Only", "Green Channel Only", "Red and Green Channels"}, style="listBox", description="Select which channel(s) to analyze.") channels
#@INPUT  String(label="Red Channel Prefix", value="RC", description="The prefix that was used to mark the red channel images when exporting them from the IncuCyte software.") rcpre
#@INPUT  String(label="Green Channel Prefix", value="GC", description="The prefix that was used to mark the green channel images when exporting them from the IncuCyte software.") gcpre
#@INPUT  String(label="Cell Mask Prefix", value="CM", description="The prefix that was used to mark the cell mask images when exporting them from the IncuCyte software.") cmpre
#@       OpService ops
#@       DatasetIOService ds
#@       LogService log

import os
import csv
from datetime import datetime
from net.imglib2.util import Intervals
from net.imglib2.type.numeric.integer import UnsignedByteType

RED_CHANNEL="Red Channel Only"
GREEN_CHANNEL="Green Channel Only"

imext = ".tif"

rcpre = rcpre + "_"
gcpre = gcpre + "_"
cmpre = cmpre + "_"

expdir = str(expdir)

if not (os.path.exists(expdir) and os.path.isdir(expdir)):
    raise Exception("Working directory %s does not exist" % expdir)


with open(os.path.join(expdir, os.path.basename(expdir) + ".csv"), "w") as csvfile:

    output = csv.writer(csvfile)

    output.writerow(["experiment", "plate", "well", "field", "timepoint", "channel", "tcf", "mcf", "cma", "ima", "confluence"])

    for root, dirs, files in os.walk(expdir):
        if not dirs:
            cms = []
            rcs = []
            gcs = []
            for f in files:
                if f.endswith(imext):
                    if f.startswith(cmpre):
                        cms.append(f[len(cmpre):-len(imext)])
                    if f.startswith(rcpre):
                        rcs.append(f[len(rcpre):-len(imext)])
                    if f.startswith(gcpre):
                        gcs.append(f[len(gcpre):-len(imext)])

            cms = sorted(set(cms))
            rcs = sorted(set(rcs))
            gcs = sorted(set(gcs))

            if channels == RED_CHANNEL:
                mismatch = cms != rcs
            elif channels == GREEN_CHANNEL:
                mismatch = cms != gcs
            else:
                mismatch = (cms != rcs) or (cms != gcs)

            if mismatch: raise Exception("No all necessary images found in %s" % root)

            bns = cms

            if len(bns) < 1:
                raise Exception("No images found in %s" % root)

            for bn in bns:
                try:
                    info = bn.split("_")
                    experiment = os.path.basename(expdir)
                    plate = os.path.basename(root)
                    timepoint = datetime.strptime(info[2]+info[3], "%Yy%mm%dd%Hh%Mm").isoformat()
                    well = info[0][0] + info[0][1:].zfill(2)
                    field = info[1]

                    if channels == RED_CHANNEL:
                        chpres = [rcpre]
                    elif channels == GREEN_CHANNEL:
                        chpres = [gcpre]
                    else:
                        chpres = [rcpre, gcpre]

                    for chpre in chpres:
                        if chpre == rcpre:
                            channel = "R"
                        elif chpre == gcpre:
                            channel = "G"

                        data = ds.open(os.path.join(expdir, root, chpre + bn + imext))
                        mask = ds.open(os.path.join(expdir, root, cmpre + bn + imext))

                        if not Intervals.equalDimensions(data, mask):
                            raise Exception("Size of images do not match.")

                        if (data.getTypeLabelShort() != "32-bit float"):
                            raise Exception("Red or green channels should be exported as 32bit TIFF images.")
                        if mask.getTypeLabelShort() != "8-bit uint":
                            raise Exception("Cell mask should be exported as an 8bit TIFF image.")

                        mask = ops.run("threshold.apply", mask, UnsignedByteType(1))
                        mask = ops.run("convert.float32", mask)

                        data = ops.run("math.multiply", mask, data)

                        tcf = ops.run("stats.sum", data).get() # integrated intensity of pixels in cell mask
                        cma = ops.run("stats.sum", mask).get() # cell mask area (in pixels)
                        mcf = tcf / cma # # average intensity of pixels in cell mask
                        ima = mask.getWidth() * mask.getHeight() # image area (in pixels)
                        confluence = cma / ima

                        log.info("experiment = %s, plate = %s, well = %s, field = %s, timepoint = %s, channel = %s, tcf = %s, mcf = %s, cma = %s, ima = %s, confluence = %s" % (experiment, plate, well, field, timepoint, channel, tcf, mcf, cma, ima, confluence))

                        output.writerow([experiment, plate, well, field, timepoint, channel, tcf, mcf, cma, ima, confluence])
                    except:
                        log.error("experiment = %s, plate = %s, well = %s, field = %s, timepoint = %s, channel = %s" % (experiment, plate, well, field, timepoint, channel))
                        continue
