#@INPUT  File(label="Working Directory", style="directory", description="Select the directory for the experiment to analyze. It should contain red and/or green channel and cell mask images in a subdirectory for each plate.") expdir
#@INPUT  String(label="Channels", choices={"Red Channel Only", "Green Channel Only", "Red and Green Channels"}, style="listBox", description="Select which channel(s) to analyze.") channels
#@INPUT  String(label="Red Channel Prefix", value="RC_", description="The prefix that was used to mark the red channel images when exporting them from the IncuCyte software.") rcpre
#@INPUT  String(label="Green Channel Prefix", value="GC_", description="The prefix that was used to mark the green channel images when exporting them from the IncuCyte software.") gcpre
#@INPUT  String(label="Cell Mask Prefix", value="CM_", description="The prefix that was used to mark the cell mask images when exporting them from the IncuCyte software.") cmpre
#@       OpService ops
#@       DatasetIOService ds
#@       LogService log

import os
import csv
from datetime import datetime
from net.imglib2.util import Intervals
from net.imglib2.type.numeric.integer import UnsignedByteType

RED_CHANNEL="Red Channel Only"
GREEN_CHANNEL="Green Channel Only"

imext = ".tif"

expdir = str(expdir)

with open(os.path.join(expdir, os.path.basename(expdir) + ".csv"), "w") as csvfile:

    output = csv.writer(csvfile)

    output.writerow(["experiment", "plate", "well", "field", "timepoint", "channel", "mfi"])

    for root, dirs, files in os.walk(expdir):
        if not dirs:
            cms = []
            rcs = []
            gcs = []
            for f in files:
                if f.endswith(imext):
                    if f.startswith(cmpre):
                        cms.append(f[len(cmpre):-len(imext)])
                    if f.startswith(rcpre):
                        rcs.append(f[len(rcpre):-len(imext)])
                    if f.startswith(gcpre):
                        gcs.append(f[len(gcpre):-len(imext)])

            cms = sorted(set(cms))
            rcs = sorted(set(rcs))
            gcs = sorted(set(gcs))

            if channels == RED_CHANNEL:
                mismatch = cms != rcs
            elif channels == GREEN_CHANNEL:
                mismatch = cms != gcs
            else:
                mismatch = (cms != rcs) or (cms != gcs)

            if mismatch: raise Exception("No all necessary images found in %s" % root)

            bns = cms

            if len(bns) < 1:
                raise Exception("No images found in %s" % root)

            for bn in bns:
                info = bn.split("_")
                experiment = os.path.basename(expdir)
                plate = os.path.basename(root)
                timepoint = datetime.strptime(info[2]+info[3], "%Yy%mm%dd%Hh%Mm").isoformat()
                well = info[0][0] + info[0][1:].zfill(2)
                field = info[1]

                if channels == RED_CHANNEL:
                    chpres = [rcpre]
                elif channels == GREEN_CHANNEL:
                    chpres = [gcpre]
                else:
                    chpres = [rcpre, gcpre]

                for chpre in chpres:
                    if chpre == rcpre:
                        channel = "R"
                    elif chpre == gcpre:
                        channel = "G"

                    data = ds.open(os.path.join(expdir, root, chpre + bn + imext))
                    mask = ds.open(os.path.join(expdir, root, cmpre + bn + imext))

                    if not Intervals.equalDimensions(data, mask):
                        raise Exception("Dimensions of images do not match.")

                    mask = ops.run("threshold.apply", mask, UnsignedByteType(1))
                    mask = ops.run("convert.float32", mask)

                    data = ops.run("math.multiply", mask, data)

                    mfi = ops.run("stats.sum", data).get() / ops.run("stats.sum", mask).get()

                    log.info("experiment = %s, plate = %s, well = %s, field = %s, timepoint = %s, channel = %s, mfi = %s" % (experiment, plate, well, field, timepoint, channel, mfi))

                    output.writerow([experiment, plate, well, field, timepoint, channel, mfi])
