import numpy as np
from math import *
import matplotlib
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

fontsize = 29
plt.rcParams.update({'font.size': fontsize})
plt.rcParams['font.family'] = 'sans serif'
plt.rcParams['font.sans-serif'] = ['Helvetica']

matplotlib.rcParams['mathtext.fontset'] = 'cm'
matplotlib.rcParams['mathtext.rm'] = 'sans'
plt.rcParams['xtick.major.pad']='7'

def plot(rho, figname):
    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    ax.set_zlim3d(-0.7, 0.7) 
    # for z,height in enumerate(abs(rho)):
    for z,height in enumerate(np.abs(np.real(rho))):
        ax.bar(np.arange(4),height,zs=z,zdir='y',color='b',alpha=0.6)
    plt.savefig(figname+'.pdf')

def plot_real_imaginary(rho, ax, ax2, figtitle, axis_text, color='red', width=0.7, tilt=25, rotate=30):
    ax.set_zlim3d(-0.7, 0.7) 
    for x,height in enumerate((np.real(rho))):
        ax.grid(False)
        ax.set_zticks([-0.5,0,0.5])
        ax.bar3d(
            x = np.ones(4)*x, 
            y = np.arange(4), 
            z = np.zeros(4), 
            dx=np.ones(4)*width, 
            dy=np.ones(4)*width, 
            dz=height,
            color=color
            )
    ax.text(3,-1,0.7,'Re'+axis_text,ha='right',fontsize=str(fontsize))
    xticks = np.arange(4)
    labels = ['HH','HV','VH','VV']
    ax.set_xticks(xticks)
    ax.set_xticklabels(labels, minor=False)
    ax.set_yticks(xticks)
    ax.set_yticklabels(labels, minor=False)
    ax.view_init(tilt, rotate)
    ax.set_zticklabels(['-0.5  ','0 ','0.5 '])


    ax2.set_zlim3d(-0.7, 0.7) 
    for x,height in enumerate((np.imag(rho))):
        ax2.grid(False)
        ax2.set_zticks([-0.5,0,0.5])
        ax2.bar3d(
            x = np.ones(4)*x, 
            y = np.arange(4), 
            z = np.zeros(4), 
            dx=np.ones(4)*width, 
            dy=np.ones(4)*width, 
            dz=height,
            color=color
            )
    ax2.text(3,-1,0.7,'Im'+axis_text,ha='right',fontsize=str(fontsize))
    ax2.set_xticks(xticks)
    ax2.set_xticklabels(labels, minor=False)
    ax2.set_yticks(xticks)
    ax2.set_yticklabels(labels, minor=False)
    ax2.view_init(tilt, rotate)
    ax2.set_zticklabels(['-0.5  ','0 ','0.5 '])

def plot_real_imaginary_before_after(rho_before,rho_after,figname):
    fig, big_axes = plt.subplots(
    figsize=(12,12), 
    nrows=2, 
    ncols=1) #default is (8,6)

    ax = fig.add_subplot(221, projection='3d')
    ax2 = fig.add_subplot(222, projection='3d')
    ax3 = fig.add_subplot(223, projection='3d')
    ax4 = fig.add_subplot(224, projection='3d')

    plot_real_imaginary(rho_before, ax, ax2, 'rho_before', r'$(\rho_o)$')
    plot_real_imaginary(rho_after, ax3, ax4, 'rho_after', r'$(\rho)$')

    for row, big_ax in enumerate(big_axes, start=1):
        # SETS ROW TITLE
        if row==1:
            # SETS TITLE POSITION (X, Y, string); Coordinates are normalized to 1.
            big_ax.text(0.5, 1.03, r"(a) Without circulators, fidelity with $\left|\Psi^-\right\rangle$: 98.2%.", horizontalalignment='center')
        if row==2:
            big_ax.text(0.5, 1.03, r"(b) With circulators, fidelity with $\left|\Psi^-\right\rangle$: 98.4%.", horizontalalignment='center')
        # Turn off axis lines and ticks of the big subplot 
        # obs alpha is 0 in RGBA string!
        big_ax.tick_params(labelcolor=(1.,1.,1., 0.0), top='off', bottom='off', left='off', right='off')
        # removes the white frame
        big_ax._frameon = False

    plt.tight_layout()
    fig.subplots_adjust(bottom=0.01) 
    fig.subplots_adjust(left=0.07) 
    fig.subplots_adjust(right=1-0.01) 
    plt.savefig(figname+'.pdf')

if __name__ == '__main__':

    rho_before = np.loadtxt('rho_before.dat').view(complex)
    rho_after = np.loadtxt('rho_after.dat').view(complex)

    plot_real_imaginary_before_after(rho_before, rho_after, 'temp')
