#@ File outputDB(label="Specify your Database (.db-file)")
#@ Boolean Normalize_To_Selection_Area(value="False")
#@ Boolean Normalize_To_Mask_Area(value="False")
#@ Boolean Inside_Mask(value="True")
#@ String Channel_Name(label="Specify the channel you want to explore", style="Text Field")
#@ String Second_Channel(label="Specify the channel you want to explore", style="Text Field")
#@ String m2(label="Choose threshold method used for the second channel", choices={'Default', 'Huang', 'Intermodes', 'IsoData', 'IJ_IsoData', 'Li', 'MaxEntropy', 'Mean', 'MinError', 'Minimum', 'Moments', 'Otsu', 'Percentile', 'RenyiEntropy', 'Shanbhag', 'Triangle', 'Yen'})
#@ String Selection(label="Specify the selection (mask) you want to explore", style="Text Field")
#@ String Group_by(style="Text Field")
#@ String plottype(label="Choose threshold method", choices={"Boxplot", "Barplot"})
#@ String parameter(label="Choose the parameter you want to plot", choices={"Number_of_Particles", "Area", "IntDen", "Mean"})

from __future__ import with_statement
from tech.tablesaw.api import Table
from tech.tablesaw.plotly import Plot
from tech.tablesaw.plotly.api import BoxPlot, Histogram, VerticalBarPlot, ScatterPlot
from tech.tablesaw.aggregate import AggregateFunctions
#from tech.tablesaw.plotting import Plot
from java.sql import DriverManager, SQLException
from org.sqlite import SQLiteConfig, JDBC
from java.lang import Class
import sys
from random import gauss
from java.io import File
from ij.measure import ResultsTable



def getConnection(jdbc_driver, DB_URL):
        """
        Get Connection to DB and returns connection handler
        """

        config = SQLiteConfig()
        config.enforceForeignKeys(True)

        try:
            Class.forName(jdbc_driver).newInstance()
        except Exception, msg:
            print msg
            sys.exit(-1)
        try:
            dbConn = DriverManager.getConnection(DB_URL, config.toProperties())
        except SQLException, msg:
            print msg
            sys.exit(-1)
        return dbConn

def getArrayAsString(tab, colname):
    return tab.column(colname).print().split("\n")[1:-1]

##############################################################################################################
####### Start of the script
##############################################################################################################

pa = "Coloc_Analysis"

if Inside_Mask:
    inorout = "Inside"
else:
    inorout = "Outside"

    
pathDB = outputDB.getAbsolutePath()
DB_URL = "jdbc:sqlite:" + pathDB
jdbc_driver = "org.sqlite.JDBC"

dbConn = getConnection(jdbc_driver, DB_URL)

if not Selection:
    Selection = "allSelected"

try: 
    stmt = dbConn.createStatement()
    if parameter == "Number_of_Particles":
        query = "SELECT * FROM %s_Table WHERE Channel_Name = '%s' and Selection = '%s' and Second_Channel = '%s' and INorOUT = '%s' and Method2='%s'" %(pa, Channel_Name, Selection, Second_Channel, inorout, m2)
    else:
        query = """SELECT 
        Coloc_Analysis_Table.COLOC_ID as COLOC_ID,
        Coloc_Analysis_Table.Channel_Name as Channel_Name,
        Coloc_Analysis_Table.Selection as Selection,
        Coloc_Analysis_Table.Method2 as Method2,
        Coloc_Analysis_Table.%s,
        Coloc_Measurement_Tables.%s
        FROM Coloc_Analysis_Table INNER JOIN Coloc_Measurement_Tables ON Coloc_Measurement_Tables.COLOC_ID = Coloc_Analysis_Table.COLOC_ID
        WHERE Coloc_Analysis_Table.Channel_Name = '%s' and Coloc_Analysis_Table.Selection = '%s' and Coloc_Analysis_Table.Second_Channel = '%s' and Coloc_Analysis_Table.INorOUT = '%s' and Coloc_Analysis_Table.Method2='%s'"""% (Group_by, parameter, Channel_Name, Selection, Second_Channel, inorout, m2)
    try:
        results = stmt.executeQuery(query)
        print DB_URL
        tab = Table.read().db(results, "Colocalisation Analysis Table")
    except SQLException, msg:
        print msg
        sys.exit(-1)
