#!/usr/bin/env python3

"""Validate randomized tall-skinny singular value decomposition (SVD).

Functions
---------
randtsvd1
    Computes the SVD of a tall-skinny matrix, via a binary reduce tree,
    orthonormalizing the left singular vectors only once, for efficiency.
randtsvd2
    Computes the SVD of a tall-skinny matrix, via a binary reduce tree,
    orthonormalizing left singular vectors twice in succession, for accuracy.
maketree
    Constructs a binary reduce tree.
tsqr
    Computes the QR factorization of a matrix, via a binary reduce tree.
testrandtsvds
    Tests the functions randtsvd1 and randtsvd2.
"""

import numpy as np
from numpy import exp, log
from numpy.random import standard_normal
from scipy.linalg import lstsq, norm, qr, svd
from fastrand import frandsinit, frands, frandsinv


def randtsvd1(A, iffast=True):
    """Computes the SVD of the tall-skinny matrix A, via a binary reduce tree,
    orthonormalizing the left singular vectors only once, for efficiency.

    Parameters
    ----------
    A : ndarray
        tall-skinny matrix being decomposed
    iffast : boolean, optional
        set to True to use fast discrete cosine transforms when generating
        random numbers; set to False to use random orthogonal Gaussian matrices

    Returns
    -------
    U : ndarray
        matrix of left singular vectors
    s : ndarray
        vector of singular values
    V : ndarray
        matrix of right singular vectors

    Notes
    -----
    A = U S V^*, where S = diag(s).

    """
    if iffast:
        # Initialize the appropriately random orthogonal matrix.
        ds, ps = frandsinit(A.shape[1], 2)
        # Apply the appropriately random orthogonal matrix
        # to every column of A^*, obtaining B.
        B = np.zeros_like(A.T)
        for k in range(A.shape[0]):
            B[:, k] = frands(A.T[:, k], ds, ps)
    else:
        # Construct an appropriately random orthogonal matrix Omega.
        Omega = standard_normal((A.shape[1], A.shape[1]))
        Omega, _ = qr(Omega)
        # Apply the appropriately random orthogonal matrix Omega
        # to every column of A^*, obtaining B.
        B = Omega @ A.T
    # Construct a QR decomposition of B^*.
    Q, R = tsqr(B.T)
    # Calculate the SVD  Utilde S Vtilde^* = R, where S = diag(s).
    Utilde, s, Vtildet = svd(R, full_matrices=False)
    Vtilde = Vtildet.T
    # Form U = Q Utilde.
    U = Q @ Utilde
    if iffast:
        # Apply the inverse of the random orthogonal matrix
        # to every column of Vtilde, obtaining V.
        V = np.zeros_like(Vtilde)
        for k in range(Vtilde.shape[1]):
            V[:, k] = frandsinv(Vtilde[:, k], ds, ps)
    else:
        # Apply the inverse of the random orthogonal matrix Omega
        # to every column of Vtilde, obtaining V (as Omega is orthogonal,
        # its inverse is Omega^*).
        V = Omega.T @ Vtilde
    # Return the singular values and singular vectors.
    return U, s, V


