import numpy as np
import scipy.linalg # type: ignore
from typing import TypeVar, Union, Optional, Callable
from dataclasses import dataclass
from .cross import (
CrossInterpolation,
CrossResults,
CrossStrategy,
BlackBox,
_check_convergence,
)
from ..sampling import random_mps_indices
from ...state import MPS
from ...state._contractions import _contract_last_and_first
from ...tools import make_logger, Logger
@dataclass
class CrossStrategyGreedy(CrossStrategy):
tol_pivot: float = 1e-10
partial: bool = True
maxiter_partial: int = 5
points_partial: int = 10
"""
Dataclass containing parameters for TCI with greedy pivot updates.
Supplements the base `CrossStrategy` class.
Parameters
----------
tol_pivot : float, default=1e-12
Minimum allowable error for a pivot, excluding those below this threshold.
The algorithm halts when the maximum pivot error across all sites falls below this limit.
partial : bool, default=True
Whether to use a row-column alternating partial search strategy to find pivots in the superblock.
If False, performs a 'full search' that uses more function evaluations (O(chi) vs. O(chi^2)) but
can introduce potentially smaller errors.
maxiter_partial : int, default=5
Number of row-column iterations in each partial search.
points_partial : int, default=10
Number of initial random points used to initialize each partial search.
"""
[docs]
def cross_greedy(
black_box: BlackBox,
cross_strategy: CrossStrategyGreedy = CrossStrategyGreedy(),
initial_points: Optional[np.ndarray] = None,
callback: Optional[Callable] = None,
) -> CrossResults:
"""
Computes the MPS representation of a black-box function using the tensor cross-approximation (TCI)
algorithm based on two-site optimizations following greedy updates of the pivot matrices.
The black-box function can represent several different structures. See `black_box` for usage examples.
Parameters
----------
black_box : BlackBox
The black box to approximate as a MPS.
cross_strategy : CrossStrategy, default=CrossStrategy()
A dataclass containing the parameters of the algorithm.
initial_points : np.ndarray, optional
A collection of initial points used to initialize the algorithm.
If None, an initial random point is used.
callback : Callable, optional
A callable called on the MPS after each iteration.
The output of the callback is included in a list 'callback_output' in CrossResults.
Returns
-------
CrossResults
A dataclass containing the MPS representation of the black-box function,
among other useful information.
"""
if initial_points is None:
initial_points = random_mps_indices(
black_box.physical_dimensions,
num_indices=1,
allowed_indices=getattr(black_box, "allowed_indices", None),
rng=cross_strategy.rng,
)
cross = CrossInterpolationGreedy(black_box, initial_points)
if cross_strategy.partial == True:
update_method = _update_partial_search
else:
update_method = _update_full_search
pivot_errors = np.zeros((black_box.sites - 1,))
converged = False
callback_output = []
with make_logger(2) as logger:
for i in range(cross_strategy.maxiter):
# Forward sweep
direction = True
for k in range(cross.sites - 1):
pivot_errors[k] = update_method(cross, k, cross_strategy)
if callback:
callback_output.append(callback(cross.mps, logger=logger))
if converged := (
_check_convergence(cross, i, cross_strategy, logger)
or _check_local_convergence(pivot_errors, cross_strategy, logger)
):
break
# Backward sweep
direction = False
for k in reversed(range(cross.sites - 1)):
pivot_errors[k] = update_method(cross, k, cross_strategy)
if callback:
callback_output.append(callback(cross.mps, logger=logger))
if converged := (
_check_convergence(cross, i, cross_strategy, logger)
or _check_local_convergence(pivot_errors, cross_strategy, logger)
):
break
if not converged:
logger("Maximum number of TT-Cross iterations reached")
points = cross.indices_to_points(direction)
return CrossResults(
mps=cross.mps,
points=points,
evals=black_box.evals,
callback_output=callback_output,
)
class CrossInterpolationGreedy(CrossInterpolation):
def __init__(self, black_box: BlackBox, initial_point: np.ndarray):
super().__init__(black_box, initial_point)
self.fibers = [self.sample_fiber(k) for k in range(self.sites)]
self.Q_factors = []
self.R_matrices = []
for fiber in self.fibers[:-1]:
Q, R = self.fiber_to_QR(fiber)
self.Q_factors.append(Q)
self.R_matrices.append(R)
## Translate the initial multiindices I_l and I_g to integer indices J_l and J_g
## TODO: Refactor
def get_row_indices(rows, all_rows):
large_set = {tuple(row): idx for idx, row in enumerate(all_rows)}
return np.array([large_set[tuple(row)] for row in rows])
J_l = []
J_g = []
for k in range(self.sites - 1):
i_l = self.combine_indices(self.I_l[k], self.I_s[k])
J_l.append(get_row_indices(self.I_l[k + 1], i_l))
i_g = self.combine_indices(self.I_l[k], self.I_s[k])
J_g.append(get_row_indices(self.I_l[k + 1], i_g))
self.J_l = [np.array([])] + J_l # add empty indices to respect convention
self.J_g = J_g[::-1] + [np.array([])]
##
G_cores = [self.Q_to_G(Q, j_l) for Q, j_l in zip(self.Q_factors, self.J_l[1:])]
self.mps = MPS(G_cores + [self.fibers[-1]])
# _Index = TypeVar("_Index", bound=Union[np.intp, np.ndarray, slice])
_Index = Union[np.intp, np.ndarray, slice]
def sample_superblock(
self, k: int, j_l: _Index = slice(None), j_g: _Index = slice(None)
) -> np.ndarray:
i_ls = self.combine_indices(self.I_l[k], self.I_s[k])[j_l]
i_ls = i_ls.reshape(1, -1) if i_ls.ndim == 1 else i_ls # Prevent collapse to 1D
i_sg = self.combine_indices(self.I_s[k + 1], self.I_g[k + 1])[j_g]
i_sg = i_sg.reshape(1, -1) if i_sg.ndim == 1 else i_sg
mps_indices = self.combine_indices(i_ls, i_sg)
return self.black_box[mps_indices].reshape((len(i_ls), len(i_sg)))
def sample_skeleton(
self, k: int, j_l: _Index = slice(None), j_g: _Index = slice(None)
) -> np.ndarray:
r_l, r_s1, chi = self.mps[k].shape
chi, r_s2, r_g = self.fibers[k + 1].shape
G = self.mps[k].reshape(r_l * r_s1, chi)[j_l]
R = self.fibers[k + 1].reshape(chi, r_s2 * r_g)[:, j_g]
return _contract_last_and_first(G, R)
def update_indices(self, k: int, j_l: _Index, j_g: _Index) -> None:
i_l = self.combine_indices(self.I_l[k], self.I_s[k])[j_l]
i_g = self.combine_indices(self.I_s[k + 1], self.I_g[k + 1])[j_g]
self.I_l[k + 1] = np.vstack((self.I_l[k + 1], i_l))
self.J_l[k + 1] = np.append(self.J_l[k + 1], j_l) # type: ignore
self.I_g[k] = np.vstack((self.I_g[k], i_g))
self.J_g[k] = np.append(self.J_g[k], j_g) # type: ignore
def update_tensors(
self,
k: int,
r: np.ndarray,
c: np.ndarray,
) -> None:
# Update left fiber, Q-factor and MPS site
r_l, r_s1, chi = self.fibers[k].shape
C = self.fibers[k].reshape(r_l * r_s1, chi)
self.fibers[k] = np.hstack((C, c.reshape(-1, 1))).reshape(r_l, r_s1, chi + 1)
QR_INSERT = False # TODO: Check why it is unstable
if QR_INSERT:
Q = self.Q_factors[k].reshape(r_l * r_s1, chi)
Q, self.R_matrices[k] = scipy.linalg.qr_insert(
Q,
self.R_matrices[k],
u=c,
k=Q.shape[1],
which="col",
rcond=None,
check_finite=False,
)
self.Q_factors[k] = Q.reshape(r_l, r_s1, chi + 1)
else:
self.Q_factors[k], self.R_matrices[k] = self.fiber_to_QR(self.fibers[k])
self.mps[k] = self.Q_to_G(self.Q_factors[k], self.J_l[k + 1])
# Update right fiber, Q-factor and MPS site
chi, r_s2, r_g = self.fibers[k + 1].shape
R = self.fibers[k + 1].reshape(chi, r_s2 * r_g)
self.fibers[k + 1] = np.vstack((R, r)).reshape(chi + 1, r_s2, r_g)
if k < self.sites - 2:
if QR_INSERT:
Q = self.Q_factors[k + 1].reshape(chi * r_s2, r_g)
Q, self.R_matrices[k + 1] = scipy.linalg.qr_insert(
Q,
self.R_matrices[k + 1],
u=r.reshape(-1, Q.shape[1]),
k=Q.shape[0],
which="row",
check_finite=False,
)
self.Q_factors[k + 1] = Q.reshape(chi + 1, r_s2, r_g)
else:
self.Q_factors[k + 1], self.R_matrices[k + 1] = self.fiber_to_QR(
self.fibers[k + 1]
)
self.mps[k + 1] = self.Q_to_G(self.Q_factors[k + 1], self.J_l[k + 2])
else:
self.mps[k + 1] = self.fibers[k + 1]
@staticmethod
def fiber_to_QR(fiber: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
"""Performs the QR decomposition of a fiber."""
r_l, r_s, r_g = fiber.shape
Q, R = scipy.linalg.qr( # type: ignore
fiber.reshape(r_l * r_s, r_g), mode="economic", check_finite=False
)
Q_factor = Q.reshape(r_l, r_s, r_g) # type: ignore
return Q_factor, R
@staticmethod
def Q_to_G(Q_factor: np.ndarray, j_l: np.ndarray) -> np.ndarray:
"""Transforms a Q-factor into a MPS tensor core G."""
r_l, r_s, r_g = Q_factor.shape
Q = Q_factor.reshape(r_l * r_s, r_g)
P = scipy.linalg.inv(Q[j_l], check_finite=False)
G = _contract_last_and_first(Q, P)
return G.reshape(r_l, r_s, r_g)
def _update_full_search(
cross: CrossInterpolationGreedy,
k: int,
cross_strategy: CrossStrategyGreedy,
) -> float:
max_pivots = cross.black_box.base ** (1 + min(k, cross.sites - (k + 2)))
if len(cross.I_g[k]) >= max_pivots or len(cross.I_l[k + 1]) >= max_pivots:
return 0
A = cross.sample_superblock(k)
B = cross.sample_skeleton(k)
error_function = lambda A, B: np.abs(A - B)
diff = error_function(A, B)
j_l, j_g = np.unravel_index(np.argmax(diff), A.shape)
pivot_error = diff[j_l, j_g]
if pivot_error >= cross_strategy.tol_pivot:
cross.update_indices(k, j_l=j_l, j_g=j_g)
cross.update_tensors(k, r=A[j_l, :], c=A[:, j_g])
return pivot_error
def _update_partial_search(
cross: CrossInterpolationGreedy,
k: int,
cross_strategy: CrossStrategyGreedy,
) -> float:
max_pivots = cross.black_box.base ** (1 + min(k, cross.sites - (k + 2)))
if len(cross.I_g[k]) >= max_pivots or len(cross.I_l[k + 1]) >= max_pivots:
return 0
j_l_random = cross_strategy.rng.integers(
low=0,
high=len(cross.I_l[k]) * len(cross.I_s[k]),
size=cross_strategy.points_partial,
)
j_g_random = cross_strategy.rng.integers(
low=0,
high=len(cross.I_s[k + 1]) * len(cross.I_g[k + 1]),
size=cross_strategy.points_partial,
)
A_random = cross.sample_superblock(k, j_l=j_l_random, j_g=j_g_random)
B_random = cross.sample_skeleton(k, j_l=j_l_random, j_g=j_g_random)
error_function = lambda A, B: np.abs(A - B)
diff = error_function(A_random, B_random)
i, j = np.unravel_index(np.argmax(diff), A_random.shape)
j_l, j_g = j_l_random[i], j_g_random[j]
for iter in range(cross_strategy.maxiter_partial):
# Traverse column residual
c_A = cross.sample_superblock(k, j_g=j_g).reshape(-1)
c_B = cross.sample_skeleton(k, j_g=j_g)
new_j_l = np.argmax(error_function(c_A, c_B))
if new_j_l == j_l and iter > 0:
break
j_l = new_j_l
# Traverse row residual
r_A = cross.sample_superblock(k, j_l=j_l).reshape(-1)
r_B = cross.sample_skeleton(k, j_l=j_l)
new_j_g = np.argmax(error_function(r_A, r_B))
if new_j_g == j_g:
break
j_g = new_j_g
pivot_error = error_function(c_A[j_l], c_B[j_l])
if pivot_error >= cross_strategy.tol_pivot:
cross.update_indices(k, j_l=j_l, j_g=j_g)
cross.update_tensors(k, r=r_A, c=c_A)
return pivot_error
def _check_local_convergence(
pivot_errors: np.ndarray,
cross_strategy: CrossStrategyGreedy,
logger: Logger,
) -> bool:
max_pivot_error = np.max(pivot_errors)
if logger:
logger(f"Max. pivot error={max_pivot_error}")
if max_pivot_error < cross_strategy.tol_pivot:
logger(f"State converged within tolerance {cross_strategy.tol_pivot}")
return True
return False