import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

from uncertainties import unumpy

c_palette = sns.color_palette()
red_palette = sns.color_palette("Reds")


def eff(pairs, singles):
    p_u = unumpy.uarray(pairs, np.sqrt(pairs))
    s_u = unumpy.uarray(singles, np.sqrt(singles))
    return p_u / s_u


def aom2freq(x):
    return ((x - 68) * 2)


def duty_cycle(x):
    return x / 1.875


def brig(pairs, tau):
    return pairs * 2 * np.pi * tau * 1e-3


pscan = pd.read_csv('powerscans/power_scan2D.dat', delim_whitespace=True)
pscan['Brig'] = brig(pscan.Pairs, pscan.tau)


dscan = pd.read_csv('detunings/plot_data_tau.dat', delim_whitespace=True)

dscan['delta'] = aom2freq(dscan['AOM_freq'])
dscan['Delta'] = 60
dscan['Pump1'] = .45
dscan['Pump2'] = 15
dscan['Pairs_err'] = np.sqrt(dscan['Pairs'])
dscan['Signal_err'] = np.sqrt(dscan['Signal'])
dscan['Idler_err'] = np.sqrt(dscan['Idler'])
eff_i_u = eff(dscan['Pairs'], dscan['Idler'])
eff_s_u = eff(dscan['Pairs'], dscan['Signal'])
dscan['eff_i'] = unumpy.nominal_values(eff_i_u)
dscan['eff_s'] = unumpy.nominal_values(eff_s_u)
dscan['eff_s_err'] = unumpy.std_devs(eff_s_u)
dscan['eff_i_err'] = unumpy.std_devs(eff_i_u)
dscan['OD'] = 29

dty = pd.read_csv('density/fwm_OD.dat', delim_whitespace=True)
dty['eff_s'] = unumpy.nominal_values(eff(dty['Pairs'], dty['Signal']))
dty['eff_i'] = unumpy.nominal_values(eff(dty['Pairs'], dty['Idler']))
dty['eff_s_err'] = unumpy.std_devs(eff(dty['Pairs'], dty['Signal']))
dty['eff_i_err'] = unumpy.std_devs(eff(dty['Pairs'], dty['Idler']))
dty['delta'] = 12
dty['Delta'] = 60
dty['Pump2'] = 15
dty['Pump1'] = .3

new_t = pd.concat([pscan, dscan, dty])
# print(pscan.columns.values, dscan.columns.values, dty.columns.values)
# print(new_t.columns.values)
# print(new_t)

# plt.plot(new_t['Pump2'], new_t['eff_i'], 'o-')
# plt.plot(new_t['Pump2'], new_t['eff_s'], 'o-')

# pivots

a, b = pd.cut(pscan.tau, 5, retbins=True)
pscan['tau_r'] = a
ppiv = pscan.pivot(columns='tau_r', values='Pairs', index='eff_i')
# ppiv.fillna(method='ffill').plot(style='o-')
# ppiv.sort_valuesort_index().dropna(axis=1, how='all').plot(style='o-')

#
# plt.figure('Pair rate vs Tau')
# plt.errorbar(pscan.tau, pscan.Pairs,
#              xerr=pscan.tau_err, yerr=pscan.Pairs_err,
#              fmt='o-', label='Power scan')
# plt.errorbar(dscan.tau, dscan.Pairs,
#              xerr=dscan.tau_err, yerr=dscan.Pairs_err,
#              fmt='o', label='$\delta$ scan')
# plt.errorbar(dty.tau, dty.Pairs,
#              xerr=dty.tau_err, yerr=dty.Pairs_err,
#              fmt='o', label='OD scan')
# plt.legend()
# plt.xlabel(r'$\tau$ (ns)')
# plt.ylabel(r'pair rate (1/s)')


# plt.figure('Eff_s vs Tau')
# plt.errorbar(pscan.tau, pscan.eff_s,
#              xerr=pscan.tau_err, yerr=pscan.eff_s_err,
#              fmt='o', label='Power scan')
# plt.errorbar(dscan.tau, dscan.eff_s,
#              xerr=dscan.tau_err, yerr=dscan.eff_s_err,
#              fmt='o', label='$\delta$ scan')
# plt.errorbar(dty.tau, dty.eff_s,
#              xerr=dty.tau_err, yerr=dty.eff_s_err,
#              fmt='o', label='OD scan')
# plt.legend()

# plt.xlabel(r'$\tau$ (ns)')
# plt.ylabel(r'$\eta_s$')


# plt.figure('Eff_s vs spectral brightness')
# plt.errorbar(pscan.Pairs * 2 * np.pi * pscan.tau * 1e-3, pscan.eff_s,
#              # xerr=pscan.tau_err,
#              yerr=pscan.eff_s_err,
#              fmt='o', label='Power scan')
# plt.errorbar(dscan.Pairs * 2 * np.pi * dscan.tau * 1e-3, dscan.eff_s,
#              # xerr=dscan.tau_err,
#              yerr=dscan.eff_s_err,
#              fmt='o', label='$\delta$ scan')
# plt.errorbar(dty.Pairs * 2 * np.pi * dty.tau * 1e-3, dty.eff_s,
#              # xerr=dty.tau_err,
#              yerr=dty.eff_s_err,
#              fmt='o', label='OD scan')
# plt.legend()
# plt.xlabel(r'brightness (MHz * s)$^{-1}$')
# plt.ylabel(r'$\eta_s$')

