#!/usr/bin/env python2
import numpy as np

DeltaT = 30e-9
Gamma1 = 6.066
Gamma2 = 0.666


def s_33(Oma, Omb, Gammaa, Gammab, delta, Delta):
    return (Oma**2*Omb**2*(Gammaa*Gammab*((delta - Delta)**2 + (Gammaa + Gammab)**2) + Gammaa*(Gammaa + Gammab)*Oma**2 + (Gammaa + Gammab)**2*Omb**2))/(Delta**4*Gammaa*Gammab**3 + delta**4*Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Oma**2) - 2*delta**3*Delta*Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Oma**2 + Omb**2) + (Gammab*(Gammaa + Gammab) + Oma**2 + Omb**2)*(Gammaa**2*Gammab + 2*Gammab*Oma**2 + Gammaa*Omb**2)*(Gammaa*(Gammab*(Gammaa + Gammab) + Oma**2) + (Gammaa + Gammab)*Omb**2) + Delta**2*Gammab*(Gammaa*(Gammab**2*(2*Gammaa**2 + 2*Gammaa*Gammab + Gammab**2) + 2*Gammab*(Gammaa + 2*Gammab)*Oma**2 + Oma**4) + Gammab*(Gammaa*(3*Gammaa + Gammab) + Oma**2)*Omb**2 + Gammaa*Omb**4) + delta**2*(Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Gammaa*Gammab + 2*Gammab**2 - 2*Oma**2)*(Delta**2 + Gammaa**2 + 2*Oma**2) + (Delta**2*Gammaa*(Gammaa + 5*Gammab) + Gammaa**2*(Gammaa**2 + Gammaa*Gammab + 2*Gammab**2) + 2*(Gammaa + Gammab)**2*Oma**2)*Omb**2 + Gammaa*Gammab*Omb**4) + 2*delta*Delta*(Gammaa*Gammab*(-Gammab**2 + Oma**2)*(Delta**2 + Gammaa**2 + 2*Oma**2) - Gammab*(Gammaa*(Delta**2 + Gammaa**2 + 4*Gammaa*Gammab + Gammab**2) + Gammab*Oma**2)*Omb**2 - Gammaa*(Gammaa + 2*Gammab)*Omb**4))


def s_22(Oma, Omb, Gammaa, Gammab, delta, Delta):
    return (Oma**2*(Gammaa*Gammab*((delta**2 + Gammab**2)*((delta - Delta)**2 + (Gammaa + Gammab)**2) + 2*(delta*(-delta + Delta) + Gammab*(Gammaa + Gammab))*Oma**2 + Oma**4) + (-2*delta*Delta*Gammab**2 + delta**2*(Gammaa**2 + Gammaa*Gammab + Gammab**2) + Gammab*(Delta**2*Gammab + (2*Gammaa + Gammab)*(Gammab*(Gammaa + Gammab) + Oma**2)))*Omb**2 + Gammab*(Gammaa + Gammab)*Omb**4))/(Delta**4*Gammaa*Gammab**3 + delta**4*Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Oma**2) - 2*delta**3*Delta*Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Oma**2 + Omb**2) + (Gammab*(Gammaa + Gammab) + Oma**2 + Omb**2)*(Gammaa**2*Gammab + 2*Gammab*Oma**2 + Gammaa*Omb**2)*(Gammaa*(Gammab*(Gammaa + Gammab) + Oma**2) + (Gammaa + Gammab)*Omb**2) + Delta**2*Gammab*(Gammaa*(Gammab**2*(2*Gammaa**2 + 2*Gammaa*Gammab + Gammab**2) + 2*Gammab*(Gammaa + 2*Gammab)*Oma**2 + Oma**4) + Gammab*(Gammaa*(3*Gammaa + Gammab) + Oma**2)*Omb**2 + Gammaa*Omb**4) + delta**2*(Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Gammaa*Gammab + 2*Gammab**2 - 2*Oma**2)*(Delta**2 + Gammaa**2 + 2*Oma**2) + (Delta**2*Gammaa*(Gammaa + 5*Gammab) + Gammaa**2*(Gammaa**2 + Gammaa*Gammab + 2*Gammab**2) + 2*(Gammaa + Gammab)**2*Oma**2)*Omb**2 + Gammaa*Gammab*Omb**4) + 2*delta*Delta*(Gammaa*Gammab*(-Gammab**2 + Oma**2)*(Delta**2 + Gammaa**2 + 2*Oma**2) - Gammab*(Gammaa*(Delta**2 + Gammaa**2 + 4*Gammaa*Gammab + Gammab**2) + Gammab*Oma**2)*Omb**2 - Gammaa*(Gammaa + 2*Gammab)*Omb**4))


