"""
******************************************************************************************
			Written by Jim Rowe in collaboration with Alexander Jones
								Started: 2022-08-01		
							 		@BotanicalJim
							james.rowe at sheffield.ac.uk
									Version v2.01

******************************************************************************************
"""


# *******************************import libraries*****************************************
from ij 			import IJ, ImageStack, ImagePlus,CompositeImage
from fiji.util.gui 	import GenericDialogPlus
from ij.process 	import ImageProcessor, StackStatistics, ImageConverter, FloatProcessor, ImageConverter
from ij.measure 	import ResultsTable
from array 			import array, zeros
from java.lang 		import Thread
from ij.plugin 		import Slicer, ZProjector
from os.path 		import exists
import json
import math
from java.awt.event	import ActionListener 
from java.awt import GridLayout
# *******************************functions************************************************

class buttonClick(ActionListener):
	"""Class which unique function is to handle the button clics"""
	def __init__(self):
		self.bypass = 0 
		self.quickload = 0
		self.save = 0 
		self.load = 0
	def actionPerformed(self, event): # self to state that the method will be associated to the class instances
		self.bypass=0
		# Check from where comes the event
		source = event.getSource()
		

		# Do an action depending on the button clicked
		if source.label == "Quickload nuclei settings":
			self.thresholdMethod = 'Otsu'
			self.gaussianSigma =1.2 
			self.largeDoGSigma = 6.0
			self.DoG =True
			self.manualSegment= False
			self.dilation=0
			self.sizeExclude= False 
			self.watershed = True
			self.pixelByPixel = False
			self.bypass = 1
			self.quickload=1
			# Do an action depending on the button clicked
		if source.label == "Quickload diffuse settings":
			self.thresholdMethod = 'Otsu'
			self.gaussianSigma =2
			self.largeDoGSigma = 6.0
			self.DoG =False
			self.manualSegment= False
			self.dilation=0
			self.sizeExclude= False 
			self.watershed = False
			self.pixelByPixel = True
			self.bypass = 1 
			self.quickload=1
			self.load=0
		if source.label == "Load settings file":
			loadFilepath=self.fileSelectDialog()
			if exists(loadFilepath):
				print('exists')
				sf=file(loadFilepath, 'rb')
				self.options=json.load(sf)

				print 'options loaded'
			else: 
				print 'failed to load'
			self.bypass = 1			
			self.load = 1
			self.quickload = 0
		if source.label == "Save settings file":
			self.saveFilepath=self.fileSelectDialog()
			self.save = 1
			self.bypass = 1
			self.quickload = 0
			print self.saveFilepath
			
	def fileSelectDialog(self):
		"""file location"""
		gd = GenericDialogPlus("Select file location")
		gd.addFileField("Select settings file .json file", "")
		gd.showDialog()
		
		if gd.wasCanceled():
			IJ.exit()
		settingFilePath =gd.getNextString()
	
		return 	settingFilePath
			
def globalBackSub(labelGFX, quantGFX, otherGFX):
	"""Requires a labelGFX image, the quantGFX to be quantified (also the output) and, one sacrificial otherGFX images"""
	results=ResultsTable()
	clij2.statisticsOfBackgroundAndLabelledPixels(quantGFX, labelGFX, results)
	IJ.log(str(results.getValue("MEAN_INTENSITY",0)))
	clij2.addImageAndScalar(quantGFX, otherGFX, -results.getValue("MEAN_INTENSITY",0))
	clij2.copy(otherGFX,quantGFX)
	return quantGFX
	
	
def createSubtractionLabels(labelGFX, gfx2, outputGFX):
	"""Requires a labelGFX image, two sacrificial otherGFX images, one of which will be the output"""
	clij2.dilateLabels(labelGFX, gfx2 , 4)
	clij2.subtractImages(gfx2, labelGFX, outputGFX)
	return outputGFX
	
	
def localLabelBackSub(dilatedLabelGFX, labelGFX, quantGFX, otherGFX1, otherGFX2):
	"""Requires a dilatedlabelGFX image, labelGFX, a quantGFX image, and two sacrificial otherGFX images"""
	results4=ResultsTable()
	clij2.statisticsOfBackgroundAndLabelledPixels(quantGFX, dilatedLabelGFX, results4)
	intensities=results4.getColumn(12)
	fp= FloatProcessor(len(intensities), 1, intensities, None)
	intensitiesImp= ImagePlus("IntensitiesImp", fp)
	intGFX=clij2.push(intensitiesImp)
	clij2.replaceIntensities(labelGFX, intGFX, otherGFX1)
	intGFX.close()
	clij2.subtractImages(quantGFX, otherGFX1, otherGFX2)
	clij2.copy(otherGFX2, quantGFX)
	return quantGFX
			
							
def extractChannel(imp, nChannel, nFrame):
	"""extract a channel from the image, at a given frame returning a new imagePlus labelled with the channel name"""
	stack = imp.getImageStack()
	ch=ImageStack(imp.width, imp.height)
	for i in range(imp.getNSlices()):
		index = imp.getStackIndex(nChannel, i, nFrame)
		ch.addSlice(str(i), stack.getProcessor(index))
	imp3 = ImagePlus("Channel " + str(nChannel), ch).duplicate()
	stats =StackStatistics(imp3) 
	IJ.setMinAndMax(imp3, stats.min, stats.max)
	return imp3

def extractFrame(imp, nFrame):
	"""extract a frame from the image, returning a new 16 bit imagePlus labelled with the channel name"""
	stack = imp.getImageStack()
	fr=ImageStack(imp.width, imp.height)
	for i in range(1, imp.getNSlices() + 1):
		for nChannel in range(1, imp.getNChannels()+1):
			index = imp.getStackIndex(nChannel, i, nFrame)
			fr.addSlice(str(i), stack.getProcessor(index))
	imp3 = ImagePlus("Frame " + str(nFrame), fr).duplicate()
	imp3.setDimensions(imp.getNChannels(), imp.getNSlices(), 1)
	comp = CompositeImage(imp3, CompositeImage.COMPOSITE)  
	comp.show()
	return comp

