import matplotlib.pyplot as plt
import numpy as np

from uncertainties import unumpy
from lmfit.models import LinearModel
from lmfit import Parameters
# Model

import scipy.odr as odr


def sat_f(p, x):
    a, b = p
    return a * (1 - np.exp(-x / b))


sat_model = odr.Model(sat_f)

"""
data import and column names
"""
raw_data = np.genfromtxt('fwm_OD.dat')

OD = raw_data[:, 1]
OD_err = raw_data[:, 2]
pairs = raw_data[:, 5]
pairs_err = raw_data[:, 6]
signal = raw_data[:, 7]
signal_err = raw_data[:, 8]
idler = raw_data[:, 9]
idler_err = raw_data[:, 10]
tau = raw_data[:, 11]
tau_err = raw_data[:, 12]

eff_s_u = unumpy.uarray(pairs, pairs_err) / unumpy.uarray(signal, signal_err)
eff_i_u = unumpy.uarray(pairs, pairs_err) / unumpy.uarray(idler, idler_err)
eff_s = unumpy.nominal_values(eff_s_u)
eff_i = unumpy.nominal_values(eff_i_u)
eff_s_err = unumpy.std_devs(eff_s_u)
eff_i_err = unumpy.std_devs(eff_i_u)


"""
plot of the singles
"""

p_line = Parameters()
p_line.add('slope', 1 / 1e3)
p_line.add('intercept', 0, vary=0)

result_s = LinearModel().fit(OD, p_line, x=signal, weights=1 / OD_err)
# print(result_s.fit_report())

result_i = LinearModel().fit(OD, p_line, x=idler, weights=1 / OD_err)
# print(result_i.fit_report())

data = odr.Data(OD, signal, wd=1 / OD_err, we=1 / signal_err)
odr_fit = odr.ODR(data, sat_model, beta0=[max(signal), 8])
out_s = odr_fit.run()
# print('beta: {}\nResidual variance: {}'.format(out_s.beta, out_s.res_var))

f, ax = plt.subplots()
xes = np.linspace(np.min(signal), np.max(signal), int(1e3))
plt.plot(result_s.eval(x=xes), xes)
xes = np.linspace(min(idler), max(idler), int(1e3))
plt.plot(result_i.eval(x=xes), xes)


x_vec = np.linspace(0, np.max(OD), int(1e3))
plt.plot(x_vec, sat_f(out_s.beta, x_vec))

plt.errorbar(OD, signal,
             yerr=signal_err,
             xerr=OD_err,
             fmt='o')
plt.errorbar(OD, idler,
             yerr=idler_err,
             xerr=OD_err,
             fmt='o')


plt.yticks(np.arange(0, 40001, int(1e4)))
plt.ylim(0, 42000)
plt.xlim(0, 36)
plt.xlabel('optical density')
plt.ylabel('rate (1/s)')
plt.savefig("singles_vs_OD.pdf", format="pdf")


"""
plot for the pair rate
"""
f, ax = plt.subplots()

plt.errorbar(OD, pairs,
             yerr=pairs_err,
             xerr=OD_err,
             fmt='o')


xes = np.linspace(min(pairs), max(pairs), int(1e3))

plt.yticks(np.arange(0, 5000.1, 2500))
plt.ylim(0, 6000)
plt.xlim(4, 36)
plt.xlabel('optical density')
plt.ylabel('coincidences (1/s)')

plt.savefig("pairs_vs_OD.pdf", format="pdf")


"""
plot for the efficiency
"""
data = odr.Data(OD, eff_s, wd=1 / OD_err, we=1 / eff_s_err)
odr_fit = odr.ODR(data, sat_model, beta0=[.2, 8])
out = odr_fit.run()
# print('beta: {}\nResidual variance: {}'.format(out.beta, out.res_var))
out.pprint()

data = odr.Data(OD, eff_i, wd=1 / OD_err, we=1 / eff_i_err)
odr_fit = odr.ODR(data, sat_model, beta0=[.2, 8])
out_i = odr_fit.run()
# print('beta: {}\nResidual variance: {}'.format(out_i.beta, out_i.res_var))
out_i.pprint()

f, ax = plt.subplots()
x_vec = np.linspace(0, np.max(OD), int(1e3))
plt.plot(x_vec, sat_f(out.beta, x_vec))
plt.plot(x_vec, sat_f(out_i.beta, x_vec))

plt.errorbar(OD, eff_s,
             yerr=eff_s_err,
             xerr=OD_err,
             fmt='o')

