#Jeff Hardin, Dept. of Integrative Biology
#Univ. of Wisconsin-Madison
#jdhardin@wisc.edu
#Current version July 30, 2025

#Additional colors could be added via org.jfree.chart.ChartColor
#Green = #32a852
#Purple = #7132a8
#Dark red = #9C2005
#Dark brown = #754F04
#Light brown = #9C6905

#This script imports data from one or more CSV files in the form x[i],y[i] as paired columns
#Using JFreeChart allows export as an editable SVG file.
#User asked to choose legend placement and data structure

#Currently uses the hard-wired colors built into ImageJ's Plot class
#Could be expanded to a nicer color palette using hexadecimal colors.

from org.jfree.chart import ChartFactory, ChartPanel, JFreeChart, LegendItem  
from org.jfree.chart.title import LegendTitle
from org.jfree.chart.annotations import XYTitleAnnotation
from org.jfree.chart.plot import PlotOrientation, XYPlot
from org.jfree.data.xy import YIntervalSeries, YIntervalSeriesCollection
from org.jfree.chart.plot import PlotOrientation, XYPlot
from org.jfree.data.xy import XYDataset, XYSeries, XYSeriesCollection
from org.jfree.chart.renderer.xy import XYErrorRenderer
from org.jfree.chart.axis import NumberAxis
from org.jfree.data.category import DefaultCategoryDataset
from org.jfree.chart.ui import RectangleEdge,RectangleInsets,RectangleAnchor
from org.jfree.data.statistics import Statistics
from org.jfree.util import ShapeUtilities
from javax.swing import JFrame  
import java.awt.Color as Color
import java.awt.Rectangle as Rectangle
from java.awt import Dimension, BasicStroke
import java.awt.Frame as Frame
import java.awt.Window as Window
from java.io import File as File
from java.lang import System as System
from ij import WindowManager as WindowManager
from ij.plugin.frame import RoiManager as RoiManager
from ij.process import ImageStatistics as ImageStatistics
from ij.measure import Measurements as Measurements
from ij import IJ as IJ
from ij.measure import CurveFitter as CurveFitter
from ij.gui import Plot as Plot
from ij.gui import PlotWindow as PlotWindow
from ij.gui import ImageWindow as ImageWindow
from ij.text import TextWindow
from ij.gui import GenericDialog
from ij import ImagePlus as ImagePlus
from ij.io import FileInfo as FileInfo
from ij.measure import ResultsTable as ResultsTable
from ij import WindowManager as WindowManager
import math
import os
from os import path, mkdir
import csv
from array import array
from java.lang import Double
import jarray

#below from 
#https://stackoverflow.com/questions/736043/checking-if-a-string-can-be-converted-to-float-in-python
def isfloat(value):
  try:
    float(value)
    return True
  except ValueError:
    return False

