Source code for qmetro.qmtensor.operations

from __future__ import annotations

from collections.abc import Hashable
from typing import get_args

import numpy as np

from .classes import (
    ConstTensor, ParamTensor, VarTensor, Tensor, Scalar
)




[docs] def Tr(tensor: Tensor, *spaces: Hashable, full: bool=False ) -> Tensor | Scalar: """ For Choi-like tensor compute partial trace of the Choi matrix. Parameters ---------- spaces[0...*] : Hashable Spaces to be traced out. full : bool, optional If True then computes the trace over all spaces and returns a complex number, by default False. Returns ------- tensor : Tensor | complex Tensor of the result. """ return tensor.choi_trace(*spaces, full=full)
[docs] def kron(*args: Tensor | Scalar) -> Tensor: """ For Choi-like tensors it computes their Kronecker product and for other tensors it is just contraction where there are no doubled indices (no gets contracted). Parameters ---------- args[0...*] : Hashable Tensors to be Kronecker multiplied. Returns ------- tensor : Tensor Kronecker product. """ x = 1 for i, arg in enumerate(args): if isinstance(arg, (int, float, complex)): x *= arg elif isinstance(arg, get_args(Tensor)): return arg.kron(*args[i + 1:], x) else: raise ValueError(f'Unsupported type: {type(arg)}') raise ValueError('Empty list given.')
[docs] def contr(*args: Tensor | Scalar) -> Tensor: """ Contract tensors. Parameters ---------- others[0...*] : Tensor | Scalar Tensors to be contracted. Returns ------- tensor : Tensor Contraction result. """ x = 1 for i, arg in enumerate(args): if isinstance(arg, (int, float, complex)): x *= arg elif isinstance(arg, get_args(Tensor)): return arg.contr(*args[i + 1:], x) else: raise ValueError(f'Unsupported type: {arg}') raise ValueError('Empty list given.')
[docs] def is_scalar(x: Tensor | Scalar) -> bool: if isinstance(x, get_args(Scalar)): return True if isinstance(x, ParamTensor): return bool( len(x.spaces) == 0 and np.all(x.dtensor.choi([]) == 0) ) if isinstance(x, ConstTensor): return len(x.spaces) == 0 return False
[docs] def is_var(x : Tensor | Scalar) -> bool: return isinstance(x, VarTensor)
[docs] def is_cptp_var(x : Tensor | Scalar) -> bool: return is_var(x) and not x.bond_spaces and not x.is_measurement
[docs] def is_mps_var(x : Tensor | Scalar) -> bool: return bool(is_var(x) and x.bond_spaces and not x.is_measurement)
[docs] def is_comb_var(x : Tensor | Scalar) -> bool: return bool(is_var(x) and x.is_comb)
[docs] def is_mps(x : Tensor | Scalar) -> bool: return bool( is_mps_var(x) or ( isinstance(x, ConstTensor) and x.bond_spaces and not x.input_spaces ) )
[docs] def is_param(x : Tensor | Scalar) -> bool: return isinstance(x, ParamTensor)
[docs] def is_measurement(x: Tensor | Scalar) -> bool: return is_var(x) and x.is_measurement
[docs] def is_sld_mpo(x: Tensor | Scalar) -> bool: return bool(is_measurement(x) and x.bond_spaces)
[docs] def is_full_sld(x: Tensor | Scalar) -> bool: return bool(is_measurement(x) and not x.bond_spaces)
[docs] def is_const(x: Tensor) -> bool: return isinstance(x, ConstTensor)