import numpy as np
from qutip import *

F2_m2 = basis(12, 0)
F2_m1 = basis(12, 1)
F2_0  = basis(12, 2)
F2_p1 = basis(12, 3)
F2_p2 = basis(12, 4)

F3_m3 = basis(12, 5)
F3_m2 = basis(12, 6)
F3_m1  = basis(12, 7)
F3_0 = basis(12, 8)
F3_p1 = basis(12, 9)
F3_p2 = basis(12, 10)
F3_p3 = basis(12, 11)

F2_m2m2 = F2_m2 * F2_m2.dag()
F2_m1m1 = F2_m1 * F2_m1.dag()
F2_00 = F2_0 * F2_0.dag()
F2_p1p1 = F2_p1 * F2_p1.dag()
F2_p2p2 = F2_p2 * F2_p2.dag()

F3_m3m3 = F3_m3 * F3_m3.dag()
F3_m2m2 = F3_m2 * F3_m2.dag()
F3_m1m1 = F3_m1 * F3_m1.dag()
F3_00 = F3_0 * F3_0.dag()
F3_p1p1 = F3_p1 * F3_p1.dag()
F3_p2p2 = F3_p2 * F3_p2.dag()
F3_p3p3 = F3_p3 * F3_p3.dag()

F3_m3m2 = F3_m3 * F3_m2.dag()
F3_m2m1 = F3_m2 * F3_m1.dag()
F3_m10 = F3_m1 * F3_0.dag()
F3_0p1 = F3_0 * F3_p1.dag()
F3_p1p2 = F3_p1 * F3_p2.dag()
F3_p2p3 = F3_p2 * F3_p3.dag()

sm_m2m3 = F2_m2 * F3_m3.dag()  # |F=2,-2><F=3,-3|
sm_m2m2 = F2_m2 * F3_m2.dag() 
sm_m2m1 = F2_m2 * F3_m1.dag() 

sm_m1m2 = F2_m1 * F3_m2.dag()  # |F=2,-1><F=3,-2|
sm_m1m1 = F2_m1 * F3_m1.dag() 
sm_m10 = F2_m1 * F3_0.dag() 

sm_0m1 = F2_0 * F3_m1.dag()  # |F=2,0><F=3,-1|
sm_00 = F2_0 * F3_0.dag() 
sm_0p1 = F2_0 * F3_p1.dag() 

sm_p10 = F2_p1 * F3_0.dag()  # |F=2,+1><F=3,0|
sm_p1p1 = F2_p1 * F3_p1.dag() 
sm_p1p2 = F2_p1 * F3_p2.dag() 

sm_p2p1 = F2_p2 * F3_p1.dag()  # |F=2,+2><F=3,+1|
sm_p2p2 = F2_p2 * F3_p2.dag() 
sm_p2p3 = F2_p2 * F3_p3.dag() 


d_m2m3 = 1
d_m2m2 = np.sqrt(1/6)* np.sqrt(2)
d_m2m1 =  np.sqrt(1/30)* np.sqrt(2)

d_m1m2 =  np.sqrt(1/3)* np.sqrt(2)
d_m1m1 = np.sqrt(4/15)* np.sqrt(2)
d_m10 =  np.sqrt(1/10)* np.sqrt(2)

d_0m1 = np.sqrt(1/5)* np.sqrt(2)
d_00 = np.sqrt(3/10)* np.sqrt(2)
d_0p1 =  np.sqrt(1/5)* np.sqrt(2)

d_p10 = np.sqrt(1/10)* np.sqrt(2)
d_p1p1 = np.sqrt(4/15)* np.sqrt(2)
d_p1p2 =  np.sqrt(1/3)* np.sqrt(2)

d_p2p1 = np.sqrt(1/30)* np.sqrt(2)
d_p2p2 = np.sqrt(1/6)* np.sqrt(2)
d_p2p3 = 1

gamma1 = 6.07e6  * 2 * np.pi  # relaxation rate
# collapse operators
c_op_list = []

c_op_list.append(d_m2m3*sm_m2m3 * np.sqrt(gamma1))  
c_op_list.append(d_m2m2*sm_m2m2 * np.sqrt(gamma1))
c_op_list.append(d_m2m1*sm_m2m1 * np.sqrt(gamma1))

c_op_list.append(d_m1m2*sm_m1m2 * np.sqrt(gamma1))
c_op_list.append(d_m1m1*sm_m1m1 * np.sqrt(gamma1))
c_op_list.append(d_m10*sm_m10 * np.sqrt(gamma1))

