#!/usr/bin/env python3

"""Implement fast random orthogonal transforms based on a complex embedding.

Functions
---------
real2complex
    Converts a real vector into a complex vector of half the length.
complex2real
    Converts a complex vector into a real vector of twice the length.
frandinit
    Creates random numbers for use in the transforms frand and frandinv.
frand
    Implements a fast random orthogonal transform.
frandinv
    Implements the inverse of frand (which is also the adjoint).
frandsinit
    Creates random numbers for use in the transforms frands and frandsinv.
frands
    Implements a fast random orthogonal transform, more random than just frand.
frandsinv
    Implements the inverse of frands (which is also the adjoint).
testfastrand
    Tests real2complex, complex2real, frand, frandinv, frands, & frandsinv.
"""

import numpy as np
from numpy import absolute
from numpy.random import permutation, standard_normal
from scipy.linalg import norm
from scipy.fftpack import dct, idct


def real2complex(a):
    """Converts a real vector into a complex vector of half the length.

    Parameters
    ----------
    a : ndarray
        vector of real numbers

    Returns
    -------
    c : ndarray
        vector of complex numbers (the last entry is real if the length of a
        is odd)

    Notes
    -----
    real2complex(complex2real(c, 2 * c.size)) = c  for any complex vector c.
    complex2real(real2complex(a), a.size) = a  for any real vector a.

    """
    # Calculate the index of the middle of a.
    mid = (a.size + 1) // 2
    # Set b to be the second half of the entries of a.
    if a.size % 2 == 0:
        b = a[mid:]
    else:
        b = np.append(a[mid:], 0)
    # Use the first half of the entries of a for the real part of c,
    # and b for the imaginary part.
    c = a[:mid] + 1j * b
    # Return the complex array c.
    return c


def complex2real(c, n):
    """Converts a complex vector into a real vector of twice the length.

    Parameters
    ----------
    c : ndarray
        vector of complex numbers
    n : integer
        length of the real vector to return (specifying this allows eliminating
        the last entry of the real vector returned, so that the length of the
        returned vector can be odd)

    Returns
    -------
    a : ndarray
        vector of n real numbers

    Notes
    -----
    real2complex(complex2real(c, 2 * c.size)) = c  for any complex vector c.
    complex2real(real2complex(a), a.size) = a  for any real vector a.

    """
    # Extract the real and imaginary parts of c,
    # resizing the latter if n is odd.
    if n == 2 * c.size:
        a = c.real
        b = c.imag
    elif n == 2 * c.size - 1:
        a = c.real
        b = c.imag[0:-1]
    else:
        raise ValueError('Neither ' + str(2 * c.size) + ' nor ' +
                         str(2 * c.size - 1) + ' equals ' + str(n) + '.')
    # Append b to a.
    a = np.append(a, b)
    # Return the concatenated vector a.
    return a