def errorDialog(message):
	"""Outputs a given error for end users"""
	gd = GenericDialogPlus("Error")
	gd.addMessage(message)
	gd.showDialog()
	return
	
def concatStacks(masterStack, impToAdd):
	"""takes an IMP and adds it to a stack, returning the concatenated stack"""
	impToAddStack=impToAdd.getImageStack()
	for i in xrange(1, impToAdd.getNSlices()+1):
		try:	
			masterStack.addSlice(impToAddStack.getProcessor(i))	
		except: print "FAILED To addto stack for: "+ impToAdd.getTitle() +" " + str(i)	
	return masterStack

def previewDialog(imp, options):
	"""Generates the settings dialog and preview window, which live updates dependent on chosen settings"""
	gd = GenericDialogPlus("FRETENATOR2: 2FRET2FURIOUSLY")
	#unpack default settings
	segmentChannel, donorChannel, acceptorChannel, acceptorChannel2, thresholdMethod, maxIntensity, gaussianSigma, largeDoGSigma, DoG,  manualSegment, manualThreshold, makeNearProj, dilation, sizeExclude, minSize, maxSize, watershed, backsubVal, pixelByPixel, saveSettings =options
	#create a list of the channels in the provided imagePlus
	types = []
	for i in xrange(1, imp.getNChannels()+1):
		types.append(str(i))
	gd.setInsets(0,0,0)
	gd.addMessage("""Channel choices:""")
	#user can pick which channel to base the segmentation on
	if (segmentChannel<= len(types) and
		donorChannel<= len(types) and
		acceptorChannel<= len(types) and
		acceptorChannel2<= len(types)):
			gd.setInsets(0,0,0)
			gd.addChoice("Segmentation", types, str(segmentChannel))
			gd.setInsets(0,0,0)
			gd.addChoice("Donor (denominator)", types, str(donorChannel))
			gd.setInsets(0,0,0)
			gd.addChoice("Acceptor (FRET) (numerator)", types, str(acceptorChannel))
			gd.setInsets(0,0,0)
			gd.addChoice("Acceptor", types, str(acceptorChannel2))
	else:
		if len(types)>2:
			gd.setInsets(0,0,0)
			gd.addChoice("Segmentation", types, types[2])
			gd.setInsets(0,0,0)
			gd.addChoice("Donor channel (denominator)", types, types[0])
			gd.setInsets(0,0,0)
			gd.addChoice("Acceptor (FRET) (numerator)", types, types[1])
			gd.setInsets(0,0,0)
			gd.addChoice("Acceptor", types, types[2])
			#print('YAY')
		else:
			gd.setInsets(0,0,0)
			gd.addChoice("Segmentation", types, types[-1])
			gd.setInsets(0,0,0)
			gd.addChoice("Donor (denominator)", types, types[0])
			gd.setInsets(0,0,0)
			gd.addChoice("Acceptor (FRET) (numerator)", types, types[-1])
			gd.setInsets(0,0,0)
			gd.addChoice("Acceptor", types, types[-1])
	gd.setInsets(0,0,0)
	gd.addMessage("""Segmentation settings:""")
	methods=["Otsu","Default", "Huang", "Intermodes", "IsoData", "IJ_IsoData", "Li", "MaxEntropy", "Mean", "MinError", "Minimum", "Moments", "Percentile", "RenyiEntropy", "Shanbhag", "Triangle", "Yen"]
	gd.setInsets(0,15,0)
	gd.addCheckbox("Difference of Gaussian instead of Gaussian?", DoG)
	gd.addSlider("Gaussian / DoG 1 sigma", 0.5, 10, gaussianSigma, 0.1)
	gd.addSlider("DoG 2 sigma", 0.5, 20, largeDoGSigma,0.1)
	gd.setModal(False)
	gd.addChoice("Autosegmentation method", methods, thresholdMethod)
	gd.addCheckbox("Manually set threshold? ", manualSegment)
	gd.addSlider("Manual threshold", 10, 65534, manualThreshold, 1)
	
	dilationOptions=["0", "1", "2","3", "4", "5", "6"]
	
	gd.addChoice("Dilation?", dilationOptions, str(dilation))
	gd.addCheckbox("Size exclusion of ROI? ", sizeExclude)
	gd.addSlider("Minimum ROI size", 0, 9999, minSize, 1)
	gd.addSlider("Maximum ROI size", 1, 10000, maxSize, 1)
	gd.addCheckbox("Watershed object splitting? ", watershed)
	
	gd.setInsets(0,0,0)
	gd.addMessage("""Analysis settings:""")
	backsubOpts=["Off", "Local label based", "Global mean"]
	intensities=["254", "4094", "65534"]
	gd.addChoice("Saturation removal threshold", intensities, str(maxIntensity))
	gd.addChoice("Background subtraction", backsubOpts, backsubOpts[backsubVal])
	gd.addCheckbox("""Use pixel by pixel analysis?""", pixelByPixel)
	gd.addCheckbox("Create nearest point projection? ", makeNearProj)
	gd.addCheckbox("Save settings as default? ", False)
	buttonListener = buttonClick()
	gd.addButton("Quickload nuclei settings", buttonListener)
	gd.addButton("Quickload diffuse settings", buttonListener)
	
	gd.addButton("Load settings file", buttonListener)
	gd.addButton("Save settings file", buttonListener)
	gd.setInsets(0,0,0)
	gd.addMessage("""For tutorials, click 'help'.
	
	https://doi.org/10.1007/978-1-0716-2297-1_17
	https://doi.org/10.1038/s41477-023-01447-4 """)
	gd.addHelp("https://github.com/JimageJ/FRETENATOR2/blob/main/README.md")
	gd.setLocation(0,0)
	#gd.setLayout(GridLayout(0,8))

	gd.showDialog()

		
	cal = imp.getCalibration()
	pixelAspect=(cal.pixelDepth/cal.pixelWidth)
	
	originalTitle=imp1.getTitle()
	
	choices=gd.getChoices()
	sliders=gd.getSliders()
	checkboxes=gd.getCheckboxes()		
	segmentChannel=int(choices.get(0).getSelectedItem())
	donorChannel=int(choices.get(1).getSelectedItem())
	acceptorChannel=int(choices.get(2).getSelectedItem())
	acceptorChannel2=int(choices.get(3).getSelectedItem())
	thresholdMethod=choices.get(4).getSelectedItem()
	dilation=int(choices.get(5).getSelectedItem())
	maxIntensity=int(choices.get(6).getSelectedItem())
	backSub= choices.get(7).getSelectedItem()
	
	gaussianSigma=sliders.get(0).getValue()/10.0
	largeDoGSigma = gd.sliders.get(1).getValue()/10.0
	manualThreshold=gd.sliders.get(2).getValue()
	minSize = gd.sliders.get(3).getValue()
	maxSize = gd.sliders.get(4).getValue()
	
	
	

	DoG = gd.checkboxes.get(0).getState()
	manualSegment = gd.checkboxes.get(1).getState()
	sizeExclude=gd.checkboxes.get(2).getState()
	watershed = gd.checkboxes.get(3).getState()
	#backSub = gd.checkboxes.get(4).getState()
	pixelByPixel = gd.checkboxes.get(4).getState()
	#print dir(gd.sliders.get(5))
	#print maxSize
	
	segmentChannelOld=segmentChannel
	thresholdMethodOld=thresholdMethod
	maxIntensityOld=maxIntensity
	gaussianSigmaOld=gaussianSigma
	largeDoGSigmaOld= largeDoGSigma
	DoGOld=DoG
	manualSegmentOld= manualSegment
	manualThresholdOld=manualThreshold
	dilationOld=dilation
	sizeExcludeOld=sizeExclude
	minSizeOld=minSize
	maxSizeOld=maxSize
	watershedOld=watershed
	clij2.clear()
	
	segmentImp=extractChannel(imp1, segmentChannel, 0)

	try:
		gfx1=clij2.push(segmentImp)
		gfx2=clij2.create(gfx1)
		gfx3=clij2.create(gfx1)
		gfx4=clij2.create(gfx1)
		gfx5=clij2.create(gfx1)
		gfx7=clij2.create([imp.getWidth(), imp.getHeight()])
	except:	
		try:
		
			Thread.sleep(500)
			print("Succeeded to sending to graphics card on the second time...")
			gfx1=clij2.push(segmentImp)
			gfx2=clij2.create(gfx1)
			gfx3=clij2.create(gfx1)
			gfx4=clij2.create(gfx1)
			gfx5=clij2.create(gfx1)
			gfx7=clij2.create([imp.getWidth(), imp.getHeight()])
		except:
			errorDialog("""Could not send image to graphics card, it may be too large!
		
			Easy solutions: Try	processing as 8-bit, cropping or scaling the image, or
			select a different CLIJ2 GPU.

			This issue is often intermittent, so trying again may also work! 

			See the "Big Images on x graphics cards' notes at:
			https://clij2.github.io/clij2-docs/troubleshooting for more solutions
			
			"""	+ str(clij2.reportMemory()) )


	gfx1,gfx2,gfx3,gfx4,gfx5 = segment(gfx1,gfx2,gfx3,gfx4,gfx5, gaussianSigma, thresholdMethod, maxIntensity, largeDoGSigma, pixelAspect, originalTitle, DoG,  manualSegment, manualThreshold, dilation, sizeExclude, minSize, maxSize, watershed)
	clij2.maximumZProjection(gfx5, gfx7)

	labelPrevImp= clij2.pull(gfx7)
	IJ.setMinAndMax(labelPrevImp, 0,clij2.getMaximumOfAllPixels(gfx7))
	labelPrevImp.setTitle("Preview segmentation")
	labelPrevImp.show()
	
	IJ.run("glasbey_inverted")
	buttonListener.bypass = 0
	while ((not gd.wasCanceled()) and not (gd.wasOKed())):
		

	#IF SETTINGS ARE LOADED, REMAKE THE DIALOG TO APPLY THEM
		if buttonListener.bypass ==1:

			if buttonListener.quickload ==1:
				thresholdMethod= buttonListener.thresholdMethod
				gaussianSigma = buttonListener.gaussianSigma
				largeDoGSigma = buttonListener.largeDoGSigma
				DoG = buttonListener.DoG
				manualSegment = buttonListener.manualSegment
				dilation = buttonListener.dilation
				sizeExclude= buttonListener.sizeExclude
				watershed = buttonListener.watershed
				pixelByPixel = buttonListener.pixelByPixel
				buttonListener.quickload ==0
			if buttonListener.load ==1:
				segmentChannel, donorChannel, acceptorChannel, acceptorChannel2, thresholdMethod, maxIntensity, gaussianSigma, largeDoGSigma, DoG,  manualSegment, manualThreshold, makeNearProj, dilation, sizeExclude, minSize, maxSize, watershed, backSub, pixelByPixel, saveSettings =buttonListener.options
			gd.dispose()
			if buttonListener.save==1:
				options= segmentChannel, donorChannel, acceptorChannel, acceptorChannel2, thresholdMethod, maxIntensity, gaussianSigma, largeDoGSigma, DoG,  manualSegment, manualThreshold, makeNearProj, dilation, sizeExclude, minSize, maxSize, watershed, backsubVal, pixelByPixel, saveSettings
				sf=file(buttonListener.saveFilepath, 'wb')
				json.dump(options, sf)
				sf.close()
			gd=GenericDialogPlus("FRETENATOR2: 2FRET2FURIOUSLY")
	
			gd.addMessage("Channel choices:")
			#user can pick which channel to base the segmentation on

			gd.setInsets(0,0,0)
			gd.addChoice("Segmentation", types, str(segmentChannel))
			gd.setInsets(0,0,0)
			gd.addChoice("Donor (denominator)", types, str(donorChannel))
			gd.setInsets(0,0,0)
			gd.addChoice("Acceptor (FRET) (numerator)", types, str(acceptorChannel))
			gd.setInsets(0,0,0)
			gd.addChoice("Acceptor", types, str(acceptorChannel2))

			gd.setInsets(0,0,0)
			gd.addMessage("""Segmentation settings:""")
			methods=["Otsu","Default", "Huang", "Intermodes", "IsoData", "IJ_IsoData", "Li", "MaxEntropy", "Mean", "MinError", "Minimum", "Moments", "Percentile", "RenyiEntropy", "Shanbhag", "Triangle", "Yen"]
			gd.setInsets(0,15,0)
			gd.addCheckbox("Difference of Gaussian instead of Gaussian?", DoG)
			gd.addSlider("Gaussian / DoG 1 sigma", 0.5, 10, gaussianSigma, 0.1)
			gd.addSlider("DoG 2 sigma", 0.5, 20, largeDoGSigma,0.1)
			gd.setModal(False)
			gd.addChoice("Autosegmentation method", methods, thresholdMethod)
			gd.addCheckbox("Manually set threshold? ", manualSegment)
			gd.addSlider("Manual threshold", 10, 65534, manualThreshold, 1)
			
			dilationOptions=["0", "1", "2","3", "4", "5", "6"]
			
			gd.addChoice("Dilation?", dilationOptions, str(dilation))
			gd.addCheckbox("Size exclusion of ROI? ", sizeExclude)
			gd.addSlider("Minimum ROI size", 0, 9999, minSize, 1)
			gd.addSlider("Maximum ROI size", 1, 10000, maxSize, 1)
			gd.addCheckbox("Watershed object splitting? ", watershed)
			
			gd.setInsets(0,0,0)
			gd.addMessage("""Analysis settings:""")
			backsubOpts=["Off", "Local label based", "Global mean"]
			intensities=["254", "4094", "65534"]
			gd.addChoice("Saturation removal threshold", intensities, str(maxIntensity))
			gd.addChoice("Background subtraction", backsubOpts, backsubOpts[backsubVal])
			gd.addCheckbox("""Use pixel by pixel analysis?""", pixelByPixel)
			gd.addCheckbox("Create nearest point projection? ", makeNearProj)
			gd.addCheckbox("Save settings as default? ", False)
			buttonListener = buttonClick()
			gd.addButton("Quickload nuclei settings", buttonListener)
			gd.addButton("Quickload diffuse settings", buttonListener)
			
			gd.addButton("Load settings file", buttonListener)
			gd.addButton("Save settings file", buttonListener)
			gd.setInsets(0,0,0)
			gd.addMessage("""For tutorials, click 'help'.
			
			https://doi.org/10.1007/978-1-0716-2297-1_17
			https://doi.org/10.1038/s41477-023-01447-4 """)
			gd.addHelp("https://github.com/JimageJ/FRETENATOR2/blob/main/README.md")
			gd.setLocation(0,0)

			gd.showDialog()
			buttonListener.bypass = 0

			choices=gd.getChoices()
			sliders=gd.getSliders()
			checkboxes=gd.getCheckboxes()
		segmentChannel=int(choices.get(0).getSelectedItem())
		donorChannel=int(choices.get(1).getSelectedItem())
		acceptorChannel=int(choices.get(2).getSelectedItem())
		acceptorChannel2=int(choices.get(3).getSelectedItem())
		thresholdMethod=choices.get(4).getSelectedItem()
		dilation=int(choices.get(5).getSelectedItem())
		maxIntensity=int(choices.get(6).getSelectedItem())
		
		gaussianSigma=sliders.get(0).getValue()/10.0
		largeDoGSigma = gd.sliders.get(1).getValue()/10.0
		manualThreshold=gd.sliders.get(2).getValue()
		minSize = gd.sliders.get(3).getValue()
		maxSize = gd.sliders.get(4).getValue()
		
		
		
		
		DoG = gd.checkboxes.get(0).getState()
		manualSegment = gd.checkboxes.get(1).getState()
		sizeExclude=gd.checkboxes.get(2).getState()
		watershed = gd.checkboxes.get(3).getState()
		#backSub = gd.checkboxes.get(4).getState()
		pixelByPixel = gd.checkboxes.get(4).getState()
		
		if (segmentChannelOld !=segmentChannel or
			thresholdMethodOld !=thresholdMethod or
			maxIntensityOld !=maxIntensity or
			gaussianSigmaOld !=gaussianSigma or
			largeDoGSigmaOld != largeDoGSigma or
			DoGOld !=DoG or
			manualSegmentOld != manualSegment or
			manualThresholdOld !=manualThreshold or
			dilation != dilationOld or
			sizeExcludeOld!=sizeExclude or
			minSizeOld!=minSize or
			maxSizeOld!=maxSize or
			watershedOld!=watershed
			):
				if minSizeOld!=minSize:
					if minSize>=maxSize:
						maxSize=minSize+1
						gd.sliders.get(5).setValue(maxSize)
				if maxSizeOld!=maxSize:
					if minSize>=maxSize:
						minSize=maxSize-1
						gd.sliders.get(4).setValue(minSize)
				if segmentChannelOld!=segmentChannel:
						clij2.clear()
						print('eh')
						segmentImp=extractChannel(imp1, segmentChannel, 0)
						gfx1=clij2.push(segmentImp)
						gfx2=clij2.create(gfx1)
						gfx3=clij2.create(gfx1)
						gfx4=clij2.create(gfx1)
						gfx5=clij2.create(gfx1)
						gfx7=clij2.create([imp.getWidth(), imp.getHeight()])
				gfx1,gfx2,gfx3,gfx4,gfx5 = segment(gfx1,gfx2,gfx3,gfx4,gfx5, gaussianSigma, thresholdMethod, maxIntensity, largeDoGSigma, pixelAspect, originalTitle, DoG, manualSegment, manualThreshold, dilation,sizeExclude, minSize, maxSize, watershed)
				clij2.maximumZProjection(gfx5, gfx7)
				labelPrevImp.close()
				labelPrevImp= clij2.pull(gfx7)
				IJ.setMinAndMax(labelPrevImp, 0,clij2.getMaximumOfAllPixels(gfx7))
				labelPrevImp.setTitle("Preview segmentation")
				labelPrevImp.show()
				
				IJ.run("glasbey_inverted")
		
		segmentChannelOld=segmentChannel
		thresholdMethodOld=thresholdMethod
		maxIntensityOld=maxIntensity
		gaussianSigmaOld=gaussianSigma
		largeDoGSigmaOld = largeDoGSigma
		DoGOld=DoG
		manualSegmentOld= manualSegment
		manualThresholdOld=manualThreshold
		dilationOld=dilation
		sizeExcludeOld=sizeExclude
		minSizeOld=minSize
		maxSizeOld=maxSize
		watershedOld=watershed
		buttonListener.bypass=0 
		buttonListener.quickload = 0
		buttonListener.save = 0 
		buttonListener.load = 0
		Thread.sleep(200)
	if gd.wasCanceled():
		clij2.clear()
		IJ.exit()
	
	labelPrevImp.close()
	makeNearProj = gd.checkboxes.get(5).getState()
	backSub= choices.get(7).getSelectedItem()
	saveSettings = gd.checkboxes.get(6).getState()
	backsubVal=0
	if backSub=="Local label based":
		backsubVal=1
	if backSub=="Global mean":
		backsubVal=2
	if gd.wasCanceled():
		clij2.clear()
		IJ.exit()
	return segmentChannel, donorChannel, acceptorChannel, acceptorChannel2, thresholdMethod, maxIntensity, gaussianSigma, largeDoGSigma, DoG,  manualSegment, manualThreshold, makeNearProj, dilation, sizeExclude, minSize, maxSize, watershed, backsubVal, pixelByPixel, saveSettings
	
