Source code for seemps.cgs

from __future__ import annotations
from typing import Optional, Callable, Any, Union
from .state import (
    MPS,
    MPSSum,
    CanonicalMPS,
    DEFAULT_TOLERANCE,
    DEFAULT_STRATEGY,
    Strategy,
)
from .operators import MPO, MPOList, MPOSum
from .truncate import simplify
from .tools import make_logger


# TODO: Write tests for this
[docs] def cgs( A: Union[MPO, MPOList, MPOSum], b: Union[MPS, MPSSum], guess: Optional[MPS] = None, maxiter: int = 100, strategy: Strategy = DEFAULT_STRATEGY, tolerance: float = DEFAULT_TOLERANCE, callback: Optional[Callable[[MPS, float], Any]] = None, ) -> tuple[CanonicalMPS, float]: """Approximate solution of :math:`A \\psi = b`. Given the :class:`MPO` `A` and the :class:`MPS` `b`, use the conjugate gradient method to estimate another MPS that solves the linear system of equations :math:`A \\psi = b`. Parameters ---------- A : MPO | MPOList | MPOSum Matrix product state that will be inverted b : MPS | MPSSum Right-hand side of the equation maxiter : int, default = 100 Maximum number of iterations strategy : Strategy, default = DEFAULT_STRATEGY Truncation strategy for MPS and MPO operations tolerance : float, default = DEFAULT_TOLERANCE Error tolerance for the algorithm. Results ------- MPS Approximate solution to :math:`A ψ = b` float Norm square of the residual :math:`\\Vert{A \\psi - b}\\Vert^2` """ normb2 = b.norm_squared() if strategy.get_normalize_flag(): strategy = strategy.replace(normalize=False) x = simplify(b if guess is None else guess, strategy=strategy) r = b - A @ x p = simplify(r, strategy=strategy) ρ = r.norm_squared() with make_logger(2) as logger: logger(f"CGS algorithm for {maxiter} iterations") for i in range(maxiter): α = ρ / A.expectation(p).real x = simplify(MPSSum([1, α], [x, p]), strategy=strategy) r = b - A @ x ρ, ρold = r.norm_squared(), ρ if callback is not None: callback(x, ρ) if ρ < tolerance * normb2: logger( f"CGS converged with residual {ρ} below relative tolerance {tolerance}" ) break p = simplify(MPSSum([1.0, ρ / ρold], [r, p]), strategy=strategy) logger(f"CGS step {i:5}: |r|^2={ρ:5g} tol={tolerance:5g}") return x, abs(ρ)