def frandinit(n):
    """Creates random numbers for use in length-n transforms frand & frandinv
    (as well as in frands & frandsinv).

    Parameters
    ----------
    n : integer
        length of the vectors to be transformed in frand & frandinv

    Returns
    -------
    d : ndarray
        complex vector whose entries are random numbers distributed uniformly
        over the unit circle in the complex plane; if n is even, the length
        of d is n/2, whereas if n is odd, then the length of d is (n+1)/2
    p : ndarray
        uniformly random permutation of the integers 0, 1, 2, ..., n-1

    """
    # Generate random numbers distributed uniformly over the unit circle
    # in the complex plane.
    d = standard_normal(2 * ((n + 1) // 2))
    d = real2complex(d)
    d = d / absolute(d)
    # If n is odd, then ensure that the last entry of d is purely real.
    if n % 2 == 1:
        x = d[-1].real
        d[-1] = x / abs(x)
    # Generate a random permutation.
    p = permutation(n)
    # Return both random vectors.
    return d, p


def frand(a, d, p):
    """Transforms the input vector a via a fast discrete cosine transform
    together with entrywise multiplication by d and permutation specified by p,
    performing the entrywise multiplication first, the cosine transform next,
    and the permutation last.

    Parameters
    ----------
    a : ndarray
        real vector being transformed
    d : ndarray
        complex vector obtained from  d, p = frandinit(a.size)
    p : ndarray
        permutation of the integers 0, 1, 2, ..., a.size-1, obtained from
        d, p = frandinit(a.size)

    Returns
    -------
    prbhat : ndarray
        transform of a

    Notes
    -----
    The transform is orthogonal.
    frandinv(frand(a, d, p), d, p) = a  for any real vector a.
    frand(frandinv(prbhat, d, p), d, p) = prbhat  for any real vector prbhat.

    """
    n = a.size
    # Convert a to a complex vector c.
    c = real2complex(a)
    # Multiply c and d entrywise to obtain b.
    b = c * d
    # Convert b to a real vector rb.
    rb = complex2real(b, n)
    # DCT rb to obtain rbhat.
    rbhat = dct(rb, type=2, norm='ortho')
    # Shuffle rbhat according to p, obtaining prbhat.
    prbhat = rbhat[p]
    # Return the transformed vector.
    return prbhat


def frandinv(prbhat, d, p):
    """Transforms the input vector prbhat via a fast discrete cosine transform
    together with entrywise division by d and the permutation specified by p,
    performing the permutation first, the (inverse) cosine transform next, and
    the entrywise division last.

    Parameters
    ----------
    prbhat : ndarray
        real vector being transformed
    d : ndarray
        complex vector obtained from  d, p = frandinit(prbhat.size)
    p : ndarray
        permutation of the integers 0, 1, 2, ..., a.size-1, obtained from
        d, p = frandinit(prbhat.size)

    Returns
    -------
    a : ndarray
        transform of prbhat

    Notes
    -----
    The transform is orthogonal.
    frandinv(frand(a, d, p), d, p) = a  for any real vector a.
    frand(frandinv(prbhat, d, p), d, p) = prbhat  for any real vector prbhat.

    """
    n = prbhat.size
    # Unshuffle prhbat according to p, obtaining rbhat.
    rbhat = np.zeros_like(prbhat)
    rbhat[p] = prbhat
    # Inverse DCT rbhat to obtain rb.
    rb = idct(rbhat, type=2, norm='ortho')
    # Convert rb to a complex vector b.
    b = real2complex(rb)
    # Divide b by d entrywise to obtain c.
    c = b / d
    # Convert c to a real vector a.
    a = complex2real(c, n)
    # Return the transformed vector.
    return a


def frandsinit(n, n_iter=1):
    """Creates random numbers for use in length-n transforms frands & frandsinv
    where frands consist of n_iter iterations of frand in succession and
    frandsinv consists of n_iter iterations of frandinv in succession.

    Parameters
    ----------
    n : integer
        length of the vectors to be transformed in frand & frandinv
    n_iter : integer, optional
        number of times to call frandinit

    Returns
    -------
    ds : list
        list of n_iter complex vectors obtained from n_iter calls to frandinit
    ps : list
        list of n_iter permutations of the integers 0, 1, 2, ..., n-1, obtained
        from n_iter calls to frandinit

    """
    ds = []
    ps = []
    for _ in range(n_iter):
        # Construct random numbers for the fast random transforms.
        d, p = frandinit(n)
        # Append the new numbers to ds and ps.
        ds.append(d)
        ps.append(p)
    return ds, ps


def frands(a, ds, ps):
    """Transforms the input vector a via len(ds) = len(ps) calls to frand.

    Parameters
    ----------
    a : ndarray
        real vector being transformed
    ds : list
        list of complex vectors obtained from  ds, ps = frandsinit(a.size)
        or  ds, ps = frandsinit(a.size, n_iter)
    ps : list
        list of permutations of the integers 0, 1, 2, ..., a.size-1, obtained
        from ds, ps = frandsinit(a.size) or ds, ps = frandsinit(a.size, n_iter)

    Returns
    -------
    b : ndarray
        transform of a

    Notes
    -----
    The transform is orthogonal.
    frandsinv(frands(a, ds, ps), ds, ps) = a  for any real vector a.
    frands(frandsinv(b, ds, ps), ds, ps) = b  for any real vector b.

    """
    # Initialize b to be the input a.
    b = a
    # Iteratively update b.
    for it in range(len(ds)):
        b = frand(b, ds[it], ps[it])
    return b


def frandsinv(b, ds, ps):
    """Transforms the input vector b via len(ds) = len(ps) calls to frandinv.

    Parameters
    ----------
    b : ndarray
        real vector being transformed
    ds : list
        list of complex vectors obtained from  ds, ps = frandsinit(b.size)
        or  ds, ps = frandsinit(b.size, n_iter)
    ps : list
        list of permutations of the integers 0, 1, 2, ..., b.size-1, obtained
        from ds, ps = frandsinit(b.size) or ds, ps = frandsinit(b.size, n_iter)

    Returns
    -------
    a : ndarray
        transform of b

    Notes
    -----
    The transform is orthogonal.
    frandsinv(frands(a, ds, ps), ds, ps) = a  for any real vector a.
    frands(frandsinv(b, ds, ps), ds, ps) = b  for any real vector b.

    """
    # Initialize a to be the input b.
    a = b
    # Iteratively update a.
    for it in range(len(ds)):
        a = frandinv(a, ds[-it-1], ps[-it-1])
    return a


def testfastrand(n):
    """Tests real2complex, complex2real, frand, frandinv, frands, & frandsinv.

    Parameters
    ----------
    n : integer
        dimension of all real vectors

    """
    # Generate a random vector.
    v = standard_normal(n)
    # Convert v to a complex vector.
    c = real2complex(v)
    # Convert c back to a real vector.
    w = complex2real(c, n)
    np.testing.assert_array_equal(v, w)

    # Construct random numbers for frand and frandinv.
    d, p = frandinit(n)
    # Transform v according to d and p.
    vdp = frand(v, d, p)
    np.testing.assert_almost_equal(norm(v), norm(vdp))
    # Inverse transform vdp according to d and p.
    u = frandinv(vdp, d, p)
    np.testing.assert_almost_equal(u, v)

    # Construct random numbers for frands and frandsinv.
    ds, ps = frandsinit(n, 3)
    # Transform v according to ds and ps.
    vdps = frands(v, ds, ps)
    np.testing.assert_almost_equal(norm(v), norm(vdps))
    # Inverse transform vdps according to ds and ps.
    t = frandsinv(vdps, ds, ps)
    np.testing.assert_almost_equal(v, t)


if __name__ == '__main__':
    testfastrand(110)
    testfastrand(111)
    print('All tests succeeded!')