def segment(gfx1,gfx2,gfx3,gfx4,gfx5, gaussianSigma, thresholdMethod, maxIntensity, largeDoGSigma, pixelAspect, originalTitle, DoG,  manualSegment, manualThreshold, dilation, sizeExclude, minSize, maxSize, watershed):
	"""Segmentation based on user settings"""
	# DoG filter for background normalised segmentation. NB. Kernel is Z-normalised to pixel aspect ratio
	if DoG == True :	
		clij2.differenceOfGaussian3D(gfx1, gfx2, gaussianSigma, gaussianSigma, 1+(gaussianSigma-1)/pixelAspect, largeDoGSigma, largeDoGSigma,largeDoGSigma/pixelAspect)
	else:
		clij2.gaussianBlur3D(gfx1, gfx2, gaussianSigma,gaussianSigma, 1+(gaussianSigma-1)/pixelAspect)

	if manualSegment == True :
		clij2.threshold(gfx2, gfx3, manualThreshold)
	else:
		#auto threshold and watershed to seed the object splitting
		clij2.automaticThreshold(gfx2, gfx3, thresholdMethod)

	if watershed:
		clij2.watershed(gfx3,gfx2)
	else:
		clij2.copy(gfx3,gfx2)
	
	# add watershed to original threshold, and then use this to generate a binary image of any ROI lost in watershed process
	clij2.addImages(gfx2, gfx3, gfx5)
	clij2.floodFillDiamond(gfx5, gfx4, 1, 2)
	clij2.replaceIntensity(gfx4, gfx5, 2, 0)
	
	#label watershed image
	clij2.connectedComponentsLabelingDiamond(gfx2, gfx4)

	#dilate all the labeled watershed ROI out (only onto zero labeled pixels), then multiply this by original binary map, to get labeled ROI
	clij2.dilateLabels(gfx4, gfx2, 6)

	clij2.multiplyImages(gfx2,gfx3, gfx4)
	
	#label the missed ROI then add on the largest value from the other labelled image (so they can be combined)
	watershedLabelMax =clij2.getMaximumOfAllPixels(gfx4)
	clij2.connectedComponentsLabelingDiamond(gfx5, gfx2)
	clij2.addImageAndScalar(gfx2, gfx5, (1 + watershedLabelMax))
	
	#delete the background and combine the two images
	clij2.replaceIntensity(gfx5, gfx2,(1 + watershedLabelMax), 0)
	clij2.maximumImages(gfx4, gfx2, gfx5)
	
	#remove labeled objects that are too big or too small
	if sizeExclude:
		clij2.excludeLabelsOutsideSizeRange(gfx5, gfx4, minSize, maxSize)
		clij2.copy(gfx4, gfx5)
		
	#dilate images
	if dilation:
		clij2.dilateLabels(gfx5, gfx4, dilation)
		clij2.copy(gfx4, gfx5)
	#gfx3 = threshold channel, gfx5 = label image, gfx1=original image, gfx2 & gfx4  are junk
	clij2.closeIndexGapsInLabelMap(gfx5,gfx4)
	return gfx1,gfx2,gfx3,gfx5, gfx4

