import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import os
import re
from lmfit import Model, Parameters
from lmfit.models import ConstantModel

data_folder="g2X_merged_data"
data_file_list = os.listdir(data_folder) #np.flip(np.asarray(os.listdir(data_folder)))
plot_folder="cross_g2_plots"
fit_folder="cross_g2_fit"
report_folder="cross_g2_report"

def constant(c):
	return c

def two_exponential_g2_model(tau, c, A_0, tau_c,tau_0,tau_delay,alpha):
	"""
	tau: time_bin_vector
	c: floor_offset
	A: decay_amplitude,
	tau_0: timing offset
	tau_c: time_constant
	"""
	# tau_S = 6e-8
	# tau_0 = 241.5*2e-9
	# tau_c = tau_fwhm/(2*np.log(2))
	return c*(1 + 0.5*A_0*alpha*(1-alpha)*((np.exp(-2*np.abs(tau-tau_0+tau_delay)/(tau_c))) + (np.exp(-2*np.abs(tau-tau_0-tau_delay)/(tau_c))) ) + 0.5*A_0*(2*alpha*(1-alpha) - 1)*((np.exp(-2*np.abs(tau-tau_0)/(tau_c)))) )

def one_exponential_g2_model_symmetric_non_piecewise(tau, c, A_0, tau_c_0,tau_0):
	"""
	tau: time_bin_vector
	c: floor_offset
	A: decay_amplitude,
	tau_0: timing offset
	tau_c: time_constant
	"""
	# tau_S = 6e-8
	# tau_0 = 241.5*2e-9
	return c +  A_0*(np.exp(-2*np.abs(tau-tau_0)/tau_c_0)) 

alpha_list = []
data_file_dic = {}
for filename in data_file_list:
	# Name of filename
	alpha = int(filename[-7:-4])/1000
	alpha_list.append(alpha)
	data_file_dic[alpha] = filename

# print(data_file_dic[current])
alpha_list.sort()

fitted_alpha_value_list = []
fitted_alpha_err_list = []

for alpha in alpha_list:
	filename = data_file_dic[alpha]
	root_filename = filename[:-4]
	print("Processing: {}".format(filename))

	alpha_guess = int(filename[-7:-4])/1000
	# if alpha_guess != 0.5:
	# 	continue

	# extract the normalisation factor
	g2_file = open(os.path.join(data_folder,filename))
	line0 = re.split(' |,',g2_file.readline())
	line1 = re.split(' |,',g2_file.readline())[1:]

	counts1 = int(line1[2])
	counts2 = int(line1[4])

	line2 = re.split(' |,',g2_file.readline())[1:]
	total_time =  int(line2[2])
	time_interval = int(line2[5])
	bin_size = time_interval*0.125

	# raise KeyboardInterrupt
	scaling_factor = total_time/(counts1*counts2*time_interval)

	g2_file.close() 

	skip_time= 0#30#ns

	# skip the parts with deadtime
	data_arr = np.genfromtxt(os.path.join(data_folder,filename),skip_header=6).T
	time_vector = data_arr[0]*bin_size #ns just to change to time in units of ns for fitting
	print(time_vector)

	# MODELS USED IN FIT
	# g2_model = Model(three_peak_single_decay_g2_model_symmetric,nan_policy="propagate")
	one_peak_model = Model(one_exponential_g2_model_symmetric_non_piecewise,nan_policy="propagate")
	# constant_model = Model(constant,nan_policy="propagate")



	# Find sharpest point
	time_highest=time_vector[np.argmin(data_arr[1])]
	g2_highest = data_arr[1,np.argmin(data_arr[1])]

	c_guess = 1/scaling_factor #data_arr[1,2]
	A_0_guess = c_guess - data_arr[1,np.argmin(data_arr[1])]
	tau_0_guess =  time_vector[np.argmin(data_arr[1])]

	current_coincidence_peak_amp = c_guess - data_arr[1,np.argmin(data_arr[1])] 
	tau_c_position = np.argmin(data_arr[1])
	while current_coincidence_peak_amp > 0.5*A_0_guess:
		tau_c_position += 1
		current_coincidence_peak_amp = c_guess - data_arr[1,tau_c_position] 

	tau_c_0_guess = np.abs(time_vector[tau_c_position]-time_highest) #s

	# start_point = 0#211-3 # bin position
	# end_point = len(data_arr[0])#start_point + 60 #number of bins 
	# time_vector = data_arr[0][start_point:end_point]
	# data_vector = data_arr[1][start_point:end_point]
	# # print(time_vector)

	# to help with fitting put all in ns
	fit_params = Parameters()
	fit_params.add('tau_c_0', value=tau_c_0_guess)#, max=5e8, min=0)
	fit_params.add('tau_0', value=tau_0_guess)
	# fit_params.add('tau_delay', value=827)
	# fit_params.add('tau_0', value=241*2e-9, max=243*2e-9, min=239*2e-9)
	fit_params.add('c',value=c_guess)
	fit_params.add('A_0',value = -1*A_0_guess)
	# fit_params.add('alpha',value = alpha_guess)

	data_vector = data_arr[1]
	weights_vector = 1/np.sqrt(data_arr[1])

	# Try to see if there are side peaks to fit to else 
	# result = g2_model.fit(data_arr[1],tau=data_arr[0]*bin_size, c= c_guess, A_0 = A_0_guess,reci_tau_c_0 = reci_tau_c_0_guess,A_S = A_S_guess, reci_tau_c_S = reci_tau_c_S_guess,weights=1/np.sqrt(data_arr[1]))
	peak_result = one_peak_model.fit(data_vector,fit_params,tau=time_vector, weights=weights_vector,nan_policy="propagate")
	# constant_result = ConstantModel().fit(data_vector,x=time_vector*bin_size, c= c_guess,nan_policy="propagate",weights=1/np.sqrt(data_vector))

	print(peak_result.fit_report())
	# print(constant_result.fit_report())



