#@ File outputDB(label="Specify your Database (.db-file)")
#@ Boolean Normalize_To_Selection_Area(value="False")
#@ Boolean Normalize_To_Mask_Area(value="False")
#@ Boolean Inside_Mask(value="True")
#@ String Channel_Name(label="Specify the channel you want to explore", style="Text Field")
#@ String Second_Channel(label="Specify the channel you want to explore", style="Text Field")
#@ String m2(label="Choose threshold method used for the second channel", choices={'Default', 'Huang', 'Intermodes', 'IsoData', 'IJ_IsoData', 'Li', 'MaxEntropy', 'Mean', 'MinError', 'Minimum', 'Moments', 'Otsu', 'Percentile', 'RenyiEntropy', 'Shanbhag', 'Triangle', 'Yen'})
#@ String Selection(label="Specify the selection (mask) you want to explore", style="Text Field")
#@ String Group_by(style="Text Field")
#@ String parameter(label="Choose the parameter you want to plot", choices={"Number_of_Particles", "Area", "IntDen", "Mean"})

from __future__ import with_statement
from tech.tablesaw.api import Table
from tech.tablesaw.plotly import Plot
from tech.tablesaw.plotly.api import BoxPlot, Histogram
from tech.tablesaw.aggregate import AggregateFunctions
#from tech.tablesaw.plotting import Plot
from java.sql import DriverManager, SQLException
from org.sqlite import SQLiteConfig, JDBC
from java.lang import Class
import sys
from java.io import File
from ij.measure import ResultsTable



def getConnection(jdbc_driver, DB_URL):
        """
        Get Connection to DB and returns connection handler
        """

        config = SQLiteConfig()
        config.enforceForeignKeys(True)

        try:
            Class.forName(jdbc_driver).newInstance()
        except Exception, msg:
            print msg
            sys.exit(-1)
        try:
            dbConn = DriverManager.getConnection(DB_URL, config.toProperties())
        except SQLException, msg:
            print msg
            sys.exit(-1)
        return dbConn

##############################################################################################################
####### Start of the script
##############################################################################################################

pa = "Coloc_Analysis"

if Inside_Mask:
    inorout = "Inside"
else:
    inorout = "Outside"

    
pathDB = outputDB.getAbsolutePath()
DB_URL = "jdbc:sqlite:" + pathDB
jdbc_driver = "org.sqlite.JDBC"

dbConn = getConnection(jdbc_driver, DB_URL)

if not Selection:
    Selection = "allSelected"

try: 
    stmt = dbConn.createStatement()
    if parameter == "Number_of_Particles":
        query = "SELECT * FROM %s_Table WHERE Channel_Name = '%s' and Selection = '%s' and Second_Channel = '%s' and INorOUT = '%s' and Method2='%s'" %(pa, Channel_Name, Selection, Second_Channel, inorout, m2)
    else:
        query = """SELECT 
        Coloc_Analysis_Table.COLOC_ID as COLOC_ID,
        Coloc_Analysis_Table.Channel_Name as Channel_Name,
        Coloc_Analysis_Table.Selection as Selection,
        Coloc_Analysis_Table.Method2 as Method2,
        Coloc_Analysis_Table.%s,
        Coloc_Measurement_Tables.%s
        FROM Coloc_Analysis_Table INNER JOIN Coloc_Measurement_Tables ON Coloc_Measurement_Tables.COLOC_ID = Coloc_Analysis_Table.COLOC_ID
        WHERE Coloc_Analysis_Table.Channel_Name = '%s' and Coloc_Analysis_Table.Selection = '%s' and Coloc_Analysis_Table.Second_Channel = '%s' and Coloc_Analysis_Table.INorOUT = '%s' and Coloc_Analysis_Table.Method2='%s'"""% (Group_by, parameter, Channel_Name, Selection, Second_Channel, inorout, m2)
    try:
        results = stmt.executeQuery(query)
        tab = Table.read().db(results, "Colocalisation Analysis Table")
    except SQLException, msg:
        print msg
        sys.exit(-1)
except SQLException, msg:
    print msg
    sys.exit(-1)

if parameter == "Number_of_Particles":
    if Normalize_To_Selection_Area:
        paColumn = tab.column("Number_of_Particles").asDoubleColumn()
        areaColumn = tab.column("Selection_Area")
        
        normColumn = paColumn.divide(areaColumn)
        normColumn.setName("Norm")
        tab.addColumns([normColumn])
    
        norm = "Norm"
    
    elif Normalize_To_Mask_Area:
        paColumn = tab.column("Number_of_Particles").asDoubleColumn()
        areaColumn = tab.column("Mask_Area")
        
        normColumn = paColumn.divide(areaColumn)
        normColumn.setName("Norm")
        tab.addColumns([normColumn])
        norm = "Norm"
    else:
        norm = "Number_of_Particles"

	
	bx = BoxPlot.create("BoxPlot of Colocalisation Results between %s and %s in %s" %(Channel_Name, Second_Channel, Selection), tab, Group_by, norm)

else:
    norm = "Mean [%s]"%parameter
    summary = tab.summarize(parameter, AggregateFunctions.mean).by(Group_by, "COLOC_ID")

    print summary

    bx =  BoxPlot.create("BoxPlot of Particle Analysis Results of channel %s, selection %s" %(Channel_Name, Selection), summary, Group_by, norm)

    colnames = summary.columnNames()
    rt = ResultsTable()
    for c in colnames:
        ar = summary.column(c).asObjectArray()
        for i, a in enumerate(ar):
            rt.setValue(c,i,a)

    rt.show("Summary of ResultTable")
        


print tab


codeString1 = bx.asJavascript("target")

codeString1 = codeString1.replace("\ntype: 'box',\nname: '',\n};\n\n", "\ntype: 'box',\nname: '',\n jitter:1,\npointpos: 0,\nboxpoints: 'all',\n};\n\n")

codeString1 = "target' ></div>" + codeString1

print codeString1

Plot.show(bx, codeString1)

colnames = tab.columnNames()
rt = ResultsTable()
for c in colnames:
    ar = tab.column(c).asObjectArray()
    for i, a in enumerate(ar):
        rt.setValue(c,i,a)

rt.show("ResultTable of Query")
    