def fretCalculations(imp1, nFrame, donorChannel, acceptorChannel, acceptorChannel2, table, gfx1, gfx2, gfx3, gfx4, gfx5, originalTitle, backSub, pixelByPixel):
	"""Perform FRET calculations (Sorry! Complex function!)"""
	
	#Extract appropriate channels
	donorImp=extractChannel(imp1, donorChannel, nFrame)
	acceptorImp=extractChannel(imp1, acceptorChannel, nFrame)
	acceptorImp2=extractChannel(imp1, acceptorChannel2, nFrame)
	
	#push donor and acceptor channels to gpu threshold
	gfx4=clij2.push(donorImp)
	gfx5=clij2.push(acceptorImp)
	gfx6=clij2.create(gfx5)

	#thresholds to create a mask to remove saturated pixels from donor and acceptor images	
	clij2.threshold(gfx4,gfx2, maxIntensity)
	clij2.binarySubtract(gfx3, gfx2, gfx6)
	clij2.threshold(gfx5,gfx2, maxIntensity)
	clij2.binarySubtract(gfx6, gfx2, gfx3)
	clij2.threshold(gfx3,gfx6, 0.5)
	
	#Apply appropriate background subtraction if required (must be after the saturated pixel removal). NB. Local label uses more GPU memory, which may be a limiting factor on some computers
	if backSub==1:
		gfx7=clij2.create(gfx1)
		gfx3=createSubtractionLabels(gfx1, gfx2, gfx3)
		gfx4=localLabelBackSub(gfx3, gfx1, gfx4, gfx2, gfx7)
		gfx5=localLabelBackSub(gfx3, gfx1, gfx5, gfx2, gfx7)
	if backSub==2:
		gfx4=globalBackSub(gfx1, gfx4, gfx2)
		gfx5=globalBackSub(gfx1, gfx5, gfx2)
	
	
	#Mask to remove the saturated pixels
	#donor is gfx2, acceptor FRET is gfx4, segment channel (acceptor normal) is gfx6, gfx3 is triple threshold
	clij2.mask(gfx4,gfx6, gfx2)
	clij2.mask(gfx5,gfx6, gfx4)
			
	#NB have to push the acceptor image now...
	gfx6=clij2.push(acceptorImp2)
	if backSub==1:
		gfx6=localLabelBackSub(gfx3, gfx1, gfx6, gfx5, gfx7)
		gfx7.close()
	if backSub==2:
		gfx6=globalBackSub(gfx1, gfx6, gfx5)

	#extract the intensity of each nucleus for each channel
	results=ResultsTable()
	clij2.statisticsOfBackgroundAndLabelledPixels(gfx2, gfx1, results)
	donorChIntensity=results.getColumn(13)
	results2=ResultsTable()
	clij2.statisticsOfBackgroundAndLabelledPixels(gfx4, gfx1, results2)
	acceptorChIntensity=results2.getColumn(13)
	results3=ResultsTable()
	clij2.statisticsOfBackgroundAndLabelledPixels(gfx6, gfx1, results3)
	
	#calculate the fret ratios, removing any ROI with intensity of zero
	FRET =[]
	
	for i in xrange(len(acceptorChIntensity)):
		if (acceptorChIntensity[i]>0) and (donorChIntensity[i]>0):
			#don't write in the zeros to the results
			FRET.append((float(acceptorChIntensity[i])/float(donorChIntensity[i])))
			table.incrementCounter()
			#frame, label and ER
			table.addValue("Frame (Time)", nFrame)
			table.addValue("Label", i)
			table.addValue("Emission ratio", acceptorChIntensity[i]/donorChIntensity[i])

			#mean emission
			table.addValue("Mean donor emission", results.getValue("MEAN_INTENSITY",i))
			table.addValue("Mean acceptor emission (FRET)", results2.getValue("MEAN_INTENSITY",i))
			table.addValue("Mean acceptor emission", results3.getValue("MEAN_INTENSITY",i))
			
			#sum emission
			table.addValue("Sum donor emission", donorChIntensity[i])
			table.addValue("Sum acceptor emission (FRET)", acceptorChIntensity[i])
			table.addValue("Sum acceptor emission", results3.getValue("SUM_INTENSITY",i))
			#shape and location descriptors
			table.addValue("Volume", cal.pixelWidth * cal.pixelHeight * cal.pixelDepth * results.getValue("PIXEL_COUNT",i))
			table.addValue("Pixel count", results.getValue("PIXEL_COUNT",i))
			table.addValue("x", cal.pixelWidth*results.getValue("CENTROID_X",i))
			table.addValue("y", cal.pixelHeight*results.getValue("CENTROID_Y",i))
			table.addValue("z", cal.pixelDepth*results.getValue("CENTROID_Z",i))
			
			#File name for traceability 
			table.addValue("File name", originalTitle)
		else:
			#must write in the zeros as this array is used to generate the map of emission ratios
			FRET.append(0)
			
	table.show("Results of " + originalTitle)
	#export label image
	labelImp = clij2.pull(gfx1)
	if pixelByPixel==0:
		#write all the emission ratios to an array, push to an GFX image, use this to map emission ratios
		FRET[0]= float('nan')
		#print FRET
		FRETarray= array( "f", FRET)
		fp= FloatProcessor(len(FRET), 1, FRETarray, None)
		FRETImp= ImagePlus("FRETImp", fp)
		gfx4=clij2.push(FRETImp)
		gfx5.close()
		gfx5=clij2.create(gfx1.getDimensions(), clij2.Float)
		clij2.replaceIntensities(gfx1, gfx4, gfx5)

		FRETimp2=clij2.pull(gfx5)
		
		
		project = ZProjector()
		project.setMethod(ZProjector.AVG_METHOD)
		project.setImage(FRETimp2) #imageplus
		project.doProjection()
		FRETProjImp = project.getProjection()
		
	else:
		
		#donor is gfx2, acceptor FRET is gfx4, segment channel (acceptor normal) is gfx6, threshold image:gfx3
		#blur the acceptor and donor channels, then remask to prevent NaN/infinite values after division
		clij2.gaussianBlur3D(gfx4, gfx5, 1.1, 1.1, 1.1)
		clij2.mask(gfx5, gfx1, gfx4)
		clij2.gaussianBlur3D(gfx2, gfx5, 1.1, 1.1, 1.1)
		clij2.mask(gfx5, gfx1, gfx2)
		
