#@ File outputDB("Specify your Database (.db-file)")
#@ String Selection(label="Specify the selection (mask) you want to explore", style="Text Field")
#@ String Group_by(label="Group by which Column?", style="Text Field", required=False)
#@ String plottype(label="Choose threshold method", choices={"Boxplot", "Barplot"})
#@ String parameter(label="Choose the parameter you want to plot", choices={"Number_of_Spines", "Area_per_spine", "Spines_Area", "Spine_per_Area", "Mean", "IntDen", "Area"})
#@ String addParams(label="(Optional) State additional filtering parameters in exact this style, one condition per line -> ColumnName='value'", style="text area", required=False)
#@ String addGrouping(label="(Optional) State additional groupings in exact this style, one condition per line -> ColumnName", style="text area", required=False)

from __future__ import with_statement
from tech.tablesaw.api import Table
from tech.tablesaw.plotly import Plot
from tech.tablesaw.plotly.api import BoxPlot, Histogram, VerticalBarPlot, ScatterPlot
from tech.tablesaw.aggregate import AggregateFunctions
#from tech.tablesaw.plotting import Plot
from java.sql import DriverManager, SQLException
from org.sqlite import SQLiteConfig, JDBC
from java.lang import Class
import sys
from random import gauss
from java.io import File
from ij.measure import ResultsTable


def getConnection(jdbc_driver, DB_URL):
        """
        Get Connection to DB and returns connection handler
        """

        config = SQLiteConfig()
        config.enforceForeignKeys(True)

        try:
            Class.forName(jdbc_driver).newInstance()
        except Exception, msg:
            print msg
            sys.exit(-1)
        try:
            dbConn = DriverManager.getConnection(DB_URL, config.toProperties())
        except SQLException, msg:
            print msg
            sys.exit(-1)
        return dbConn

def getArrayAsString(tab, colname):
    return tab.column(colname).print().split("\n")[1:-1]


##############################################################################################################
####### Start of the script
##############################################################################################################


pa = "Spine_Analysis"
simpleList = ["Number_of_Spines", "Area_per_spine", "Spines_Area", "Spine_per_Area"]

if not Selection:
    Selection = "allSelected"

pathDB = outputDB.getAbsolutePath()
DB_URL = "jdbc:sqlite:" + pathDB
jdbc_driver = "org.sqlite.JDBC"

dbConn = getConnection(jdbc_driver, DB_URL)

if addParams:
    addParams = addParams.split("\n")
    addParams = " ".join(["AND %s" %a for a in addParams])

if addGrouping:
    addGrouping = addGrouping.split("\n")
    strGrouping = " ".join([",Spine_Analysis_Table.%s" %a for a in addGrouping])
else:
    strGrouping = ""

if Group_by:
    strGroup_by = ",Spine_Analysis_Table.%s %s" %(Group_by, strGrouping)
else:
    strGroup_by = "strGrouping"
    Group_by = "Selection"

try: 
    stmt = dbConn.createStatement()

    if parameter in simpleList:
        query = "SELECT * FROM %s_Table WHERE Selection = '%s' %s" %(pa, Selection, addParams)

    else:
        query = """SELECT 
        Spine_Analysis_Table.Spine_ID as Spine_ID,
        Spine_Analysis_Table.Selection as Selection,
        Spine_Measurement_Table.%s
        %s
        FROM Spine_Analysis_Table INNER JOIN Spine_Measurement_Table ON Spine_Measurement_Table.Spine_ID = Spine_Analysis_Table.Spine_ID
        WHERE Spine_Analysis_Table.Selection = '%s' %s"""% (parameter, strGroup_by, Selection, addParams) #WHERE Channel_Name = '%s' and Selection = '%s' and Method='%s'
    try:
        print query
        results = stmt.executeQuery(query)
        tab = Table.read().db(results, "Spine Analysis Table")
        print tab
        colnames = tab.columnNames()
        rt = ResultsTable()
        for c in colnames:
            ar = tab.column(c).asObjectArray()
            for i, a in enumerate(ar):
                rt.setValue(c,i,a)
        rt.show("Spine Analysis Table")

        if addGrouping:
            tab = tab.summarize(parameter,"Spine_ID", AggregateFunctions.mean).by(Group_by, addGrouping[0])
            tab.column("Mean [%s]"%parameter).setName(parameter)
            tab.column("Mean [Spine_ID]").setName("Spine_ID")
            print tab
    except SQLException, msg:
        print msg
        sys.exit(-1)
except SQLException, msg:
    print msg
    sys.exit(-1)

