/* Fit_Polynomial.bsh
 * IJ BAR: https://github.com/tferr/Scripts#scripts
 *
 * Implements polynomial fitting in ImageJ using the apache commons math library.
 * 2016.03, Tiago Ferreira
 *
 * N.B.: CurveFitter will be deprecated after commons-math3.3. It should be replaced by
 *       AbstractCurveFitter once Fiji adopts it.
 * TODO: Plot confidence bands (perhaps using stat.correlation.Covariance to determine the
 *       covariance matrix?)
 */

import ij.IJ;
import ij.ImagePlus;
import ij.WindowManager;
import ij.gui.GenericDialog;
import ij.gui.ImageWindow;
import ij.gui.Overlay;
import ij.gui.Plot;
import ij.gui.PlotWindow;
import ij.gui.Roi;
import ij.gui.TextRoi;
import ij.process.ImageProcessor;
import ij.util.Tools;
import org.apache.commons.math3.analysis.polynomials.PolynomialFunction;
import org.apache.commons.math3.fitting.CurveFitter;
import org.apache.commons.math3.optim.nonlinear.vector.jacobian.LevenbergMarquardtOptimizer;

boolean verbose = true;	// print details to the log window?
boolean guess = true;	// guess a 'polynomial of best fit'?
int minDegree = 2;		// the smallest polynomial order to consider when guessing POBF
int maxDegree = 40;		// the highest polynomial order to consider when guessing POBF
int defDegree = 12;		// the default polynomial degree


// Check upfront if required dependencies exist and alert non-Fiji users about them
try {
	Class.forName("org.apache.commons.math3.analysis.polynomials.PolynomialFunction");
} catch( ClassNotFoundException e ) {
	String URL = "http://jenkins.imagej.net/job/Stable-Fiji/ws/Fiji.app/jars/commons-math3-3.4.1.jar";
	IJ.log("\n**** Fit Polynomial Error: Required library not found:\n"+ e +"\n \n"
		+ "The apache commons math library must be installed in the plugins folder.\n"
		+ "Please install it in plugins/jars/ by double-clicking on the link below:\n"+ URL);
	IJ.error("Dependencies missing! Check the Log window for details.");
	return;
}

/* Returns an array with the sum of y and y^2 */
double[] calculateSumYandY2(double[] yArray) {
	double sumY = 0.0; double sumY2 = 0.0;
	for (int i=0; i<yArray.length; i++) {
		double y = yArray[i];
		sumY += y;
		sumY2 += y*y;
	}
	return new double[] {sumY, sumY2};
}

/* Returns a readable string of the chosen polynomial */
String getDegreeString(int degree) {
	String polynName;
	if (degree==1)
		polynName = "Straight line";
	else if (degree==2)
		polynName = "2nd degree polynomial";
	else if (degree==3)
		polynName = "3rd degree polynomial";
	else
		polynName = degree +"th degree polynomial";
	return polynName;
}

/* Returns details of the fit */
String getLogMessage(PolynomialFunction fit, int degree) {
	double[] coefficients = fit.getCoefficients();
	String formula = "\ny = ";
	String parameters = "";
	char cChar = 'a';
	for (int i=0; i<degree+1; i++) {
		parameters += "\n"+ cChar +" = "+ IJ.d2s(coefficients[i], 5, 9);
		formula += cChar +"x^"+ i;
		if (i<degree)
			formula += " + ";
		cChar++;
		if (i==25)
			cChar = 'A'; // this won't work for more than 52 parameters
	}
	return "\nPolynomial fit ("+ super.polynomialName +"):"+ formula + parameters +"\n";
}

/* Returns RSquared, http://en.wikipedia.org/wiki/Coefficient_of_determination */
double getRSquared(double[] sampledData, double[] fittedData) {
	double[] sSumYY2 = calculateSumYandY2(sampledData);
	double[] fSumYY2 = calculateSumYandY2(fittedData);
	double sSumY = sSumYY2[0];
	double sSumY2 = sSumYY2[1];
	double fSumY2 = fSumYY2[1];
	double sumMeanDiffSqr = sSumY2 - sSumY*sSumY/sampledData.length;
	double sumResidualsSqr = sSumY2 - fSumY2;
	double rSquared = 0.0;
	if (sumMeanDiffSqr > 0.0)
		rSquared = 1.0 - sumResidualsSqr/sumMeanDiffSqr;
	return rSquared;
}

/* Returns Adjusted RSquared, http://en.wikipedia.org/wiki/Coefficient_of_determination */
double getAdjustedRSquared(double rSquared, int p, int n) {
	return ( rSquared - ((1-rSquared) * p) / (n-p-1) );
}