#		#create Z sum projected donor and acceptor images for Z-proj ratio calc -> may replace with a different technique later
#		donorSum=clij2.create(gfx4.getWidth(), gfx4.getHeight(), 1)
#		acceptorFSum=clij2.create(gfx4.getWidth(), gfx4.getHeight(), 1)
#		clij2.sumZProjection(gfx2, donorSum)
#		clij2.sumZProjection(gfx4, acceptorFSum)
#		
#		#Divide Z proj Acceptor by Z proj Donor and pull image
#		maxProj=clij2.create(gfx4.getWidth(), gfx4.getHeight(), 1)
#		clij2.divideImages(acceptorFSum, donorSum, maxProj)
#		clij2.multiplyImageAndScalar(maxProj, donorSum,1000)
#		
#		
#		FRETProjImp=clij2.pull(donorSum)
		
		#pull acceptor and donor stacks to convert to 32 bit
		acceptorImp=clij2.pull(gfx4)
		donorImp=clij2.pull(gfx2)
		ImageConverter(acceptorImp).convertToGray32()
		ImageConverter(donorImp).convertToGray32()
		
		#clean up GPU memory
		clij2.clear()
		
		#push 32bit images and perform ratio calc
		gfx4=clij2.push(acceptorImp)
		gfx2=clij2.push(donorImp)
		gfx1=clij2.create(gfx2.getDimensions(), clij2.Float)
		clij2.divideImages(gfx4, gfx2, gfx1)
		clij2.multiplyImageAndScalar(gfx1, gfx2,1000)
		#pull ratio stack
		FRETimp2=clij2.pull(gfx1)
		project = ZProjector()
		project.setMethod(ZProjector.AVG_METHOD)
		project.setImage(FRETimp2) #imageplus
		project.doProjection()
		FRETProjImp = project.getProjection()

	#clean up
	clij2.clear()
	donorImp.close()
	acceptorImp.close()
	acceptorImp2.close()
	
	return table, FRETimp2, FRETProjImp, labelImp
	
	