# 	# best fit arrays
# 	dip_best_fit_array = dip_result.best_fit

# 	# fit results
# 	dip_dic = dip_result.params
# 	# constant_dic = constant_result.params

	plt.clf()
	plt.rcParams['font.size'] = '22'
	fig, axs = plt.subplots(figsize=(11.7,8.3),constrained_layout=True)


	axs.set_ylim([0.5,1.5])
	axs.set_ylabel(r"$g^{(2)}$")	
	axs.set_xlabel("time [ns]")
	# axs.set_ylim([0.95,2])#1.25])

	# scaling_factor = 1#scaling_factor#1/peak_result.params['c'].value
	# plot raw data

	axs.errorbar(time_vector-peak_result.params['tau_0'].value,data_vector*scaling_factor,yerr=np.sqrt(data_arr[1])*scaling_factor,fmt='r.')
	# plot fit
	axs.plot(time_vector-peak_result.params['tau_0'].value,peak_result.best_fit*scaling_factor,"k--")

	# plot coherence time
	peak_time = 0#peak_result.params['tau_0'].value
	arrow_length = peak_result.params['tau_c_0'].value/2
	arrow_g2_value = (peak_result.params['A_0'].value)/2 + peak_result.params['c'].value

	# axs.arrow(peak_time, arrow_g2_value,-arrow_length,0,length_includes_head=True,linestyle='--',color='blue')
	# axs.arrow(peak_time, arrow_g2_value,+arrow_length,0,length_includes_head=True,linestyle='--',color='blue')

	modified_time_vector = time_vector-peak_result.params['tau_0'].value
	# axs.arrow(peak_time, arrow_g2_value,-arrow_length,0,length_includes_head=True,linestyle='--',color='blue')
	# axs.arrow(peak_time, arrow_g2_value,+arrow_length,0,length_includes_head=True,linestyle='--',color='blue')

	# np.savetxt(os.path.join(fit_folder,root_filename+"_fitted.dat"),np.asarray([modified_time_vector,data_vector*scaling_factor,np.sqrt(data_vector)*scaling_factor,peak_result.best_fit*scaling_factor]),header="time[ns],raw g(2),error bar, fit value")


	# plot secondary coincidences axis
	axs2 = axs.twinx()
	axs2.set_ylim(np.asarray(axs.get_ylim())/scaling_factor)
	axs2.set_ylabel("coincidences")

	# plt.xlim([-1,1])
	# plt.show()
	plt.savefig(os.path.join(plot_folder,root_filename+".png"),format="png",dpi=600)

	fit_time_vector = np.linspace(-500,500,10000) #ns
	fit_data_vector =  one_exponential_g2_model_symmetric_non_piecewise(fit_time_vector, peak_result.params['c'].value, peak_result.params['A_0'].value, peak_result.params['tau_c_0'].value,0)
	fit_coincidence_vector = fit_data_vector*scaling_factor

	np.savetxt(os.path.join(fit_folder,root_filename+"_fitted.dat"),np.asarray([fit_time_vector,fit_data_vector,fit_coincidence_vector]).T,header="time[ns],coincidences, g(2)")
	np.savetxt(os.path.join(report_folder,root_filename+'_report.txt'),np.asarray([0]),header=peak_result.fit_report())

	# fitted_alpha_value_list.append(peak_result.params['alpha'].value)
	# fitted_alpha_err_list.append(peak_result.params['alpha'].stderr)

# np.savetxt("alpha_actual_vs_fit.dat",np.asarray([alpha_list,fitted_alpha_value_list,fitted_alpha_err_list]).T,header="#DFB 780nm laser\n# set alpha, measured alpha, error")