plt.errorbar(OD, eff_i,
             yerr=eff_i_err,
             xerr=OD_err,
             fmt='o')

plt.yticks(np.arange(0, .21, .05))
plt.ylim(0, .21)
plt.xlim(0, 36)
plt.xlabel('optical density')
plt.ylabel('efficiency')

plt.savefig("eff_vs_OD.pdf", format="pdf")


def sat_f2D(p, x):
    a1, a2, b = p
    eta1 = (a1 * (1 - np.exp(-x / b)) - eff_s) / eff_s_err
    eta2 = (a2 * (1 - np.exp(-x / b)) - eff_i) / eff_i_err
    return eta1 + eta2


sat_model2D = odr.Model(sat_f2D)
data2D = odr.Data(x=OD, y=[0] * 8, wd=1 / OD_err)
odr_fit2D = odr.ODR(data2D, sat_model2D, beta0=[.18, .14, 9])
out2D = odr_fit2D.run()
# out2D.pprint()

plt.figure()
plt.plot(OD, sat_f((out2D.beta[0], out2D.beta[2]), OD))
plt.plot(OD, sat_f((out2D.beta[1], out2D.beta[2]), OD))

plt.errorbar(OD, eff_s,
             yerr=eff_s_err,
             xerr=OD_err,
             fmt='o')

plt.errorbar(OD, eff_i,
             yerr=eff_i_err,
             xerr=OD_err,
             fmt='o')

# """
# plot for the tau vs pairs
# """

# f, ax = plt.subplots()
# f.set_size_inches(14, 9)

# plt.errorbar(tau, pairs,
#              yerr=pairs_err,
#              xerr=tau_err,
#              fmt='o',
#              color=red)

# ax.tick_params(labelsize=26)
# ax.xaxis.set_tick_params(width=3)
# ax.yaxis.set_tick_params(width=3)
# for axis in ['bottom', 'left']:
#     ax.spines[axis].set_linewidth(3)
# sns.despine(trim=False, offset=20)
# ax.spines['bottom'].set_position('zero')
# # plt.yticks(np.arange(0, .21, .05))
# # plt.xticks(np.arange(-40, 40.1, 40))
# # plt.ylim(0, .21)
# # plt.xlim(4, 36)
# plt.xlabel('coherence time (ns)', x=1, fontsize=32)
# plt.ylabel('pair rate',
#            fontsize=32, rotation=0,
#            y=1.02, labelpad=-60)

# plt.tight_layout()

# plt.savefig("pairs_vs_tau.pdf", format="pdf")


# """
# plot for the tau vs OD
# """

# f, ax = plt.subplots()
# f.set_size_inches(14, 9)

# ax.errorbar(OD, tau,
#             yerr=tau_err,
#             xerr=OD_err,
#             fmt='o',
#             color=red)

# ax.tick_params(labelsize=26)
# ax.xaxis.set_tick_params(width=3)
# ax.yaxis.set_tick_params(width=3)
# for axis in ['bottom', 'left']:
#     ax.spines[axis].set_linewidth(3)
# # sns.despine(trim=False, offset=20)
# # ax.spines['bottom'].set_position('zero')
# # plt.yticks(np.arange(0, .21, .05))
# # plt.xticks(np.arange(-40, 40.1, 40))
# plt.ylim(0, 27)
# plt.xlim(0, 31)
# plt.xlabel('optical density', x=.5, fontsize=32)
# plt.ylabel('coherence\ntime (ns)',
#            fontsize=32, rotation=0,
#            y=1.02, labelpad=-60)


# ax2 = ax.twinx()
# ax2.errorbar(OD, pairs,
#              yerr=pairs_err,
#              xerr=OD_err,
#              fmt='o',
#              color=blue)
# ax2.tick_params(labelsize=26)
# ax2.xaxis.set_tick_params(width=3)
# ax2.yaxis.set_tick_params(width=3)
# # sns.axes_style( )
# for axis in ['right']:
#     ax2.spines[axis].set_linewidth(3)
# for axis in ['top']:
#     ax2.spines[axis].set_linewidth(0)
# sns.despine(trim=False, right=True)
# # ax2.spines['bottom'].set_position('zero')
# # plt.yticks(np.arange(0, .21, .05))
# # plt.xticks(np.arange(-40, 40.1, 40))
# # plt.ylim(0, 27)
# # plt.xlim(0, 31)
# plt.ylabel('pair rate',
#            fontsize=32, rotation=0,
#            y=1.1, labelpad=-60)

# plt.tight_layout()

# plt.savefig("tau_vs_OD.pdf", format="pdf")

plt.show()