def randtsvd2(A, iffast=True):
    """Computes the SVD of the tall-skinny matrix A, via a binary reduce tree,
    orthonormalizing left singular vectors twice in succession, for accuracy.

    Parameters
    ----------
    A : ndarray
        tall-skinny matrix being decomposed
    iffast : boolean, optional
        set to True to use fast discrete cosine transforms when generating
        random numbers; set to False to use random orthogonal Gaussian matrices

    Returns
    -------
    U : ndarray
        matrix of left singular vectors
    s : ndarray
        vector of singular values
    V : ndarray
        matrix of right singular vectors

    Notes
    -----
    A = U S V^*, where S = diag(s).

    """
    if iffast:
        # Initialize the appropriately random orthogonal matrix.
        ds, ps = frandsinit(A.shape[1], 2)
        # Apply the appropriately random orthogonal matrix
        # to every column of A^*, obtaining B.
        B = np.zeros_like(A.T)
        for k in range(A.shape[0]):
            B[:, k] = frands(A.T[:, k], ds, ps)
    else:
        # Construct an appropriately random orthogonal matrix Omega.
        Omega = standard_normal((A.shape[1], A.shape[1]))
        Omega, _ = qr(Omega)
        # Apply the appropriately random orthogonal matrix Omega
        # to every column of A^*, obtaining B.
        B = Omega @ A.T
    # Construct a QR decomposition of B^*.
    Qtilde, Rtilde = tsqr(B.T)
    # Construct a QR decomposition of Qtilde.
    Q, R = tsqr(Qtilde)
    # Form T = R Rtilde.
    T = R @ Rtilde
    # Calculate the SVD  Utilde S Vtilde^* = T, where S = diag(s).
    Utilde, s, Vtildet = svd(T, full_matrices=False)
    Vtilde = Vtildet.T
    # Form U = Q Utilde
    U = Q @ Utilde
    if iffast:
        # Apply the inverse of the random orthogonal matrix
        # to every column of Vtilde, obtaining V.
        V = np.zeros_like(Vtilde)
        for k in range(Vtilde.shape[1]):
            V[:, k] = frandsinv(Vtilde[:, k], ds, ps)
    else:
        # Apply the inverse of the random orthogonal matrix Omega
        # to every column of Vtilde, obtaining V (as Omega is orthogonal,
        # its inverse is Omega^*).
        V = Omega.T @ Vtilde
    # Return the singular values and singular vectors.
    return U, s, V


