#!/usr/bin/env python3

"""Computes almost optimally accurate lowrank approximations via randomization.

Functions
---------
iterations
    Conducts subspace iterations.
ssvd
    Computes an almost optimally accurate low-rank approximation.
testlowrank
    Tests the function ssvd.
"""

import numpy as np
from numpy import exp, log
from numpy.random import standard_normal
from scipy.linalg import norm, qr, svd
from gram import gram1, gram2
from randtsvd import randtsvd1, randtsvd2


def iterations(A, k, i=2, l=None, ifgram=False, iffast=True):
    """Conducts subspace iterations to approximate A via Q Q^* A (Q is output),
    where the columns of Q are orthonormal.

    Parameters
    ----------
    A : ndarray
        matrix whose range is being approximated
    k : integer
        rank of the approximation being approximated
    i : integer, optional
        number of subspace iterations to conduct (well, depending on whether
        you count starting from 0)
    l : integer, optional
        rank of the approximation being constructed (defaults to k+2)
    ifgram : boolean, optional
        set to True to orthonormalize using Gram matrices; set to False
        to orthonormalize using a binary reduce tree
    iffast : boolean, optional
        (irrelevant if ifgram is True) set to True to use fast discrete cosine
        transforms when generating random numbers; set to False to use random
        orthonormal Gaussian matrices

    Returns
    -------
    Q : ndarray
        matrix whose l columns are orthonormal, such that the spectral norm
        || Q Q^* A - A || is nearly as small as the (k+1)st singular value of A

    """
    # Set the default value for l if necessary.
    if l is None:
        l = k + 2
    # Form an n x l matrix Q whose entries are independent and identically
    # distributed standard normal variates, where A is m x n.
    Q = standard_normal((A.shape[1], l))
    for _ in range(i):
        # Form (or update) Y = A Q.
        Y = A @ Q
        # Factorize Y into the product of a matrix Q whose columns
        # are orthonormal and a smaller square matrix (the latter could be
        # the right singular vectors of Y scaled by the singular values of Y).
        if ifgram:
            Q, _, _ = gram1(Y)
        else:
            Q, _, _ = randtsvd1(Y, iffast)
        # Form (or, rather, update) Y = A^* Q.
        Y = A.T @ Q
        # Factorize Y into the product of a matrix Q whose columns
        # are orthonormal and a smaller square matrix (the latter could be
        # the right singular vectors of Y scaled by the singular values of Y).
        if ifgram:
            Q, _, _ = gram1(Y)
        else:
            Q, _, _ = randtsvd1(Y, iffast)
    # Form (or update) Y = A Q.
    Y = A @ Q
    # Factorize Y into the product of a matrix Q whose columns
    # are orthonormal and a smaller square matrix (the latter could be
    # the right singular vectors of Y scaled by the singular values of Y),
    # using in this last step double orthonormalization (that is,
    # we use gram2 instead of gram1, or randtsvd2 instead of randtsvd1).
    if ifgram:
        Q, _, _ = gram2(Y)
    else:
        Q, _, _ = randtsvd2(Y, iffast)
    # Return the matrix whose columns are orthonormal.
    return Q


def ssvd(A, k, i=2, l=None, ifgram=False, iffast=True):
    """Computes a rank-l approx. to A nearly as accurate as the best possible
    rank-k approx.

    Parameters
    ----------
    A : ndarray
        matrix being approximated
    k : integer
        rank of the approximation being approximated
    i : integer, optional
        number of subspace iterations to conduct when constructing the approx.
        (well, depending on whether you count starting from 0)
    l : integer, optional
        rank of the approximation being constructed (defaults to k+2)
    ifgram : boolean, optional
        set to True to orthonormalize using Gram matrices; set to False
        to orthonormalize using a binary reduce tree
    iffast : boolean, optional
        (irrelevant if ifgram is True) set to True to use fast discrete cosine
        transforms when generating random numbers; set to False to use random
        orthonormal Gaussian matrices

    Returns
    -------
    U : ndarray
        matrix in the approximation U S V^* to A whose columns are orthonormal,
        where S = diag(s)
    s : ndarray
        vector in the approximation U S V^* to A, where S = diag(s)
    V : ndarray
        matrix in the approximation U S V^* to A whose columns are orthonormal,
        where S = diag(s)

    Notes
    -----
    The spectral norm || A - U S V^* ||, where S = diag(s), is nearly as small
    as the spectral-norm accuracy of the best rank-k approximation to A
    (whereas the rank of U S V^* is l).

    """
    # Calculate a matrix Q whose columns are orthonormal such that
    # the spectral norm || A - Q Q^* A || is small.
    Q = iterations(A, k, i, l, ifgram, iffast)
    # Form the matrix B = Q^* A.
    B = Q.T @ A
    # Compute an SVD of B.
    Utilde, s, Vt = svd(B, full_matrices=False)
    V = Vt.T
    # Form the matrix U = Q Utilde.
    U = Q @ Utilde
    # Return the singular values and singular vectors.
    return U, s, V


def testlowrank():
    """Tests the function ssvd.
    """
    # Set and print the precision of the best-possible approximation.
    prec = .1e-12
    print()
    print(' spectral-norm accuracy of the best-possible approximation =')
    print(prec)
    print()
    # Loop through all test cases.
    for (ifgram, iffast) in [(True, False), (False, True), (False, False)]:
        print()
        print('ifgram = ' + str(ifgram))
        if not ifgram:
            print('iffast = ' + str(iffast))
        print()
        # Set the number m of rows.
        m = 1000
        # Set the number n of columns.
        n = 800
        # Set the rank k.
        k = 10
        # Generate a random m x n matrix A
        # with exponentially decaying singular values plus a flat tail.
        U = standard_normal((m, n))
        U, _ = qr(U, mode='economic')
        V = standard_normal((n, n))
        V, _ = qr(V, mode='economic')
        s = exp(log(prec) * np.arange(k) / (k - 1))
        s = np.append(s, prec * np.ones(n - k))
        A = U * s @ V.T
        # Construct the SVD U S V^* of A, where S = diag(s).
        U, s, V = ssvd(A, k, ifgram=ifgram, iffast=iffast)
        # Reconstruct A from its SVD to obtain Atilde.
        Atilde = U * s @ V.T
        # Construct the difference Delta = A - Atilde.
        Delta = A - Atilde
        # Compute the spectral norm d of Delta.
        d = norm(Delta, ord=2)
        print(' d = spectral-norm accuracy of the approximation =')
        print(d)
        # Assess the numerical orthonormality of the singular vectors.
        print()
        print(' numerical orthonormality of the left singular vectors =')
        print(norm(U.T @ U - np.identity(U.shape[1]), ord=2))
        print(' numerical orthonormality of the right singular vectors =')
        print(norm(V.T @ V - np.identity(V.shape[1]), ord=2))
        print()


if __name__ == '__main__':
    testlowrank()