# plt.figure('Eff_s vs pair rate')
# plt.errorbar(pscan.Pairs, pscan.eff_s,
#              # xerr=pscan.tau_err,
#              yerr=pscan.eff_s_err,
#              fmt='o', label='Power scan')
# plt.errorbar(dscan.Pairs, dscan.eff_s,
#              # xerr=dscan.tau_err,
#              yerr=dscan.eff_s_err,
#              fmt='o', label='$\delta$ scan')
# plt.errorbar(dty.Pairs, dty.eff_s,
#              # xerr=dty.tau_err,
#              yerr=dty.eff_s_err,
#              fmt='o', label='OD scan')
# plt.legend()
# plt.xlabel(r'pairs rate s$^{-1}$')
# plt.ylabel(r'$\eta_s$')


# plt.figure('spectral brightness vs power')
# plt.errorbar(pscan.Pump2, pscan.Pairs * 2 * np.pi * pscan.tau * 1e-3,
#              # xerr=pscan.tau_err, yerr=pscan.eff_s_err,
#              fmt='o', label='Power scan')

# plt.legend()
# plt.xlabel(r'$P_2$ (mW)')
# plt.ylabel(r'brightness (MHz * s)$^{-1}$')


# plt.figure('spectral brightness vs OD')
# plt.errorbar(dty.OD, dty.Pairs * 2 * np.pi * dty.tau * 1e-3,
#              # xerr=pscan.tau_err, yerr=pscan.eff_s_err,
#              fmt='o', label='Power scan')

# plt.legend()
# plt.xlabel(r'OD')
# plt.ylabel(r'brightness (MHz * s)$^{-1}$')


# plt.figure('spectral brightness vs detuning')
# plt.errorbar(dscan.delta, dscan.Pairs * 2 * np.pi * dscan.tau * 1e-3,
#              # xerr=pscan.tau_err, yerr=pscan.eff_s_err,
#              fmt='o', label='Power scan')
# plt.legend()
# plt.xlabel(r'$\delta$ (MHz)')
# plt.ylabel(r'brightness (MHz * s)$^{-1}$')

replacements = {'Pump2': r'$P_2$ (mW)',
                'Brig': r'brightness (MHz * s)$^{-1}$',
                'tau': r'$\tau$ (ns)',
                'eff_s': r'$\eta_s$',
                'delta': r'$\delta$ (MHz)'}

pscan[['Pump2', 'Brig', 'tau', 'eff_s', 'eff_i',
       'tau_err', 'eff_s_err',
       'eff_i_err', 'Pump1']].to_csv('Power_data_table.dat', sep='\t')
g = sns.PairGrid(pscan, vars=['Pump2', 'Brig', 'tau', 'eff_s'],
                 hue='Pump1',
                 palette=red_palette)
g.map_lower(plt.scatter)
for i, j in np.ndindex(np.shape(g.axes)):
    xlabel = g.axes[i][j].get_xlabel()
    ylabel = g.axes[i][j].get_ylabel()
    if xlabel in replacements.keys():
        g.axes[i][j].set_xlabel(replacements[xlabel])
    if ylabel in replacements.keys():
        g.axes[i][j].set_ylabel(replacements[ylabel])

for i, j in zip(*np.triu_indices_from(g.axes, 0)):
    # plt.delaxes(g.axes[i, j])
    g.axes[i, j].axis('off')
g.add_legend()
for i in range(len(g.fig.get_children()[-1].texts)):
    label = g.fig.get_children()[-1].texts[i].get_text()
    g.fig.get_children()[-1].texts[i].set_text(label + ' mW')
    # if label in replacements.keys():
    #     g.fig.get_children()[-1].texts[i].set_text(replacements[label])
g.fig.get_children()[-1].set_bbox_to_anchor((.7, 0.6, 0, 0))
plt.tight_layout()


dscan['Brig'] = brig(dscan.Pairs, dscan.tau)
dscan['color'] = np.int_(dscan.delta > dscan.delta[np.argmax(dscan.Pairs)])
dscan[['delta', 'Brig', 'tau', 'eff_s', 'color', 'eff_i',
       'tau_err', 'eff_s_err',
       'eff_i_err']].to_csv('det_data_table.dat', sep='\t',
                            encoding='utf-8')
d = sns.PairGrid(dscan, vars=['delta', 'Brig', 'tau', 'eff_s'],
                 hue='color', palette="Set1", hue_order=[0, 1])
d.map_lower(plt.scatter)
for i, j in np.ndindex(np.shape(d.axes)):
    xlabel = d.axes[i][j].get_xlabel()
    ylabel = d.axes[i][j].get_ylabel()
    if xlabel in replacements.keys():
        d.axes[i][j].set_xlabel(replacements[xlabel])
    if ylabel in replacements.keys():
        d.axes[i][j].set_ylabel(replacements[ylabel])

for i, j in zip(*np.triu_indices_from(d.axes, 0)):
    plt.delaxes(d.axes[i, j])
plt.tight_layout()


dty['Brig'] = brig(dty.Pairs, dty.tau)
dty[['OD', 'Brig', 'tau', 'eff_s', 'eff_i',
       'tau_err', 'eff_s_err',
       'eff_i_err']].to_csv('OD_data_table.dat', sep='\t',
                            encoding='utf-8')
r = sns.PairGrid(dty, vars=['OD', 'Brig', 'tau', 'eff_s'])
r.map_lower(plt.scatter)
for i, j in np.ndindex(np.shape(r.axes)):
    xlabel = r.axes[i][j].get_xlabel()
    ylabel = r.axes[i][j].get_ylabel()
    if xlabel in replacements.keys():
        r.axes[i][j].set_xlabel(replacements[xlabel])
    if ylabel in replacements.keys():
        r.axes[i][j].set_ylabel(replacements[ylabel])

for i, j in zip(*np.triu_indices_from(r.axes, 0)):
    plt.delaxes(r.axes[i, j])
plt.subplots_adjust()
plt.tight_layout()
plt.savefig('od_table.pdf')

# plt.show()