if parameter in simpleList: # == "Number_of_Spines":
    norm = parameter

    if plottype == "Boxplot":
        bx = BoxPlot.create("BoxPlot of Spine Analysis Results of selection %s" %(Selection), tab, Group_by, norm)
        codeString1 = bx.asJavascript("target")
        
        pa_id = getArrayAsString(tab, "Spine_ID")
        cond = getArrayAsString(tab, Group_by)
        pa_id = str([str(c + "_" + p) for p,c in zip(pa_id, cond)])
        codeString1 = codeString1.replace("\ntype: 'box',\nname: '',\n};\n\n", "\ntype: 'box',\nname: '',\n jitter:1,\npointpos: 0,\nboxpoints: 'all',\ntext: %s,\n};\n\n"%pa_id)
        codeString1 = "target' ></div>" + codeString1

        Plot.show(bx, codeString1)

        colnames = tab.columnNames()
        rt = ResultsTable()
        for c in colnames:
            ar = tab.column(c).asObjectArray()
            for i, a in enumerate(ar):
                rt.setValue(c,i,a)
        rt.show("Table used to generate the Plot")
        
    else:
        summary = tab.summarize(norm, AggregateFunctions.mean, AggregateFunctions.stdDev).by(Group_by)
        print summary
        bx = VerticalBarPlot.create("Bar", summary, Group_by, "Mean [%s]"%norm)

        codeString1 = bx.asJavascript("target")
        error_string =  getArrayAsString(summary, "Std. Deviation [%s]"%norm)
        error_string = str([float(e) for e in error_string])

        xName = getArrayAsString(summary, Group_by)

        xName2 = getArrayAsString(tab, Group_by)

        xName = [str(x) for x in xName]
        xName2 = [str(x) for x in xName2]

        dicName = {}
        l = 1
        for x in xName:
            if not x in dicName:
                dicName[x] = l
                l += 1

        xName2 = [float(gauss(dicName[x], 0.15)) for x in xName2]
        xName3 =  [str(dicName[x]) for x in xName]
        print dicName
        y = getArrayAsString(tab, norm)

        y = [float(p) for p in y]
        
        scatter = ScatterPlot.create("Hello", Group_by, xName2, norm, y)

        codeString = scatter.asJavascript("target")
        codeString = codeString.split("\n")

        l = []
        b = False
        for c in codeString:
            if not b:
                if c == "var trace0 =":
                    b = True

                    #l.append(c)
            if c == "};":
                b = False
                l.append(c + "\n")
            if b:
                l.append(c)
        l = "\n".join(l)
        l = l.replace("};\n\nvar trace0 =", "var trace1 =")
        
        pa_id = getArrayAsString(tab, "Spine_ID")
        cond = getArrayAsString(tab, Group_by)
        pa_id = str([str(c + "_" + p) for p,c in zip(pa_id, cond)])
        
        l = l.replace("'scatter',\nname: '',\n};\n", "'scatter',\nname: '',\ntext: %s,\n};\n\n"%pa_id)
        print [l]

        error = "error_y: {type: 'data',\narray: %s,\nvisible: true\n}" %error_string 
        
        codeString1 = codeString1.replace("type: 'bar',\nname: '',\n};\n\n", "\ntype: 'bar',\nname: '',\n %s,\n}\n\n%s" %(error, l))
        codeString1 = codeString1.replace("var data = [ trace0];", "var data = [ trace0, trace1];")
        codeString1 = codeString1.replace(str(xName).replace(", ", ","),str(xName3).replace(", ", ","))
        codeString1 = codeString1.replace("\nvisible: true\n},\n}\n\nvar trace1", "\nvisible: true\n},\ntext: %s,\n}\n\nvar trace1"%xName)
        
        
        codeString1 = "target' ></div>" + codeString1
        codeString1 = codeString1.replace("var layout = {\n    title: 'Bar',\n    height: 700,\n    width: 900,\n\n};", "var layout = {\n    title: 'Bar',\n    height: 700,\n    width: 900,\nxaxis: {tickvals:%s,ticktext : %s, showticks: true},\n\n};"%(xName3, xName))
        print codeString1
        Plot.show(bx, codeString1)

        colnames = tab.columnNames()
        rt = ResultsTable()
        for c in colnames:
            ar = tab.column(c).asObjectArray()
            for i, a in enumerate(ar):
                rt.setValue(c,i,a)
        
        
        rt.show("ResultTable of Query")
        
