import numpy as np
import matplotlib.pyplot as plt
import matplotlib
from matplotlib.ticker import MultipleLocator
matplotlib.rcParams.update({'font.size': 18})
matplotlib.rcParams.update({'mathtext.default':'sf'})
# Import histograms
pnr_area_above_th = np.loadtxt('pnr_area_above_th.dat')
pnr_area_disc = np.loadtxt('pnr_area_disc.dat')

plt.figure(figsize=(8,6))
plt.step(2*pnr_area_disc[1], pnr_area_disc[0], where='mid', label='area (with \ndiscriminator)');
plt.step(2*pnr_area_above_th[1], pnr_area_above_th[0], where='mid', linestyle='--', label='area (above \n$V_{th}$)');
# plt.step(pnr_area[:,1],pnr_area[:,0], where='mid', label='without pulse identification');
# plt.axvline(5604.84593463,color='black', label=r'$th_{01}$',linestyle='-.')
# plt.axvline(15921.253917,color='black', label=r'$th_{12}$', linestyle='--')
# limit plot range
plt.ylim(0,7000)
plt.xlim(0,30)

plt.annotate('single-level threshold', 
    xy=(8,6500), 
    xytext=(10,6500),
    arrowprops=dict(facecolor='black', shrink=0.05, width=2),
    horizontalalignment='left', verticalalignment='center'
    )

plt.annotate('two-level threshold', 
    xy=(13,4000), 
    xytext=(15,4000),
    arrowprops=dict(facecolor='black', shrink=0.05, width=2),
    horizontalalignment='left', verticalalignment='center'
    )
plt.annotate('n=1', 
    xy=(11.2,5000), 
    xytext=(11.2,6000),
    arrowprops=dict(facecolor='black', shrink=0.05, width=2),
    horizontalalignment='center', verticalalignment='center'
    )

plt.annotate('n=2', 
    xy=(22.2,375), 
    xytext=(22.2,1375),
    arrowprops=dict(facecolor='black', shrink=0.05, width=2),
    horizontalalignment='center', verticalalignment='center'
    )

#change x axis scaling from mVns to mVus
# ticks = (plt.gca().get_xticks()*10**-3).astype('int')
# plt.gca().set_xticklabels(ticks)

plt.xlabel('area of TES signal (nVs)')
plt.ylabel('number of events')
# plt.legend()

plt.savefig('area_histos_comparison.eps')
plt.savefig('area_histos_comparison.tiff')
# plt.savefig('area_histos_comparison_with_thresholds.eps')

plt.show()