def maketree(m, n):
    """Constructs a binary reduce tree.

    Parameters
    ----------
    m : integer
        number of indices to partition
    n : integer
        threshold such that no partition contains less than n indices

    Returns
    -------
    t : dictionary
        binary reduce tree, with each node containing the following entries:
            'low' -- lowest index
            'high' -- highest index
            and with each internal node containing the following entries:
            'blow' -- subtree (a dictionary) with lower indices
            'bhigh' -- subtree (a dictionary) with higher indices

    """

    def makebranch(t, n):
        """Recursively splits into a binary reduce tree.

        Parameters
        ----------
        t : dictionary
            tree to split (if required), with the following entries upon input:
                'low' -- lowest index
                'high' -- highest index
                and with the following entries upon termination (if split):
                'blow' -- subtree (a dictionary) with lower indices
                'bhigh' -- subtree (a dictionary) with higher indices
        n : integer
            threshold such that no partition contains less than n indices

    """
        # Cut the incoming block into two pieces if it is big enough.
        if t['high'] - t['low'] + 1 > 2 * n:
            # Find a splitting point in between t['low'] and t['high'].
            middle = t['low'] + ((t['high'] - t['low']) // 2)
            # Construct the lower branch.
            t['blow'] = {'low': t['low'], 'high': middle}
            makebranch(t['blow'], n)
            # Construct the higher branch.
            t['bhigh'] = {'low': middle + 1, 'high': t['high']}
            makebranch(t['bhigh'], n)

    # Construct the root of the tree.
    t = {'low': 0, 'high': m-1}
    # Grow the tree.
    makebranch(t, n)
    # Return the tree.
    return t


def tsqr(A):
    """Computes the QR factorization of the matrix A, via a binary reduce tree.

    Parameters
    ----------
    A : ndarray
        matrix to be decomposed into a matrix Q whose columns are orthonormal
        times a triangular matrix R

    Returns
    -------
    Q : ndarray
        matrix whose columns are orthonormal such that A = QR
    R : ndarray
        triangular matrix such that A = QR

    Notes
    -----
    Due to roundoff error, the columns of Q may not be numerically orthonormal.

    """

    def tsqrr(A, t):
        """Merges triangular factors from both children into the factor for the
        parent, assuming that the parent has two children; or computes a
        triangular factor for a leaf node; the triangular factor refers to the
        triangular matrix R in a QR factorization.

        Parameters
        ----------
        A : ndarray
            matrix whose blocks are the leaves in the tree t
        t : dictionary
            binary reduce tree constructed by maketree(A.shape[0], A.shape[1])

        Returns
        -------
        R : ndarray
            triangular factor in a QR factorization

        """
        if 'blow' in t:
            # Merge the triangular factors in QR decompositions
            # via their joint QR decomposition.
            Rlow = tsqrr(A, t['blow'])
            Rhigh = tsqrr(A, t['bhigh'])
            Rlohi = np.vstack((Rlow, Rhigh))
            _, R = qr(Rlohi, mode='economic')
        else:
            # Being at a leaf in the tree, construct the QR decomp. directly.
            _, R = qr(A[t['low']:(t['high']+1), :], mode='economic')
        return R

    # Recursively bisect the rows of A.
    t = maketree(A.shape[0], A.shape[1])
    # Recursively merge the triangular factors in QR decompositions,
    # starting with the construction of QR decompositions at the leaves.
    R = tsqrr(A, t)
    # Discard rows of the triangular matrix R for which the diagonal entry
    # is numerically zero.
    for k in range(R.shape[0] - 1, 0, -1):
        if abs(R[k, k]) < .1e-10 * abs(R[0, 0]):
            R = np.delete(R, k, 0)
    # (Forward) solve for Q in a QR decomposition of A with the given R.
    Qt, _, _, _ = lstsq(R.T, A.T, lapack_driver='gelss')
    Q = Qt.T
    # Return Q and R in the QR decomposition of A.
    return Q, R


def testrandtsvds():
    """Tests the functions randtsvd1 and randtsvd2.
    """
    for iffast in [True, False]:
        print()
        print('iffast = ' + str(iffast))
        print()
        # Set the number m of rows.
        m = 1000
        # Set the number n of columns.
        n = 50
        # Generate a random m x n matrix A
        # with exponentially decaying singular values.
        U = standard_normal((m, n))
        U, _ = qr(U, mode='economic')
        V = standard_normal((n, n))
        V, _ = qr(V, mode='economic')
        prec = .1e-30
        s = exp(log(prec) * np.arange(n) / (n - 1))
        A = U * s @ V.T
        # Construct the SVD U1 S1 V1^* of A, where S1 = diag(s1).
        U1, s1, V1 = randtsvd1(A, iffast)
        # Construct the SVD U2 S2 V2^* of A, where S2 = diag(s2).
        U2, s2, V2 = randtsvd2(A, iffast)
        # Reconstruct A from its SVD to obtain Atilde1.
        Atilde1 = U1 * s1 @ V1.T
        # Reconstruct A from its SVD to obtain Atilde2.
        Atilde2 = U2 * s2 @ V2.T
        # Construct the difference Delta1 = A - Atilde1.
        Delta1 = A - Atilde1
        # Construct the difference Delta2 = A - Atilde2.
        Delta2 = A - Atilde2
        # Compute the spectral norm d1 of Delta1.
        d1 = norm(Delta1, ord=2)
        print(' d1 = spectral-norm accuracy of the approximation ' +
              'from randtsvd1 =')
        print(d1)
        # Compute the spectral norm d2 of Delta2.
        d2 = norm(Delta2, ord=2)
        print(' d2 = spectral-norm accuracy of the approximation ' +
              'from randtsvd2 =')
        print(d2)
        # Assess the numerical orthonormality of the singular vectors.
        print()
        print(' numerical orthonormality ' +
              'of the left singular vectors from randtsvd1 =')
        print(norm(U1.T @ U1 - np.identity(U1.shape[1]), ord=2))
        print(' numerical orthonormality ' +
              'of the left singular vectors from randtsvd2 =')
        print(norm(U2.T @ U2 - np.identity(U2.shape[1]), ord=2))
        print()
        print(' numerical orthonormality ' +
              'of the right singular vectors from randtsvd1 =')
        print(norm(V1.T @ V1 - np.identity(V1.shape[1]), ord=2))
        print(' numerical orthonormality ' +
              'of the right singular vectors from randtsvd2 =')
        print(norm(V2.T @ V2 - np.identity(V2.shape[1]), ord=2))
        print()


if __name__ == '__main__':
    testrandtsvds()