def nearestZProject(imp1):
	relicedImp=Slicer().reslice(imp1)
	relicedStack=relicedImp.getStack()
	width=imp1.getWidth()
	height=imp1.getHeight()
	depth=imp1.getNSlices()
	
	topPixels=array('f', [float('nan')]*width * height)  
	
	stack2=ImageStack( width, height)
	for i in range(1,relicedImp.getNSlices()):
		pixels= relicedStack.getPixels(i)

		for x in xrange(width):
			for pixel in xrange(x, x+width*(depth-1),width):
				#after finding the first pixel above the threshold value, add the value to the list

				if math.isnan(pixels[pixel]) != True:
					topPixels[i*width+x]=pixels[pixel]
					#break from looping the y when 1st threshold pixel is found is met -> increases speed drastically! Otherwise need an if statement every loop...
					break
	
	ip2=FloatProcessor(width, height, topPixels, None)
	imp2=ImagePlus("Nearest point proj",ip2)
	#imp2.show()
	imp3= imp2.resize(imp2.getWidth()*2, imp2.getHeight()*2, 'none')
	return imp3



def outline(imp3, originalTitle):
	
	
	#clij implementation -thicker lines
	"""
	src=clij2.push(imp3)
	dst=clij2.create(src)
	dst2=clij2.create(src)
	clij2.detectLabelEdges(src,dst)
	clij2.binaryNot(dst,dst2)
	clij2.multiplyImages(src, dst2, dst)
	imp4=clij2.pull(dst)
	imp4.show()
	clij2.clear()
	"""
	
	imp2=imp3.duplicate()
	stack1=imp3.getStack()
	width=imp3.getWidth()
	height=imp3.getHeight()
	stack2=ImageStack(width, height)
	pixlist=[]
	
	for i in range(imp3.getNSlices()):
		pixlist=[]
		pixels1=stack1.getPixels(i+1)
		#if pixel is different to the pixel to the left or above, set it to 0
		pixels2=map(lambda j: pixels1[j] if pixels1[j]-pixels1[j-1]==0 and pixels1[j]-pixels1[j-width]==0 else float('nan'), xrange(len(pixels1)))
		processor=FloatProcessor(width, height, pixels2, None)
		stack2.addSlice(processor)
	imp2=ImagePlus("Nearest point emission ratios of "+ originalTitle, stack2)
	imp2.show()
	return imp2