/* Tries to guess the polynomial of best fit while providing some visual feedback */
int guessDegree(PlotWindow pw, Plot plot, xsampled, ysampled) {
	ImagePlus imp = pw.getImagePlus();
	imp.setHideOverlay(false);
	n = xsampled.length;
	yfitted = new double[n];
	double maxRSquared = 0; //double maxARSquared = 0;
	for (int i=super.minDegree; i<=super.maxDegree; i++) {
		super.curDegree = i;
		super.color = getColor(i);
		CurveFitter fitter = new CurveFitter(new LevenbergMarquardtOptimizer());
		for (int j=0; j<n; j++)
			fitter.addObservedPoint(xsampled[j], ysampled[j]);
		double[] best = fitter.fit(new PolynomialFunction.Parametric(), new double[i+1]);
		PolynomialFunction fitted = new PolynomialFunction(best);
		for (int k=0; k<n; k++)
			yfitted[k] = fitted.value(xsampled[k]);
		double RSquared = getRSquared(ysampled, yfitted);
		double ARSquared = getAdjustedRSquared(RSquared, i-1, n);
		if (RSquared>maxRSquared) { //(ARSquared>maxARSquared)
			bestDegree = i;
			maxRSquared = RSquared; //maxARSquared = ARSquared;
		}
		if (super.verbose) {
			IJ.log(getDegreeString(i) +": R^2= "+ RSquared +" Adj. R^2= "+ ARSquared);
			Roi roi = new TextRoi(0.0, 0.0, "Degree: "+ i + " (R\u00b2=" + IJ.d2s(ARSquared, 4) +")");
			Overlay overlay = new Overlay();
			overlay.add(roi);
			overlay.setFillColor(Color.GRAY);
			overlay.setStrokeColor(super.colors[color]);
			imp.setOverlay(overlay);
			plot.setColor(super.colors[color]);
			plot.addPoints(xsampled, yfitted, Plot.LINE);
			pw.drawPlot(plot);
			if (i==super.maxDegree)
				overlay.clear();
		} else
			IJ.showProgress(i + super.minDegree, super.maxDegree);
	}
	return bestDegree;
}

/* Returns an array of 256 'heatmap' colors */
Color[] parabolicColors() {
	int[] r = new int[256];
	int[] g = new int[256];
	int[] b = new int[256];
	Color[] colors = new Color[256];
	for (int i=0; i<256; i++) { // create parabolas of vertices (0,255; 255/2,255; 255,255)
		r[i] = (int)Math.max(0, -255/(127.5*127.5) * (i-255)*(i-255) + 255);
		g[i] = (int)(-255/(127.5*127.5) * (i-127.5)*(i-127.5) + 255);
		b[i] = (int)Math.max(0, -255/(127.5*127.5) * i*i + 255);
		colors[i] = new Color(r[i], g[i], b[i]);
	}
	return colors;
}

/* Returns the 8-bit position of index relatively to [minDegree maxDegree] */
int getColor(int index) {
	return ((255*(index-super.minDegree))/(super.maxDegree-super.minDegree));
}

/* Adds legend and color ramp to plot */
void makeLegend(Plot plot, String upperLabel, boolean ramp) {
	ImageProcessor ip = plot.getProcessor();
	ip.drawString(upperLabel, 0, ip.getFontMetrics().getHeight());
	if (ramp) {
		int x = (int)(0.5 * (Plot.LEFT_MARGIN + ip.getWidth() - 256));
		int y = ip.getHeight()-4;
		for (int i=0; i<256; i++) {
			ip.setColor( (i!=super.color) ? super.colors[i]:Color.BLACK );
			ip.drawLine(x+i, y, x+i, y-5);
		}
		ip.setColor(Color.BLACK);
		ip.setJustification(ip.CENTER_JUSTIFY);
		ip.drawString("Degree", x+128, y-5);
		ip.drawString(""+super.minDegree, x, y-5);
		ip.drawString(""+super.maxDegree, x+256, y-5);
	}
}

/* Prompts user for settings. Returns false if dialog is dismissed */
boolean getUserInput() {
	GenericDialog gd = new GenericDialog("Fit Polynomial to Curve...");
	gd.addSlider("Polynomial degree:", 1, super.maxDegree, super.defDegree);
	gd.addCheckbox("Guess \"best fit\" (ignores specified degree above)", super.guess);
	gd.addStringField("          Degree range:", super.minDegree +"-"+ super.maxDegree, 12);
	gd.addCheckbox("Log details (animates retrieval of \"best fit\")", super.verbose);
	gd.addHelp("http://imagej.net/Sholl_Analysis#Complementary_Tools");
	gd.showDialog();

	super.defDegree = (int)Math.max(1, gd.getNextNumber());
	String[] minAndMax = Tools.split(gd.getNextString(), " -");
	int min = minAndMax.length>=1 ? (int)gd.parseDouble(minAndMax[0]) : super.minDegree;
	int max = minAndMax.length==2 ? (int)gd.parseDouble(minAndMax[1]) : super.maxDegree;
	min = Double.isNaN(min) ? super.minDegree : Math.max(1, min);
	max = Double.isNaN(max) ? super.maxDegree : max;
	super.minDegree = min; super.maxDegree = max;
	super.guess = gd.getNextBoolean();
	super.verbose = gd.getNextBoolean();
	return !gd.wasCanceled();
}


