import numpy as np
import scipy.linalg # type: ignore
from dataclasses import dataclass
from typing import Optional, Callable
from .cross import (
BlackBox,
CrossResults,
CrossInterpolation,
CrossStrategy,
maxvol_square,
_check_convergence,
)
from ..sampling import random_mps_indices
from ...state import Strategy, DEFAULT_TOLERANCE
from ...state.schmidt import svd
from ...state.core import destructively_truncate_vector
from ...truncate import SIMPLIFICATION_STRATEGY
from ...tools import make_logger
DEFAULT_CROSS_STRATEGY = SIMPLIFICATION_STRATEGY.replace(
normalize=False,
tolerance=DEFAULT_TOLERANCE**2,
simplification_tolerance=DEFAULT_TOLERANCE**2,
)
# TODO: Implement local error evaluation
@dataclass
class CrossStrategyDMRG(CrossStrategy):
strategy: Strategy = DEFAULT_CROSS_STRATEGY
tol_maxvol_square: float = 1.05
maxiter_maxvol_square: int = 10
"""
Dataclass containing the parameters for the DMRG-based TCI.
The common parameters are documented in the base `CrossStrategy` class.
Parameters
----------
strategy : Strategy, default=DEFAULT_CROSS_STRATEGY
Simplification strategy used at the truncation of Schmidt values
at each SVD split of the DMRG superblocks.
tol_maxvol_square : float, default=1.05
Sensibility for the square maxvol decomposition.
maxiter_maxvol_square : int, default=10
Maximum number of iterations for the square maxvol decomposition.
"""
[docs]
def cross_dmrg(
black_box: BlackBox,
cross_strategy: CrossStrategyDMRG = CrossStrategyDMRG(),
initial_points: Optional[np.ndarray] = None,
callback: Optional[Callable] = None,
) -> CrossResults:
"""
Computes the MPS representation of a black-box function using the tensor cross-approximation (TCI)
algorithm based on two-site optimizations in a DMRG-like manner.
The black-box function can represent several different structures. See `black_box` for usage examples.
Parameters
----------
black_box : BlackBox
The black box to approximate as a MPS.
cross_strategy : CrossStrategy, default=CrossStrategy()
A dataclass containing the parameters of the algorithm.
initial_points : np.ndarray, optional
A collection of initial points used to initialize the algorithm.
If None, an initial random point is used.
callback : Callable, optional
A callable called on the MPS after each iteration.
The output of the callback is included in a list 'callback_output' in CrossResults.
Returns
-------
CrossResults
A dataclass containing the MPS representation of the black-box function,
among other useful information.
"""
if initial_points is None:
initial_points = random_mps_indices(
black_box.physical_dimensions,
num_indices=1,
allowed_indices=getattr(black_box, "allowed_indices", None),
rng=cross_strategy.rng,
)
cross = CrossInterpolationDMRG(black_box, initial_points)
converged = False
callback_output = []
with make_logger(2) as logger:
for i in range(cross_strategy.maxiter):
# Forward sweep
direction = True
for k in range(cross.sites - 1):
_update_dmrg(cross, k, direction, cross_strategy)
if callback:
callback_output.append(callback(cross.mps, logger=logger))
if converged := _check_convergence(cross, i, cross_strategy, logger):
break
# Backward sweep
direction = False
for k in reversed(range(cross.sites - 1)):
_update_dmrg(cross, k, direction, cross_strategy)
if callback:
callback_output.append(callback(cross.mps, logger=logger))
if converged := _check_convergence(cross, i, cross_strategy, logger):
break
if not converged:
logger("Maximum number of TT-Cross iterations reached")
points = cross.indices_to_points(direction)
return CrossResults(
mps=cross.mps,
points=points,
evals=black_box.evals,
callback_output=callback_output,
)
class CrossInterpolationDMRG(CrossInterpolation):
def __init__(self, black_box: BlackBox, initial_point: np.ndarray):
super().__init__(black_box, initial_point)
def sample_superblock(self, k: int) -> np.ndarray:
i_l, i_g = self.I_l[k], self.I_g[k + 1]
i_s1, i_s2 = self.I_s[k], self.I_s[k + 1]
mps_indices = self.combine_indices(i_l, i_s1, i_s2, i_g)
return self.black_box[mps_indices].reshape(
(len(i_l), len(i_s1), len(i_s2), len(i_g))
)
def _update_dmrg(
cross: CrossInterpolationDMRG,
k: int,
forward: bool,
cross_strategy: CrossStrategyDMRG,
) -> None:
superblock = cross.sample_superblock(k)
r_l, s1, s2, r_g = superblock.shape
A = superblock.reshape(r_l * s1, s2 * r_g)
## Non-destructive SVD
U, S, V = svd(A, check_finite=False)
destructively_truncate_vector(S, cross_strategy.strategy)
r = S.size
U, S, V = U[:, :r], np.diag(S), V[:r, :]
##
if forward:
if k < cross.sites - 2:
C = U.reshape(r_l * s1, r)
Q, _ = scipy.linalg.qr(C, mode="economic", overwrite_a=True, check_finite=False) # type: ignore
I, G = maxvol_square(
Q, cross_strategy.maxiter_maxvol_square, cross_strategy.tol_maxvol_square # type: ignore
)
cross.I_l[k + 1] = cross.combine_indices(cross.I_l[k], cross.I_s[k])[I]
cross.mps[k] = G.reshape(r_l, s1, r)
else:
cross.mps[k] = U.reshape(r_l, s1, r)
cross.mps[k + 1] = (S @ V).reshape(r, s2, r_g)
else:
if k > 0:
R = V.reshape(r, s2 * r_g)
Q, _ = scipy.linalg.qr( # type: ignore
R.T, mode="economic", overwrite_a=True, check_finite=False
)
I, G = maxvol_square(
Q, cross_strategy.maxiter_maxvol_square, cross_strategy.tol_maxvol_square # type: ignore
)
cross.I_g[k] = cross.combine_indices(cross.I_s[k + 1], cross.I_g[k + 1])[I]
cross.mps[k + 1] = (G.T).reshape(r, s2, r_g)
else:
cross.mps[k] = (U @ S).reshape(r_l, s1, r)
cross.mps[k + 1] = V.reshape(r, s2, r_g)