import matplotlib.pyplot as plt
import numpy as np
from qutip import *
from lintrap_sim_hamiltonian import *
# set magnetic field range. 
voltagelist_to_currentdriver = np.array([9])
volt_to_B = 1.604 # Gauss/V
bfieldlist = voltagelist_to_currentdriver * volt_to_B # in Gauss
# set trap depth
trapdepth_list = np.arange(0.0,3.0,0.05)
#set time steps for sim etc
psi0 = basis(12,0)
tsteps = 10000
tfinal = 1000e-6
tlist = np.linspace(0.0, tfinal, tsteps)
tstep=tlist[1]-tlist[0]

# norminal power
A_sigma_m = 0.0515 * gamma1   #
max_ph = (A_sigma_m**2/4)/(gamma1**2/4+A_sigma_m**2/2)*gamma1*tfinal
print('max photons: '+ str(max_ph))
outfile = open('sim_tx_over_trap_depth.dat','w')
for trap in trapdepth_list:
	print(trap)
	freqlist = np.linspace(4, 11, 26)* 1e6 * 2 * np.pi + (trap-0.88)*32e6 * 2 * np.pi
	emission_spec = getTxMatrix(psi0, tlist, A_sigma_m, trap, bfieldlist, freqlist)
	spec = emission_spec / max_ph
	eta = np.max(spec)
	print(spec)
	for item in [trap, eta]:
			outfile.write("%s " % item)	
	outfile.write("\n") 

outfile.close()
	
#----------- high probe power
A_sigma_high = A_sigma_m*np.sqrt(1.1)

max_ph = (A_sigma_high**2/4)/(gamma1**2/4+A_sigma_high**2/2)*gamma1*tfinal
print('max photons: '+ str(max_ph))
outfile = open('sim_tx_over_trap_depth_highpower.dat','w')
for trap in trapdepth_list:
	print(trap)
	freqlist = np.linspace(4, 11, 26)* 1e6 * 2 * np.pi + (trap-0.88)*32e6 * 2 * np.pi
	emission_spec = getTxMatrix(psi0, tlist, A_sigma_high, trap, bfieldlist, freqlist)
	spec = emission_spec / max_ph
	eta = np.max(spec)
	print(spec)
	for item in [trap, eta]:
			outfile.write("%s " % item)	
	outfile.write("\n")
outfile.close()
#----------- low probe power 
A_sigma_low = A_sigma_m*np.sqrt(0.9)

max_ph = (A_sigma_low**2/4)/(gamma1**2/4+A_sigma_low**2/2)*gamma1*tfinal
print('max photons: '+ str(max_ph))
outfile = open('sim_tx_over_trap_depth_lowpower.dat','w')
for trap in trapdepth_list:
	print(trap)
	freqlist = np.linspace(4, 11, 26)* 1e6 * 2 * np.pi + (trap-0.88)*32e6 * 2 * np.pi
	emission_spec = getTxMatrix(psi0, tlist, A_sigma_low, trap, bfieldlist, freqlist)
	spec = emission_spec / max_ph
	eta = np.max(spec)
	print(spec)
	for item in [trap, eta]:
			outfile.write("%s " % item)	
	outfile.write("\n")
outfile.close()
  