def s_11(Oma, Omb, Gammaa, Gammab, delta, Delta):
    return (delta**4*Gammaa*Gammab*(Delta**2 + Gammaa**2 + Oma**2) + Gammaa*Gammab*(Delta**2 + Gammaa**2 + Oma**2)*(Delta**2*Gammab**2 + (Gammab*(Gammaa + Gammab) + Oma**2)**2) + Gammab*(Delta**2*Gammaa*(3*Gammaa*Gammab + Gammab**2 - Oma**2) + (Gammab*(Gammaa + Gammab) + Oma**2)*(Gammaa**2*(3*Gammaa + 2*Gammab) + (Gammaa + Gammab)*Oma**2))*Omb**2 + Gammaa*(Gammab*(Delta**2 + (Gammaa + Gammab)*(3*Gammaa + Gammab)) + Gammaa*Oma**2)*Omb**4 + Gammaa*(Gammaa + Gammab)*Omb**6 - 2*delta**3*Delta*Gammaa*Gammab*(Delta**2 + Gammaa**2 + Oma**2 + Omb**2) + delta**2*(Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Gammaa*Gammab + 2*Gammab**2 - 2*Oma**2)*(Delta**2 + Gammaa**2 + Oma**2) + (Delta**2*Gammaa*(Gammaa + 5*Gammab) + Gammaa**2*(Gammaa**2 + Gammaa*Gammab + 2*Gammab**2) + (Gammaa + Gammab)**2*Oma**2)*Omb**2 + Gammaa*Gammab*Omb**4) + 2*delta*Delta*Gammaa*(Gammab*(Delta**2 + Gammaa**2 + Oma**2)*(-Gammab**2 + Oma**2) - Gammab*(Delta**2 + Gammaa**2 + 4*Gammaa*Gammab + Gammab**2 - Oma**2)*Omb**2 - (Gammaa + 2*Gammab)*Omb**4))/(Delta**4*Gammaa*Gammab**3 + delta**4*Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Oma**2) - 2*delta**3*Delta*Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Oma**2 + Omb**2) + (Gammab*(Gammaa + Gammab) + Oma**2 + Omb**2)*(Gammaa**2*Gammab + 2*Gammab*Oma**2 + Gammaa*Omb**2)*(Gammaa*(Gammab*(Gammaa + Gammab) + Oma**2) + (Gammaa + Gammab)*Omb**2) + Delta**2*Gammab*(Gammaa*(Gammab**2*(2*Gammaa**2 + 2*Gammaa*Gammab + Gammab**2) + 2*Gammab*(Gammaa + 2*Gammab)*Oma**2 + Oma**4) + Gammab*(Gammaa*(3*Gammaa + Gammab) + Oma**2)*Omb**2 + Gammaa*Omb**4) + delta**2*(Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Gammaa*Gammab + 2*Gammab**2 - 2*Oma**2)*(Delta**2 + Gammaa**2 + 2*Oma**2) + (Delta**2*Gammaa*(Gammaa + 5*Gammab) + Gammaa**2*(Gammaa**2 + Gammaa*Gammab + 2*Gammab**2) + 2*(Gammaa + Gammab)**2*Oma**2)*Omb**2 + Gammaa*Gammab*Omb**4) + 2*delta*Delta*(Gammaa*Gammab*(-Gammab**2 + Oma**2)*(Delta**2 + Gammaa**2 + 2*Oma**2) - Gammab*(Gammaa*(Delta**2 + Gammaa**2 + 4*Gammaa*Gammab + Gammab**2) + Gammab*Oma**2)*Omb**2 - Gammaa*(Gammaa + 2*Gammab)*Omb**4))


def s_31(Oma, Omb, Gammaa, Gammab, delta, Delta):
    return (Oma*Omb*(delta**3*(Delta - 1j*Gammaa)*Gammaa*Gammab - 1j*Delta**3*Gammaa*Gammab**2 - Delta**2*Gammaa*Gammab*(Gammaa*Gammab - Oma**2 + Omb**2) - delta**2*Gammaa*Gammab*((Delta - 1j*Gammaa)*(2*Delta + 1j*Gammab) + Oma**2 + Omb**2) - 1j*Delta*Gammaa*Gammab*(Gammaa + Gammab)*(Gammab*(Gammaa + Gammab) + 2*Oma**2 + Omb**2) - (Gammaa*Gammab*(Gammaa + Gammab) - Gammab*Oma**2 + Gammaa*Omb**2)*(Gammaa*(Gammab*(Gammaa + Gammab) + Oma**2) + (Gammaa + Gammab)*Omb**2) + delta*Gammaa*((Delta - 1j*Gammaa)*Gammab*(Delta**2 + 2j*Delta*Gammab + (Gammaa + Gammab)**2) + 2j*Gammab*(Gammaa + Gammab)*Oma**2 + ((-1j)*Gammaa*(Gammaa + Gammab) + Delta*(Gammaa + 3*Gammab))*Omb**2)))/(Delta**4*Gammaa*Gammab**3 + delta**4*Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Oma**2) - 2*delta**3*Delta*Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Oma**2 + Omb**2) + (Gammab*(Gammaa + Gammab) + Oma**2 + Omb**2)*(Gammaa**2*Gammab + 2*Gammab*Oma**2 + Gammaa*Omb**2)*(Gammaa*(Gammab*(Gammaa + Gammab) + Oma**2) + (Gammaa + Gammab)*Omb**2) + Delta**2*Gammab*(Gammaa*(Gammab**2*(2*Gammaa**2 + 2*Gammaa*Gammab + Gammab**2) + 2*Gammab*(Gammaa + 2*Gammab)*Oma**2 + Oma**4) + Gammab*(Gammaa*(3*Gammaa + Gammab) + Oma**2)*Omb**2 + Gammaa*Omb**4) + delta**2*(Gammaa*Gammab*(Delta**2 + Gammaa**2 + 2*Gammaa*Gammab + 2*Gammab**2 - 2*Oma**2)*(Delta**2 + Gammaa**2 + 2*Oma**2) + (Delta**2*Gammaa*(Gammaa + 5*Gammab) + Gammaa**2*(Gammaa**2 + Gammaa*Gammab + 2*Gammab**2) + 2*(Gammaa + Gammab)**2*Oma**2)*Omb**2 + Gammaa*Gammab*Omb**4) + 2*delta*Delta*(Gammaa*Gammab*(-Gammab**2 + Oma**2)*(Delta**2 + Gammaa**2 + 2*Oma**2) - Gammab*(Gammaa*(Delta**2 + Gammaa**2 + 4*Gammaa*Gammab + Gammab**2) + Gammab*Oma**2)*Omb**2 - Gammaa*(Gammaa + 2*Gammab)*Omb**4))


