import numpy as np
import scipy.odr as odr

import matplotlib.pyplot as plt

from lmfit import Parameters
from lmfit.models import LinearModel
from lmfit.models import QuadraticModel

""" Data import """
raw_data = np.genfromtxt('plot_data.dat')
OD = raw_data[:, 1]
OD_err = raw_data[:, 2]
g_max = raw_data[:, 3]
g_err = raw_data[:, 4]
x_fit = np.linspace(min(OD), max(OD), 1000)


fit_model = LinearModel()
p = Parameters()
p.add('slope', 20)
p.add('intercept', -25)
result = fit_model.fit(g_max, x=OD, params=p)
print(result.fit_report())

fit_model2 = QuadraticModel()
p2 = Parameters()
p2.add('a', 1)
p2.add('b', 0, vary=0)
p2.add('c', 0, vary=0)
result_2 = fit_model2.fit(g_max, x=OD, params=p2, weights=1 / g_err)
print(result_2.fit_report())

plt.figure('lmfit')
plt.errorbar(OD, g_max, yerr=g_err, xerr=OD_err, fmt='o')
plt.plot(x_fit, result.eval(x=x_fit))
plt.plot(x_fit, result_2.eval(x=x_fit))


"""
ODR section
"""


def lin_func(p, x):
    a, b = p
    return a * x + b


def quad_func(p, x):
    a = p
    return a * x**2


quad_model = odr.Model(quad_func)
lin_model = odr.Model(lin_func)

data = odr.Data(OD, g_max, wd=1 / OD_err, we=1 / g_err)

odr_fit_lin = odr.ODR(data, lin_model, beta0=[25, -10])
out_lin = odr_fit_lin.run()
print('Residual variance {}'.format(out_lin.res_var))

odr_fit = odr.ODR(data, quad_model, beta0=[.7])
out2 = odr_fit.run()
print('beta: {}\nResidual variance: {}'.format(out2.beta, out2.res_var))

plt.figure('odr')
plt.errorbar(OD, g_max, yerr=g_err, xerr=OD_err, fmt='o')
plt.plot(x_fit, lin_func(out_lin.beta, x_fit))
plt.plot(x_fit, quad_func(out2.beta, x_fit))

plt.show()