except SQLException, msg:
    print msg
    sys.exit(-1)

if parameter == "Number_of_Particles":
    print parameter
    if Normalize_To_Selection_Area:
        paColumn = tab.column("Number_of_Particles").asDoubleColumn()
        areaColumn = tab.column("Selection_Area")
        
        normColumn = paColumn.divide(areaColumn)
        normColumn.setName("Norm")
        tab.addColumns([normColumn])
    
        norm = "Norm"
    
    elif Normalize_To_Mask_Area:
        paColumn = tab.column("Number_of_Particles").asDoubleColumn()
        areaColumn = tab.column("Mask_Area")
        
        normColumn = paColumn.divide(areaColumn)
        normColumn.setName("Norm")
        tab.addColumns([normColumn])
        norm = "Norm"
    else:
        norm = "Number_of_Particles"

    
    if plottype == "Boxplot":
        bx = BoxPlot.create("BoxPlot of Particle Analysis Results of channel %s, selection %s" %(Channel_Name, Selection), tab, Group_by, norm)
        codeString1 = bx.asJavascript("target")
        
        pa_id = getArrayAsString(tab, "COLOC_ID")
        cond = getArrayAsString(tab, Group_by)
        pa_id = str([str(c + "_" + p) for p,c in zip(pa_id, cond)])
        codeString1 = codeString1.replace("\ntype: 'box',\nname: '',\n};\n\n", "\ntype: 'box',\nname: '',\n jitter:1,\npointpos: 0,\nboxpoints: 'all',\ntext: %s,\n};\n\n"%pa_id)
        codeString1 = "target' ></div>" + codeString1
        Plot.show(bx, codeString1)

        colnames = tab.columnNames()
        rt = ResultsTable()
        for c in colnames:
            ar = tab.column(c).asObjectArray()
            for i, a in enumerate(ar):
                rt.setValue(c,i,a)
        
        
        rt.show("ResultTable of Query")
        
    else:
        summary = tab.summarize(norm, AggregateFunctions.mean, AggregateFunctions.stdDev).by(Group_by)
        print summary
        bx = VerticalBarPlot.create("Bar", summary, Group_by, "Mean [%s]"%norm)

        codeString1 = bx.asJavascript("target")
        error_string =  getArrayAsString(summary, "Std. Deviation [%s]"%norm)
        error_string = str([float(e) for e in error_string])

        xName = getArrayAsString(summary, Group_by)

        xName2 = getArrayAsString(tab, Group_by)

        xName = [str(x) for x in xName]
        xName2 = [str(x) for x in xName2]

        dicName = {}
        l = 1
        for x in xName:
            if not x in dicName:
                dicName[x] = l
                l += 1

        xName2 = [float(gauss(dicName[x], 0.15)) for x in xName2]
        xName3 =  [str(dicName[x]) for x in xName]
        print dicName
        y = getArrayAsString(tab, norm)

        y = [float(p) for p in y]
        
        scatter = ScatterPlot.create("Hello", Group_by, xName2, norm, y)

        codeString = scatter.asJavascript("target")
        codeString = codeString.split("\n")

        l = []
        b = False
        for c in codeString:
            if not b:
                if c == "var trace0 =":
                    b = True

                    #l.append(c)
            if c == "};":
                b = False
                l.append(c + "\n")
            if b:
                l.append(c)
        l = "\n".join(l)
        l = l.replace("};\n\nvar trace0 =", "var trace1 =")
        
        pa_id = getArrayAsString(tab, "COLOC_ID")
        cond = getArrayAsString(tab, Group_by)
        pa_id = str([str(c + "_" + p) for p,c in zip(pa_id, cond)])
        
        l = l.replace("'scatter',\nname: '',\n};\n", "'scatter',\nname: '',\ntext: %s,\n};\n\n"%pa_id)
        print [l]

        error = "error_y: {type: 'data',\narray: %s,\nvisible: true\n}" %error_string 
        
        codeString1 = codeString1.replace("type: 'bar',\nname: '',\n};\n\n", "\ntype: 'bar',\nname: '',\n %s,\n}\n\n%s" %(error, l))
        codeString1 = codeString1.replace("var data = [ trace0];", "var data = [ trace0, trace1];")
        codeString1 = codeString1.replace(str(xName).replace(", ", ","),str(xName3).replace(", ", ","))
        codeString1 = codeString1.replace("\nvisible: true\n},\n}\n\nvar trace1", "\nvisible: true\n},\ntext: %s,\n}\n\nvar trace1"%xName)
        
        
        codeString1 = "target' ></div>" + codeString1
        codeString1 = codeString1.replace("var layout = {\n    title: 'Bar',\n    height: 700,\n    width: 900,\n\n};", "var layout = {\n    title: 'Bar',\n    height: 700,\n    width: 900,\nxaxis: {tickvals:%s,ticktext : %s, showticks: true},\n\n};"%(xName3, xName))
        print codeString1
        codeString1.replace("</script>", "</script")
        Plot.show(bx, codeString1)
    