else:
    
    if not addGrouping:
    	norm = "Mean [%s]" %parameter
        summary = tab.summarize(parameter, AggregateFunctions.mean).by(Group_by, "Spine_ID")
    else:
    	norm = parameter
        summary = tab
    print summary
    
    if plottype == "Boxplot":
        bx =  BoxPlot.create("BoxPlot of Particle Analysis Results of channel selection %s" %(Selection), summary, Group_by, norm)
        codeString1 = bx.asJavascript("target")

        codeString1 = codeString1.replace("\ntype: 'box',\nname: '',\n};\n\n", "\ntype: 'box',\nname: '',\n jitter:1,\npointpos: 0,\nboxpoints: 'all',\n};\n\n")
        
        codeString1 = "target' ></div>" + codeString1
        
        print codeString1
        
        Plot.show(bx, codeString1)

        bx = BoxPlot.create("BoxPlot of Particle Analysis Results of channel %s, selection %s" %(Selection), summary, Group_by, norm)
        codeString1 = bx.asJavascript("target")
        
        pa_id = getArrayAsString(summary, "Spine_ID")
        cond = getArrayAsString(summary, Group_by)
        pa_id = str([str(c + "_" + p) for p,c in zip(pa_id, cond)])
        codeString1 = codeString1.replace("\ntype: 'box',\nname: '',\n};\n\n", "\ntype: 'box',\nname: '',\n jitter:1,\npointpos: 0,\nboxpoints: 'all',\ntext: %s,\n};\n\n"%pa_id)
        codeString1 = "target' ></div>" + codeString1

        Plot.show(bx, codeString1)

    elif plottype == "Barplot":
        summary2 = summary
        summary = summary2.summarize(norm, AggregateFunctions.mean, AggregateFunctions.stdDev).by(Group_by)
        print summary
        bx = VerticalBarPlot.create("Bar", summary, Group_by, "Mean [%s]"%norm)

        codeString1 = bx.asJavascript("target")
        error_string =  getArrayAsString(summary, "Std. Deviation [%s]"%norm)
        error_string = str([float(e) for e in error_string])

        xName = getArrayAsString(summary, Group_by)

        xName2 = getArrayAsString(summary2, Group_by)

        xName = [str(x) for x in xName]
        xName2 = [str(x) for x in xName2]

        dicName = {}
        l = 1
        for x in xName:
            if not x in dicName:
                dicName[x] = l
                l += 1

        xName2 = [float(gauss(dicName[x], 0.15)) for x in xName2]
        xName3 =  [str(dicName[x]) for x in xName]
        print dicName
        y = getArrayAsString(summary2, norm)

        y = [float(p) for p in y]
        
        scatter = ScatterPlot.create("Hello", Group_by, xName2, norm, y)

        codeString = scatter.asJavascript("target")
        codeString = codeString.split("\n")

        l = []
        b = False
        for c in codeString:
            if not b:
                if c == "var trace0 =":
                    b = True

                    #l.append(c)
            if c == "};":
                b = False
                l.append(c + "\n")
            if b:
                l.append(c)
        l = "\n".join(l)
        l = l.replace("};\n\nvar trace0 =", "var trace1 =")
        
        pa_id = getArrayAsString(summary2, "Spine_ID")
        cond = getArrayAsString(summary2, Group_by)
        pa_id = str([str(c + "_" + p) for p,c in zip(pa_id, cond)])
        
        l = l.replace("'scatter',\nname: '',\n};\n", "'scatter',\nname: '',\ntext: %s,\n};\n\n"%pa_id)
        print [l]

        error = "error_y: {type: 'data',\narray: %s,\nvisible: true\n}" %error_string 
        
        codeString1 = codeString1.replace("type: 'bar',\nname: '',\n};\n\n", "\ntype: 'bar',\nname: '',\n %s,\n}\n\n%s" %(error, l))
        codeString1 = codeString1.replace("var data = [ trace0];", "var data = [ trace0, trace1];")
        codeString1 = codeString1.replace(str(xName).replace(", ", ","),str(xName3).replace(", ", ","))
        codeString1 = codeString1.replace("\nvisible: true\n},\n}\n\nvar trace1", "\nvisible: true\n},\ntext: %s,\n}\n\nvar trace1"%xName)
        
        
        codeString1 = "target' ></div>" + codeString1
        codeString1 = codeString1.replace("var layout = {\n    title: 'Bar',\n    height: 700,\n    width: 900,\n\n};", "var layout = {\n    title: 'Bar',\n    height: 700,\n    width: 900,\nxaxis: {tickvals:%s,ticktext : %s, showticks: true},\n\n};"%(xName3, xName))
        print codeString1
        Plot.show(bx, codeString1)

        
    colnames = tab.columnNames()
    rt = ResultsTable()
    for c in colnames:
        ar = tab.column(c).asObjectArray()
        for i, a in enumerate(ar):
            rt.setValue(c,i,a)
    rt.show("Summary of ResultTable")

    colnames = summary2.columnNames()
    rt = ResultsTable()
    for c in colnames:
        ar = summary2.column(c).asObjectArray()
        for i, a in enumerate(ar):
            rt.setValue(c,i,a)

    rt.show("Summary of Means ResultTable")

print tab




