import glob
import matplotlib.pyplot as plt
import numpy as np

from uncertainties import ufloat
from uncertainties import unumpy
from uncertainties.unumpy import uarray

from lmfit import Model
from lmfit import Parameters

""" useful parameters. Frequencies in MHz"""
Gamma1 = 6.066  # p state natural linewidth
Gamma2 = 0.666  # d state natural linewidth
Delta = 60.  # first pump detuning
delta = -3
lw = 2.7
# a1 = 10.285912696499032
# a2 = 1.7815723392554117

a1 = 18.8249539
a2 = 1.92883123
det_v = 1
power_v = 0

PUMP1 = np.array([.030, .130, .290, .420, .580, .640])


# 5.208125 corresponds to 450uW 780
# 12.54479 corresponds to 15mW 776

# 7 corresponds to 450uW 780
# 7 corresponds to 15mW 776


def data_ext(filename):
    """
    routine to read the raw data into vectors, including poissonian errors
    """
    raw_data = np.genfromtxt(filename)

    p_776 = (250 * (raw_data[:, 0] - 19) / (4750 * 0.5))
    pd_v = unumpy.uarray(raw_data[:, 0], .1)
    p_776_u = (250 * (pd_v - ufloat(19, .1)) / (4750 * 0.5))
    p_776_err = unumpy.std_devs(p_776_u)
    rate_p = raw_data[:, 3]
    rate_p_err = raw_data[:, 4]
    rate_s = raw_data[:, 5]
    rate_s_err = raw_data[:, 6]
    rate_i = raw_data[:, 7]
    rate_i_err = raw_data[:, 8]
    eff = uarray(rate_p, rate_p_err) / uarray(rate_s, rate_s_err)
    eff_s = unumpy.nominal_values(eff)
    eff_s_err = unumpy.std_devs(eff)
    eff = uarray(rate_p, rate_p_err) / uarray(rate_i, rate_i_err)
    eff_i = unumpy.nominal_values(eff)
    eff_i_err = unumpy.std_devs(eff)
    return np.array([p_776, p_776_err,
                    rate_p, rate_p_err,
                    rate_s, rate_s_err,
                    rate_i, rate_i_err,
                    eff_s, eff_s_err,
                    eff_i, eff_i_err])


""" model """


def s_33(Oma, Omb, Gammaa, Gammab, delta, Delta):
    """
    33 element of the steady state density matrix
    """
    return (Oma**2*Omb**2*(Gammaa*Gammab*((delta - Delta)**2 + (Gammaa + Gammab)**2) + Gammaa*(Gammaa + Gammab)*Oma**2 + (Gammaa + Gammab)**2*Omb**2))/(Delta**4*Gammaa*Gammab**3 + delta**4*Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Oma**2) - 2*delta**3*Delta*Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Oma**2 + Omb**2) + (Gammab*(Gammaa + Gammab) + Oma**2 + Omb**2)*(Gammaa**2*Gammab + 2*Gammab*Oma**2 + Gammaa*Omb**2)*(Gammaa*(Gammab*(Gammaa + Gammab) + Oma**2) + (Gammaa + Gammab)*Omb**2) + Delta**2*Gammab*(Gammaa*(Gammab**2*(2*Gammaa**2 + 2*Gammaa*Gammab + Gammab**2) + 2*Gammab*(Gammaa + 2*Gammab)*Oma**2 + Oma**4) + Gammab*(Gammaa*(3*Gammaa + Gammab) + Oma**2)*Omb**2 + Gammaa*Omb**4) + delta**2*(Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Gammaa*Gammab + 2*Gammab**2 - 2*Oma**2)*(Delta**2 + Gammaa**2 + 2*Oma**2) + (Delta**2*Gammaa*(Gammaa + 5*Gammab) + Gammaa**2*(Gammaa**2 + Gammaa*Gammab + 2*Gammab**2) + 2*(Gammaa + Gammab)**2*Oma**2)*Omb**2 + Gammaa*Gammab*Omb**4) + 2*delta*Delta*(Gammaa*Gammab*(-Gammab**2 + Oma**2)*(Delta**2 + Gammaa**2 + 2*Oma**2) - Gammab*(Gammaa*(Delta**2 + Gammaa**2 + 4*Gammaa*Gammab + Gammab**2) + Gammab*Oma**2)*Omb**2 - Gammaa*(Gammaa + 2*Gammab)*Omb**4))


