import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
import glob


from uncertainties import unumpy

from lmfit import Parameters, Model

AOM_offset = 16


def data_extract(fname):

    raw_data = np.genfromtxt(fname, invalid_raise=False)
    freq = 2 * raw_data[:, 0] - 267.11
    freq = freq[1:]

    sig = unumpy.uarray(raw_data[:, 1] - AOM_offset, 1)
    ref = unumpy.uarray(raw_data[:, 2] - AOM_offset, 1)

    tra = sig / ref
    tra = tra[1:]

    tra_s = unumpy.nominal_values(tra)
    tra_err = unumpy.std_devs(tra)
    return freq, tra_s, tra_err


def lorentz(x, amplitude, gamma, x0):
    return amplitude * gamma**2 / ((x - x0)**2 + gamma**2)


def profile(x, OD, gamma, x0):
    return np.exp(-OD * lorentz(x, 1, gamma, x0))


fit_model = Model(profile)

p = Parameters()
# p.add('amplitude', 5)
p.add('x0', 0, vary=0)
p.add('gamma', 6.0666, vary=0)
p.add('OD', 5)


filelist = glob.glob('tra*.dat')
od_v = np.empty(len(filelist))
od_err = np.empty(len(filelist))


# filelist = ['transmission_0.03.dat', 'transmission_0.2.dat']
for j, fname in enumerate(filelist):
    freq, tra_s, tra_err = data_extract(fname)

    result = fit_model.fit(tra_s,
                           p,
                           x=freq,
                           weights=1 / tra_err)
    od_v[j] = result.params['OD'].value
    od_err[j] = result.params['OD'].stderr
# result.plot()

# plt.figure()
    x_vec = np.linspace(np.min(freq), np.max(freq), 1000)
    plt.plot(x_vec, result.eval(x=x_vec))
    plt.errorbar(freq, tra_s, yerr=tra_err, fmt='o')

    print(result.fit_report())

print[('{} +/- {}\n'.format(s, e)) for s, e in zip(od_v, od_err)]
plt.xlabel('detuning (MHz)')
plt.ylabel('transmission')
plt.savefig('OD_fit_2cases.pdf')
plt.show()
