Source code for seemps.truncate.simplify

from __future__ import annotations
from typing import Optional, Union
from math import sqrt
import numpy as np
from ..tools import make_logger
from ..state import (
    DEFAULT_TOLERANCE,
    MAX_BOND_DIMENSION,
    MPS,
    CanonicalMPS,
    MPSSum,
    Simplification,
    Strategy,
    Truncation,
)
from ..typing import Weight
from .antilinear import AntilinearForm

# TODO: We have to rationalize all this about directions. The user should
# not really care about it and we can guess the direction from the canonical
# form of either the guess or the state.

SIMPLIFICATION_STRATEGY = Strategy(
    method=Truncation.RELATIVE_NORM_SQUARED_ERROR,
    tolerance=DEFAULT_TOLERANCE,
    max_bond_dimension=MAX_BOND_DIMENSION,
    normalize=True,
    max_sweeps=4,
    simplify=Simplification.VARIATIONAL,
)


[docs] def simplify( state: Union[MPS, MPSSum], strategy: Strategy = SIMPLIFICATION_STRATEGY, direction: int = +1, guess: Optional[MPS] = None, ) -> CanonicalMPS: """Simplify an MPS state transforming it into another one with a smaller bond dimension, sweeping until convergence is achieved. Parameters ---------- state : MPS | MPSSum State to approximate. strategy : Strategy Truncation strategy. Defaults to `SIMPLIFICATION_STRATEGY`. direction : { +1, -1 } Initial direction for the sweeping algorithm. guess : MPS A guess for the new state, to ease the optimization. Returns ------- CanonicalMPS Approximation :math:`\\xi` to the state. """ if isinstance(state, MPSSum): return simplify_mps_sum(state, strategy, direction, guess) # Prepare initial guess normalize = strategy.get_normalize_flag() size = state.size start = 0 if direction > 0 else -1 logger = make_logger(2) # If we only do canonical forms, not variational optimization, a second # pass on that initial guess suffices if strategy.get_simplification_method() == Simplification.CANONICAL_FORM: mps = CanonicalMPS(state, center=start, strategy=strategy) mps = CanonicalMPS(mps, center=-1 - start, strategy=strategy) if logger: logger( f"SIMPLIFY state with |state|={mps.norm():5e}\nusing two-pass " f"canonical form, with tolerance {strategy.get_tolerance():5e}\n" f"produces error {mps.error():5e}.\nStrategy: {strategy}", ) return mps # TODO: DO_NOT_SIMPLIFY should do nothing. However, since the # output is expected to be a CanonicalMPS, we must use the # strategy to construct it. if strategy.get_simplification_method() == Simplification.DO_NOT_SIMPLIFY: mps = CanonicalMPS(state, center=-1 - start, strategy=strategy) if logger: logger( f"SIMPLIFY state with |state|={mps.norm():5e}\nusing single-pass " f"canonical form, with tolerance {strategy.get_tolerance():5e}\n" f"produces error {mps.error():5e}.\nStrategy: {strategy}", ) return mps mps = CanonicalMPS( state if guess is None else guess, center=start, normalize=False, strategy=strategy, ) simplification_tolerance = strategy.get_simplification_tolerance() if not (norm_state_sqr := state.norm_squared()): return CanonicalMPS(state.zero_state(), is_canonical=True) form = AntilinearForm(mps, state, center=start) err = 2.0 if logger: logger( f"SIMPLIFY state with |state|={norm_state_sqr**0.5} for " f"{strategy.get_max_sweeps()} sweeps, with tolerance {simplification_tolerance}.\nStrategy: {strategy}", ) for sweep in range(max(1, strategy.get_max_sweeps())): if direction > 0: for n in range(0, size - 1): mps.update_2site_right(form.tensor2site(direction), n, strategy) form.update_right() last_tensor = mps[size - 1] else: for n in reversed(range(0, size - 1)): mps.update_2site_left(form.tensor2site(direction), n, strategy) form.update_left() last_tensor = mps[0] # # We estimate the error # norm_mps_sqr = np.vdot(last_tensor, last_tensor).real mps_state_scprod = np.vdot(last_tensor, form.tensor1site()) old_err = err err = 2 * abs(1.0 - mps_state_scprod.real / sqrt(norm_mps_sqr * norm_state_sqr)) if logger: logger( f"sweep={sweep}, rel.err.={err:6g}, old err.={old_err:6g}, " f"|mps|={norm_mps_sqr**0.5:6g}, tol={simplification_tolerance:6g}", ) if err < simplification_tolerance or err > old_err: logger("Stopping, as tolerance reached") break direction = -direction total_error_bound = state._error + sqrt(err) if normalize and norm_mps_sqr: factor = sqrt(norm_mps_sqr) last_tensor /= factor total_error_bound /= factor mps._error = total_error_bound logger.close() return mps
# TODO: We have to rationalize all this about directions. The user should # not really care about it and we can guess the direction from the canonical # form of either the guess or the state. def simplify_mps_sum( sum_state: MPSSum, strategy: Strategy = SIMPLIFICATION_STRATEGY, direction: int = +1, guess: Optional[MPS] = None, ) -> CanonicalMPS: """Approximate a linear combination of MPS :math:`\\sum_i w_i \\psi_i` by another one with a smaller bond dimension, sweeping until convergence is achieved. Parameters ---------- state : MPSSum State to approximate guess : MPS, optional Initial guess for the iterative algorithm. strategy : Strategy Truncation strategy. Defaults to `SIMPLIFICATION_STRATEGY`. direction : {+1, -1} Initial direction for the sweeping algorithm. Returns ------- CanonicalMPS Approximation to the linear combination in canonical form """ # Compute norm of output and eliminate zero states norm_state_sqr = sum_state.delete_zero_components() logger = make_logger(2) if not norm_state_sqr: if logger: logger("COMBINE state with |state|=0. Returning zero state.") return CanonicalMPS(sum_state.states[0].zero_state(), is_canonical=True) normalize = strategy.get_normalize_flag() start = 0 if direction > 0 else -1 # CANONICAL_FORM implements a simplification based on two passes if strategy.get_simplification_method() == Simplification.CANONICAL_FORM: mps = CanonicalMPS(sum_state.join(), center=start, strategy=strategy) mps = CanonicalMPS(mps, center=-1 - start, strategy=strategy) if logger: logger( f"COMBINE state with |state|={mps.norm():5e}\nusing two-pass " f"canonical form, with tolerance {strategy.get_tolerance():5e}\n" f"produces error {mps.error():5e}.\nStrategy: {strategy}", ) return mps # TODO: DO_NOT_SIMPLIFY should do nothing. However, since the # output is expected to be a CanonicalMPS, we must use the # strategy to construct it. if strategy.get_simplification_method() == Simplification.DO_NOT_SIMPLIFY: mps = CanonicalMPS(sum_state.join(), center=-1 - start, strategy=strategy) if logger: logger( f"COMBINE state with |state|={mps.norm():5e}\nusing single-pass " f"canonical form, with tolerance {strategy.get_tolerance():5e}\n" f"produces error {mps.error():5e}.\nStrategy: {strategy}", ) return mps # Prepare initial guess mps = CanonicalMPS( sum_state.join() if guess is None else guess, center=start, normalize=False, strategy=strategy, ) simplification_tolerance = strategy.get_simplification_tolerance() size = mps.size weights, states = sum_state.weights, sum_state.states forms = [AntilinearForm(mps, si, center=start) for si in states] if logger: logger( f"COMBINE state with |state|={norm_state_sqr**0.5:5e} for {strategy.get_max_sweeps():5e}" f"sweeps with tolerance {simplification_tolerance:5e}.\nStrategy: {strategy}" f"\nWeights: {weights}", ) err = 2.0 for sweep in range(max(1, strategy.get_max_sweeps())): if direction > 0: for n in range(0, size - 1): mps.update_2site_right( sum(w * f.tensor2site(direction) for w, f in zip(weights, forms)), # type: ignore n, strategy, ) for f in forms: f.update_right() last_tensor = mps[size - 1] else: for n in reversed(range(0, size - 1)): mps.update_2site_left( sum(w * f.tensor2site(direction) for w, f in zip(weights, forms)), # type: ignore n, strategy, ) for f in forms: f.update_left() last_tensor = mps[0] # # We estimate the error # norm_mps_sqr = np.vdot(last_tensor, last_tensor).real mps_state_scprod = np.vdot( last_tensor, sum(w * f.tensor1site() for w, f in zip(weights, forms)), ) old_err = err err = 2 * abs(1.0 - mps_state_scprod.real / sqrt(norm_mps_sqr * norm_state_sqr)) if logger: logger( f"sweep={sweep}, rel.err.={err:6g}, old err.={old_err:6g}, " f"|mps|={norm_mps_sqr**0.5:6g}, tol={simplification_tolerance:6g}", ) if err < simplification_tolerance or err > old_err: logger("Stopping, as tolerance reached") break direction = -direction total_error_bound = sum_state.error() + sqrt(err) if normalize and norm_mps_sqr: factor = sqrt(norm_mps_sqr) last_tensor /= factor total_error_bound /= factor mps._error = total_error_bound logger.close() return mps def combine( weights: list[Weight], states: list[MPS], strategy: Strategy = SIMPLIFICATION_STRATEGY, direction: int = +1, guess: Optional[MPS] = None, ) -> CanonicalMPS: """Deprecated, use `simplify` instead.""" return simplify_mps_sum(MPSSum(weights, states)) __all__ = ["simplify"]