from __future__ import division
import numpy as np
from lmfit import Model
import matplotlib.pyplot as plt

T_atom = np.genfromtxt("T_atom")
T_woatom = np.genfromtxt("T_woatom")

# Define the fitting function
def lorentzian(x, amp, width, mean):
    top = amp * (width**2 / 4)
    bottom = (x - mean)**2 + (width/2)**2
    y = top/bottom
    return y

def normal_mode(x, amp, mean, gamma, coupling, kappa, shift):
    atom_freq = x - mean
    cav_freq = atom_freq - shift
    top = kappa**2 * (atom_freq**2 + gamma**2)
    bottom = (cav_freq**2 + kappa**2) * (atom_freq**2 + gamma**2) + 2 * coupling**2 * (atom_freq*cav_freq + kappa*gamma) + coupling**4
    return amp * (top/bottom)

def lorentz_fitting(data, init, no_trials = 1):
    # Declare model and initializing data
    lzmod = Model(lorentzian)
    xdata = data[:,0]
    ydata = data[:,1]
    ystderr = data[:,2] / np.sqrt(no_trials)

    # Set parameter hints
    lzmod.set_param_hint('amp', value = init[0], min=0)
    lzmod.set_param_hint('width', value = init[1], min=0)
    lzmod.set_param_hint('mean', value = init[2])
    pars = lzmod.make_params()

    fit = lzmod.fit(ydata, pars, x=xdata, weights=1/ystderr, verbose=False)
    print fit.fit_report()

    amp_est = fit.params['amp'].value
    amp_std = fit.params['amp'].stderr
    width_est = fit.params['width'].value
    width_std = fit.params['width'].stderr
    mean_est = fit.params['mean'].value
    mean_std = fit.params['mean'].stderr
    redchi = fit.redchi

    # Return back the result list with the following order:
    result_list = [amp_est, amp_std, width_est, width_std, mean_est, mean_std, redchi]

    return result_list

def normal_mode_fitting(data, init, no_trials = 1):
    # Declare model and initializing data
    lzmod = Model(normal_mode)
    xdata = data[:,0]
    ydata = data[:,1]
    ystderr = data[:,2] / np.sqrt(no_trials)

    # Set parameter hints
    lzmod.set_param_hint('amp', value = init[0], min=0)
    lzmod.set_param_hint('mean', value = init[1])
    lzmod.set_param_hint('gamma', value = init[2], vary=False)
    lzmod.set_param_hint('coupling', value = init[3], min=0)
    lzmod.set_param_hint('kappa', value = init[4], min=0)
    lzmod.set_param_hint('shift', value = init[5])

    pars = lzmod.make_params()

    fit = lzmod.fit(ydata, pars, x=xdata, weights=1/ystderr, verbose=False)
    print fit.fit_report()

    amp_est = fit.params['amp'].value
    amp_std = fit.params['amp'].stderr
    mean_est = fit.params['mean'].value
    mean_std = fit.params['mean'].stderr
    gamma_est = fit.params['gamma'].value
    gamma_std = fit.params['gamma'].stderr
    coupling_est = fit.params['coupling'].value
    coupling_std = fit.params['coupling'].stderr
    kappa_est = fit.params['kappa'].value
    kappa_std = fit.params['kappa'].stderr
    shift_est = fit.params['shift'].value
    shift_std = fit.params['shift'].stderr

    redchi = fit.redchi

    # Return back the result list with the following order:
    result_list = [amp_est, amp_std, mean_est, mean_std, gamma_est, gamma_std, redchi, coupling_est, coupling_std, kappa_est, kappa_std, shift_est, shift_std]

    coop_est = coupling_est**2 / (2 * kappa_est * gamma_est)
    coop_std = coop_est * np.sqrt((2*coupling_std/coupling_est)**2 + (kappa_std/kappa_est)**2 + (gamma_std/gamma_est)**2)
    print "Cooperativity :", coop_est, "+/-", coop_std


    return result_list

T_woatom_init = np.array([30, 80, 40])
res_woatom = lorentz_fitting(T_woatom, T_woatom_init)
T_atom_init = np.array([32.7, 30, 6.0659/2, 10, 100, -5])
res_atom = normal_mode_fitting(T_atom[1:], T_atom_init)

data_x = np.linspace(-96, 144, 1000)

# plt.plot(T_woatom[:,0], T_woatom[:,1], marker='o')
# plt.plot(data_x[:], lorentzian(data_x[:], res_woatom[0], res_woatom[2], res_woatom[4]))

plt.plot(T_atom[:,0], T_atom[:,1], marker='o')
plt.plot(data_x[:], normal_mode(data_x[:], res_atom[0], res_atom[2], res_atom[4], res_atom[7], res_atom[9], res_atom[11]))
plt.show()

np.savetxt("fit_T_woatom", np.stack((data_x[:], lorentzian(data_x[:], res_woatom[0], res_woatom[2], res_woatom[4])), axis=1))
np.savetxt("fit_T_atom", np.stack((data_x[:], normal_mode(data_x[:], res_atom[0],res_atom[2], res_atom[4], res_atom[7], res_atom[9], res_atom[11])), axis=1))