c_op_list.append(d_0m1*sm_0m1 * np.sqrt(gamma1))
c_op_list.append(d_00*sm_00 * np.sqrt(gamma1))
c_op_list.append(d_0p1*sm_0p1 * np.sqrt(gamma1))

c_op_list.append(d_p10*sm_p10 * np.sqrt(gamma1))
c_op_list.append(d_p1p1*sm_p1p1 * np.sqrt(gamma1))
c_op_list.append(d_p1p2*sm_p1p2 * np.sqrt(gamma1))

c_op_list.append(d_p2p1*sm_p2p1 * np.sqrt(gamma1))
c_op_list.append(d_p2p2*sm_p2p2 * np.sqrt(gamma1))
c_op_list.append(d_p2p3*sm_p2p3 * np.sqrt(gamma1))

def getTxMatrix(psi0, tlist, A_sigma_m, trapdepth, bfieldlist, freqlist ):
#
# parameters for 851nm trap
	GS_lightshift = -20.827e6 * trapdepth *  2 * np.pi 
	ES_lightshift_scalar = -0.741753493 * GS_lightshift
	ES_lightshift_tensor =  0.071636601 * GS_lightshift
	Bsteps=len(bfieldlist)
	k=0
	h=0
	tx_spectrum = np.zeros( (len(freqlist), Bsteps) ) 
	for bfield in bfieldlist:
		g_F2 = 0.7e6 * bfield * 2 * np.pi
		g_F3 = 0.93e6 * bfield * 2 * np.pi
		for delta in freqlist:
			states = qubit_integrate(delta, A_sigma_m, g_F2, g_F3, GS_lightshift, ES_lightshift_scalar, ES_lightshift_tensor,  psi0, tlist, )						
			p_F3m3 = expect( F3_m3m3, states )
			p_F3m2 = expect( F3_m2m2, states )
			p_F3m1 = expect( F3_m1m1, states )
			p_F30 = expect( F3_00, states )
			p_F3p1 = expect( F3_p1p1, states )
			p_F3p2 = expect( F3_p2p2, states )
			p_F3p3 = expect( F3_p3p3, states )
			emittedphotons= (sum(p_F3m3) + sum(p_F3m2) + sum(p_F3m1) + sum(p_F30) + sum(p_F3p1) +  sum(p_F3p2) +  sum(p_F3p3))*gamma1*tlist[-1]/len(tlist)
			tx_spectrum[k,h] = emittedphotons	
			k=k+1
		print(bfield)	
		h=h+1
		k=0
	return tx_spectrum

def qubit_integrate(delta, A_sigma_m, g_F2, g_F3, GS_lightshift, ES_lightshift_scalar, ES_lightshift_tensor, psi0, tlist  ):

    # Hamiltonian
	H0 = -2*g_F2*F2_m2m2 - 1*g_F2*F2_m1m1 + 0*F2_00 + 1*g_F2*F2_p1p1 + 2*g_F2*F2_p2p2 - 3*g_F3*F3_m3m3  - 2*g_F3*F3_m2m2  - 1*g_F3*F3_m1m1  - 0*g_F3*F3_00 + 1*g_F3*F3_p1p1  + 2*g_F3*F3_p2p2  + 3*g_F3*F3_p3p3 - delta*F3_m3m3  - delta*F3_m2m2  - delta*F3_m1m1  - delta*F3_00 - delta*F3_p1p1  - delta*F3_p2p2  - delta*F3_p3p3 
	H_GS_AC = GS_lightshift * ( F2_m2m2 + F2_m1m1 + F2_00 + F2_p1p1 + F2_p2p2 )
	H_m =  A_sigma_m/2.0 *  ( d_m2m3*sm_m2m3 + d_m1m2*sm_m1m2 + d_0m1*sm_0m1 + d_p10*sm_p10 + d_p2p1*sm_p2p1    )    
	H_ES_AC_scalar = ES_lightshift_scalar * (F3_m3m3  + F3_m2m2  + F3_m1m1  + F3_00 + F3_p1p1  + F3_p2p2  + F3_p3p3 )
	Sx_tmp =  ( np.sqrt(3/2)*F3_m3m2  + np.sqrt(5/2)*F3_m2m1  + np.sqrt(3)*F3_m10  + np.sqrt(3)*F3_0p1 + np.sqrt(5/2)*F3_p1p2  +  np.sqrt(3/2)*F3_p2p3 ) 
	Sx = Sx_tmp + Sx_tmp.dag()    
	H_ES_AC_tensor = ES_lightshift_tensor * Sx**2
	
	H = H0 + H_GS_AC + H_ES_AC_scalar + H_m + H_m.dag() + H_ES_AC_tensor
	output = mesolve(H, psi0, tlist, c_op_list, [])  

	return output.states
 