def s_31(Oma, Omb, Gammaa, Gammab, delta, Delta):
    """
    31 element of the steady state density matrix,
    correpsonding to coherent effects
    """
    return (Oma*Omb*(delta**3*(Delta - 1j*Gammaa)*Gammaa*Gammab - 1j*Delta**3*Gammaa*Gammab**2 - Delta**2*Gammaa*Gammab*(Gammaa*Gammab - Oma**2 + Omb**2) - delta**2*Gammaa*Gammab*((Delta - 1j*Gammaa)*(2*Delta + 1j*Gammab) + Oma**2 + Omb**2) - 1j*Delta*Gammaa*Gammab*(Gammaa + Gammab)*(Gammab*(Gammaa + Gammab) + 2*Oma**2 + Omb**2) - (Gammaa*Gammab*(Gammaa + Gammab) - Gammab*Oma**2 + Gammaa*Omb**2)*(Gammaa*(Gammab*(Gammaa + Gammab) + Oma**2) + (Gammaa + Gammab)*Omb**2) + delta*Gammaa*((Delta - 1j*Gammaa)*Gammab*(Delta**2 + 2j*Delta*Gammab + (Gammaa + Gammab)**2) + 2j*Gammab*(Gammaa + Gammab)*Oma**2 + ((-1j)*Gammaa*(Gammaa + Gammab) + Delta*(Gammaa + 3*Gammab))*Omb**2)))/(Delta**4*Gammaa*Gammab**3 + delta**4*Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Oma**2) - 2*delta**3*Delta*Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Oma**2 + Omb**2) + (Gammab*(Gammaa + Gammab) + Oma**2 + Omb**2)*(Gammaa**2*Gammab + 2*Gammab*Oma**2 + Gammaa*Omb**2)*(Gammaa*(Gammab*(Gammaa + Gammab) + Oma**2) + (Gammaa + Gammab)*Omb**2) + Delta**2*Gammab*(Gammaa*(Gammab**2*(2*Gammaa**2 + 2*Gammaa*Gammab + Gammab**2) + 2*Gammab*(Gammaa + 2*Gammab)*Oma**2 + Oma**4) + Gammab*(Gammaa*(3*Gammaa + Gammab) + Oma**2)*Omb**2 + Gammaa*Omb**4) + delta**2*(Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Gammaa*Gammab + 2*Gammab**2 - 2*Oma**2)*(Delta**2 + Gammaa**2 + 2*Oma**2) + (Delta**2*Gammaa*(Gammaa + 5*Gammab) + Gammaa**2*(Gammaa**2 + Gammaa*Gammab + 2*Gammab**2) + 2*(Gammaa + Gammab)**2*Oma**2)*Omb**2 + Gammaa*Gammab*Omb**4) + 2*delta*Delta*(Gammaa*Gammab*(-Gammab**2 + Oma**2)*(Delta**2 + Gammaa**2 + 2*Oma**2) - Gammab*(Gammaa*(Delta**2 + Gammaa**2 + 4*Gammaa*Gammab + Gammab**2) + Gammab*Oma**2)*Omb**2 - Gammaa*(Gammaa + 2*Gammab)*Omb**4))


def laser(x, lw):
    return np.exp(-x**2 / (2 * lw**2)) / (lw * np.sqrt(2 * np.pi))
    # ker = np.exp(-x**2 / (2 * lw**2))
    # return ker / np.sum(ker)
    # return x


nu, step = np.linspace(-10 * lw, 10 * lw, int(1e3), retstep=True)