def doPlot():
	k = len(listOfPaths)
	if (k < 2):
		IJ.showMessage("Incorrect data structure","This script required more than one dataset.")
		return None
	#Ask for options
	gd = GenericDialog("Options")
	gd.addCheckbox("Combine plots", True)
	gd.addCheckbox("Plot Y mean values", True)
	gd.addCheckbox("Show error bars", True)
	items = ["Std err", "Std dev"];
	gd.addRadioButtonGroup("Error Bars", items, 1, 2, "Std err");
	gd.addRadioButtonGroup("Legend position:", ["JFree default", "Inside plot"], 1,2,"JFree default")
	gd.addCheckbox("Curve fit means", True)
	gd.addChoice("Curve fitting method:", ["Single", "Double"], "Single")
	gd.addCheckbox("Force curve fit through origin for single exponential", True)
	gd.showDialog()	
	if gd.wasOKed():
		combinePlots = gd.getNextBoolean()
		plotMeans = gd.getNextBoolean()
		plotStdErrs = gd.getNextBoolean()
		errorBarType = gd.getNextRadioButton()
		legendPosition = gd.getNextRadioButton()
		fitMeans = gd.getNextBoolean()
		fitMethod = gd.getNextChoice()
		throughOrigin = gd.getNextBoolean()
	else: return None

	lineseparator = "\n"
	cellseparator = ","
	myTable = ResultsTable()
	#set the next to True so that we can throw NaN cells later
	myTable.setNaNEmptyCells(True)
	xMin = float()
	xMax = float()
	xMax = 0
	xMin = 0
	plotColor=[]
	plotColor = (Color.black, Color.red, Color.blue, Color.green, Color.magenta, Color.darkGray, Color.gray, Color.lightGray, Color.cyan, Color.orange, Color.pink, Color.yellow)
	#Call below is a feature of Fiji/IJ2
	#See https://imagej.net/scripting/parameters
	k = len(listOfPaths)
	#for i in range(0, k):
		#print(listOfPaths[i])
	dataset = XYSeriesCollection()
	meanDataset = XYSeriesCollection()

	for i in range(0, k):
		#read in FRAP curve data from CSV file
		#copies the whole file to an array of lines
		#have to type cast pathnames to Python str
		text_file = open(str(listOfPaths[i]), "r")		
		#read whole file to a string
		data = text_file.read()
		#close file
		text_file.close()		
		#separate into lines of text
		lines=data.split(lineseparator)
		numRows = len(lines)	
		# recreates the columns headers
		labels=lines[0].split(cellseparator)		
		#get total columns
		numCol = len(labels)
		#declare arrays to hold data points
		x1=[]
		y1=[]
		#IJ.open(str(listOfPaths[i]))
		with open(str(listOfPaths[i]), 'r') as read_obj:
			# pass the file object to reader() to get the reader object
			csv_reader = csv.reader(read_obj)
			# Iterate over each row in the csv using reader object
			header = next(csv_reader)
			for row in csv_reader:
			#need to add code to check for empty cells, since ImageJ
			#produces CSV files in which there can be unequal numbers of rows with empty cells
			#Jython interpreter crashes when trying to convert to float if cell is blank
			#ImageJ plot functions are savvy about blank cells, but this code isn't!
				if (isfloat(row[0])):
					x1.append(float(row[0]))
				if (isfloat(row[1])):
					y1.append(float(row[1]))

		#get min and mx for X axis
		if (min(x1) < xMin):
			xMin = min(x1)
		if (max(x1) > xMax):
			xMax = max(x1)

		#Use a Java method instead (need JVM 7+)
		file = File(str(listOfPaths[i]))
		#get file name using getName()
		filename = file.getName()
	
		#Now have to get arrays into a JFreeChart dataset
		
		series1 = XYSeries(filename)
		#add XY data to series
		for m in range(len(x1)):
			series1.add(x1[m],y1[m])
		#add series to plot
		dataset.addSeries(series1)

		#create ResultTable for data
		for m in range(len(x1)):
			#if first dataset, create new rows
			if (i == 0 ):
				myTable.incrementCounter()
				myTable.addValue(filename + "-x",x1[m])
				myTable.addValue(filename + "-y",y1[m])
			else:
				#myTable.setValue(filename + "-x",m,x1[m])
				myTable.setValue(filename + "-y",m,y1[m])
					
	if (combinePlots):
		chart = ChartFactory.createXYLineChart("Plot", "X", "Y", dataset) 
		plot = chart.getPlot()
		r = plot.getRenderer()
		plot.setBackgroundPaint(Color.white)
		myDimension = Dimension(600,400)
		chart.title.setVisible(False)
		chartPanel = ChartPanel(chart)
		chartPanel.setPreferredSize(myDimension)
	
		#set color of curves
		for i in range(0, len(listOfPaths)):
			if (i < len(plotColor)):
				r.setSeriesPaint(i, plotColor[i])
			else:
				r.setSeriesPaint(i, Color.black)
			r.setSeriesStroke(i,BasicStroke(float(2.0)))
						
		legend = chart.getLegend()	
		if (legendPosition == "JFree default"):
			#next lines draw legend using defaults
			legend.setPosition(RectangleEdge.TOP)	
		else:
			#hide autogenerated legend
			legend.setVisible(False)
			#Draw legend inside the plot, since deafult position is outside the plot
			#from JFreeCHart samp[e code
			
			lt = LegendTitle(plot)
			lt.setPosition(RectangleEdge.BOTTOM)
			lt.setBorder(1,1,1,1)
			lt.setBackgroundPaint(Color.WHITE)
			ta = XYTitleAnnotation(0.98, 0.02, lt,RectangleAnchor.BOTTOM_RIGHT)
			ta.setMaxWidth(0.5)
			plot.addAnnotation(ta)
		
		#Manually set max for X axis; default. 
		#org.jfree.chart.axis.ValueAxis.autoAdjustRange() doesn't give
		#great results, as it pads the X axis to much.
		domainAxis = plot.getDomainAxis()
		domainAxis.setRange(xMin,xMax)

		# Show the chart in an interactive window  
		# where the right-click menu enables saving to PNG or SVG, and adjusting properties  
		frame = JFrame("Plots")  
		frame.getContentPane().add(chartPanel)  
		frame.pack()  
		frame.setVisible(True)
	
	#Calculate stats and graph mean + std err
	if (len(listOfPaths) > 1):
		yValues = []
		#yMeans = []
		#yStdDevs = []
		#yMean = 0
		for j in range(len(x1)):
			for i in range(0, len(listOfPaths)):
				#print("i:")
				#print(i)
				if (j == 0):
					yValues.append(myTable.getValueAsDouble(i+1,j))
				else:
					yValues[i] = myTable.getValueAsDouble(i+1,j)
			#print(yValues)
			#Add columns for stats
			myTable.setValue("Mean",j,Statistics.calculateMean(yValues))
			myTable.setValue("Std Dev",j,Statistics.getStdDev(yValues))
			myTable.setValue("Std Err",j,Statistics.getStdDev(yValues)/math.sqrt(len(listOfPaths)))
	
	myTable.show("Data")
	yMeans = myTable.getColumn(1+len(listOfPaths)+1)
	yStdDevs = myTable.getColumn(1+len(listOfPaths)+2)
	yStdErrs = myTable.getColumn(1+len(listOfPaths)+3)

	if (plotMeans):
		#Plot mean values
		#meanSeries = XYSeries("X v. Mean")
		#add XY data to series

		meanCollection = YIntervalSeriesCollection()
		seriesName = "X v. Y"
		if (plotStdErrs):
			if (errorBarType == "Std dev"):
				seriesName = "X v. Means +/- SD"
			else:
				seriesName = "X v. Means +/- SEM"
		meanSeries = YIntervalSeries(seriesName)
		#now add values to series
		for m in range(len(x1)):
			meanValue = float(myTable.getValue("Mean",m))
			stdDevValue = float(myTable.getValue("Std Dev",m))
			stdErrValue = float(myTable.getValue("Std Err",m))
			if (errorBarType == "Std dev"):
				meanSeries.add(x1[m],meanValue,meanValue - stdDevValue,meanValue + stdDevValue)
			else:
				meanSeries.add(x1[m],meanValue,meanValue - stdErrValue,meanValue + stdErrValue)
		meanCollection.addSeries(meanSeries)
		errorRenderer = XYErrorRenderer()
		errorRenderer.setCapLength(0)
		xAxis = NumberAxis("X")
		yAxis = NumberAxis("Y")
		meanPlot = XYPlot(meanCollection, xAxis, yAxis, errorRenderer)
		meanChart = JFreeChart("Means", meanPlot)
		errorRenderer.setSeriesLinesVisible(0, True)
		errorRenderer.setSeriesShapesVisible(0, False)
		if (not plotStdErrs):
			errorRenderer.setDrawYError(False)
		meanPlot.setBackgroundPaint(Color.white)
		meanPlot.setDomainGridlinesVisible(False)
		meanPlot.setRangeGridlinesVisible(False)
		#Need to set font size and bold/non-bold for axis labels...
		
		myDimension = Dimension(600,400)
		meanChart.title.setVisible(False)
		#Not sure why, but if the next line isn't added, the edges are gray...
		meanChart.setBackgroundPaint(Color.WHITE)
		meanChartPanel = ChartPanel(meanChart)
		meanChartPanel.setPreferredSize(myDimension)

		#set color of curve
		errorRenderer.setSeriesPaint(0, Color.red)
		errorRenderer.setSeriesStroke(0,BasicStroke(float(2.0)))
				
		legend = meanChart.getLegend()
		
		if (legendPosition == "JFree default"):
			#next line would draw legend using defaults
			legend.setPosition(RectangleEdge.TOP)
		else:
			#hide autogenerated legend
			meanLegend = meanChart.getLegend()
			meanLegend.setVisible(False)

			#Draw legend inside the plot, since default position is outside the plot
			#from JFreeCHart samp[e code
			mlt = LegendTitle(meanPlot)
			mlt.setPosition(RectangleEdge.BOTTOM)
			mlt.setBorder(1,1,1,1)
			mlt.setBackgroundPaint(Color.WHITE)
			mta = XYTitleAnnotation(0.98, 0.02, mlt,RectangleAnchor.BOTTOM_RIGHT)
			mta.setMaxWidth(0.6)
			meanPlot.addAnnotation(mta)

		#Manually set max for X axis; default. 
		#org.jfree.chart.axis.ValueAxis.autoAdjustRange() doesn't give
		#great results, as it pads the X axis to much.
		mdomainAxis = meanPlot.getDomainAxis()
		mdomainAxis.setRange(xMin,xMax)

		# Show the chart in an interactive window  
		# where the right-click menu enables saving to PNG or SVG, and adjusting properties  
		mframe = JFrame("Mean")
		#next call didn't change anything!
		mframe.setBackground(Color.WHITE) 
		mframe.getContentPane().add(meanChartPanel)  
		mframe.pack()  
		mframe.setVisible(True)
		
		if (fitMeans):
			# Fitter
			# Have to create a Double array with the y means data
			yFit = jarray.zeros(len(x1), 'd')
			yFitPlot = []
			for i in range(len(x1)):
				yFit[i] = float(myTable.getValue("Mean",i))
			fitter = CurveFitter(x1, yFit)
			if fitMethod == "Single":
				if throughOrigin: 
					fitter.doFit(CurveFitter.EXP_RECOVERY_NOOFFSET)
				else:
					fitter.doFit(CurveFitter.EXP_RECOVERY)
				param_values = fitter.getParams()
				for i in range(len(x1)):
					yFitPlot.append( fitter.f( fitter.getParams(), x1[i]) )
					#IJ.log(str(yFit[i]))
					myTable.setValue("Fit",i,yFitPlot[i])
				myTable.show("Data")
				
				fitSeries = YIntervalSeries("Fit")
				#now add values to series
				for m in range(len(x1)):
					fitValue = float(myTable.getValue("Fit",m))
					#fitValue = yFit[i]
					fitSeries.add(x1[m],fitValue,0,0)
				meanCollection.addSeries(fitSeries)
				errorRenderer.setSeriesPaint(1, Color.blue)
				errorRenderer.setSeriesStroke(1,BasicStroke(float(2.0)))
				errorRenderer.setSeriesLinesVisible(1, True)
				errorRenderer.setSeriesShapesVisible(1, False)
			else:
				eqn = "y = a*(1-exp(-b*x)) +c*(1-exp(-d*x)) + e"
				params = fitter.doCustomFit(eqn, None, False)
				for i in range(len(x1)):
					yFitPlot.append( fitter.f( fitter.getParams(), x1[i]) )
					myTable.setValue("Fit",i,yFitPlot[i])
				myTable.show("Data")
				
				fitSeries = YIntervalSeries("Fit")
				#now add values to series
				for m in range(len(x1)):
					fitValue = float(myTable.getValue("Fit",m))
					#fitValue = yFit[i]
					fitSeries.add(x1[m],fitValue,0,0)
				meanCollection.addSeries(fitSeries)
				errorRenderer.setSeriesPaint(1, Color.blue)
				errorRenderer.setSeriesPaint(1, Color.blue)
		
			#Output FRAP paramaters
			#need to fix this!
			time_units = "sec"
			myTextWindow = TextWindow("FRAP Results","",600,300)
			myTextWindow.append("Fit FRAP curve by " + fitter.getFormula() )
			param_values = fitter.getParams()
			if fitMethod == "Single":
				if throughOrigin:
					myTextWindow.append("Fit constrained through origin: TRUE")
				else:
					myTextWindow.append("Fit constrained through origin: FALSE")
				thalf = math.log(2) / param_values[1]
				mobile_fraction = param_values[0]
				str1 = ('Half-life = %.2f ' + time_units) % thalf
				myTextWindow.append( str1 )
				str2 = "Mobile fraction = %.1f %%" % (100 * mobile_fraction)
				myTextWindow.append( str2 )
				myTextWindow.append( "" )
				myTextWindow.append("*******Details********" + fitter.getResultString() )

			else:
				thalf1 = math.log(2) / param_values[1]
				thalf2 = math.log(2) / param_values[3]
				mobile_fraction = param_values[0] + param_values[2]
				str1 = ('Half-life #1= %.2f ' + time_units) % thalf1
				myTextWindow.append( str1 )
				str2 = ('Half-life #2= %.2f ' + time_units) % thalf2
				myTextWindow.append( str2 )
				str3 = "Mobile fraction = %.1f %%" % (100 * mobile_fraction)
				myTextWindow.append( str3 )
				myTextWindow.append( "" )
				myTextWindow.append("*******Details********" + fitter.getResultString() )


		
#main
#@ File[] listOfPaths (label="select files", style="files")
doPlot()