import matplotlib.pyplot as plt
import numpy as np


import scipy.odr as odr


tau_0 = 1000 / (2 * np.pi * 6.067)


def tau_f(beta, x):
    return tau_0 / (1 + beta[0] * x)


fit_model = odr.Model(tau_f)

"""
data import and column names
"""
raw_data = np.genfromtxt('fwm_OD.dat')

OD = raw_data[:, 1]
OD_err = raw_data[:, 2]
pairs = raw_data[:, 5]
pairs_err = raw_data[:, 6]
signal = raw_data[:, 7]
signal_err = raw_data[:, 8]
idler = raw_data[:, 9]
idler_err = raw_data[:, 10]
tau = raw_data[:, 11]
tau_err = raw_data[:, 12]


""" fitting with ODR"""
data = odr.Data(OD, tau, wd=1 / OD_err, we=1 / tau_err)
odr_fit = odr.ODR(data, fit_model, beta0=[.1])
out = odr_fit.run()
out.pprint()


"""
plotting
"""
pt_size = 15
f, ax = plt.subplots()
# f.set_size_inches(14, 9)

od_v = np.linspace(0, max(OD), 1000)
ax.plot(od_v, tau_f(out.beta, od_v))
ax.plot(od_v, tau_f([.0808], od_v))
ax.errorbar(OD, tau,
            yerr=tau_err,
            xerr=OD_err,
            fmt='o',
            color='b',
            # markersize=pt_size,
            capsize=0,
            elinewidth=3)

# ax.tick_params(labelsize=26)
# ax.xaxis.set_tick_params(width=3)
# ax.yaxis.set_tick_params(width=3, color='b')

# plt.ylim(0, 29)
# plt.xlabel('optical density', x=.5, fontsize=32)

# # Make the y-axis label and tick labels match the line color.
# plt.ylabel('coherence\ntime (ns)',
#            fontsize=32,
#            rotation=0,
#            y=1.02,
#            labelpad=-60,
#            color='b')
# for tl in ax.get_yticklabels():
#     tl.set_color('b')

# ax2 = ax.twinx()
# ax2.errorbar(OD, pairs,
#              yerr=pairs_err,
#              xerr=OD_err,
#              fmt='s',
#              color='r',
#              markersize=pt_size,
#              capsize=0,
#              elinewidth=3)

# ax2.tick_params(labelsize=26)
# ax2.yaxis.set_tick_params(width=3, color='r')
# plt.ylabel('pair rate (1/s)',
#            fontsize=32,
#            rotation=0,
#            y=1.1,
#            labelpad=-60,
#            color='r')

# for axis in ['bottom', 'left', 'right', 'top']:
#     ax.spines[axis].set_linewidth(3)
#     ax2.spines[axis].set_linewidth(3)
# ax2.spines['left'].set_color('b')
# ax2.spines['right'].set_color('r')
# ax.spines['top'].set_visible(False)
# ax2.spines['top'].set_visible(False)
# ax.yaxis.set_ticks_position('left')
# ax.xaxis.set_ticks_position('bottom')
# ax2.yaxis.set_ticks_position('right')
# for tl in ax2.get_yticklabels():
#     tl.set_color('r')

# plt.xlim(0, 33)
# plt.tight_layout()

# ax.plot([0, 1], [27, 27], '-b', linewidth=3)
# ax.text(-1.2, 27, r'$\tau_0$', fontsize=32, color='b')

# plt.savefig("tau_vs_OD.pdf", format="pdf")

plt.show()
