import numpy as np
# from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
import seaborn as sns
# from uncertainties import ufloat
from lmfit import Model, Parameters
# from lmfit.models import ConstantModel

tau0 = 5.6
tau1 = 13.2
fr = 2*np.pi*0.266


filename1 = 'out_HV_new_4ns.dat'
filename2 = 'pump_HV_out_VH_processed.dat'
filename3 = 'pump_HV_out_LL_processed.dat'

min_idx = 888
max_idx = 1144
rawdata1 = np.genfromtxt(filename1, skip_header=2)
rawdata2 = np.genfromtxt(filename2, skip_header=2)
rawdata3 = np.genfromtxt(filename3, skip_header=2)

delay = rawdata1[min_idx:max_idx, 0]
counts1 = rawdata1[min_idx:max_idx, 1]
counts1_err = np.sqrt(counts1)
counts1_err[counts1_err == 0] = 1
#
counts2 = rawdata2[min_idx:max_idx, 1]
counts2_err = np.sqrt(counts2)
counts2_err[counts2_err == 0] = 1
#
counts3 = rawdata3[min_idx:max_idx, 1]
counts3_err = np.sqrt(counts3)
counts3_err[counts3_err == 0] = 1


# Function model for the fit
def beat_osc(x, g, r, x0, phi, fr, y0, tau_r):
    phi = phi/180*np.pi
    x = x-x0
    f = (x > 0)*g*(2*r*np.exp(-x*(1/(2*tau0)+1/(2*tau1)))*np.cos(fr*x+phi) +
                   np.exp(-x/tau0)+r**2*np.exp(-x/tau1))
    rise = (x <= 0) * g * (2*r*np.cos(phi)+r**2+1) * np.exp(x/tau_r)
    return f+y0


# def exp_2(x, g, x0, y0, tau_r, tau_d):
#     x = x - x0
#     decay = (x > 0) * g * np.exp(-x/tau_d)
#     rise = (x <= 0) * g * np.exp(x/tau_r)
#     return decay+rise+y0

gmod = Model(beat_osc)
pars = Parameters()
pars.add('tau_r', value=.631, vary=0)  # value=.631
pars.add('x0', value=459.2, vary=0)
pars.add('fr', value=fr,  vary=0)
pars.add('g',  value=400, vary=1)
pars.add('y0', value=6.4, vary=1)

x_lin = np.linspace(delay[0], delay[-1], 3000)
plt.figure(1)
sns.set_style("white")

"""
Fitting of data 1
"""
pars.add('r', value=1./35, vary=0)
pars.add('phi', value=180, vary=0)

result = gmod.fit(counts1, x=delay, params=pars, weights=1/counts1_err)
print(result.fit_report())

# plt.figure()
plt.subplot(311)
plt.errorbar(delay, counts1, yerr=counts1_err, fmt='o')
plt.plot(x_lin, result.eval(x=x_lin), '-r')
plt.xlim(450, 510)

"""
Fitting of data 2
"""
pars.add('r', value=1./.7, vary=0)
pars.add('phi', value=0, vary=0)

result = gmod.fit(counts2, x=delay, params=pars, weights=1/counts2_err)
print(result.fit_report())

# plt.figure()
plt.subplot(312)
plt.errorbar(delay, counts2, yerr=counts2_err, fmt='o')
plt.plot(x_lin, result.eval(x=x_lin), '-r')
plt.xlim(450, 510)

"""
Fitting of data 3
"""
pars.add('r', value=1./2, vary=1)
pars.add('phi', value=180, vary=0)

result = gmod.fit(counts3, x=delay, params=pars, weights=1/counts3_err)
print(result.fit_report())

# plt.figure()
plt.subplot(313)
plt.errorbar(delay, counts3, yerr=counts3_err, fmt='o')
plt.plot(x_lin, result.eval(x=x_lin), '-r')

plt.xlim(450, 510)
sns.despine(trim=True)
plt.show()
