from __future__ import annotations
import numpy as np
from math import sqrt
from typing import Union, Iterable
from ..tools import InvalidOperation
from ..typing import Weight, Vector, Tensor3
from .environments import scprod
[docs]
class MPSSum:
"""Class representing a weighted sum (or difference) of two or more :class:`MPS`.
This class is an intermediate representation for the linear combination of
MPS quantum states. Assume that :math:`\\psi, \\phi` and :math:`\\xi` are
MPS and :math:`a, b, c` some real or complex numbers. The addition
:math:`a \\psi - b \\phi + c \\xi` can be stored as
`MPSSum([a, -b, c], [ψ, ϕ, ξ])`.
Parameters
----------
weights : list[Weight]
Real or complex numbers representing the weights of the linear combination.
states : list[MPS]
List of matrix product states weighted.
"""
weights: list[Weight]
states: list[MPS]
size: int
#
# This class contains all the matrices and vectors that form
# a Matrix-Product State.
#
__array_priority__ = 10000
def __init__(
self,
weights: Iterable[Weight],
states: Iterable[Union[MPS, MPSSum]],
check_args: bool = True,
):
if check_args:
self.weights = new_weights = []
self.states = new_states = []
for w, s in zip(weights, states):
if isinstance(s, MPS):
new_weights.append(w)
new_states.append(s)
elif isinstance(s, MPSSum):
new_weights.extend(w * wi for wi in s.weights)
new_states.extend(s.states)
else:
raise ValueError(s)
self.size = new_states[0].size
else:
self.weights = weights # type: ignore
self.states = states # type: ignore
self.size = states[0].size # type: ignore
[docs]
def as_mps(self) -> MPS:
return self.join()
[docs]
def copy(self) -> MPSSum:
"""Return a shallow copy of the MPS sum and its data. Does not copy
the states, only the list that stores them."""
return MPSSum(self.weights.copy(), self.states.copy())
def __copy__(self) -> MPSSum:
return self.copy()
def __add__(self, state: Union[MPS, MPSSum]) -> MPSSum:
"""Add `self + state`, incorporating it to the lists."""
match state:
case MPS():
return MPSSum(
self.weights + [1.0], self.states + [state], check_args=False
)
case MPSSum(weights=w, states=s):
return MPSSum(self.weights + w, self.states + s, check_args=False)
case _:
raise InvalidOperation("+", self, state)
def __sub__(self, state: Union[MPS, MPSSum]) -> MPSSum:
"""Subtract `self - state`, incorporating it to the lists."""
match state:
case MPS():
return MPSSum(
self.weights + [-1], self.states + [state], check_args=False
)
case MPSSum(weights=w, states=s):
return MPSSum(
self.weights + [-wi for wi in w], self.states + s, check_args=False
)
case _:
raise InvalidOperation("-", self, state)
def __mul__(self, n: Weight) -> MPSSum:
"""Rescale the linear combination `n * self` for scalar `n`."""
if isinstance(n, (int, float, complex)):
return MPSSum([n * w for w in self.weights], self.states, check_args=False)
raise InvalidOperation("*", self, n)
def __rmul__(self, n: Weight) -> MPSSum:
"""Rescale the linear combination `self * n` for scalar `n`."""
if isinstance(n, (int, float, complex)):
return MPSSum([n * w for w in self.weights], self.states, check_args=False)
raise InvalidOperation("*", n, self)
[docs]
def to_vector(self) -> Vector:
"""Return the wavefunction of this quantum state."""
return sum(wa * A.to_vector() for wa, A in zip(self.weights, self.states)) # type: ignore
def _joined_tensors(self, i: int, L: int) -> Tensor3:
"""Join the tensors from all MPS into bigger tensors."""
As: list[Tensor3] = [s[i] for s in self.states]
if i == 0:
return np.concatenate([w * A for w, A in zip(self.weights, As)], axis=2)
if i == L - 1:
return np.concatenate(As, axis=0)
DL: int = 0
DR: int = 0
d: int
w: Weight = 0
for A in As:
a, d, b = A.shape
DL += a
DR += b
w += A[0, 0, 0]
B = np.zeros((DL, d, DR), dtype=type(w))
DL = 0
DR = 0
for A in As:
a, d, b = A.shape
B[DL : DL + a, :, DR : DR + b] = A
DL += a
DR += b
return B
[docs]
def join(self) -> MPS:
"""Create an `MPS` by combining all tensors from all states in
the linear combination.
Returns
-------
MPS
Quantum state approximating this sum.
"""
L = self.size
return MPS([self._joined_tensors(i, L) for i in range(L)])
[docs]
def join_canonical(self, *args, **kwdargs) -> CanonicalMPS:
"""Similar to join() but return canonical form"""
return CanonicalMPS(self.join(), *args, **kwdargs)
[docs]
def conj(self) -> MPSSum:
"""Return the complex-conjugate of this quantum state."""
return MPSSum(
[np.conj(w) for w in self.weights], [state.conj() for state in self.states]
)
[docs]
def norm_squared(self) -> float:
"""Norm-2 squared :math:`\\Vert{\\psi}\\Vert^2` of this MPS."""
w = self.weights
s = self.states
L = len(w)
return abs(
sum(
(w[i].conjugate() * w[j] * scprod(s[i], s[j])).real
* (1 if i == j else 2)
for i in range(L)
for j in range(i, L)
)
)
[docs]
def norm(self) -> float:
"""Norm-2 :math:`\\Vert{\\psi}\\Vert^2` of this MPS."""
return sqrt(self.norm_squared())
[docs]
def error(self) -> float:
"""Upper bound of the norm-2 error accumulated in this MPS."""
return sum(
abs(weight) * state._error
for weight, state in zip(self.weights, self.states)
)
[docs]
def delete_zero_components(self) -> float:
"""Compute the norm-squared of the linear combination of weights and
states and eliminate states that are zero or have zero weight."""
c: float = 0.0
final_weights: list[Weight] = []
final_states: list[MPS] = []
for wi, si in zip(self.weights, self.states):
wic = wi.conjugate()
ni = (wic * wi).real * si.norm_squared()
if ni:
for wj, sj in zip(final_weights, final_states):
c += 2 * (wic * wj * scprod(si, sj)).real
final_states.append(si)
final_weights.append(wi)
c += ni
L = len(final_weights)
if L < len(self.states):
if L == 0:
self.weights = [0.0]
self.states = [self.states[0].zero_state()]
return 0.0
else:
self.weights = final_weights
self.states = final_states
return abs(c)
from .canonical_mps import CanonicalMPS # noqa: E402
from .mps import MPS # noqa: E402