import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.ticker import MultipleLocator
matplotlib.rcParams.update({'font.size': 18})
matplotlib.rcParams.update({'mathtext.default': 'sf'})

# import area histogram
pnr_area_disc = np.loadtxt('pnr_area_disc.dat')
frequencies = pnr_area_disc[0]
x_val = 2 * pnr_area_disc[1]  # all x vals are area in units mVns

# import fitted components
g1 = np.loadtxt('g1_area_component.dat')
g2 = np.loadtxt('g2_area_component.dat')
g3 = np.loadtxt('g3_area_component.dat')

# combine data to save in a more gnuplot-friendly format
np.savetxt('area_histo_with_fit.dat',
           zip(x_val, frequencies, np.sqrt(frequencies), g1, g2, g3),
           header='area(nVs)\tcounts\tsqrt(counts)\tg1\tg2\tg3')

plt.figure(figsize=(8, 6))
plt.errorbar(x_val, frequencies,
             yerr=np.sqrt(frequencies),
             linestyle='',
             ecolor='black',
             color='black',
             # label='$C(a)$'
             )
plt.plot(x_val, g1, color='blue', label='n=1')
plt.plot(x_val, g2, color='red', label='n=2')
plt.plot(x_val, g3, color='orange', label='n=3')
# plt.axvline(5604.84593463,color='black', label=r'$th_{01}$')
# plt.axvline(15921.253917,color='black', label=r'$th_{12}$', linestyle='--')

plt.text(11.2, 5500, 'n=1', ha='center')
plt.text(22.2, 800, 'n=2', ha='center')

plt.annotate(r'$a_{t_{01}}$',
             xy=(2.75653158938 * 2, 300),
             xytext=(2.75653158938 * 2, 1000),
             arrowprops=dict(facecolor='black', shrink=0.05, width=2),
             ha='center',
             fontsize='large'
             )
plt.annotate(r'$a_{t_{12}}$',
             xy=(7.98319588758 * 2, 300),
             xytext=(7.98319588758 * 2, 1000),
             arrowprops=dict(facecolor='black', shrink=0.05, width=2),
             ha='center',
             fontsize='large'
             )

plt.ylim(0, 6000)
plt.xlim(0, 60)

# ticks = (plt.gca().get_xticks()*10**-3).astype('int')
# plt.gca().set_xticklabels(ticks)

plt.xlabel('area of TES signal (nVs)')
plt.ylabel('number of events')

# plt.legend()
# plt.savefig('area_histo_with_fit.eps')
# plt.savefig('area_histo_with_fit.tiff')

plt.show()
