import matplotlib.pyplot as plt
import numpy as np

from uncertainties import ufloat
from uncertainties import unumpy

from lmfit import Model
from lmfit import Parameters
from lmfit.models import ConstantModel

"""
Let's start by importing the files and getting the data into a meaningful
structure
"""
filename = '0_4297_20141023.dat'
# filename = '0_0_20141023.dat'
filename = '0_4297_20141021.dat'
# filename = '0_8594_20141021.dat'
raw_data = np.genfromtxt(filename)

# The acquisition time is 5000 ms
int_time = 5

"""
Assigning the data to the proper vector, including the associated poissonian
error
"""


def c2u(arr):
    err = np.sqrt(arr)
    err[err == 0] = 1
    return unumpy.uarray(arr, err)


angles = raw_data[:, 0]
single_a0 = c2u(raw_data[:, 2])
single_a1 = c2u(raw_data[:, 3])
single_b0 = c2u(raw_data[:, 5])
single_b1 = c2u(raw_data[:, 9])

singles = [single_a0, single_a1, single_b0, single_b1]

c_a0b0 = c2u(raw_data[:, 6])
c_a0b1 = c2u(raw_data[:, 10])
c_a1b0 = c2u(raw_data[:, 7])
c_a1b1 = c2u(raw_data[:, 11])

coinc = [c_a0b0, c_a0b1, c_a1b0, c_a1b1]

# Coincidences normalized, similar to efficiency
coinc_n = [c_a0b0 / unumpy.sqrt(single_a0 * single_b0),
           c_a0b1 / unumpy.sqrt(single_a0 * single_b1),
           c_a1b0 / unumpy.sqrt(single_a1 * single_b0),
           c_a1b1 / unumpy.sqrt(single_a1 * single_b1)]

# [plt.errorbar(angles,
#               unumpy.nominal_values(c),
#               yerr=unumpy.std_devs(c),
#               fmt='o-')
#  for c in coinc_n]

"""
Fitting of the visibilities
"""


def hwp_osc(x, V, x_offset, period):
    x = x - x_offset
    return V * np.cos(np.pi / period * x) + 1


def vis_fit(angle_range, sig, weights=''):
    osc_model = ConstantModel(prefix='Amp_') * \
        (Model(hwp_osc) + ConstantModel(prefix='bg_'))

    p0 = Parameters()
    p0.add('Amp_c',
           np.mean(sig),
           min=0,
           vary=True)
    p0.add('V',
           (np.max(sig) - np.min(sig)) / (np.max(sig) + np.min(sig)),
           # .8,
           min=0,
           # max=1,
           vary=True)
    p0.add('x_offset',
           angle_range[np.argmax(sig)],
           vary=True)
    p0.add('bg_c',
           # np.min(sig),
           0,
           min=0, vary=False)
    p0.add('period',
           45,
           vary=False)
    if weights is not'':
        result = osc_model.fit(sig,
                               x=angle_range,
                               params=p0,
                               weights=1 / weights)
    else:
        result = osc_model.fit(sig,
                               x=angle_range,
                               params=p0)
    visibility = ufloat(result.params['V'].value,
                        result.params['V'].stderr)
    x_offset = ufloat(result.params['x_offset'].value,
                      result.params['x_offset'].stderr)
    Amplitude = 2 * ufloat(result.params['Amp_c'].value,
                           result.params['Amp_c'].stderr)

    return visibility, x_offset, Amplitude, result

D = [vis_fit(angles, unumpy.nominal_values(c), unumpy.std_devs(c))
     for c
     in coinc_n]


"""
Efficiency and rates
"""
eta_A1 = c_a1b1 / single_b1
plt.errorbar(angles,
             unumpy.nominal_values(eta_A1),
             yerr=unumpy.std_devs(eta_A1))

A = vis_fit(angles, unumpy.nominal_values(eta_A1),
            unumpy.std_devs(eta_A1))

eta_A1 = c_a1b0 / single_b0
plt.errorbar(angles,
             unumpy.nominal_values(eta_A1),
             yerr=unumpy.std_devs(eta_A1))

A = vis_fit(angles, unumpy.nominal_values(eta_A1),
            unumpy.std_devs(eta_A1))
"""
Plotting
"""
# row and column sharing
f, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2, sharex='col', sharey='row')
f.set_size_inches(14, 9)
axes = [ax1, ax2, ax3, ax4]

[aa.errorbar(angles,
             unumpy.nominal_values(c),
             yerr=unumpy.std_devs(c),
             fmt='o-')
 for c, aa
 in zip(coinc_n, axes)]

[aa.plot(angles, c[3].best_fit, 'r-')
 for c, aa
 in zip(D, axes)]

plt.suptitle('Visibilities for ' + filename, size=16)
ax1.set_ylim(0)
ax3.set_ylim(0)
ax1.set_ylabel('Efficiency')
ax3.set_ylabel('Efficiency')
ax3.set_xlabel('Alice HWP angle')
ax4.set_xlabel('Alice HWP angle')

seq = ['A0B0', 'A0B1', 'A1B0', 'A1B1']


def textstr(det, vis, eff):
    return 'Detector pair ' + det + '\n' + \
        'Visibility: {0:.2u} %\n'.format(vis) + \
        'Efficiency: {0:.2u} %'.format(eff)

props = dict(boxstyle='round', facecolor='white', alpha=0.5)
[ax.text(0.05,
         0.95,
         textstr(det, 100 * A[0], 100 * A[2]),
         transform=ax.transAxes,
         fontsize=14,
         verticalalignment='top',
         bbox=props)
 for ax, A, det
 in zip(axes, D, seq)]

plt.tight_layout()
plt.subplots_adjust(top=0.95)
plt.savefig(filename + ".pdf", format="pdf")
plt.show()