int color = 1, curDegree = 1, bestDegree = 1;
String polynomialName = "";
Color[] colors;
PlotWindow pw;

ImagePlus imp = WindowManager.getCurrentImage();
if (imp==void || imp==null) {
	IJ.error("There are no plots open.");
	return;
}

ImageWindow win = imp.getWindow();

if (win==void || win==null || !(win!=null && win instanceof PlotWindow)) {
	IJ.error(imp.getTitle() +" is not a plot window.");
	return;
}

pw = (PlotWindow)win;
double[] xsampled = Tools.toDouble(pw.getXValues());
double[] ysampled = Tools.toDouble(pw.getYValues());

if (!getUserInput())
	return;
long startTime = System.currentTimeMillis();

// Ensure xsampled and ysampled have the same size and set the n. of points to be fitted
int xsize = xsampled.length;
int ysize = ysampled.length;
int nPoints = Math.min(xsize, ysize);
if (xsize!=ysize) {
	xsampled = Arrays.copyOf(xsampled, nPoints);
	ysampled = Arrays.copyOf(ysampled, nPoints);
}

if (nPoints==0) {
	IJ.error("Could not retrieve values from plot");
	return;
}

IJ.showStatus("Retrieving fits. Please wait...");
colors = parabolicColors();
String plotTitle = imp.getTitle();
boolean plotExists = plotTitle.startsWith("Polynomial fit");
Plot plot = new Plot("Polynomial fit "+ plotTitle, "", "", xsampled, ysampled);

try {

	if (guess) { // Run fitting loop if degree is to be guessed

		if (verbose) {
			IJ.log("\nGuessing polynomial of 'best fit' for "+ plotTitle +":");
			if (!plotExists) { // Display plot for visual feedback
				pw = plot.show();
				plotExists = true;
			}
		}
		defDegree = guessDegree(pw, plot, xsampled, ysampled);
	}

	// Perform final fit
	CurveFitter fitter = new CurveFitter(new LevenbergMarquardtOptimizer());
	for (int i=0; i<nPoints; i++)
		fitter.addObservedPoint(xsampled[i], ysampled[i]);

	// Compute optimal coefficients. The polynomial degree is deduced from the length of
	// the array containing the initial guess for the coefficients (all set to zero here)
	double[] best = fitter.fit(new PolynomialFunction.Parametric(), new double[defDegree+1]);

	// Construct the polynomial that best fits the data
	PolynomialFunction fitted = new PolynomialFunction(best);
	double[] xfitted = new double[nPoints];
	double[] yfitted = new double[nPoints];
	for (int i=0; i<nPoints; i++) {
		xfitted[i] = xsampled[i];
		yfitted[i] = fitted.value(xsampled[i]);
	}

	polynomialName = getDegreeString(defDegree);
	if (verbose)
		IJ.log("\nFitting "+ plotTitle + ":"+ getLogMessage(fitted, defDegree));

	double RSquared = getRSquared(ysampled, yfitted);
	color = getColor(Math.min(defDegree,maxDegree));
	plot.setColor(colors[color]);
	plot.setLineWidth(2);
	plot.addPoints(xfitted, yfitted, Plot.LINE);
	plot.setColor(Color.BLACK);
	plot.setLineWidth(1);
	String plotLabel = (guess) ? polynomialName +" (guessed)": polynomialName;
	plotLabel += "  R^2= "+ IJ.d2s(RSquared, 5, 9) +"  "+ nPoints +" data points";
	makeLegend(plot, plotLabel, guess);
	if (plotExists)
		pw.drawPlot(plot);
	else
		plot.show();

	IJ.showStatus("Done. "+ (System.currentTimeMillis()-startTime)/1000.0 +" seconds");
	IJ.showProgress(0,0);

} catch (Exception e) {

	int errD = (guess) ? curDegree : defDegree;
	IJ.error("Optimizer failed when dealing with a polynomial of degree\n"
		+ errD +" ["+ nPoints +" data points] and will now terminate (see console\n"
		+ "for details)... Perhaps settings should be revised?" );
	IJ.log("\n*** Exception when parsing data from "+ plotTitle +":\n"+ e);
	if (guess)
		IJ.log(">>> Polynomial of 'best fit' until exception: "+ getDegreeString(bestDegree));
	else
		IJ.log(">>> Specified polynomial: "+ getDegreeString(defDegree));
	return;

}
