#!/usr/bin/env python3
"""
script to split the bit_compression .data file into individual repetitions of
the sequence of 4 measurements corresponding to a single basis choice.
It already includes the necessary operations of bit swapping and sorting for
symmetrization.
"""


import numpy as np

from itertools import islice
from itertools import zip_longest
from math import pi

raw_file = 'bit_compression_range_29102014.dat'
sink_file = 'test.dat'
angles = [x / 100. for x in range(0, 15, 2)] + \
    [.15] + [x / 100. for x in range(16, 25, 2)]
angles = [angle / pi * 180 for angle in angles]

sym_table = [0b00, 0b01, 0b10, 0b11]


def bit_redux(x):
    """
    reduction from 4-bit representation to 2-bit representation
    """
    x = int(x)
    return (x & 0x1) + (((x >> 2) & 0x1) << 1)


def grouper(iterable, n, fillvalue=None):
    """Collect data into fixed-length chunks or blocks"""
    # grouper('ABCDEFG', 3, 'x') --> ABC DEF Gxx
    args = [iter(iterable)] * n
    return zip_longest(fillvalue=fillvalue, *args)


def symmetrization(source):
    """
    read four lines,
    reduces the bit representation,
    applies the swapping indicated by sym_table,
    and fill vec by interleaving
    """
    vec = np.array([[bit_redux(k) ^ b for k in line.strip().split()]
                    for line, b
                    in zip(list(islice(source, 4)), sym_table)])

    # interleave the values from the four basis, dropping the extra length
    return np.array([val for pair in zip(*vec) for val in pair])


def p_corr(vec):
    """
    Standard single particle Bell correlation
    """
    return np.sum(2 * ((vec >> 1 & 0x1) ^ (vec & 0x1)) - 1) / len(vec)


def maj(vec):
    """
    implementation of the majority vote
    """
    l = len(vec)
    return (2 * np.sum(vec & 0x1) - l > 0) ^ \
        (2 * np.sum((vec >> 1) & 0x1) - l > 0)


def m_corr(vec, N=3):
    """
    N particle correlation based on majority vote
    """
    # l = len(vec)
    # vec_ar = np.array_split(vec, int(l / N))
    a = [2 * maj(k) - 1 for k in grouper(vec, N, 0)]
    return sum(a) / len(a)


def parityOf(vec):
    parity = 1
    for k in vec:
        parity = (parity + ((k & 0x1) ^ ((k >> 1) & 0x1))) & 0x1
    return 2 * parity - 1


def pp_corr(vec, N=3):
    a = [parityOf(k) for k in grouper(vec, N, 0)]
    return sum(a) / len(a)


# def par_corr(vec, N=3):
#     """
#     N particle correlation based on parity
#     """
#     l = len(vec)
#     vec_ar = np.array_split(vec, int(l / N))
#     a = [parity(k) for k in vec_ar]
#     return sum(a) / len(a)


def bell_value(a0b0, a0b1, a1b0, a1b1, corr, N=1):
    """
    returns bell-operator estimated value
    for a given correlation function
    """
    return np.sum([corr(k, N) * s
                   for k, s
                   in zip([a0b0, a0b1, a1b0, a1b1],
                          [1, -1, 1, 1])])


if __name__ == '__main__':

    with open(raw_file, 'r') as source:
        line_num = sum([1 for _ in source])

    repetitions = int(line_num / len(angles) / 16)

    with open(raw_file, 'r') as source:
        d = np.array([[]] * len(angles) * 4)
        for _ in range(repetitions):
            d = [np.concatenate((x, y)).astype(int)
                 for x, y
                 in zip(d, [symmetrization(source) for _ in range(len(d))])]

    for k, row in enumerate(d):
        with open('meas' + '{:02d}'.format(k) + '.dat', 'w') as f:
            [f.write('{}'.format(j)) for j in row]



    # with open(sink_file, 'w') as f:
    #     [f.write(''.join(map(str, line)) + '\n')
    #      for line
    #      in d]

    # with open(sink_file, 'w') as sink:
    #     sink.write(('angle' + '\t{:.4f}' * len(angles) + '\n').format(*angles))
    #     for N in range(1, 20):
    #         bells = [bell_value(*d[k * 4: k * 4 + 4], pp_corr, N)
    #                  for k
    #                  in range(len(angles))]
    #         sink.write(
    #             ('{}' + '\t{:.5f}' * len(angles) + '\n').format(N, *bells))
    #         print(N)