# *****************************body of code starts****************************************

if __name__ == "__main__":
	
	#give install instructions for CLIJ if not installed
	
	try: 
		from net.haesleinhuepf.clij2 import CLIJ2
	
	except:
		errorDialog("""This plugin requires clij2 to function. 
		
		To install please follow these instructions: 
		
		1. Click Help>Update> Manage update sites
		2. Make sure the "clij2" update site is selected.
		3. Click Close> Apply changes.
		4. Close and reopen ImageJ""")
	
	
	clij2 = CLIJ2.getInstance()
	clij2.clear()
	
	
	#get the current image
	imp1= IJ.getImage()
	
	#define inputs (to be put in a dialog if I automate) 
	if exists('FRETENATOR2SegSettings.json'):
		print('exists')
		sf=file('FRETENATOR2SegSettings.json', 'rb')
		options=json.load(sf)
		sf.close()
		print(options)
	else:
		options=(3, 1, 2, 3, 'Otsu', 65534, 0.8, 4.0, True, False, 3000, True, 0, False, 10, 10000, True, 0, False, True)
		
	options= previewDialog(imp1, options)
	
	#get the pixel aspect for use in zscaling kernels for filters
	cal = imp1.getCalibration()
	pixelAspect=(cal.pixelDepth/cal.pixelWidth)
	originalTitle=imp1.getTitle()
	
	
	
	IJ.log(originalTitle +" settings:")
	IJ.log("segmentChannel, donorChannel, acceptorChannel, acceptorChannel2, thresholdMethod, maxIntensity, gaussianSigma, largeDoGSigma, DoG, manualSegment, manualThreshold, makeNearProj, dilation, sizeExclude, minSize, maxSize, watershed, backSub:")
	IJ.log(str(options))
	
	segmentChannel, donorChannel, acceptorChannel, acceptorChannel2, thresholdMethod, maxIntensity, gaussianSigma, largeDoGSigma, DoG,  manualSegment, manualThreshold, makeNearProj, dilation, sizeExclude, minSize, maxSize, watershed, backSub, pixelByPixel, saveSettings =options
	if saveSettings==1:
		sf=file('FRETENATOR2SegSettings.json', 'wb')
		json.dump(options, sf)
		sf.close()
	if pixelByPixel==1:
		makeNearProj =0
	totalFrames=imp1.getNFrames() +1
	
	#table is the final results table
	table = ResultsTable()
	
	clij2 = CLIJ2.getInstance()
	clij2.clear()
	
	
	
	conThresholdStack=ImageStack(imp1.width, imp1.height)
	conFRETImp2Stack=ImageStack(imp1.width, imp1.height)
	conFRETProjImpStack=ImageStack(imp1.width, imp1.height)
	conlabelImpStack=ImageStack(imp1.width, imp1.height)
	conNearZStack=ImageStack(imp1.width*2, imp1.height*2)
	for nFrame in xrange(1, totalFrames):
		clij2.clear()
		segmentImp=extractChannel(imp1, segmentChannel, nFrame)
		gfx1=clij2.push(segmentImp)
		gfx2=clij2.create(gfx1)
		gfx3=clij2.create(gfx1)
		gfx4=clij2.create(gfx1)
		gfx5=clij2.create(gfx1)
		gfx1,gfx2,gfx3,gfx4,gfx5 = segment(gfx1,gfx2,gfx3,gfx4,gfx5, gaussianSigma, thresholdMethod,maxIntensity, largeDoGSigma, pixelAspect, originalTitle, DoG,manualSegment, manualThreshold, dilation,sizeExclude, minSize, maxSize, watershed)
		
		thresholdImp = clij2.pull(gfx3)
		IJ.setMinAndMax(thresholdImp, 0,1)
		thresholdImp.setCalibration(cal)
		thresholdImp.setTitle("Binary mask of "+originalTitle)
	
		table, FRETimp2, FRETProjImp, labelImp=fretCalculations(imp1, nFrame, donorChannel, acceptorChannel, acceptorChannel2, table, gfx5, gfx2, gfx3, gfx4, gfx1, originalTitle,backSub, pixelByPixel)
	
		if makeNearProj == True:
			nearZImp = nearestZProject(FRETimp2)
			conNearZStack=concatStacks(conNearZStack,nearZImp)
			nearZImp.close()
			
		#add the images to concatenated stacks
		conThresholdStack = concatStacks(conThresholdStack, thresholdImp)
		conFRETImp2Stack=concatStacks(conFRETImp2Stack, FRETimp2)
		conFRETProjImpStack=concatStacks(conFRETProjImpStack, FRETProjImp)
		conlabelImpStack=concatStacks(conlabelImpStack, labelImp)
		
		thresholdImp.close()
		FRETimp2.close()
		FRETProjImp.close()
		labelImp.close()
	
	#Show the images and make the images pretty... I should have put in a function`
	
	conThresholdImp= ImagePlus( "Threshold image for "+ originalTitle, conThresholdStack)
	conThresholdImp.setDimensions(1,  imp1.getNSlices(), imp1.getNFrames())
	IJ.setMinAndMax(conThresholdImp, 0,1)
	conThresholdImp.setCalibration(cal)
	conThresholdImp = CompositeImage(conThresholdImp, CompositeImage.COMPOSITE)
	#conThresholdImp.show()
	
	conlabelImp= ImagePlus("Label map "+ originalTitle, conlabelImpStack)
	conlabelImp.setDimensions(1, imp1.getNSlices(), imp1.getNFrames())
	conlabelImp.setCalibration(cal)
	stats=StackStatistics(conlabelImp)
	conlabelImp = CompositeImage(conlabelImp, CompositeImage.COMPOSITE)  
	IJ.setMinAndMax(conlabelImp, 0,stats.max)
	conlabelImp.show()
	IJ.run("glasbey_inverted")
	
	conFRETImp2 = ImagePlus( "Emission ratios of "+ originalTitle, conFRETImp2Stack)
	conFRETImp2.setDimensions(1, imp1.getNSlices(), imp1.getNFrames())
	conFRETImp2.setCalibration(cal)
	stats=StackStatistics(conFRETImp2)
	conFRETImp2 = CompositeImage(conFRETImp2, CompositeImage.COMPOSITE)  
	IJ.setMinAndMax(conFRETImp2, 0, 5)
	conFRETImp2.show()
	IJ.run("mpl-inferno")
	
	
	conFRETProjImp= ImagePlus( "Mean Z  projection of emission ratios X1000 of "+ originalTitle, conFRETProjImpStack)
	print imp1.getNFrames()
	conFRETProjImp.setDimensions(1, 1, imp1.getNFrames())
	conFRETProjImp.setCalibration(cal)
	stats=StackStatistics(conFRETProjImp)
	IJ.setMinAndMax(conFRETProjImp, 0, 5)
	conFRETProjImp = CompositeImage(conFRETProjImp, CompositeImage.COMPOSITE)  
	conFRETProjImp.show()
	IJ.run("mpl-inferno")
	

	
	if makeNearProj == True:
		conNearZImp=ImagePlus("Nearest Z proj of  ratios of"+ originalTitle, conNearZStack)
		nearZImpOutlines = outline(conNearZImp,originalTitle)
		IJ.setMinAndMax(nearZImpOutlines, 1, 5)
		nearZImpOutlines.show()
		IJ.run("mpl-inferno")
	
	clij2.clear()