def laser(delta, lw):
    f = np.exp(-delta**2 / (2 * lw**2))
    # return np.edeltap(-delta**2 / (2 * lw**2)) / (lw * np.sqrt(2 * np.pi))
    return f / sum(f)


def convolve(dat, kernel):
    """ simple convolution """
    # first checks the size of dat and kernel
    try:
        npts = min(len(dat), len(kernel))
        pad = np.ones(npts)
        tmp = np.concatenate((pad * dat[0], dat, pad * dat[-1]))
        out = np.convolve(tmp, kernel, mode='valid')
        noff = int((len(out) - npts) / 2)
        return (out[noff:])[:npts]
    except TypeError:
        npts = len(kernel)
        dat = np.zeros(len(kernel) + 1)
        dat[int(len(kernel) / 2) + 1] = dat[0]
        out = np.convolve(dat, kernel, mode='valid')
        return np.sum(out)


def co_f(delta, Delta, Oma, Omb, delta0):
    delta = - (delta - delta0)
    return np.abs(s_31(Oma, Omb, Gamma1, Gamma2, delta, Delta))**2


def inco_f(delta, Delta, Oma, Omb, delta0):
    delta = - (delta - delta0)
    return s_33(Oma, Omb, Gamma1, Gamma2, delta, Delta)


def single_f(delta, Delta, Oma, Omb, delta0):
    return inco_f(delta, Delta, Oma, Omb, delta0)


def single_lw(delta, Delta, Oma, Omb, delta0, lw):
    return convolve(single_f(delta, Delta, Oma, Omb, delta0), laser(delta, lw))


def pairs_f(delta, Delta, Oma, Omb, delta0):
    return (co_f(delta, Delta, Oma, Omb, delta0) +
            inco_f(delta, Delta, Oma, Omb, delta0)**2 * DeltaT)


def pairs_lw(delta, Delta, Oma, Omb, delta0, lw):
    return convolve(pairs_f(delta, Delta, Oma, Omb, delta0), laser(delta, lw))


def eff(delta, Delta, Oma, Omb, delta0):
    return (pairs_f(delta, Delta, Oma, Omb, delta0) /
            single_f(delta, Delta, Oma, Omb, delta0))


def eff_lw(delta, Delta, Oma, Omb, delta0, lw):
    return convolve(eff(delta, Delta, Oma, Omb, delta0), laser(delta, lw))


def signal_f(delta, pump_a, pump_b, parvals):
    return (parvals['num'] * parvals['etas'] *
            single_lw(delta, parvals['Delta'], np.sqrt(pump_a) * parvals['ma'],
                      np.sqrt(pump_b) * parvals['mb'], parvals['x0'],
                      parvals['lw']) +
            parvals['dc_s'])


def idler_f(delta, pump_a, pump_b, parvals):
    return (parvals['num'] * parvals['etai'] *
            single_lw(delta, parvals['Delta'], np.sqrt(pump_a) * parvals['ma'],
                      np.sqrt(pump_b) * parvals['mb'], parvals['x0'],
                      parvals['lw']) +
            parvals['dc_i'])


def pair_f(delta, pump_a, pump_b, parvals):
    return (parvals['num'] * parvals['etai'] * parvals['etas'] *
            pairs_lw(delta, parvals['Delta'], np.sqrt(pump_a) * parvals['ma'],
                     np.sqrt(pump_b) * parvals['mb'], parvals['x0'],
                     parvals['lw']))


def eff_s_f(delta, pump_a, pump_b, parvals):
    return (pair_f(delta, pump_a, pump_b, parvals) /
            signal_f(delta, pump_a, pump_b, parvals))


def eff_i_f(delta, pump_a, pump_b, parvals):
    return (pair_f(delta, pump_a, pump_b, parvals) /
            idler_f(delta, pump_a, pump_b, parvals))