def co_f(x, y, ma, mb, delta):
    Oma = ma * np.sqrt(x)
    Omb = mb * np.sqrt(y)
    return np.abs(s_31(Oma, Omb, Gamma1, Gamma2, delta, Delta))**2


def co_f_lw(x, y, ma, mb, delta, lw):
    # nu, step = np.linspace(-10 * lw, 10 * lw, int(1e3), retstep=True)
    return np.sum([laser(n, lw) * co_f(x, y, ma, mb, delta + n)
                   for n
                   in nu], 0) * step


def single_f(x, y, ma, mb, delta):
    Oma = ma * np.sqrt(x)
    Omb = mb * np.sqrt(y)
    return s_33(Oma, Omb, Gamma1, Gamma2, delta, Delta)


def single_f_lw(x, y, ma, mb, delta, lw):
    # nu, step = np.linspace(-10 * lw, 10 * lw, int(1e3), retstep=True)
    return np.sum([laser(n, lw) * single_f(x, y, ma, mb, delta + n)
                   for n
                   in nu], 0) * step


def pairs_f(x, y, Amp_p, ma, mb, delta):
    return Amp_p * (co_f_lw(x, y, ma, mb, delta, lw) +
                    single_f_lw(x, y, ma, mb, delta, lw)**2 * 30e-9)


def signal_f(x, y, Amp_s, ma, mb, delta, off):
    return Amp_s * single_f_lw(x, y, ma, mb, delta, lw) + off * x


def eff_f(x, y, eta, ma, mb, delta, off):
    return eta * (pairs_f(x, y, 1, ma, mb, delta) /
                  signal_f(x, y, 1, ma, mb, delta, off))


fit_pair = Model(pairs_f, independent_vars=["x", "y"])
fit_sig = Model(signal_f, independent_vars=["x", "y"])
fit_eff = Model(eff_f, independent_vars=["x", "y"])


def fit_pairs2D(pump1, pump2, signal, error, p=''):
    if p is '':
        p = Parameters()
        p.add('ma', a1, min=0, vary=power_v)
        p.add('mb', a2, min=0, vary=power_v)
        p.add('delta', delta, vary=det_v)
    else:
        p['ma'].set(vary=0)
        p['mb'].set(vary=0)
        p['delta'].set(vary=0)
    p.add('Amp_p', 5 * np.max(signal), min=0)
    result = fit_pair.fit(signal,
                          x=pump1, y=pump2,
                          params=p,
                          weights=1 / error
                          )
    return result


def fit_signal2D(pump1, pump2, signal, error, p=''):
    if p is '':
        p = Parameters()
        p.add('ma', a1, min=power_v)
        p.add('mb', a2, min=power_v)
        p.add('delta', delta, vary=det_v)
    else:
        p['ma'].set(vary=0)
        p['mb'].set(vary=0)
        p['delta'].set(vary=0)
    p.add('Amp_s', 10 * np.max(signal))
    p.add('off', np.min(signal), vary=1)
    result = fit_sig.fit(signal,
                         x=pump1, y=pump2,
                         params=p,
                         weights=1 / error
                         )
    return result


def fit_eff2D(pump1, pump2, signal, error, p=''):
    if p is '':
        p = Parameters()
        p.add('ma', a1, min=0, vary=power_v)
        p.add('mb', a2, min=0, vary=power_v)
        p.add('delta', delta, vary=det_v)
        # p.add('off', np,min(signal), vary=1)
        # p.add('off', 500, vary=1)
    else:
        # p['Amp'].set(vary=0)
        p['ma'].set(vary=0)
        p['mb'].set(vary=0)
        p['delta'].set(vary=0)
        p['off'].set(0.001, vary=1)
    p.add('eta', np.max(signal))
    result = fit_eff.fit(signal,
                         x=pump1, y=pump2,
                         params=p,
                         weights=1 / error
                         )
    return result


