#!/usr/bin/env python3

"""Validate Gram-based tall-skinny singular value decomposition (SVD).

Functions
---------
gram1
    Computes the SVD of a tall-skinny matrix A, via the Gram matrix A^* A,
    orthonormalizing the left singular vectors only once, for efficiency.
gram2
    Computes the SVD of a tall-skinny matrix A, via the Gram matrix A^* A,
    orthonormalizing left singular vectors twice in succession, for accuracy.
testgrams
    Tests the functions gram1 and gram2.
"""

import numpy as np
from numpy import exp, log
from numpy.random import standard_normal
from scipy.linalg import eigh, norm, qr, svd


def gram1(A):
    """Computes the SVD of the tall-skinny matrix A, via the Gram matrix A^* A,
    orthonormalizing the left singular vectors only once, for efficiency.

    Parameters
    ----------
    A : ndarray
        tall-skinny matrix to be decomposed

    Returns
    -------
    U : ndarray
        matrix of left singular vectors
    s : ndarray
        vector of singular values
    V : ndarray
        matrix of right singular vectors

    Notes
    -----
    A = U S V^*, where S = diag(s).

    """
    # Form the Gram matrix B = A^* A.
    B = A.T @ A
    # Calculate the eigendecomposition B = VDV^*.
    _, V = eigh(B)
    # Form Utilde = AV.
    Utilde = A @ V
    # Calculate the vector s of singular values.
    s = norm(Utilde, ord=2, axis=0)
    # Discard singular values and corresponding singular vectors
    # that are numerically zero.
    for k in range(s.size - 1, 0, -1):
        if s[k] < .1e-5 * np.amax(s):
            s = np.delete(s, k, 0)
            Utilde = np.delete(Utilde, k, 1)
            V = np.delete(V, k, 1)
    # Normalize every column of Utilde to obtain U.
    U = Utilde / s
    # Return the singular values and singular vectors.
    return U, s, V


def gram2(A):
    """Computes the SVD of the tall-skinny matrix A, via the Gram matrix A^* A,
    orthonormalizing left singular vectors twice in succession, for accuracy.

    Parameters
    ----------
    A : ndarray
        tall-skinny matrix to be decomposed

    Returns
    -------
    U : ndarray
        matrix of left singular vectors
    s : ndarray
        vector of singular values
    V : ndarray
        matrix of right singular vectors

    Notes
    -----
    A = U S V^*, where S = diag(s).

    """

    # Form the Gram matrix B = A^* A.
    B = A.T @ A
    # Calculate the eigendecomposition B = Vtilde Dtilde Vtilde^*.
    _, Vtilde = eigh(B)
    # Form Ytilde = A Vtilde.
    Ytilde = A @ Vtilde
    # Calculate the vector stilde of singular values.
    stilde = norm(Ytilde, ord=2, axis=0)
    # Discard singular values and corresponding singular vectors
    # that are numerically zero.
    for k in range(stilde.size - 1, 0, -1):
        if stilde[k] < .1e-5 * np.amax(stilde):
            stilde = np.delete(stilde, k, 0)
            Ytilde = np.delete(Ytilde, k, 1)
            Vtilde = np.delete(Vtilde, k, 1)
    # Normalize every column of Ytilde to obtain Y.
    Y = Ytilde / stilde
    # Form the Gram matrix Z = Y^* Y.
    Z = Y.T @ Y
    # Calculate the eigendecomposition Z = WDW^*.
    _, W = eigh(Z)
    # Form Qtilde = YW.
    Qtilde = Y @ W
    # Calculate the vector t of Euclidean norms of the columns of Qtilde.
    t = norm(Qtilde, ord=2, axis=0)
    # Discard singular values and corresponding singular vectors
    # that are numerically zero.
    for k in range(t.size - 1, 0, -1):
        if t[k] < .1e-5 * np.amax(t):
            t = np.delete(t, k, 0)
            Qtilde = np.delete(Qtilde, k, 1)
            W = np.delete(W, k, 1)
    # Normalize every column of Qtilde to obtain Q.
    Q = Qtilde / t
    # Form R = T W^* Stilde Vtilde^*,
    # where T = diag(t) and Stilde = diag(stilde).
    R = (W * t).T @ (Vtilde * stilde).T
    # Calculate the SVD PSV^* of R, where S = diag(s).
    P, s, Vt = svd(R, full_matrices=False)
    V = Vt.T
    # Form U = QP.
    U = Q @ P
    # Return the singular values and singular vectors.
    return U, s, V


def testgrams():
    """Tests the functions gram1 and gram2.
    """
    # Set the number m of rows.
    m = 100
    # Set the number n of columns.
    n = 50
    # Generate a random m x n matrix A
    # with exponentially decaying singular values.
    U = standard_normal((m, n))
    U, _ = qr(U, mode='economic')
    V = standard_normal((n, n))
    V, _ = qr(V, mode='economic')
    prec = .1e-13
    s = exp(log(prec) * np.arange(n) / (n - 1))
    A = U * s @ V.T
    # Construct the SVD U1 S1 V1^* of A, where S1 = diag(s1).
    U1, s1, V1 = gram1(A)
    # Construct the SVD U2 S2 V2^* of A, where S2 = diag(s2).
    U2, s2, V2 = gram2(A)
    # Reconstruct A from its SVD to obtain Atilde1.
    Atilde1 = U1 * s1 @ V1.T
    # Reconstruct A from its SVD to obtain Atilde2.
    Atilde2 = U2 * s2 @ V2.T
    # Construct the difference Delta1 = A - Atilde1.
    Delta1 = A - Atilde1
    # Construct the difference Delta2 = A - Atilde2.
    Delta2 = A - Atilde2
    # Compute the spectral norm d1 of Delta1.
    d1 = norm(Delta1, ord=2)
    print(' d1 = spectral-norm accuracy of the approximation from gram1 =')
    print(d1)
    # Compute the spectral norm d2 of Delta2.
    d2 = norm(Delta2, ord=2)
    print(' d2 = spectral-norm accuracy of the approximation from gram2 =')
    print(d2)
    # Assess the numerical orthonormality of the singular vectors.
    print()
    print(' numerical orthonormality ' +
          'of the left singular vectors from gram1 =')
    print(norm(U1.T @ U1 - np.identity(U1.shape[1]), ord=2))
    print(' numerical orthonormality ' +
          'of the left singular vectors from gram2 =')
    print(norm(U2.T @ U2 - np.identity(U2.shape[1]), ord=2))
    print()
    print(' numerical orthonormality ' +
          'of the right singular vectors from gram1 =')
    print(norm(V1.T @ V1 - np.identity(V1.shape[1]), ord=2))
    print(' numerical orthonormality ' +
          'of the right singular vectors from gram2 =')
    print(norm(V2.T @ V2 - np.identity(V2.shape[1]), ord=2))


if __name__ == '__main__':
    testgrams()
