#@ File outputDB("Specify your Database (.db-file)")
#@ Boolean Normalize_To_Area(value="False")
#@ String Channel_Name(label="Specify the channel you want to explore", style="Text Field")
#@ String m(label="Choose threshold method", choices={'Default', 'Huang', 'Intermodes', 'IsoData', 'IJ_IsoData', 'Li', 'MaxEntropy', 'Mean', 'MinError', 'Minimum', 'Moments', 'Otsu', 'Percentile', 'RenyiEntropy', 'Shanbhag', 'Triangle', 'Yen'})
#@ String Selection(label="Specify the selection (mask) you want to explore", style="Text Field")
#@ String Group_by(style="Text Field")
#@ String plottype(label="Choose threshold method", choices={"Boxplot", "Barplot"})
#@ String parameter(label="Choose the parameter you want to plot", choices={"Number_of_Particles", "Area", "IntDen", "Mean"})

from __future__ import with_statement
from tech.tablesaw.api import Table
from tech.tablesaw.plotly import Plot
from tech.tablesaw.plotly.api import BoxPlot, Histogram, VerticalBarPlot
from tech.tablesaw.aggregate import AggregateFunctions
#from tech.tablesaw.plotting import Plot
from java.sql import DriverManager, SQLException
from org.sqlite import SQLiteConfig, JDBC
from java.lang import Class
import sys
from java.io import File
from ij.measure import ResultsTable


def getConnection(jdbc_driver, DB_URL):
        """
        Get Connection to DB and returns connection handler
        """

        config = SQLiteConfig()
        config.enforceForeignKeys(True)

        try:
            Class.forName(jdbc_driver).newInstance()
        except Exception, msg:
            print msg
            sys.exit(-1)
        try:
            dbConn = DriverManager.getConnection(DB_URL, config.toProperties())
        except SQLException, msg:
            print msg
            sys.exit(-1)
        return dbConn


##############################################################################################################
####### Start of the script
##############################################################################################################


pa = "Particle_Analysis"


if not Selection:
    Selection = "allSelected"

pathDB = outputDB.getAbsolutePath()
DB_URL = "jdbc:sqlite:" + pathDB
jdbc_driver = "org.sqlite.JDBC"

dbConn = getConnection(jdbc_driver, DB_URL)

try: 
    stmt = dbConn.createStatement()

    if parameter == "Number_of_Particles":
        query = "SELECT * FROM %s_Table WHERE Channel_Name = '%s' and Selection = '%s' and Method='%s'" %(pa, Channel_Name, Selection, m)

    else:
        query = """SELECT 
        Particle_Analysis_Table.PA_ID as PA_ID,
        Particle_Analysis_Table.Channel_Name as Channel_Name,
        Particle_Analysis_Table.Selection as Selection,
        Particle_Analysis_Table.Method as Method,
        Particle_Analysis_Table.%s,
        PA_Measurement_Tables.%s
        FROM Particle_Analysis_Table INNER JOIN PA_Measurement_Tables ON PA_Measurement_Tables.PA_ID = Particle_Analysis_Table.PA_ID
        WHERE Particle_Analysis_Table.Channel_Name = '%s' and Particle_Analysis_Table.Selection = '%s' and Particle_Analysis_Table.Method='%s'"""% (Group_by, parameter, Channel_Name, Selection, m) #WHERE Channel_Name = '%s' and Selection = '%s' and Method='%s' 
    try:
        print query
        results = stmt.executeQuery(query)
        tab = Table.read().db(results, "Particle Analysis Table")
    except SQLException, msg:
        print msg
        sys.exit(-1)
except SQLException, msg:
    print msg
    sys.exit(-1)

if parameter == "Number_of_Particles":
    if Normalize_To_Area:
        paColumn = tab.column("Number_of_Particles").asDoubleColumn()
        areaColumn = tab.column("Selection_Area")
        
        normColumn = paColumn.divide(areaColumn)
        normColumn.setName("Norm")
        tab.addColumns([normColumn])
        norm = "Norm"
    else: 
        norm = "Number_of_Particles"

    if plottype == "Boxplot":
        bx = BoxPlot.create("BoxPlot of Particle Analysis Results of channel %s, selection %s" %(Channel_Name, Selection), tab, Group_by, norm)
        codeString1 = bx.asJavascript("target")

        codeString1 = codeString1.replace("\ntype: 'box',\nname: '',\n};\n\n", "\ntype: 'box',\nname: '',\n jitter:1,\npointpos: 0,\nboxpoints: 'all',\n};\n\n")
        
        codeString1 = "target' ></div>" + codeString1
        
        print codeString1
        
        Plot.show(bx, codeString1)
        
    else:
        summary = tab.summarize(norm, AggregateFunctions.mean, AggregateFunctions.stdDev).by(Group_by)
        print summary
        bx = VerticalBarPlot.create("Bar", summary, Group_by, "Mean [%s]"%norm)

        codeString1 = bx.asJavascript("target")


        error_string =  summary.column("Std. Deviation [%s]"%norm).print()

        print [error_string]

        error_string = error_string.replace("\n", " ")
        error_string = error_string.replace("Column: Std. Deviation [Norm] ", "[")
        error_string += "]"

        error_string = error_string.replace(" ", ", ")
        

        
        error = "error_y: {type: 'data',\narray: %s\nvisible: true\n}," %error_string 



        codeString1 = codeString1.replace("type: 'bar',\nname: '',\n};\n\n", "\ntype: 'bar',\nname: '',\n %s,\n}\n\n" %error)
        
        codeString1 = "target' ></div>" + codeString1

        print codeString1
        
        Plot.show(bx, codeString1)
        
else:
    norm = "Mean [%s]"%parameter
    summary = tab.summarize(parameter, AggregateFunctions.mean).by(Group_by, "PA_ID")

    print summary
    
    if plottype == "Boxplot":
        bx =  BoxPlot.create("BoxPlot of Particle Analysis Results of channel %s, selection %s" %(Channel_Name, Selection), summary, Group_by, norm)
        codeString1 = bx.asJavascript("target")

        codeString1 = codeString1.replace("\ntype: 'box',\nname: '',\n};\n\n", "\ntype: 'box',\nname: '',\n jitter:1,\npointpos: 0,\nboxpoints: 'all',\n};\n\n")
        
        codeString1 = "target' ></div>" + codeString1
        
        print codeString1
        
        Plot.show(bx, codeString1)

    colnames = summary.columnNames()
    rt = ResultsTable()
    for c in colnames:
        ar = summary.column(c).asObjectArray()
        for i, a in enumerate(ar):
            rt.setValue(c,i,a)

    rt.show("Summary of ResultTable")

print tab

    
"""    
codeString1 = bx.asJavascript("target")

codeString1 = codeString1.replace("\ntype: 'box',\nname: '',\n};\n\n", "\ntype: 'box',\nname: '',\n jitter:1,\npointpos: 0,\nboxpoints: 'all',\n};\n\n")

codeString1 = "target' ></div>" + codeString1

print codeString1

Plot.show(bx, codeString1)

colnames = tab.columnNames()
rt = ResultsTable()
for c in colnames:
    ar = tab.column(c).asObjectArray()
    for i, a in enumerate(ar):
        rt.setValue(c,i,a)


rt.show("ResultTable of Query")
"""