def plot_shade_6(x_data, y_data, y_err, fit_result, title="", y_label=""):
    fig = plt.figure(title)
    idx = 0
    for p2, pa, err, p1 in zip(np.array_split(pump2, 6),
                               np.array_split(y_data, 6),
                               np.array_split(y_err, 6),
                               np.array_split(pump1, 6)):
        idx_c = idx % 6
        plt.errorbar(p2, pa, yerr=err, fmt='o')
        g1 = np.linspace(np.min(p1), np.max(p1), 1000)
        g2 = np.linspace(np.min(p2), np.max(p2) * 3, 1000)
        fi = fit_result.eval(x=g1, y=g2)

        plt.plot(g2, fi, 'C{}'.format(idx_c))
        idx = idx + 1

    plt.xlabel(r'$P_2$ power (mW)')
    plt.ylabel(y_label)
    plt.title(title)
    plt.tight_layout()
    # plt.savefig(title + '.pdf')
    return fig


def plot_eff_f(x_data, y_data, y_err, fit_result1, fit_result2, title="", y_label=""):
    fig = plt.figure(title)
    idx = 0
    for p2, pa, err, p1 in zip(np.array_split(pump2, 6),
                               np.array_split(y_data, 6),
                               np.array_split(y_err, 6),
                               np.array_split(pump1, 6)):
        idx_c = idx % 6
        plt.errorbar(p2, pa, yerr=err, fmt='o')
        g1 = np.linspace(np.min(p1), np.max(p1), 1000)
        g2 = np.linspace(np.min(p2), np.max(p2) * 3, 1000)
        fi = fit_result1.eval(x=g1, y=g2) / fit_result2.eval(x=g1, y=g2)

        plt.plot(g2, fi, 'C{}'.format(idx_c))
        idx = idx + 1

    plt.xlabel(r'$P_2$ power (mW)')
    plt.ylabel(y_label)
    plt.title(title)
    plt.tight_layout()
    # plt.savefig(title + '.pdf')
    return fig


def save_fit(pump1, pump2, result, f_name):
    with open(f_name, 'w') as f:
        f.write('#Pump1\tPump2\tFitValue\nFitDelta')
        for p1, p2, fi in zip(np.array_split(pump1, 6),
                              np.array_split(pump2, 6),
                              np.array_split(result.best_fit, 6)):
            p1 = np.linspace(np.min(p1), np.max(p1), 1000)
            p2 = np.linspace(np.min(p2), 18, 1000)
            fi = result.eval(x=p1, y=p2)

            [f.write('{}\t{}\t{}\n'.format(a, b, c))
             for a, b, c
             in zip(p1, p2, fi)]
            f.write('\n\n')
        f.write(result.fit_report())


def save_fit_eff(pump1, pump2, result1, result2, f_name):
    with open(f_name, 'w') as f:
        f.write('#Pump1\tPump2\tFitValue\n')
        for p1, p2, fi in zip(np.array_split(pump1, 6),
                              np.array_split(pump2, 6),
                              np.array_split(result.best_fit, 6)):
            p1 = np.linspace(np.min(p1), np.max(p1), 1000)
            p2 = np.linspace(np.min(p2), 18, 1000)
            fi = result1.eval(x=p1, y=p2) / result2.eval(x=p1, y=p2)

            [f.write('{}\t{}\t{}\n'.format(a, b, c))
             for a, b, c
             in zip(p1, p2, fi)]
            f.write('\n\n')
        f.write(result.fit_report())


filenames = glob.glob('data/*data_0*.dat')
pairs = []
signal = []
idler = []
pairs_err = []
signal_err = []
idler_err = []
eta_s = []
eta_i = []
eta_s_err = []
eta_i_err = []
pump1 = []
pump2 = []
p2_error = []