else:
    norm = "Mean [%s]"%parameter
    summary = tab.summarize(parameter, AggregateFunctions.mean).by(Group_by, "COLOC_ID")

    print summary
    
    if plottype == "Boxplot":
        bx =  BoxPlot.create("BoxPlot of Particle Analysis Results of channel %s, selection %s" %(Channel_Name, Selection), summary, Group_by, norm)
        codeString1 = bx.asJavascript("target")

        codeString1 = codeString1.replace("\ntype: 'box',\nname: '',\n};\n\n", "\ntype: 'box',\nname: '',\n jitter:1,\npointpos: 0,\nboxpoints: 'all',\n};\n\n")
        
        codeString1 = "target' ></div>" + codeString1
        
        print codeString1
        
        Plot.show(bx, codeString1)

        bx = BoxPlot.create("BoxPlot of Particle Analysis Results of channel %s, selection %s" %(Channel_Name, Selection), summary, Group_by, norm)
        codeString1 = bx.asJavascript("target")
        
        pa_id = getArrayAsString(summary, "COLOC_ID")
        cond = getArrayAsString(summary, Group_by)
        pa_id = str([str(c + "_" + p) for p,c in zip(pa_id, cond)])
        codeString1 = codeString1.replace("\ntype: 'box',\nname: '',\n};\n\n", "\ntype: 'box',\nname: '',\n jitter:1,\npointpos: 0,\nboxpoints: 'all',\ntext: %s,\n};\n\n"%pa_id)
        codeString1 = "target' ></div>" + codeString1
        codeString1.replace("</script>", "</script")
        Plot.show(bx, codeString1)

    elif plottype == "Barplot":
        summary2 = summary
        summary = summary2.summarize(norm, AggregateFunctions.mean, AggregateFunctions.stdDev).by(Group_by)
        print summary
        bx = VerticalBarPlot.create("Bar", summary, Group_by, "Mean [%s]"%norm)

        codeString1 = bx.asJavascript("target")
        error_string =  getArrayAsString(summary, "Std. Deviation [%s]"%norm)
        error_string = str([float(e) for e in error_string])

        xName = getArrayAsString(summary, Group_by)

        xName2 = getArrayAsString(summary2, Group_by)

        xName = [str(x) for x in xName]
        xName2 = [str(x) for x in xName2]

        dicName = {}
        l = 1
        for x in xName:
            if not x in dicName:
                dicName[x] = l
                l += 1

        xName2 = [float(gauss(dicName[x], 0.15)) for x in xName2]
        xName3 =  [str(dicName[x]) for x in xName]
        print dicName
        y = getArrayAsString(summary2, norm)

        y = [float(p) for p in y]
        
        scatter = ScatterPlot.create("Hello", Group_by, xName2, norm, y)

        codeString = scatter.asJavascript("target")
        codeString = codeString.split("\n")

        l = []
        b = False
        for c in codeString:
            if not b:
                if c == "var trace0 =":
                    b = True

                    #l.append(c)
            if c == "};":
                b = False
                l.append(c + "\n")
            if b:
                l.append(c)
        l = "\n".join(l)
        l = l.replace("};\n\nvar trace0 =", "var trace1 =")
        
        pa_id = getArrayAsString(summary2, "COLOC_ID")
        cond = getArrayAsString(summary2, Group_by)
        pa_id = str([str(c + "_" + p) for p,c in zip(pa_id, cond)])
        
        l = l.replace("'scatter',\nname: '',\n};\n", "'scatter',\nname: '',\ntext: %s,\n};\n\n"%pa_id)
        print [l]

        error = "error_y: {type: 'data',\narray: %s,\nvisible: true\n}" %error_string 
        
        codeString1 = codeString1.replace("type: 'bar',\nname: '',\n};\n\n", "\ntype: 'bar',\nname: '',\n %s,\n}\n\n%s" %(error, l))
        codeString1 = codeString1.replace("var data = [ trace0];", "var data = [ trace0, trace1];")
        codeString1 = codeString1.replace(str(xName).replace(", ", ","),str(xName3).replace(", ", ","))
        codeString1 = codeString1.replace("\nvisible: true\n},\n}\n\nvar trace1", "\nvisible: true\n},\ntext: %s,\n}\n\nvar trace1"%xName)
        
        
        codeString1 = "target' ></div>" + codeString1
        codeString1 = codeString1.replace("var layout = {\n    title: 'Bar',\n    height: 700,\n    width: 900,\n\n};", "var layout = {\n    title: 'Bar',\n    height: 700,\n    width: 900,\nxaxis: {tickvals:%s,ticktext : %s, showticks: true},\n\n};"%(xName3, xName))
        print codeString1
        codeString1.replace("</script>", "</script")
        Plot.show(bx, codeString1)

        
    colnames = tab.columnNames()
    rt = ResultsTable()
    for c in colnames:
        ar = tab.column(c).asObjectArray()
        for i, a in enumerate(ar):
            rt.setValue(c,i,a)

    colnames = summary2.columnNames()
    rt = ResultsTable()
    for c in colnames:
        ar = summary2.column(c).asObjectArray()
        for i, a in enumerate(ar):
            rt.setValue(c,i,a)

    rt.show("Summary of ResultTable")

    
    
    """norm = "Mean [%s]"%parameter
    summary = tab.summarize(parameter, AggregateFunctions.mean).by(Group_by, "COLOC_ID")

    print summary

    bx =  BoxPlot.create("BoxPlot of Particle Analysis Results of channel %s, selection %s" %(Channel_Name, Selection), summary, Group_by, norm)

    colnames = summary.columnNames()
    rt = ResultsTable()
    for c in colnames:
        ar = summary.column(c).asObjectArray()
        for i, a in enumerate(ar):
            rt.setValue(c,i,a)

    rt.show("Summary of ResultTable")"""

        

"""
print tab

#bx = BoxPlot.create("BoxPlot of Colocalisation Results between %s and %s in %s" %(Channel_Name, Second_Channel, Selection), tab, Group_by, norm)
codeString1 = bx.asJavascript("target")

codeString1 = codeString1.replace("\ntype: 'box',\nname: '',\n};\n\n", "\ntype: 'box',\nname: '',\n jitter:1,\npointpos: 0,\nboxpoints: 'all',\n};\n\n")

codeString1 = "target' ></div>" + codeString1

print codeString1

Plot.show(bx, codeString1)

colnames = tab.columnNames()
rt = ResultsTable()
for c in colnames:
    ar = tab.column(c).asObjectArray()
    for i, a in enumerate(ar):
        rt.setValue(c,i,a)

rt.show("ResultTable of Query")
"""
    


