import numpy as np
import scipy.linalg # type: ignore
import dataclasses
import functools
from typing import Optional, Callable
from .cross import (
CrossInterpolation,
CrossResults,
CrossStrategy,
BlackBox,
maxvol_square,
_check_convergence,
)
from ..sampling import random_mps_indices
from ...tools import make_logger
# TODO: Implement local error evaluation
@dataclasses.dataclass
class CrossStrategyMaxvol(CrossStrategy):
rank_kick: tuple = (0, 1)
maxiter_maxvol_square: int = 10
tol_maxvol_square: float = 1.05
tol_maxvol_rect: float = 1.05
fortran_order: bool = True
"""
Dataclass containing the parameters for the rectangular maxvol-based TCI.
The common parameters are documented in the base `CrossStrategy` class.
Parameters
----------
rank_kick : tuple, default=(0, 1)
Minimum and maximum rank increase or 'kick' at each rectangular maxvol decomposition.
maxiter_maxvol_square : int, default=10
Maximum number of iterations for the square maxvol decomposition.
tol_maxvol_square : float, default=1.05
Sensibility for the square maxvol decomposition.
tol_maxvol_rect : float, default=1.05
Sensibility for the rectangular maxvol decomposition.
fortran_order: bool, default=True
Whether to use the Fortran order in the computation of the maxvol indices.
"""
[docs]
def cross_maxvol(
black_box: BlackBox,
cross_strategy: CrossStrategyMaxvol = CrossStrategyMaxvol(),
initial_points: Optional[np.ndarray] = None,
callback: Optional[Callable] = None,
) -> CrossResults:
"""
Computes the MPS representation of a black-box function using the tensor cross-approximation (TCI)
algorithm based on one-site optimizations using the rectangular maxvol decomposition.
The black-box function can represent several different structures. See `black_box` for usage examples.
Parameters
----------
black_box : BlackBox
The black box to approximate as a MPS.
cross_strategy : CrossStrategy, default=CrossStrategy()
A dataclass containing the parameters of the algorithm.
initial_points : np.ndarray, optional
A collection of initial points used to initialize the algorithm.
If None, an initial random point is used.
callback : Callable, optional
A callable called on the MPS after each iteration.
The output of the callback is included in a list 'callback_output' in CrossResults.
Returns
-------
CrossResults
A dataclass containing the MPS representation of the black-box function,
among other useful information.
"""
if initial_points is None:
initial_points = random_mps_indices(
black_box.physical_dimensions,
num_indices=1,
allowed_indices=getattr(black_box, "allowed_indices", None),
rng=cross_strategy.rng,
)
cross = CrossInterpolationMaxvol(black_box, initial_points)
converged = False
callback_output = []
with make_logger(2) as logger:
for i in range(cross_strategy.maxiter):
# Forward sweep
for k in range(cross.sites):
_update_maxvol(cross, k, True, cross_strategy)
# Backward sweep
for k in reversed(range(cross.sites)):
_update_maxvol(cross, k, False, cross_strategy)
if callback:
callback_output.append(callback(cross.mps, logger=logger))
if converged := _check_convergence(cross, i, cross_strategy, logger):
break
if not converged:
logger("Maximum number of iterations reached")
points = cross.indices_to_points(False)
return CrossResults(
mps=cross.mps,
points=points,
evals=black_box.evals,
callback_output=callback_output,
)
class CrossInterpolationMaxvol(CrossInterpolation):
def __init__(self, black_box: BlackBox, initial_point: np.ndarray):
super().__init__(black_box, initial_point)
@staticmethod
def combine_indices_fortran(*indices: np.ndarray) -> np.ndarray:
"""
Computes the Cartesian product of a set of multi-indices arrays and arranges the
result as concatenated indices in Fortran order (row-major).
Parameters
----------
indices : *np.ndarray
A variable number of arrays where each array is treated as a set of multi-indices.
Example
-------
>>> combine_indices(np.array([[1, 2, 3], [4, 5, 6]]), np.array([[0], [1]]), fortran_order=True)
array([[1, 2, 3, 0],
[4, 5, 6, 0],
[1, 2, 3, 1],
[4, 5, 6, 1]])
"""
def cartesian_fortran(A: np.ndarray, B: np.ndarray) -> np.ndarray:
A_tiled = np.tile(A, (B.shape[0], 1))
B_repeated = np.repeat(B, repeats=A.shape[0], axis=0)
return np.hstack((A_tiled, B_repeated))
return functools.reduce(cartesian_fortran, indices)
def _update_maxvol(
cross: CrossInterpolationMaxvol,
k: int,
forward: bool,
cross_strategy: CrossStrategyMaxvol,
) -> None:
if cross_strategy.fortran_order is True:
combine_indices = cross.combine_indices_fortran
order = "F"
else:
combine_indices = cross.combine_indices
order = "C"
fiber = cross.sample_fiber(k)
r_l, s, r_g = fiber.shape
if forward:
C = fiber.reshape(r_l * s, r_g, order=order) # type: ignore
Q, _ = scipy.linalg.qr(C, mode="economic", overwrite_a=True, check_finite=False) # type: ignore
I, _ = choose_maxvol(
Q, # type: ignore
cross_strategy.rank_kick,
cross_strategy.maxiter_maxvol_square,
cross_strategy.tol_maxvol_square,
cross_strategy.tol_maxvol_rect,
)
if k < cross.sites - 1:
cross.I_l[k + 1] = combine_indices(cross.I_l[k], cross.I_s[k])[I]
else:
if k > 0:
R = fiber.reshape(r_l, s * r_g, order=order) # type: ignore
Q, _ = scipy.linalg.qr( # type: ignore
R.T, mode="economic", overwrite_a=True, check_finite=False
)
I, G = choose_maxvol(
Q, # type: ignore
cross_strategy.rank_kick,
cross_strategy.maxiter_maxvol_square,
cross_strategy.tol_maxvol_square,
cross_strategy.tol_maxvol_rect,
)
cross.mps[k] = (G.T).reshape(-1, s, r_g, order=order) # type: ignore
cross.I_g[k - 1] = combine_indices(cross.I_s[k], cross.I_g[k])[I]
else:
cross.mps[0] = fiber
def choose_maxvol(
A: np.ndarray,
rank_kick: tuple = (0, np.inf),
maxiter: int = 10,
tol: float = 1.1,
tol_rect: float = 0.1,
) -> tuple[np.ndarray, np.ndarray]:
n, r = A.shape
max_rank_kick = min(rank_kick[1], n - r)
min_rank_kick = min(rank_kick[0], max_rank_kick)
if n < r:
return np.arange(n, dtype=int), np.eye(n)
elif rank_kick == 0:
return maxvol_square(A, maxiter, tol)
else:
return maxvol_rectangular(
A, min_rank_kick, max_rank_kick, maxiter, tol, tol_rect
)
def maxvol_rectangular(
A: np.ndarray,
min_rank_kick: int = 0,
max_rank_kick: float = np.inf,
maxiter: int = 10,
tol: float = 1.1,
tol_rect: float = 1.05,
):
n, r = A.shape
r_min = r + min_rank_kick
r_max = min(r + max_rank_kick, n)
if r_min < r or r_min > r_max or r_max > n:
raise ValueError("Invalid minimum/maximum number of added rows")
I0, B = maxvol_square(A, maxiter, tol)
I = np.hstack([I0, np.zeros(r_max - r, dtype=I0.dtype)]) # type: ignore
S = np.ones(n, dtype=int)
S[I0] = 0
F = S * np.linalg.norm(B) ** 2
for k in range(r, int(r_max)):
i = np.argmax(F)
if k >= r_min and F[i] <= tol_rect**2:
break
I[k] = i
S[i] = 0
v = B.dot(B[i])
l = 1.0 / (1 + v[i])
B = np.hstack([B - l * np.outer(v, B[i]), l * v.reshape(-1, 1)])
F = S * (F - l * v * v)
I = I[: B.shape[1]]
B[I] = np.eye(B.shape[1], dtype=B.dtype)
return I, B