for filename, p1 in zip(filenames, PUMP1):
    raw_data = np.genfromtxt(filename)

    p_776 = (250 * (raw_data[:, 0] - 19) / (4750 * 0.5))

    pd_v = unumpy.uarray(raw_data[:, 0], .1)
    p_776_u = (250 * (pd_v - ufloat(19, .1)) / (4750 * 0.5))
    p_776_err = unumpy.std_devs(p_776_u)

    rate_p = raw_data[:, 3]
    rate_p_err = raw_data[:, 4]
    rate_s = raw_data[:, 5]
    rate_s_err = raw_data[:, 6]
    rate_i = raw_data[:, 7]
    rate_i_err = raw_data[:, 8]
    eff = uarray(rate_p, rate_p_err) / uarray(rate_s, rate_s_err)
    eff_s = unumpy.nominal_values(eff)
    eff_s_err = unumpy.std_devs(eff)
    eff = uarray(rate_p, rate_p_err) / uarray(rate_i, rate_i_err)
    eff_i = unumpy.nominal_values(eff)
    eff_i_err = unumpy.std_devs(eff)

    for r, r_e, s, s_e, i, i_e, ef_s, ef_se, ef_i, ef_ie, p2, e in zip(rate_p,
                                                                       rate_p_err,
                                                                       rate_s,
                                                                       rate_s_err,
                                                                       rate_i,
                                                                       rate_i_err,
                                                                       eff_s,
                                                                       eff_s_err,
                                                                       eff_i,
                                                                       eff_i_err,
                                                                       p_776,
                                                                       p_776_err):
        pairs.append(r)
        signal.append(s)
        idler.append(i)
        pairs_err.append(r_e)
        signal_err.append(s_e)
        idler_err.append(i_e)
        pump1.append(p1)
        pump2.append(p2)
        p2_error.append(e)
        eta_s.append(ef_s)
        eta_i.append(ef_i)
        eta_s_err.append(ef_se)
        eta_i_err.append(ef_ie)

    """ Plotting """
pairs = np.array(pairs)
signal = np.array(signal)
idler = np.array(idler)
pairs_err = np.array(pairs_err)
signal_err = np.array(signal_err)
idler_err = np.array(idler_err)
eta_s = np.array(eta_s)
eta_i = np.array(eta_i)
eta_s_err = np.array(eta_s_err)
eta_i_err = np.array(eta_i_err)

pump1 = np.array(pump1)
pump2 = np.array(pump2)
p2_error = np.array(p2_error)


result = fit_pairs2D(pump1, pump2, pairs, pairs_err)
print(result.fit_report())
save_fit(pump1, pump2, result, 'fit_pairs.dat')
plot_shade_6(pump2,
             pairs,
             pairs_err,
             result,
             title="Pump power vs Pair rate",
             y_label="Pair rate (1/s)")


result_s = fit_signal2D(pump1, pump2, signal, signal_err, result.params)
print(result_s.fit_report())
save_fit(pump1, pump2, result_s, 'fit_signal.dat')
plot_shade_6(pump2,
             signal,
             signal_err,
             result_s,
             title="Pump power vs Signal rate",
             y_label="Signal rate (1/s)")

# result = fit_eff2D(pump1, pump2, eta_s, eta_s_err, result.params)
# print(result.fit_report())
save_fit_eff(pump1, pump2, result, result_s, 'fit_eff_s.dat')
plot_eff_f(pump2,
           eta_s, eta_s_err,
           result, result_s,
           title="Pump power vs Signal Efficiency",
           y_label="Efficiency")


result_i = fit_signal2D(pump1, pump2, idler, idler_err, result.params)
print(result_i.fit_report())
save_fit(pump1, pump2, result_i, 'fit_idler.dat')
plot_shade_6(pump2,
             idler,
             idler_err,
             result_i,
             title="Pump power vs Idler rate",
             y_label="Idler rate (1/s)")


# result = fit_eff2D(pump1, pump2, eta_i, eta_i_err, result.params)
save_fit_eff(pump1, pump2, result, result_i, 'fit_eff_i.dat')
# print(result.fit_report())
plot_eff_f(pump2,
           eta_i, eta_i_err,
           result, result_i,
           title="Pump power vs Idler Efficiency",
           y_label="Efficiency")


plt.show()
