import matplotlib.pyplot as plt
import numpy as np
import scipy.odr as odr

from pathlib import Path
from qitdevices.g2lib import g2_lib
from uncertainties import ufloat
from uncertainties import unumpy


Dt = 30e-9
ds = 165
di = 508

data_files = Path('030713b').glob('rf_*.dat')

# print([k for k in data_files])
coinc_win = [325, int(325 + Dt * 1e9)]
duty_cycle = 16.03


def coinc(vec, t):
    c = sum(vec[coinc_win[0]:coinc_win[1]])
    return ufloat(c, np.sqrt(c)) * duty_cycle / t * 1e9


a = [g2_lib.g2_extr(k, bins=600, max_range=600) for k in data_files]

Rp_u = np.array([coinc(k[0], k[-1]) for k in a])
rs_u = np.array([ufloat(k[1], np.sqrt(k[1])) * duty_cycle / k[-1] * 1e9
                 for k in a])
ri_u = np.array([ufloat(k[2], np.sqrt(k[2])) * duty_cycle / k[-1] * 1e9
                 for k in a])

ds = ufloat(a[0][1], np.sqrt(a[0][1])) * duty_cycle / a[0][-1] * 1e9
di = ufloat(a[0][2], np.sqrt(a[0][2])) * duty_cycle / a[0][-1] * 1e9

print(ds, di)
CAR_u = Rp_u / (rs_u * ri_u * Dt)
rp_u = Rp_u - (rs_u * ri_u * Dt)

eta_s_u = np.mean(rp_u[1:] / (rs_u[1:] - ds))
eta_i_u = np.mean(rp_u[1:] / (ri_u[1:] - di))
print(eta_s_u, eta_i_u)

CAR = unumpy.nominal_values(CAR_u)  # raw_data[:, 6]
rp = unumpy.nominal_values(rp_u)  # raw_data[:, 6]
rp_err = unumpy.std_devs(rp_u)
rp_err[rp_err == 0] = 1
CAR_err = unumpy.std_devs(CAR_u)  # raw_data[:, 7]
CAR_err[CAR_err == 0] = 1


def CAR_func(p, x):
    eff_s, eff_i, s, i = p
    r_signal = (x / eff_s) + s
    r_idler = (x / eff_i) + i
    ra = r_signal * r_idler * Dt
    return (x + ra) / ra


with open('CAR.dat', 'w') as f:
    f.write('#Corrected_pairs\tC_pairs_err\tCAR\tCAR_error\n')
    [f.write(('{:.3f}\t' * 3 + '{:.3f}\n').format(k.n, k.s, j.n, j.s))
     for k, j
     in zip(rp_u, CAR_u)]

CAR_model = odr.Model(CAR_func)
data = odr.Data(rp, CAR, wd=1 / rp_err, we=1 / CAR_err)
odr_fit = odr.ODR(data, CAR_model, beta0=[eta_s_u.n, eta_i_u.n, ds.n, di.n])
out = odr_fit.run()
print('beta: {}\nResidual variance: {}'.format(out.beta, out.res_var))

plt.figure('odr')
plt.errorbar(rp, CAR, yerr=CAR_err, xerr=rp_err, fmt='o')
x_fit = np.logspace(.01, np.log10(max(rp)), 1000)
plt.plot(x_fit, CAR_func(out.beta, x_fit), 'g')
plt.plot(x_fit, CAR_func([eta_s_u.n, eta_i_u.n, ds.n, di.n], x_fit), 'b')
plt.xscale('log')
plt.show()
