diff --git a/grassmann_tensor/__init__.py b/grassmann_tensor/__init__.py index a6dde9b..f2dcb4a 100644 --- a/grassmann_tensor/__init__.py +++ b/grassmann_tensor/__init__.py @@ -2,7 +2,7 @@ A Grassmann algebra tensor package. """ -__all__ = ["__version__", "GrassmannTensor"] +__all__ = ["__version__", "GrassmannTensor", "NamedGrassmannTensor"] from .version import __version__ -from .tensor import GrassmannTensor +from .tensor import GrassmannTensor, NamedGrassmannTensor diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 30269c1..944a31a 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -4,12 +4,13 @@ from __future__ import annotations -__all__ = ["GrassmannTensor"] +__all__ = ["GrassmannTensor", "NamedGrassmannTensor"] import dataclasses import functools import typing import math +import operator import torch @@ -718,10 +719,6 @@ def svd( f"Cutoff must be an integer or a tuple of two integers, but got {cutoff}" ) - assert (k_even > 0 or n_even == 0) and (k_odd > 0 or n_odd == 0), ( - "Per-block cutoff must be compatible with available singulars" - ) - keep_even = torch.zeros(n_even, dtype=torch.bool, device=S_even.device) keep_odd = torch.zeros(n_odd, dtype=torch.bool, device=S_odd.device) if k_even > 0: @@ -1354,3 +1351,735 @@ def __copy__(self) -> GrassmannTensor: def __deepcopy__(self, memo: dict) -> GrassmannTensor: return self.clone() + + +@dataclasses.dataclass +class NamedGrassmannTensor: + _names: tuple[str, ...] + _arrow: tuple[bool, ...] + _edges: tuple[tuple[int, int], ...] + _tensor: torch.Tensor + _parity: tuple[torch.Tensor, ...] | None = None + _mask: torch.Tensor | None = None + + _name_dict: dict[str, int] = dataclasses.field(init=False, repr=False) + _gt: GrassmannTensor = dataclasses.field(init=False, repr=False) + + @property + def names(self) -> tuple[str, ...]: + """ + The names of the tensor, represented as a tuple of str. + """ + return self._names + + @property + def arrow(self) -> tuple[bool, ...]: + return self._arrow + + @property + def edges(self) -> tuple[tuple[int, int], ...]: + return self._edges + + @property + def tensor(self) -> torch.Tensor: + return self._tensor + + @property + def gt(self) -> GrassmannTensor: + return self._gt + + @property + def parity(self) -> tuple[torch.Tensor, ...]: + if self._parity is None: + self._parity = self.gt.parity + return self._parity + + @property + def mask(self) -> torch.Tensor: + if self._mask is None: + self._mask = self.gt.mask + return self._mask + + def update_mask(self) -> NamedGrassmannTensor: + tensor = self.gt.update_mask() + return dataclasses.replace( + self, + _tensor=tensor.tensor, + ) + + def rename(self, name_map: dict[str, str]) -> NamedGrassmannTensor: + if not name_map: + return self + + names = tuple(name_map.get(name, name) for name in self.names) + + if len(set(names)) != len(names): + raise ValueError(f"Duplicate names after rename: {names}") + + return dataclasses.replace(self, _names=names) + + def __post_init__(self) -> None: + assert len(self._names) == len(set(self._names)), ( + f"Names must be unique, but got {self._names}" + ) + assert len(self._names) == self._tensor.dim(), ( + f"Names length ({len(self._names)}) must match tensor dimensions ({self._tensor.dim()})." + ) + object.__setattr__(self, "_name_dict", {name: i for i, name in enumerate(self._names)}) + gt = GrassmannTensor( + _arrow=self._arrow, + _edges=self._edges, + _tensor=self._tensor, + _parity=self._parity, + _mask=self._mask, + ) + object.__setattr__(self, "_gt", gt) + + def to( + self, + whatever: torch.device | torch.dtype | str | None = None, + *, + device: torch.device | None = None, + dtype: torch.dtype | None = None, + ) -> NamedGrassmannTensor: + tensor = self.gt.to(whatever, device=device, dtype=dtype) + return dataclasses.replace( + self, + _tensor=tensor.tensor, + _parity=None, + _mask=None, + ) + + def get_name_index(self, name: str) -> int: + try: + return self._name_dict[name] + except KeyError: + raise KeyError(f"{name!r} not in names list {self._names!r}") from None + + def permute(self, before_by_after: tuple[str, ...]) -> NamedGrassmannTensor: + order = tuple(self.get_name_index(name) for name in before_by_after) + tensor = self.gt.permute(order) + return dataclasses.replace( + self, + _names=before_by_after, + _arrow=tensor.arrow, + _edges=tensor.edges, + _tensor=tensor.tensor, + _parity=None, + _mask=None, + ) + + def reverse(self, reversed_names: set[str]) -> NamedGrassmannTensor: + assert len(reversed_names) == len(set(reversed_names)), ( + f"Indices must be unique, but got {reversed_names}" + ) + indices = tuple(self.get_name_index(name) for name in reversed_names) + tensor = self.gt.reverse(indices) + return dataclasses.replace( + self, + _arrow=tensor.arrow, + _edges=tensor.edges, + _tensor=tensor.tensor, + ) + + def _merge_edge_get_names(self, merge_map: dict[str, tuple[str, ...]]) -> tuple[str, ...]: + reserved_names: list[str] = [] + for name in self.names: + found = next( + ( + (new_name, old_names) + for new_name, old_names in merge_map.items() + if name in old_names + ), + None, + ) + if found is None: + reserved_names.append(name) + else: + new_name, old_names = found + if name == old_names[0]: + reserved_names.append(new_name) + return tuple(reserved_names) + + @staticmethod + def _merge_edge_get_name_group( + name: str, merge_map: dict[str, tuple[str, ...]] + ) -> tuple[str, ...]: + merge_group = merge_map.get(name, None) + return (name,) if merge_group is None else merge_group + + def merge_edge( + self, + merge_map: dict[str, tuple[str, ...]], + ) -> NamedGrassmannTensor: + all_old_names = [old_name for group in merge_map.values() for old_name in group] + assert len(all_old_names) == len(set(all_old_names)), ( + f"Names must be unique, but got {all_old_names}" + ) + assert all(len(old_names) > 0 for old_names in merge_map.values()), ( + "Merge edge does not support empty old_names." + ) + assert all( + all(old_name in self.names for old_name in old_names) + for old_names in merge_map.values() + ), f"Old names must be in names list, but got {merge_map.values()}" + + merge_map = { + new_name: tuple(sorted(group, key=self.get_name_index)) + for new_name, group in merge_map.items() + } + + names = self._merge_edge_get_names(merge_map) + + permuted_names: list[str] = functools.reduce( + operator.add, + (list(self._merge_edge_get_name_group(name, merge_map)) for name in names), + [], + ) + + permuted_tensor = self.permute(tuple(permuted_names)) + + new_edges: list[tuple[int, int]] = [] + + for new_name in names: + if new_name in merge_map: + old_names = merge_map[new_name] + merge_edges = tuple( + permuted_tensor.edges[permuted_tensor.get_name_index(old_name)] + for old_name in old_names + ) + even, odd = permuted_tensor.gt.calculate_even_odd(merge_edges) + new_edges.append((even, odd)) + else: + index = permuted_tensor.get_name_index(new_name) + new_edges.append(permuted_tensor.edges[index]) + + merged_tensor = permuted_tensor.gt.reshape(tuple(new_edges)) + + return dataclasses.replace( + self, + _names=names, + _arrow=merged_tensor.arrow, + _edges=merged_tensor.edges, + _tensor=merged_tensor.tensor, + _parity=None, + _mask=None, + ) + + def to_scalar(self) -> NamedGrassmannTensor: + tensor = self.gt.reshape(()) + return dataclasses.replace( + self, _names=(), _arrow=(), _edges=(), _tensor=tensor.tensor, _parity=None, _mask=None + ) + + @staticmethod + def _split_edge_get_name_group( + name: str, + split_map: dict[str, tuple[tuple[str, tuple[int, int]], ...]], + ) -> list[str]: + split_group = split_map.get(name, None) + return [name] if split_group is None else [new_name for new_name, _ in split_group] + + @staticmethod + def _split_edge_get_edge_group( + name: str, + edge: tuple[int, int], + split_map: dict[str, tuple[tuple[str, tuple[int, int]], ...]], + ) -> list[tuple[int, int]]: + split_group = split_map.get(name, None) + return [edge] if split_group is None else [new_edge for _, new_edge in split_group] + + def split_edge( + self, split_map: dict[str, tuple[tuple[str, tuple[int, int]], ...]] + ) -> NamedGrassmannTensor: + new_names: tuple[str, ...] + new_edges: tuple[tuple[int, int], ...] + if len(self.names) == 0: + assert set(split_map.keys()) == {""}, ( + "For scalar tensor, split_map must have only key ''." + ) + new_group = split_map[""] + new_names = tuple(new_name for new_name, _ in new_group) + new_edges = tuple(new_edge for _, new_edge in new_group) + + split_tensor = self.gt.reshape(new_edges) + return dataclasses.replace( + self, + _names=new_names, + _arrow=split_tensor.arrow, + _edges=split_tensor.edges, + _tensor=split_tensor.tensor, + _parity=None, + _mask=None, + ) + + assert all(old_name in self.names for old_name in split_map.keys()), ( + f"Old name must be in names {self.names}" + ) + + new_names = tuple( + functools.reduce( + operator.add, + (self._split_edge_get_name_group(name, split_map) for name in self.names), + [], + ) + ) + + new_edges = tuple( + functools.reduce( + operator.add, + ( + self._split_edge_get_edge_group(name, edge, split_map) + for name, edge in zip(self.names, self.edges) + ), + [], + ) + ) + + split_tensor = self.gt.reshape(new_edges) + + return dataclasses.replace( + self, + _names=tuple(new_names), + _arrow=split_tensor.arrow, + _edges=split_tensor.edges, + _tensor=split_tensor.tensor, + _parity=None, + _mask=None, + ) + + def matmul(self, other: NamedGrassmannTensor) -> NamedGrassmannTensor: + tensor_a = self + tensor_b = other + + vector_a = tensor_a.tensor.dim() == 1 + vector_b = tensor_b.tensor.dim() == 1 + + names: list[str] = [] + for i in range(-max(max(tensor_a.tensor.dim(), 2), max(tensor_b.tensor.dim(), 2)), -2): + candidate_a = candidate_b = 1 + name_a = name_b = None + if i >= -tensor_a.tensor.dim(): + candidate_a, _ = tensor_a.edges[i] + name_a = tensor_a.names[i] + if i >= -tensor_b.tensor.dim(): + candidate_b, _ = tensor_b.edges[i] + name_b = tensor_b.names[i] + if candidate_a >= candidate_b: + picked_name = name_a if name_a is not None else name_b + else: + picked_name = name_b if name_b is not None else name_a + names.append(typing.cast(str, picked_name)) + + if not vector_a: + names.append(tensor_a.names[-2]) + if not vector_b: + names.append(tensor_b.names[-1]) + + tensor = tensor_a.gt @ tensor_b.gt + + return NamedGrassmannTensor( + _names=tuple(names), + _arrow=tensor.arrow, + _edges=tensor.edges, + _tensor=tensor.tensor, + _parity=None, + _mask=None, + ) + + def conjugate(self) -> NamedGrassmannTensor: + tensor = self.gt.conj() + return dataclasses.replace( + self, + _arrow=tensor.arrow, + _edges=tensor.edges, + _tensor=tensor.tensor, + _parity=None, + _mask=None, + ) + + def conj(self) -> NamedGrassmannTensor: + return self.conjugate() + + def _get_left_right_indices( + self, pairs: set[tuple[str, str]] + ) -> tuple[tuple[str, ...], tuple[int, ...], tuple[int, ...]]: + pairs_map = {set_0: set_1 for set_0, set_1 in pairs} + left_set = set(pairs_map) + right_set = set(pairs_map.values()) + + are_disjoint = left_set.isdisjoint(right_set) + is_complete_union = (left_set | right_set) == set(self.names) + no_duplicates = len(left_set) + len(right_set) == len(self.names) + + assert are_disjoint and is_complete_union and no_duplicates, ( + f"Input pairs must cover all dimension and disjoint, but got {pairs_map}" + ) + + left_names = tuple(name for name in self.names if name in left_set) + right_names = tuple(pairs_map[name] for name in left_names) + + names = left_names + right_names + + left_idx = tuple(self.get_name_index(name) for name in left_names) + right_idx = tuple(self.get_name_index(name) for name in right_names) + + return names, left_idx, right_idx + + def exponential(self, pairs: set[tuple[str, str]]) -> NamedGrassmannTensor: + names, left_idx, right_idx = self._get_left_right_indices(pairs) + + exp = self.gt.exponential((left_idx, right_idx)) + + return dataclasses.replace( + self, + _names=names, + _arrow=exp.arrow, + _edges=exp.edges, + _tensor=exp.tensor, + ) + + def identity(self, pairs: set[tuple[str, str]]) -> NamedGrassmannTensor: + names, left_idx, right_idx = self._get_left_right_indices(pairs) + + identity = self.gt.identity((left_idx, right_idx)) + + return dataclasses.replace( + self, + _names=names, + _arrow=identity.arrow, + _edges=identity.edges, + _tensor=identity.tensor, + ) + + def svd( + self, + free_names_u: set[str], + common_name_u: str, + common_name_v: str, + singular_name_u: str, + singular_name_v: str, + *, + cutoff: int | None | tuple[int, int] = None, + ) -> tuple[NamedGrassmannTensor, NamedGrassmannTensor, NamedGrassmannTensor]: + free_names_u_indices = tuple(self.get_name_index(name) for name in free_names_u) + u, s, vh = self.gt.svd(free_names_u_indices, cutoff=cutoff) + + left_names = tuple(self.names[i] for i in free_names_u_indices) + right_names = tuple( + name for i, name in enumerate(self.names) if i not in set(free_names_u_indices) + ) + + U = NamedGrassmannTensor( + _names=left_names + (common_name_u,), + _arrow=u.arrow, + _edges=u.edges, + _tensor=u.tensor, + ) + S = NamedGrassmannTensor( + _names=(singular_name_u, singular_name_v), + _arrow=s.arrow, + _edges=s.edges, + _tensor=s.tensor, + ) + Vh = NamedGrassmannTensor( + _names=(common_name_v,) + right_names, + _arrow=vh.arrow, + _edges=vh.edges, + _tensor=vh.tensor, + ) + + return U, S, Vh + + def contract( + self, + other: NamedGrassmannTensor, + contract_pairs: set[tuple[str, str]], + ) -> NamedGrassmannTensor: + assert contract_pairs, "contract_pairs must be non-empty" + + for pair in contract_pairs: + assert ( + isinstance(pair, tuple) and len(pair) == 2 and all(isinstance(x, str) for x in pair) + ), f"Each contract pair must be (str, str), got: {pair!r}" + + names_a = [a for a, _ in contract_pairs] + names_b = [b for _, b in contract_pairs] + assert len(names_a) == len(set(names_a)), f"Duplicate names on A side: {names_a}" + assert len(names_b) == len(set(names_b)), f"Duplicate names on B side: {names_b}" + + assert all(a_name in self.names for a_name in names_a), ( + "Some names of self side not in name list." + ) + assert all(b_name in other.names for b_name in names_b), ( + "Some names of other side not in name list." + ) + + name_set_a = set(names_a) + name_set_b = set(names_b) + + if self.tensor.numel() >= other.tensor.numel(): + dict_map = {a: b for a, b in contract_pairs} + ordered_a = [name for name in self.names if name in name_set_a] + ordered_b = [dict_map[a] for a in ordered_a] + else: + dict_map = {b: a for a, b in contract_pairs} + ordered_b = [name for name in other.names if name in name_set_b] + ordered_a = [dict_map[b] for b in ordered_b] + + assert all( + self.edges[self.get_name_index(name_a)] == other.edges[other.get_name_index(name_b)] + for name_a, name_b in zip(ordered_a, ordered_b) + ), "Contract edges must be same." + + leg_a = tuple(self.get_name_index(name) for name in ordered_a) + leg_b = tuple(other.get_name_index(name) for name in ordered_b) + + c = self.gt.contract(other.gt, leg_a, leg_b) + + contract_set_a = set(ordered_a) + contract_set_b = set(ordered_b) + names = tuple(name for name in self.names if name not in contract_set_a) + tuple( + name for name in other.names if name not in contract_set_b + ) + + return NamedGrassmannTensor( + _names=names, + _arrow=c.arrow, + _edges=c.edges, + _tensor=c.tensor, + _parity=None, + _mask=None, + ) + + def reciprocal(self) -> NamedGrassmannTensor: + return dataclasses.replace( + self, _tensor=torch.where(self.tensor == 0, self.tensor, 1 / self.tensor) + ) + + def _validate_edge_compatibility(self, other: NamedGrassmannTensor) -> None: + assert self._names == other.names, ( + f"Names must match for arithmetic operations. Got {self._names} and {other.names}." + ) + assert self._arrow == other.arrow, ( + f"Arrows must match for arithmetic operations. Got {self._arrow} and {other.arrow}." + ) + assert self._edges == other.edges, ( + f"Edges must match for arithmetic operations. Got {self._edges} and {other.edges}." + ) + + def __pos__(self) -> NamedGrassmannTensor: + return dataclasses.replace( + self, + _tensor=+self._tensor, + ) + + def __neg__(self) -> NamedGrassmannTensor: + return dataclasses.replace( + self, + _tensor=-self._tensor, + ) + + def __add__(self, other: typing.Any) -> NamedGrassmannTensor: + if isinstance(other, NamedGrassmannTensor): + self._validate_edge_compatibility(other) + return dataclasses.replace( + self, + _tensor=self._tensor + other._tensor, + ) + try: + result = self._tensor + other + except TypeError: + return NotImplemented + if isinstance(result, torch.Tensor): + return dataclasses.replace( + self, + _tensor=result, + ) + return NotImplemented + + def __radd__(self, other: typing.Any) -> NamedGrassmannTensor: + try: + result = other + self._tensor + except TypeError: + return NotImplemented + if isinstance(result, torch.Tensor): + return dataclasses.replace( + self, + _tensor=result, + ) + return NotImplemented + + def __iadd__(self, other: typing.Any) -> NamedGrassmannTensor: + if isinstance(other, NamedGrassmannTensor): + self._validate_edge_compatibility(other) + self._tensor += other._tensor + return self + try: + self._tensor += other + except TypeError: + return NotImplemented + if isinstance(self._tensor, torch.Tensor): + return self + return NotImplemented + + def __sub__(self, other: typing.Any) -> NamedGrassmannTensor: + if isinstance(other, NamedGrassmannTensor): + self._validate_edge_compatibility(other) + return dataclasses.replace( + self, + _tensor=self._tensor - other._tensor, + ) + try: + result = self._tensor - other + except TypeError: + return NotImplemented + if isinstance(result, torch.Tensor): + return dataclasses.replace( + self, + _tensor=result, + ) + return NotImplemented + + def __rsub__(self, other: typing.Any) -> NamedGrassmannTensor: + try: + result = other - self._tensor + except TypeError: + return NotImplemented + if isinstance(result, torch.Tensor): + return dataclasses.replace( + self, + _tensor=result, + ) + return NotImplemented + + def __isub__(self, other: typing.Any) -> NamedGrassmannTensor: + if isinstance(other, NamedGrassmannTensor): + self._validate_edge_compatibility(other) + self._tensor -= other._tensor + return self + try: + self._tensor -= other + except TypeError: + return NotImplemented + if isinstance(self._tensor, torch.Tensor): + return self + return NotImplemented + + def __mul__(self, other: typing.Any) -> NamedGrassmannTensor: + if isinstance(other, NamedGrassmannTensor): + self._validate_edge_compatibility(other) + return dataclasses.replace( + self, + _tensor=self._tensor * other._tensor, + ) + try: + result = self._tensor * other + except TypeError: + return NotImplemented + if isinstance(result, torch.Tensor): + return dataclasses.replace( + self, + _tensor=result, + ) + return NotImplemented + + def __rmul__(self, other: typing.Any) -> NamedGrassmannTensor: + try: + result = other * self._tensor + except TypeError: + return NotImplemented + if isinstance(result, torch.Tensor): + return dataclasses.replace( + self, + _tensor=result, + ) + return NotImplemented + + def __imul__(self, other: typing.Any) -> NamedGrassmannTensor: + if isinstance(other, NamedGrassmannTensor): + self._validate_edge_compatibility(other) + self._tensor *= other._tensor + return self + try: + self._tensor *= other + except TypeError: + return NotImplemented + if isinstance(self._tensor, torch.Tensor): + return self + return NotImplemented + + def __truediv__(self, other: typing.Any) -> NamedGrassmannTensor: + if isinstance(other, NamedGrassmannTensor): + self._validate_edge_compatibility(other) + return dataclasses.replace( + self, + _tensor=self._tensor / other._tensor, + ) + try: + result = self._tensor / other + except TypeError: + return NotImplemented + if isinstance(result, torch.Tensor): + return dataclasses.replace( + self, + _tensor=result, + ) + return NotImplemented + + def __rtruediv__(self, other: typing.Any) -> NamedGrassmannTensor: + try: + result = other / self._tensor + except TypeError: + return NotImplemented + if isinstance(result, torch.Tensor): + return dataclasses.replace( + self, + _tensor=result, + ) + return NotImplemented + + def __itruediv__(self, other: typing.Any) -> NamedGrassmannTensor: + if isinstance(other, NamedGrassmannTensor): + self._validate_edge_compatibility(other) + self._tensor /= other._tensor + return self + try: + self._tensor /= other + except TypeError: + return NotImplemented + if isinstance(self._tensor, torch.Tensor): + return self + return NotImplemented + + def __matmul__(self, other: typing.Any) -> NamedGrassmannTensor: + if isinstance(other, NamedGrassmannTensor): + return self.matmul(other) + return NotImplemented + + def __rmatmul__(self, other: typing.Any) -> NamedGrassmannTensor: + return NotImplemented + + def __imatmul__(self, other: typing.Any) -> NamedGrassmannTensor: + if isinstance(other, NamedGrassmannTensor): + return self.matmul(other) + return NotImplemented + + def clone(self) -> NamedGrassmannTensor: + """ + Create a deep copy of the Grassmann tensor. + """ + return dataclasses.replace( + self, + _tensor=self._tensor.clone(), + _parity=tuple(parity.clone() for parity in self._parity) + if self._parity is not None + else None, + _mask=self._mask.clone() if self._mask is not None else None, + ) + + def __copy__(self) -> NamedGrassmannTensor: + return self.clone() + + def __deepcopy__(self, memo: dict) -> NamedGrassmannTensor: + return self.clone() diff --git a/tests/arithmetic_test.py b/tests/arithmetic_test.py index c5f758c..dd4296b 100644 --- a/tests/arithmetic_test.py +++ b/tests/arithmetic_test.py @@ -2,7 +2,7 @@ import typing import pytest import torch -from grassmann_tensor import GrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor @pytest.fixture( @@ -268,3 +268,253 @@ def test_arithmetic_fail(mismatch_tensors: tuple[GrassmannTensor, GrassmannTenso with pytest.raises(AssertionError, match="must match for arithmetic operations"): tensor_c = tensor_a.clone() tensor_c /= tensor_b + + +@pytest.fixture( + params=[ + ( + NamedGrassmannTensor(("a", "b"), (False, False), ((2, 2), (1, 3)), torch.randn([4, 4])), + NamedGrassmannTensor(("a", "b"), (False, False), ((2, 2), (1, 3)), torch.randn([4, 4])), + ), + ( + NamedGrassmannTensor( + ("a", "b", "c"), + (True, False, True), + ((1, 1), (2, 2), (3, 1)), + torch.randn([2, 4, 4]), + ), + NamedGrassmannTensor( + ("a", "b", "c"), + (True, False, True), + ((1, 1), (2, 2), (3, 1)), + torch.randn([2, 4, 4]), + ), + ), + ( + NamedGrassmannTensor( + ("a", "b", "c", "d"), + (True, True, False, False), + ((1, 2), (2, 2), (1, 1), (3, 1)), + torch.randn([3, 4, 2, 4]), + ), + NamedGrassmannTensor( + ("a", "b", "c", "d"), + (True, True, False, False), + ((1, 2), (2, 2), (1, 1), (3, 1)), + torch.randn([3, 4, 2, 4]), + ), + ), + ] +) +def named_tensors( + request: pytest.FixtureRequest, +) -> tuple[NamedGrassmannTensor, NamedGrassmannTensor]: + return request.param + + +@pytest.fixture( + params=[ + ( + NamedGrassmannTensor(("a", "b"), (False, False), ((2, 2), (1, 3)), torch.randn([4, 4])), + NamedGrassmannTensor(("a",), (False,), ((2, 2),), torch.randn([4])), + ), + ( + NamedGrassmannTensor( + ("a", "b", "c"), + (True, False, True), + ((1, 1), (2, 2), (3, 1)), + torch.randn([2, 4, 4]), + ), + NamedGrassmannTensor( + ("a", "b", "c"), + (True, False, True), + ((1, 2), (2, 2), (3, 1)), + torch.randn([3, 4, 4]), + ), + ), + ( + NamedGrassmannTensor( + ("a", "b", "c"), + (True, False, True), + ((1, 1), (2, 2), (3, 1)), + torch.randn([2, 4, 4]), + ), + NamedGrassmannTensor( + ("d", "e", "f"), + (True, False, True), + ((1, 2), (2, 2), (3, 1)), + torch.randn([3, 4, 4]), + ), + ), + ( + NamedGrassmannTensor( + ("a", "b", "c"), + (True, True, False), + ((1, 2), (2, 2), (3, 1)), + torch.randn([3, 4, 4]), + ), + NamedGrassmannTensor( + ("a", "b", "c", "d"), + (True, True, False, False), + ((3, 2), (2, 2), (1, 1), (3, 1)), + torch.randn([5, 4, 2, 4]), + ), + ), + ] +) +def mismatch_named_tensors( + request: pytest.FixtureRequest, +) -> tuple[GrassmannTensor, GrassmannTensor]: + return request.param + + +@pytest.mark.parametrize( + "unsupported_type", + [ + "string", # string + None, # NoneType + {"key", "value"}, # dict + [1, 2, 3], # list + {1, 2}, # set + object(), # arbitrary object + FakeTensor(), # an ill defined tensor-like object + ], +) +def test_named_tensor_arithmetic( + unsupported_type: typing.Any, + named_tensors: tuple[NamedGrassmannTensor, NamedGrassmannTensor], + scalar: torch.Tensor | float, +) -> None: + tensor_a, tensor_b = named_tensors + + # Test __pos__ method. + assert torch.equal((+tensor_a).tensor, +tensor_a.tensor) + + # Test __neg__ method. + assert torch.equal((-tensor_a).tensor, -tensor_a.tensor) + + # Test __add__ method. + assert torch.equal((tensor_a + scalar).tensor, tensor_a.tensor + scalar) + assert torch.equal((tensor_a + tensor_b).tensor, tensor_a.tensor + tensor_b.tensor) + assert torch.equal((scalar + tensor_a).tensor, scalar + tensor_a.tensor) + + with pytest.raises(TypeError): + tensor_a + unsupported_type + + with pytest.raises(TypeError): + unsupported_type + tensor_a + + # Test __iadd__ method. + tensor_c = tensor_a.clone() + tensor_c += scalar + assert torch.equal(tensor_c.tensor, tensor_a.tensor + scalar) + tensor_c = tensor_a.clone() + tensor_c += tensor_b + assert torch.equal(tensor_c.tensor, tensor_a.tensor + tensor_b.tensor) + + with pytest.raises(TypeError): + tensor_c = tensor_a.clone() + tensor_c += unsupported_type + + # Test __sub__ method. + assert torch.equal((tensor_a - scalar).tensor, tensor_a.tensor - scalar) + assert torch.equal((tensor_a - tensor_b).tensor, tensor_a.tensor - tensor_b.tensor) + assert torch.equal((scalar - tensor_a).tensor, scalar - tensor_a.tensor) + + with pytest.raises(TypeError): + tensor_a - unsupported_type + + with pytest.raises(TypeError): + unsupported_type - tensor_a + + # Test __isub__ method. + tensor_c = tensor_a.clone() + tensor_c -= scalar + assert torch.equal(tensor_c.tensor, tensor_a.tensor - scalar) + tensor_c = tensor_a.clone() + tensor_c -= tensor_b + assert torch.equal(tensor_c.tensor, tensor_a.tensor - tensor_b.tensor) + + with pytest.raises(TypeError): + tensor_c = tensor_a.clone() + tensor_c -= unsupported_type + + # Test __mul__ method. + assert torch.allclose((tensor_a * scalar).tensor, tensor_a.tensor * scalar) + assert torch.allclose((tensor_a * tensor_b).tensor, tensor_a.tensor * tensor_b.tensor) + assert torch.allclose((scalar * tensor_a).tensor, scalar * tensor_a.tensor) + + with pytest.raises(TypeError): + tensor_a * unsupported_type + + with pytest.raises(TypeError): + unsupported_type * tensor_a + + # Test __imul__ method. + tensor_c = tensor_a.clone() + tensor_c *= scalar + assert torch.allclose(tensor_c.tensor, tensor_a.tensor * scalar) + tensor_c = tensor_a.clone() + tensor_c *= tensor_b + assert torch.allclose(tensor_c.tensor, tensor_a.tensor * tensor_b.tensor) + + with pytest.raises(TypeError): + tensor_c = tensor_a.clone() + tensor_c *= unsupported_type + + # Test __truediv__ method. + assert torch.allclose((tensor_a / scalar).tensor, tensor_a.tensor / scalar) + assert torch.allclose((tensor_a / tensor_b).tensor, tensor_a.tensor / tensor_b.tensor) + assert torch.allclose((scalar / tensor_a).tensor, scalar / tensor_a.tensor) + + with pytest.raises(TypeError): + tensor_a / unsupported_type + + with pytest.raises(TypeError): + unsupported_type / tensor_a + + # Test __itruediv__ method. + tensor_c = tensor_a.clone() + tensor_c /= scalar + assert torch.allclose(tensor_c.tensor, tensor_a.tensor / scalar) + tensor_c = tensor_a.clone() + tensor_c /= tensor_b + assert torch.allclose(tensor_c.tensor, tensor_a.tensor / tensor_b.tensor) + + with pytest.raises(TypeError): + tensor_c = tensor_a.clone() + tensor_c /= unsupported_type + + +def test_named_tensor_arithmetic_fail( + mismatch_named_tensors: tuple[NamedGrassmannTensor, NamedGrassmannTensor], +) -> None: + tensor_a, tensor_b = mismatch_named_tensors + + # Test __add__ method. + with pytest.raises(AssertionError, match="must match for arithmetic operations"): + tensor_a + tensor_b + with pytest.raises(AssertionError, match="must match for arithmetic operations"): + tensor_c = tensor_a.clone() + tensor_c += tensor_b + + # Test __sub__ method. + with pytest.raises(AssertionError, match="must match for arithmetic operations"): + tensor_a - tensor_b + with pytest.raises(AssertionError, match="must match for arithmetic operations"): + tensor_c = tensor_a.clone() + tensor_c -= tensor_b + + # Test __mul__ method. + with pytest.raises(AssertionError, match="must match for arithmetic operations"): + tensor_a * tensor_b + with pytest.raises(AssertionError, match="must match for arithmetic operations"): + tensor_c = tensor_a.clone() + tensor_c *= tensor_b + + # Test __truediv__ method. + with pytest.raises(AssertionError, match="must match for arithmetic operations"): + tensor_a / tensor_b + with pytest.raises(AssertionError, match="must match for arithmetic operations"): + tensor_c = tensor_a.clone() + tensor_c /= tensor_b diff --git a/tests/attributes_test.py b/tests/attributes_test.py index 9a12a82..0e8a18a 100644 --- a/tests/attributes_test.py +++ b/tests/attributes_test.py @@ -1,6 +1,6 @@ import pytest import torch -from grassmann_tensor import GrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor Initialization = tuple[tuple[bool, ...], tuple[tuple[int, int], ...], torch.Tensor] @@ -60,3 +60,70 @@ def test_mask(x: Initialization) -> None: for rank, parity in enumerate(tensor.parity): expect ^= bool(parity[indices[rank]]) assert mask == expect + + +NamedInitialization = tuple[ + tuple[str, ...], tuple[bool, ...], tuple[tuple[int, int], ...], torch.Tensor +] + + +@pytest.fixture( + params=[ + (("a", "b"), (False, False), ((2, 2), (2, 2)), torch.randn([4, 4])), + (("a", "b"), (False, True), ((2, 2), (1, 3)), torch.randn([4, 4])), + (("a", "b"), (False, True), ((2, 0), (1, 3)), torch.randn([2, 4])), + (("a", "b"), (True, False), ((0, 2), (1, 3)), torch.randn([2, 4])), + (("a", "b"), (True, False), ((0, 0), (1, 3)), torch.randn([0, 4])), + (("a",), (True,), ((2, 0),), torch.randn([2])), + (("a",), (False,), ((0, 2),), torch.randn([2])), + ((), (), (), torch.randn([])), + (("a", "b", "c"), (False, False, True), ((2, 2), (1, 3), (4, 0)), torch.randn([4, 4, 4])), + ] +) +def named_x(request: pytest.FixtureRequest) -> NamedInitialization: + return request.param + + +def test_named_tensor_name(named_x: NamedInitialization) -> None: + tensor = NamedGrassmannTensor(*named_x) + assert tensor.names == named_x[0] + + +def test_named_tensor_arrow(named_x: NamedInitialization) -> None: + tensor = NamedGrassmannTensor(*named_x) + assert tensor.arrow == named_x[1] + + +def test_named_tensor_edges(named_x: NamedInitialization) -> None: + tensor = NamedGrassmannTensor(*named_x) + assert tensor.edges == named_x[2] + + +def test_named_tensor_tensor(named_x: NamedInitialization) -> None: + tensor = NamedGrassmannTensor(*named_x) + assert torch.equal(tensor.tensor, named_x[3]) + + +def test_named_tensor_parity(named_x: NamedInitialization) -> None: + tensor = NamedGrassmannTensor(*named_x) + assert len(tensor.parity) == tensor.tensor.dim() + for [even, odd], parity in zip(named_x[2], tensor.parity): + total = even + odd + assert parity.shape == (total,) + assert parity.dtype == torch.bool + for i in range(total): + assert parity[i] == (i >= even) + + +def test_named_tensor_mask(named_x: NamedInitialization) -> None: + tensor = NamedGrassmannTensor(*named_x) + assert tensor.mask.shape == tensor.tensor.shape + assert tensor.mask.dtype == torch.bool + for indices in zip( + *torch.unravel_index(torch.arange(tensor.tensor.numel()), tensor.tensor.shape) + ): + mask = tensor.mask[indices] + expect = False + for rank, parity in enumerate(tensor.parity): + expect ^= bool(parity[indices[rank]]) + assert mask == expect diff --git a/tests/clone_test.py b/tests/clone_test.py index 28d498a..182f7e4 100644 --- a/tests/clone_test.py +++ b/tests/clone_test.py @@ -2,7 +2,7 @@ import copy import pytest import torch -from grassmann_tensor import GrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor @pytest.mark.parametrize("parity,mask", [(True, True), (True, False), (False, False)]) @@ -50,3 +50,52 @@ def test_clone( assert cloned_tensor._mask is original_tensor._mask assert id(original_tensor.tensor) != id(cloned_tensor.tensor) + + +@pytest.mark.parametrize("parity,mask", [(True, True), (True, False), (False, False)]) +@pytest.mark.parametrize("which", ["clone", "copy", "deepcopy"]) +def test_named_tensor_clone( + parity: bool, + mask: bool, + which: typing.Literal["clone", "copy", "deepcopy"], +) -> None: + original_tensor = NamedGrassmannTensor( + _names=("a", "b"), + _arrow=(False, True), + _edges=((2, 2), (1, 3)), + _tensor=torch.randn([4, 4]), + ) + + if parity: + _ = original_tensor.parity + if mask: + _ = original_tensor.mask + + match which: + case "clone": + cloned_tensor = original_tensor.clone() + case "copy": + cloned_tensor = copy.copy(original_tensor) + case "deepcopy": + cloned_tensor = copy.deepcopy(original_tensor) + + assert cloned_tensor._names == original_tensor._names + assert cloned_tensor._arrow == original_tensor._arrow + assert cloned_tensor._edges == original_tensor._edges + assert torch.equal(cloned_tensor._tensor, original_tensor._tensor) + if parity: + assert cloned_tensor._parity is not None + assert original_tensor._parity is not None + assert all( + torch.equal(c, o) for c, o in zip(cloned_tensor._parity, original_tensor._parity) + ) + else: + assert cloned_tensor._parity is original_tensor._parity + if mask: + assert cloned_tensor._mask is not None + assert original_tensor._mask is not None + assert torch.equal(cloned_tensor._mask, original_tensor._mask) + else: + assert cloned_tensor._mask is original_tensor._mask + + assert id(original_tensor.tensor) != id(cloned_tensor.tensor) diff --git a/tests/conjugate_test.py b/tests/conjugate_test.py index 1d505b5..b9b3cf9 100644 --- a/tests/conjugate_test.py +++ b/tests/conjugate_test.py @@ -2,9 +2,10 @@ import pytest from typing import TypeAlias -from grassmann_tensor import GrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor Arrow_Edges: TypeAlias = tuple[tuple[bool, ...], tuple[tuple[int, int], ...]] +Names_Arrow_Edges: TypeAlias = tuple[tuple[str, ...], tuple[bool, ...], tuple[tuple[int, int], ...]] @pytest.fixture( @@ -104,3 +105,125 @@ def test_conjugate_without_symmetry_equality( assert torch.allclose(a.tensor, a.conj().tensor) assert torch.allclose(b.tensor.conj(), b.conj().tensor) + + +@pytest.fixture( + params=[ + ( + ("a", "b"), + (True, False), + ((2, 2), (2, 2)), + ), + ( + ("a", "b", "c"), + (True, False, False), + ((2, 2), (2, 2), (2, 2)), + ), + ] +) +def names_arrow_edges(request: pytest.FixtureRequest) -> Names_Arrow_Edges: + return request.param + + +@pytest.fixture( + params=[ + ( + ("a", "b"), + (True, False), + ((4, 0), (4, 0)), + ), + ( + ("a", "b", "c"), + (True, False, False), + ((4, 0), (4, 0), (4, 0)), + ), + ] +) +def names_arrow_edges_without_symmetry(request: pytest.FixtureRequest) -> Names_Arrow_Edges: + return request.param + + +def create_random_named_tensor( + names_arrow_edges: Names_Arrow_Edges, + *, + dtype: torch.dtype, +) -> NamedGrassmannTensor: + names, arrow, edges = names_arrow_edges + shape = [even + odd for even, odd in edges] + + if dtype == torch.float64: + tensor = torch.rand(*shape, dtype=dtype) + else: + tensor = torch.randn(*shape) + 1j * torch.randn(*shape).to(dtype) + return NamedGrassmannTensor(names, arrow, edges, tensor) + + +@pytest.fixture +def random_real_named_tensor(names_arrow_edges: Names_Arrow_Edges) -> NamedGrassmannTensor: + return create_random_named_tensor(names_arrow_edges, dtype=torch.float64) + + +@pytest.fixture +def random_complex_named_tensor(names_arrow_edges: Names_Arrow_Edges) -> NamedGrassmannTensor: + return create_random_named_tensor(names_arrow_edges, dtype=torch.complex128) + + +def test_conjugate_involution_with_complex_named_tensor( + random_complex_named_tensor: NamedGrassmannTensor, +) -> None: + contrast_a = random_complex_named_tensor + contrast_b = contrast_a.conj().conj() + assert contrast_a.names == contrast_b.names + assert contrast_a.arrow == contrast_b.arrow + assert torch.allclose(contrast_a.tensor, contrast_b.tensor) + + +def test_conjugate_involution_with_real_named_tensor( + random_real_named_tensor: NamedGrassmannTensor, +) -> None: + contrast_a = random_real_named_tensor + contrast_b = contrast_a.conj().conj() + assert contrast_a.names == contrast_b.names + assert contrast_a.arrow == contrast_b.arrow + assert torch.allclose(contrast_a.tensor, contrast_b.tensor) + + +def test_named_tensor_conjugate_reverse_order_of_contraction( + names_arrow_edges: Names_Arrow_Edges, +) -> None: + a = create_random_named_tensor(names_arrow_edges, dtype=torch.complex128).update_mask() + b = create_random_named_tensor(names_arrow_edges, dtype=torch.complex128).update_mask() + if b.tensor.dim() == 2: + b = b.rename({"a": "i", "b": "j"}) + else: + b = b.rename({"a": "i", "b": "j", "c": "k"}) + + if a.tensor.dim() == 2: + contrast_a = a.contract(b, {("b", "i")}) + contrast_a = contrast_a.conj() + else: + contrast_a = a.contract(b, {("c", "i")}) + contrast_a = contrast_a.conj() + + a_conj = a.conj() + b_conj = b.conj() + + if a.tensor.dim() == 2: + contrast_b = a_conj.contract(b_conj, {("b", "i")}) + else: + contrast_b = a_conj.contract(b_conj, {("c", "i")}) + + assert contrast_a.names == contrast_b.names + assert contrast_a.arrow == contrast_b.arrow + assert contrast_a.edges == contrast_b.edges + assert torch.allclose(contrast_a.tensor, contrast_b.tensor) + + +def test_named_tensor_conjugate_without_symmetry_equality( + names_arrow_edges_without_symmetry: Names_Arrow_Edges, +) -> None: + a = create_random_named_tensor(names_arrow_edges_without_symmetry, dtype=torch.float64) + b = create_random_named_tensor(names_arrow_edges_without_symmetry, dtype=torch.complex128) + + assert torch.allclose(a.tensor, a.conj().tensor) + assert torch.allclose(b.tensor.conj(), b.conj().tensor) diff --git a/tests/contract_test.py b/tests/contract_test.py index 9a57f41..63020c6 100644 --- a/tests/contract_test.py +++ b/tests/contract_test.py @@ -4,9 +4,10 @@ import random from typing import Iterable -from grassmann_tensor import GrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor ContractCases = Iterable[ParameterSet] +NamedContractCases = Iterable[ParameterSet] def contract_cases() -> ContractCases: @@ -99,3 +100,102 @@ def test_contract_full_legs() -> None: ) c = a.contract(b, (0, 1, 2, 3), (0, 1, 2, 3)) assert c.tensor.dim() == 0 + + +def named_contract_cases() -> NamedContractCases: + edge_unit = (4, 4) + max_dim = 4 + num_cases = 10 + + rng = random.Random(0) + gen = torch.Generator().manual_seed(0) + + cases = [] + + for case_idx in range(num_cases): + dim = rng.randint(1, max_dim) + + edges = tuple(edge_unit for _ in range(dim)) + arrow = tuple(bool(rng.getrandbits(1)) for _ in range(dim)) + shape = (sum(edge_unit),) * dim + tensor = torch.randn(*shape, dtype=torch.float64, generator=gen) + a = NamedGrassmannTensor( + tuple(f"a{i}" for i in range(dim)), + arrow, + edges, + tensor, + ) + + b = NamedGrassmannTensor( + tuple(f"b{i}" for i in range(dim)), + arrow, + edges, + tensor, + ) + + contract_length = rng.randint(1, dim) + + leg_a = tuple(sorted(rng.sample(range(dim), contract_length))) + leg_b = tuple(sorted(rng.sample(range(dim), contract_length))) + + pairs: set[tuple[str, str]] = {(a.names[i], b.names[j]) for i, j in zip(leg_a, leg_b)} + contracted_a = set(a.names[i] for i in leg_a) + contracted_b = set(b.names[j] for j in leg_b) + result_names = tuple(n for n in a.names if n not in contracted_a) + tuple( + n for n in b.names if n not in contracted_b + ) + + cases.append( + pytest.param( + a, + b, + pairs, + result_names, + id=f"arrow={arrow}-dim={dim}-pairs={sorted(pairs)}-result={result_names}", + ) + ) + + return cases + + +@pytest.mark.parametrize("a, b, pairs, result_names", named_contract_cases()) +def test_named_tensor_contract( + a: NamedGrassmannTensor, + b: NamedGrassmannTensor, + pairs: set[tuple[str, str]], + result_names: set[tuple[str, str]], +) -> None: + out = a.contract(b, pairs) + assert out.names == result_names + + +def test_named_tensor_contract_full_legs() -> None: + a = NamedGrassmannTensor( + ("a", "b", "c", "d"), + (False, False, False, True), + ((2, 2), (4, 4), (8, 8), (8, 8)), + torch.randn(4, 8, 16, 16, dtype=torch.float64), + ) + b = NamedGrassmannTensor( + ("a", "b", "c", "d"), + (False, True, True, True), + ((8, 8), (8, 8), (4, 4), (2, 2)), + torch.randn(16, 16, 8, 4, dtype=torch.float64), + ) + c = a.contract(b, {("a", "d"), ("b", "c"), ("c", "b"), ("d", "a")}) + assert c.tensor.dim() == 0 + + +def test_named_tensor_contract_different_order() -> None: + a = NamedGrassmannTensor( + ("a0", "a1"), (False, False), ((2, 0), (2, 0)), torch.tensor([[1, 2], [3, 4]]) + ) + b = NamedGrassmannTensor( + ("b0", "b1"), (True, True), ((2, 0), (2, 0)), torch.tensor([[10, 20], [30, 40]]) + ) + c = a.contract(b, {("a0", "b0"), ("a1", "b1")}) + d = a.contract(b, {("a1", "b1"), ("a0", "b0")}) + e = a.contract(b, {("a0", "b1"), ("a1", "b0")}) + f = a.contract(b, {("a1", "b0"), ("a0", "b1")}) + assert torch.allclose(c.tensor, d.tensor) + assert torch.allclose(e.tensor, f.tensor) diff --git a/tests/conversion_test.py b/tests/conversion_test.py index 29b1a59..c5a8411 100644 --- a/tests/conversion_test.py +++ b/tests/conversion_test.py @@ -1,7 +1,7 @@ import typing import pytest import torch -from grassmann_tensor import GrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor @pytest.fixture() @@ -61,3 +61,64 @@ def test_conversion_duplicated_value(x: GrassmannTensor) -> None: x.to(torch.complex128, dtype=torch.complex128) with pytest.raises(AssertionError, match="Duplicate device specification"): x.to("cpu", device=torch.device("cpu")) + + +@pytest.fixture() +def named_x() -> NamedGrassmannTensor: + return NamedGrassmannTensor( + ("a", "b"), (False, False), ((2, 2), (1, 3)), torch.randn([4, 4], device="cpu:0") + ) + + +@pytest.mark.parametrize("dtype_arg", ["position", "keyword", "none"]) +@pytest.mark.parametrize("device_arg", ["position", "keyword", "none"]) +@pytest.mark.parametrize("device_format", ["object", "string"]) +def test_named_tensor_conversion( + named_x: NamedGrassmannTensor, + dtype_arg: typing.Literal["position", "keyword", "none"], + device_arg: typing.Literal["position", "keyword", "none"], + device_format: typing.Literal["object", "string"], +) -> None: + args: list[typing.Any] = [] + kwargs: dict[str, typing.Any] = {} + + device_str = "cuda:0" if torch.cuda.is_available() else "cpu:0" + device = torch.device(device_str) if device_format == "object" else device_str + match device_arg: + case "position": + args.append(device) + case "keyword": + kwargs["device"] = device + case _: + pass + + match dtype_arg: + case "position": + args.append(torch.complex128) + case "keyword": + kwargs["dtype"] = torch.complex128 + case _: + pass + + if len(args) > 1: + pytest.skip("Cannot pass both dtype and device as positional arguments") + + y = named_x.to(*args, **kwargs) + assert isinstance(y, NamedGrassmannTensor) + assert y.arrow == named_x.arrow + assert y.edges == named_x.edges + assert y.tensor.dtype == torch.complex128 if dtype_arg != "none" else torch.float32 + assert ( + y.tensor.device.type + == (torch.device(device_str) if device_arg != "none" else torch.device("cpu:0")).type + ) + assert torch.allclose(y.tensor, named_x.tensor.to(dtype=y.tensor.dtype, device=y.tensor.device)) + + +def test_named_tensor_conversion_duplicated_value(named_x: NamedGrassmannTensor) -> None: + with pytest.raises(AssertionError, match="Duplicate device specification"): + named_x.to(torch.device("cpu"), device=torch.device("cpu")) + with pytest.raises(AssertionError, match="Duplicate dtype specification"): + named_x.to(torch.complex128, dtype=torch.complex128) + with pytest.raises(AssertionError, match="Duplicate device specification"): + named_x.to("cpu", device=torch.device("cpu")) diff --git a/tests/creation_test.py b/tests/creation_test.py index 1185293..1968cee 100644 --- a/tests/creation_test.py +++ b/tests/creation_test.py @@ -1,8 +1,11 @@ import pytest import torch -from grassmann_tensor import GrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor Initialization = tuple[tuple[bool, ...], tuple[tuple[int, int], ...], torch.Tensor] +NamedInitialization = tuple[ + tuple[str, ...], tuple[bool, ...], tuple[tuple[int, int], ...], torch.Tensor +] @pytest.mark.parametrize( @@ -17,6 +20,43 @@ def test_creation_success(x: Initialization) -> None: GrassmannTensor(*x) +@pytest.mark.parametrize( + "x", + [ + (("a", "b"), (False, False), ((2, 2), (1, 3)), torch.randn([4, 4])), + (("a", "b"), (True, False), ((2, 2), (3, 1)), torch.randn([4, 4])), + (("a", "b", "c"), (False, True, False), ((1, 1), (2, 2), (1, 1)), torch.randn([2, 4, 2])), + ], +) +def test_named_tensor_creation_success(x: NamedInitialization) -> None: + NamedGrassmannTensor(*x) + + +@pytest.mark.parametrize( + "x", + [ + (("a", "a"), (False, False), ((2, 2), (1, 3)), torch.randn([4, 4])), + (("a", "b", "a"), (False, True, False), ((1, 1), (2, 2), (1, 1)), torch.randn([2, 4, 2])), + ], +) +def test_named_tensor_not_unique_names(x: NamedInitialization) -> None: + with pytest.raises(AssertionError, match="Names must be unique"): + NamedGrassmannTensor(*x) + + +@pytest.mark.parametrize( + "x", + [ + (("a",), (False, False), ((2, 2), (1, 3)), torch.randn([4, 4])), + (("a", "b", "c"), (True, False), ((2, 2), (3, 1)), torch.randn([4, 4])), + (("a", "b"), (False, True, False), ((1, 1), (2, 2), (1, 1)), torch.randn([2, 4, 2])), + ], +) +def test_named_tensor_invalid_names(x: NamedInitialization) -> None: + with pytest.raises(AssertionError, match="Names length"): + NamedGrassmannTensor(*x) + + @pytest.mark.parametrize( "x", [ @@ -30,6 +70,19 @@ def test_creation_invalid_arrow(x: Initialization) -> None: GrassmannTensor(*x) +@pytest.mark.parametrize( + "x", + [ + (("a", "b"), (False,), ((2, 2), (1, 3)), torch.randn([4, 4])), + (("a", "b"), (True, False, True), ((2, 2), (3, 1)), torch.randn([4, 4])), + (("a", "b", "c"), (False, True), ((1, 1), (2, 2), (1, 1)), torch.randn([2, 4, 2])), + ], +) +def test_named_tensor_creation_invalid_arrow(x: NamedInitialization) -> None: + with pytest.raises(AssertionError, match="Arrow length"): + NamedGrassmannTensor(*x) + + @pytest.mark.parametrize( "x", [ @@ -43,6 +96,19 @@ def test_creation_invalid_edges(x: Initialization) -> None: GrassmannTensor(*x) +@pytest.mark.parametrize( + "x", + [ + (("a", "b"), (False, False), ((2, 2),), torch.randn([4, 4])), + (("a", "b"), (True, False), ((2, 2), (1, 1), (3, 1)), torch.randn([4, 4])), + (("a", "b", "c"), (False, True, False), ((1, 1), (1, 1)), torch.randn([2, 4, 2])), + ], +) +def test_named_tensor_creation_invalid_edges(x: NamedInitialization) -> None: + with pytest.raises(AssertionError, match="Edges length"): + NamedGrassmannTensor(*x) + + @pytest.mark.parametrize( "x", [ @@ -54,3 +120,16 @@ def test_creation_invalid_edges(x: Initialization) -> None: def test_creation_invalid_shape(x: Initialization) -> None: with pytest.raises(AssertionError, match="must equal sum of"): GrassmannTensor(*x) + + +@pytest.mark.parametrize( + "x", + [ + (("a", "b"), (False, False), ((2, 2), (1, 3)), torch.randn([4, 2])), + (("a", "b"), (True, False), ((2, 2), (3, 1)), torch.randn([2, 4])), + (("a", "b", "c"), (False, True, False), ((1, 1), (2, 2), (1, 1)), torch.randn([4, 4, 2])), + ], +) +def test_named_tensor_creation_invalid_shape(x: NamedInitialization) -> None: + with pytest.raises(AssertionError, match="must equal sum of"): + NamedGrassmannTensor(*x) diff --git a/tests/exponential_test.py b/tests/exponential_test.py index 1f24a46..22ad95e 100644 --- a/tests/exponential_test.py +++ b/tests/exponential_test.py @@ -1,11 +1,13 @@ import torch import pytest -from typing import TypeAlias +import typing -from grassmann_tensor import GrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor -Tensor: TypeAlias = GrassmannTensor -Pairs: TypeAlias = tuple[tuple[int, ...], tuple[int, ...]] +Tensor: typing.TypeAlias = GrassmannTensor +Pairs: typing.TypeAlias = tuple[tuple[int, ...], tuple[int, ...]] +NamedTensor: typing.TypeAlias = NamedGrassmannTensor +NamedPairs: typing.TypeAlias = set[tuple[str, str]] def test_exponential_with_empty_parity_block() -> None: @@ -97,3 +99,95 @@ def test_exponential_via_taylor_expansion( tensor_taylor_expansion = tensor_taylor_expansion.permute(inv_order) assert torch.allclose(tensor_taylor_expansion.tensor, tensor_exp.tensor) + + +def test_named_tensor_exponential_with_empty_parity_block() -> None: + a = NamedGrassmannTensor( + ("a", "b"), (False, True), ((1, 0), (1, 0)), torch.randn(1, 1, dtype=torch.float64) + ) + a.exponential({("a", "b")}) + b = NamedGrassmannTensor( + ("a", "b"), (False, True), ((0, 1), (0, 1)), torch.randn(1, 1, dtype=torch.float64) + ) + b.exponential({("a", "b")}) + + +def test_named_tensor_exponential_assertation() -> None: + a = NamedGrassmannTensor( + ("a", "b", "c", "d"), + (True, True, True, True), + ((2, 2), (4, 4), (8, 8), (16, 16)), + torch.randn(4, 8, 16, 32, dtype=torch.float64), + ) + with pytest.raises(AssertionError, match="Exponentiation requires arrow"): + a.exponential({("a", "b"), ("c", "d")}) + + b = NamedGrassmannTensor( + ("a", "b", "c", "d"), + (False, True, False, True), + ((2, 2), (4, 4), (8, 8), (16, 16)), + torch.randn(4, 8, 16, 32, dtype=torch.float64), + ) + with pytest.raises(AssertionError, match="Exponentiation requires a square operator"): + b.exponential({("a", "b"), ("c", "d")}) + + c = NamedGrassmannTensor( + ("a", "b", "c", "d"), + (False, True, False, True), + ((1, 3), (3, 1), (3, 1), (3, 1)), + torch.randn(4, 4, 4, 4, dtype=torch.float64), + ) + with pytest.raises(AssertionError, match="Parity blocks must be square"): + c.exponential({("a", "b"), ("c", "d")}) + + d = NamedGrassmannTensor( + ("a", "b", "c", "d"), + (True, True, True, True), + ((2, 2), (4, 4), (8, 8), (16, 16)), + torch.randn(4, 8, 16, 32, dtype=torch.float64), + ) + with pytest.raises(AssertionError, match="Input pairs must cover all dimension and disjoint"): + d.exponential({("a", "b"), ("a", "c"), ("a", "d")}) + + +@pytest.mark.parametrize( + "tensor, pairs", + [ + ( + NamedGrassmannTensor( + ("a", "b"), (False, True), ((4, 4), (4, 4)), torch.randn(8, 8, dtype=torch.float64) + ), + {("a", "b")}, + ), + ( + NamedGrassmannTensor( + ("a", "b"), (True, False), ((4, 4), (4, 4)), torch.randn(8, 8, dtype=torch.float64) + ), + {("a", "b")}, + ), + ( + NamedGrassmannTensor( + ("a", "b", "c", "d"), + (False, False, True, True), + ((4, 4), (8, 8), (4, 4), (8, 8)), + torch.randn(8, 16, 8, 16, dtype=torch.float64), + ), + {("a", "c"), ("b", "d")}, + ), + ], +) +def test_named_tensor_exponential_via_taylor_expansion( + tensor: NamedTensor, + pairs: NamedPairs, +) -> None: + tensor = tensor.update_mask() + tensor_exp = tensor.exponential(pairs) + iter_tensor = tensor.identity(pairs) + contract_pairs = typing.cast(set[tuple[str, str]], {item[::-1] for item in pairs}) + + tensor_taylor_expansion = iter_tensor + for i in range(1, 50): + iter_tensor = iter_tensor.contract(tensor, contract_pairs) / i + tensor_taylor_expansion += iter_tensor + + assert torch.allclose(tensor_taylor_expansion.tensor, tensor_exp.tensor) diff --git a/tests/identity_test.py b/tests/identity_test.py index b1699d6..f3296e3 100644 --- a/tests/identity_test.py +++ b/tests/identity_test.py @@ -1,11 +1,13 @@ import pytest import torch -from typing import TypeAlias +import typing -from grassmann_tensor import GrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor -Tensor: TypeAlias = GrassmannTensor -Pairs: TypeAlias = tuple[tuple[int, ...], tuple[int, ...]] +Tensor: typing.TypeAlias = GrassmannTensor +NamedTensor: typing.TypeAlias = NamedGrassmannTensor +Pairs: typing.TypeAlias = tuple[tuple[int, ...], tuple[int, ...]] +NamedPairs: typing.TypeAlias = set[tuple[str, str]] def test_identity_assertation() -> None: @@ -81,3 +83,70 @@ def test_identity_via_self_multiplication( assert torch.allclose((identity @ identity).tensor, identity.tensor) assert torch.allclose((identity @ tensor).tensor, tensor.tensor) assert torch.allclose((tensor @ identity).tensor, tensor.tensor) + + +def test_named_tensor_identity_assertation() -> None: + a = NamedGrassmannTensor( + ("a", "b", "c", "d"), + (True, True, True, True), + ((2, 2), (4, 4), (8, 8), (16, 16)), + torch.randn(4, 8, 16, 32, dtype=torch.float64), + ) + with pytest.raises(AssertionError, match="Identity requires arrow"): + a.identity({("a", "b"), ("c", "d")}) + + b = NamedGrassmannTensor( + ("a", "b", "c", "d"), + (False, True, False, True), + ((2, 2), (4, 4), (8, 8), (16, 16)), + torch.randn(4, 8, 16, 32, dtype=torch.float64), + ) + with pytest.raises(AssertionError, match="Identity requires a square operator"): + b.identity({("a", "b"), ("c", "d")}) + + c = NamedGrassmannTensor( + ("a", "b", "c", "d"), + (False, True, False, True), + ((1, 3), (3, 1), (3, 1), (3, 1)), + torch.randn(4, 4, 4, 4, dtype=torch.float64), + ) + with pytest.raises(AssertionError, match="Parity blocks must be square"): + c.identity({("a", "b"), ("c", "d")}) + + +@pytest.mark.parametrize( + "tensor, pairs", + [ + ( + NamedGrassmannTensor( + ("a", "b"), (False, True), ((4, 4), (4, 4)), torch.randn(8, 8, dtype=torch.float64) + ), + {("a", "b")}, + ), + ( + NamedGrassmannTensor( + ("a", "b"), (True, False), ((4, 4), (4, 4)), torch.randn(8, 8, dtype=torch.float64) + ), + {("a", "b")}, + ), + ( + NamedGrassmannTensor( + ("a", "b", "c", "d"), + (False, False, True, True), + ((4, 4), (8, 8), (4, 4), (8, 8)), + torch.randn(8, 16, 8, 16, dtype=torch.float64), + ), + {("a", "c"), ("b", "d")}, + ), + ], +) +def test_named_tensor_identity_via_self_multiplication( + tensor: NamedTensor, + pairs: NamedPairs, +) -> None: + tensor = tensor.update_mask() + identity = tensor.identity(pairs) + contract_pairs = typing.cast(set[tuple[str, str]], {item[::-1] for item in pairs}) + assert torch.allclose((identity.contract(identity, contract_pairs)).tensor, identity.tensor) + assert torch.allclose((identity.contract(tensor, contract_pairs)).tensor, tensor.tensor) + assert torch.allclose((tensor.contract(identity, contract_pairs)).tensor, tensor.tensor) diff --git a/tests/import_test.py b/tests/import_test.py index 20b0985..22a071f 100644 --- a/tests/import_test.py +++ b/tests/import_test.py @@ -1,4 +1,5 @@ def test_import() -> None: - from grassmann_tensor import GrassmannTensor + from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor assert isinstance(GrassmannTensor, type) + assert isinstance(NamedGrassmannTensor, type) diff --git a/tests/matmul_test.py b/tests/matmul_test.py index 589817d..b1e0aed 100644 --- a/tests/matmul_test.py +++ b/tests/matmul_test.py @@ -1,9 +1,10 @@ import pytest import torch import typing -from grassmann_tensor import GrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor Broadcast = tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...]] +NamedBroadcast = tuple[tuple[int, ...], tuple[int, ...], tuple[int, ...], tuple[str, ...]] MatmulCase = tuple[bool, bool, tuple[int, int], tuple[int, int], tuple[int, int]] @@ -43,6 +44,27 @@ def broadcast(request: pytest.FixtureRequest) -> Broadcast: return request.param +@pytest.fixture( + params=[ + ((), (), (), ()), + ((2,), (), (2,), ("a0",)), + ((), (3,), (3,), ("b0",)), + ((1,), (4,), (4,), ("b0",)), + ((5,), (1,), (5,), ("a0",)), + ((6,), (6,), (6,), ("a0",)), + ((7, 8), (7, 8), (7, 8), ("a0", "a1")), + ((1, 8), (7, 8), (7, 8), ("b0", "a1")), + ((8,), (7, 8), (7, 8), ("b0", "a0")), + ((7, 1), (7, 8), (7, 8), ("a0", "b1")), + ((7, 8), (1, 8), (7, 8), ("a0", "a1")), + ((7, 8), (8,), (7, 8), ("a0", "a1")), + ((7, 8), (7, 1), (7, 8), ("a0", "a1")), + ], +) +def named_broadcast(request: pytest.FixtureRequest) -> NamedBroadcast: + return request.param + + @pytest.fixture( params=[ (False, False, (1, 1), (1, 1), (1, 1)), @@ -181,23 +203,20 @@ def test_matmul_unpure_even( def test_matmul_operator_matmul( a_is_vector: bool, b_is_vector: bool, - normal_arrow_order: bool, - broadcast: Broadcast, ) -> None: - normal_arrow_order = True broadcast_a, broadcast_b, broadcast_result = (7, 8), (7, 1), (7, 8) arrow_a, arrow_b, edge_a, edge_common, edge_b = True, True, (2, 2), (2, 2), (2, 2) dim_a = sum(edge_a) dim_common = sum(edge_common) dim_b = sum(edge_b) a = GrassmannTensor( - (*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False), + (*(False for _ in broadcast_a), arrow_a, True), (*((i, 0) for i in broadcast_a), edge_a, edge_common), torch.randn([*broadcast_a, dim_a, dim_common]), ).update_mask() b = GrassmannTensor( - (*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b), + (*(False for _ in broadcast_b), False, arrow_b), (*((i, 0) for i in broadcast_b), edge_common, edge_b), torch.randn([*broadcast_b, dim_common, dim_b]), ).update_mask() @@ -272,3 +291,259 @@ def test_matmul_operator_rmatmul() -> None: assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b) assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b) assert torch.allclose(c.tensor, expected) + + +def test_named_matmul( + a_is_vector: bool, + b_is_vector: bool, + normal_arrow_order: bool, + named_broadcast: NamedBroadcast, + x: MatmulCase, +) -> None: + broadcast_a, broadcast_b, broadcast_result, broadcast_name_result = named_broadcast + arrow_a, arrow_b, edge_a, edge_common, edge_b = x + if a_is_vector and broadcast_a != (): + pytest.skip("Vector a cannot be broadcasted") + if b_is_vector and broadcast_b != (): + pytest.skip("Vector b cannot be broadcasted") + dim_a = sum(edge_a) + dim_common = sum(edge_common) + dim_b = sum(edge_b) + if a_is_vector: + a = NamedGrassmannTensor( + tuple((f"a{i}" for i in range(len(broadcast_a) + 1))), + (*(False for _ in broadcast_a), True if normal_arrow_order else False), + (*((i, 0) for i in broadcast_a), edge_common), + torch.randn([*broadcast_a, dim_common]), + ).update_mask() + else: + a = NamedGrassmannTensor( + tuple((f"a{i}" for i in range(len(broadcast_a) + 2))), + (*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False), + (*((i, 0) for i in broadcast_a), edge_a, edge_common), + torch.randn([*broadcast_a, dim_a, dim_common]), + ).update_mask() + if b_is_vector: + b = NamedGrassmannTensor( + tuple((f"b{i}" for i in range(len(broadcast_b) + 1))), + (*(False for _ in broadcast_b), False if normal_arrow_order else True), + (*((i, 0) for i in broadcast_b), edge_common), + torch.randn([*broadcast_b, dim_common]), + ).update_mask() + else: + b = NamedGrassmannTensor( + tuple((f"b{i}" for i in range(len(broadcast_b) + 2))), + (*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b), + (*((i, 0) for i in broadcast_b), edge_common, edge_b), + torch.randn([*broadcast_b, dim_common, dim_b]), + ).update_mask() + c = a.matmul(b) + expected = a.tensor.matmul(b.tensor) + if not a_is_vector and not b_is_vector and not normal_arrow_order: + expected[..., edge_a[0] :, edge_b[0] :] *= -1 + expected_tail: list[str] = [] + if not a_is_vector: + expected_tail.append(a.names[len(broadcast_a)]) + if not b_is_vector: + expected_tail.append(b.names[len(broadcast_b) + 1]) + expected_names = tuple(broadcast_name_result) + tuple(expected_tail) + if a_is_vector: + if b_is_vector: + assert c.names == expected_names + assert c.arrow == tuple(False for _ in broadcast_result) + assert c.edges == tuple((i, 0) for i in broadcast_result) + else: + assert c.names == expected_names + assert c.arrow == (*(False for _ in broadcast_result), arrow_b) + assert c.edges == (*((i, 0) for i in broadcast_result), edge_b) + else: + if b_is_vector: + assert c.names == expected_names + assert c.arrow == (*(False for _ in broadcast_result), arrow_a) + assert c.edges == (*((i, 0) for i in broadcast_result), edge_a) + else: + assert c.names == expected_names + assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b) + assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b) + assert torch.allclose(c.tensor, expected) + + +@pytest.mark.parametrize("impure_even_for_broadcast_indices", [1, 2]) +def test_named_matmul_unpure_even( + a_is_vector: bool, + b_is_vector: bool, + normal_arrow_order: bool, + named_broadcast: NamedBroadcast, + x: MatmulCase, + impure_even_for_broadcast_indices: int, +) -> None: + broadcast_a, broadcast_b, broadcast_result, broadcast_name_result = named_broadcast + arrow_a, arrow_b, edge_a, edge_common, edge_b = x + if a_is_vector and broadcast_a != (): + pytest.skip("Vector a cannot be broadcasted") + if b_is_vector and broadcast_b != (): + pytest.skip("Vector b cannot be broadcasted") + if a_is_vector and b_is_vector: + pytest.skip("Both vectors are ignored.") + dim_a = sum(edge_a) + dim_common = sum(edge_common) + dim_b = sum(edge_b) + if a_is_vector: + a = NamedGrassmannTensor( + tuple((f"a{i}" for i in range(len(broadcast_a) + 1))), + (*(False for _ in broadcast_a), True if normal_arrow_order else False), + (*((i, impure_even_for_broadcast_indices) for i in broadcast_a), edge_common), + torch.randn( + [*[i + impure_even_for_broadcast_indices for i in broadcast_a], dim_common] + ), + ).update_mask() + else: + a = NamedGrassmannTensor( + tuple((f"a{i}" for i in range(len(broadcast_a) + 2))), + (*(False for _ in broadcast_a), arrow_a, True if normal_arrow_order else False), + (*((i, impure_even_for_broadcast_indices) for i in broadcast_a), edge_a, edge_common), + torch.randn( + [*[i + impure_even_for_broadcast_indices for i in broadcast_a], dim_a, dim_common] + ), + ).update_mask() + if b_is_vector: + b = NamedGrassmannTensor( + tuple((f"b{i}" for i in range(len(broadcast_b) + 1))), + (*(False for _ in broadcast_b), False if normal_arrow_order else True), + (*((i, impure_even_for_broadcast_indices) for i in broadcast_b), edge_common), + torch.randn( + [*[i + impure_even_for_broadcast_indices for i in broadcast_b], dim_common] + ), + ).update_mask() + else: + b = NamedGrassmannTensor( + tuple((f"b{i}" for i in range(len(broadcast_b) + 2))), + (*(False for _ in broadcast_b), False if normal_arrow_order else True, arrow_b), + (*((i, impure_even_for_broadcast_indices) for i in broadcast_b), edge_common, edge_b), + torch.randn( + [*[i + impure_even_for_broadcast_indices for i in broadcast_b], dim_common, dim_b] + ), + ).update_mask() + if a.tensor.dim() <= 2 and b.tensor.dim() <= 2: + pytest.skip("One of the two tensors needs to have a dimension greater than 2") + with pytest.raises(AssertionError, match="All edges except the last two must be pure even"): + _ = a.matmul(b) + + +def test_named_matmul_operator_matmul( + a_is_vector: bool, + b_is_vector: bool, +) -> None: + broadcast_a, broadcast_b, broadcast_result, broadcast_name_result = ( + (7, 8), + (7, 1), + (7, 8), + ("a0", "a1"), + ) + arrow_a, arrow_b, edge_a, edge_common, edge_b = True, True, (2, 2), (2, 2), (2, 2) + dim_a = sum(edge_a) + dim_common = sum(edge_common) + dim_b = sum(edge_b) + a = NamedGrassmannTensor( + tuple((f"a{i}" for i in range(len(broadcast_a) + 2))), + (*(False for _ in broadcast_a), arrow_a, True), + (*((i, 0) for i in broadcast_a), edge_a, edge_common), + torch.randn([*broadcast_a, dim_a, dim_common]), + ).update_mask() + + b = NamedGrassmannTensor( + tuple((f"b{i}" for i in range(len(broadcast_b) + 2))), + (*(False for _ in broadcast_b), False, arrow_b), + (*((i, 0) for i in broadcast_b), edge_common, edge_b), + torch.randn([*broadcast_b, dim_common, dim_b]), + ).update_mask() + + c = a @ b + expected = a.tensor.matmul(b.tensor) + assert c.names == tuple(broadcast_name_result) + ( + a.names[len(broadcast_a)], + b.names[len(broadcast_b) + 1], + ) + assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b) + assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b) + assert torch.allclose(c.tensor, expected) + + +@pytest.fixture( + params=[ + NamedGrassmannTensor(("a", "b"), (False, False), ((2, 2), (1, 3)), torch.randn([4, 4])), + NamedGrassmannTensor( + ("a", "b", "c"), (True, False, True), ((1, 1), (2, 2), (3, 1)), torch.randn([2, 4, 4]) + ), + NamedGrassmannTensor( + ("a", "b", "c", "d"), + (True, True, False, False), + ((1, 2), (2, 2), (1, 1), (3, 1)), + torch.randn([3, 4, 2, 4]), + ), + ] +) +def named_tensors(request: pytest.FixtureRequest) -> GrassmannTensor: + return request.param + + +@pytest.mark.parametrize( + "unsupported_type", + [ + "string", # string + None, # NoneType + {"key", "value"}, # dict + [1, 2, 3], # list + {1, 2}, # set + object(), # arbitrary object + ], +) +def test_named_matmul_unsupported_type_raises_typeerror( + unsupported_type: typing.Any, + named_tensors: NamedGrassmannTensor, +) -> None: + with pytest.raises(TypeError): + _ = named_tensors @ unsupported_type + + with pytest.raises(TypeError): + _ = unsupported_type @ named_tensors + + with pytest.raises(TypeError): + named_tensors @= unsupported_type + + +def test_named_matmul_operator_rmatmul() -> None: + broadcast_a, broadcast_b, broadcast_result, broadcast_name_result = ( + (7, 8), + (7, 1), + (7, 8), + ("a0", "a1"), + ) + arrow_a, arrow_b, edge_a, edge_common, edge_b = True, True, (2, 2), (2, 2), (2, 2) + dim_a = sum(edge_a) + dim_common = sum(edge_common) + dim_b = sum(edge_b) + a = NamedGrassmannTensor( + tuple((f"a{i}" for i in range(len(broadcast_a) + 2))), + (*(False for _ in broadcast_a), arrow_a, True), + (*((i, 0) for i in broadcast_a), edge_a, edge_common), + torch.randn([*broadcast_a, dim_a, dim_common]), + ).update_mask() + + b = NamedGrassmannTensor( + tuple((f"b{i}" for i in range(len(broadcast_b) + 2))), + (*(False for _ in broadcast_b), False, arrow_b), + (*((i, 0) for i in broadcast_b), edge_common, edge_b), + torch.randn([*broadcast_b, dim_common, dim_b]), + ).update_mask() + + c = a + c @= b + expected = a.tensor.matmul(b.tensor) + assert c.names == tuple(broadcast_name_result) + ( + a.names[len(broadcast_a)], + b.names[len(broadcast_b) + 1], + ) + assert c.arrow == (*(False for _ in broadcast_result), arrow_a, arrow_b) + assert c.edges == (*((i, 0) for i in broadcast_result), edge_a, edge_b) + assert torch.allclose(c.tensor, expected) diff --git a/tests/name_test.py b/tests/name_test.py new file mode 100644 index 0000000..52b11cd --- /dev/null +++ b/tests/name_test.py @@ -0,0 +1,46 @@ +import torch +import pytest + +from grassmann_tensor import NamedGrassmannTensor + + +@pytest.mark.parametrize("rename_range", [(i, j) for i in range(5) for j in range(5) if j >= i]) +def test_rename(rename_range: tuple[int, int]) -> None: + l, h = rename_range # noqa: E741 + edge = (2, 2) + a = NamedGrassmannTensor( + tuple(f"o{i}" for i in range(5)), + tuple([False] * 5), + tuple(edge for _ in range(5)), + torch.randn(*[4] * 5), + ) + new_names = tuple(f"n{i}" for i in range(h - l)) + name_map: dict[str, str] = {old: new for old, new in zip(a.names[l:h], new_names)} + b = a.rename(name_map) + assert b.names == a.names[:l] + new_names + a.names[h:] + + +def test_rename_duplicate_names() -> None: + edge = (2, 2) + a = NamedGrassmannTensor( + tuple(f"o{i}" for i in range(5)), + tuple([False] * 5), + tuple(edge for _ in range(5)), + torch.randn(*[4] * 5), + ) + new_names = tuple("n" for i in range(5)) + name_map: dict[str, str] = {old: new for old, new in zip(a.names, new_names)} + with pytest.raises(ValueError, match="Duplicate name"): + _ = a.rename(name_map) + + +def test_get_name_index() -> None: + edge = (2, 2) + a = NamedGrassmannTensor( + tuple(f"o{i}" for i in range(5)), + tuple([False] * 5), + tuple(edge for _ in range(5)), + torch.randn(*[4] * 5), + ) + with pytest.raises(KeyError, match="not in names"): + _ = a.get_name_index("a") diff --git a/tests/permute_test.py b/tests/permute_test.py index 12f6949..3f30126 100644 --- a/tests/permute_test.py +++ b/tests/permute_test.py @@ -1,6 +1,6 @@ import pytest import torch -from grassmann_tensor import GrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor PermuteCase = tuple[ tuple[bool, ...], tuple[tuple[int, int], ...], torch.Tensor, tuple[int, ...], torch.Tensor @@ -140,3 +140,194 @@ def test_permute_high_order() -> None: assert b.tensor[l, j, i, n, k, m] == -a.tensor[i, j, k, l, m, n] else: assert b.tensor[l, j, i, n, k, m] == a.tensor[i, j, k, l, m, n] + + +NamedPermuteCase = tuple[ + tuple[str, ...], + tuple[bool, ...], + tuple[tuple[int, int], ...], + torch.Tensor, + tuple[str, ...], + torch.Tensor, +] + + +@pytest.mark.parametrize( + "x", + [ + ((), (), (), torch.tensor(6), (), torch.tensor(6)), + (("a",), (False,), ((1, 1),), torch.tensor([1, 2]), ("a",), torch.tensor([1, 2])), + ( + ("a", "b"), + (False, True), + ((1, 1), (0, 0)), + torch.zeros([2, 0]), + ("b", "a"), + torch.zeros([0, 2]), + ), + ( + ("a", "b"), + (False, True), + ((1, 1), (0, 1)), + torch.tensor([[0], [4]]), + ("b", "a"), + torch.tensor([[0, -4]]), + ), + ( + ("a", "b"), + (False, False), + ((1, 1), (1, 1)), + torch.tensor([[1, 0], [0, 4]]), + ("a", "b"), + torch.tensor([[1, 0], [0, 4]]), + ), + ( + ("a", "b"), + (True, False), + ((1, 1), (1, 1)), + torch.tensor([[1, 0], [0, 4]]), + ("b", "a"), + torch.tensor([[1, 0], [0, -4]]), + ), + ( + ("a", "b", "c"), + (False, True, True), + ((1, 1), (1, 1), (1, 1)), + torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), + ("a", "c", "b"), + torch.tensor([[[1, 0], [0, -2]], [[0, 4], [3, 0]]]), + ), + ( + ("a", "b", "c"), + (True, True, True), + ((1, 1), (1, 1), (1, 1)), + torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), + ("b", "a", "c"), + torch.tensor([[[1, 0], [0, 3]], [[0, 2], [-4, 0]]]), + ), + ( + ("a", "b", "c"), + (True, False, False), + ((1, 1), (1, 1), (1, 1)), + torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), + ("c", "b", "a"), + torch.tensor([[[1, 0], [0, -4]], [[0, -3], [-2, 0]]]), + ), + ( + ("a", "b", "c"), + (False, False, False), + ((1, 1), (1, 1), (1, 1)), + torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), + ("c", "a", "b"), + torch.tensor([[[1, 0], [0, 4]], [[0, -2], [-3, 0]]]), + ), + ], +) +def test_named_tensor_permute(x: NamedPermuteCase) -> None: + names, arrow, edges, tensor, before_by_after, expected = x + grassmann_tensor = NamedGrassmannTensor(names, arrow, edges, tensor) + result = grassmann_tensor.permute(before_by_after) + assert torch.allclose(result.tensor, expected) + + +NamedPermuteFailCase = tuple[ + tuple[str, ...], + tuple[bool, ...], + tuple[tuple[int, int], ...], + torch.Tensor, + tuple[str, ...], + str, +] + + +@pytest.mark.parametrize( + "x", + [ + ( + ("a", "b"), + (False, False), + ((1, 1), (1, 1)), + torch.tensor([[1, 0], [0, 4]]), + ("a", "a"), + "Permutation indices must be unique", + ), + ( + ("a", "b"), + (False, False), + ((1, 1), (1, 1)), + torch.tensor([[1, 0], [0, 4]]), + ("a",), + "Permutation indices must cover all dimensions", + ), + ( + ("a", "b"), + (False, False), + ((1, 1), (1, 1)), + torch.tensor([[1, 0], [0, 4]]), + ("a", "a", "b"), + "Permutation indices must be unique", + ), + ], +) +def test_named_tensor_permute_fail(x: NamedPermuteFailCase) -> None: + names, arrow, edges, tensor, before_by_after, message = x + grassmann_tensor = NamedGrassmannTensor(names, arrow, edges, tensor) + with pytest.raises(AssertionError, match=message): + grassmann_tensor.permute(before_by_after) + + +def test_named_tensor_permute_high_order() -> None: + edge = (2, 2) + a = NamedGrassmannTensor( + ("i", "j", "k", "l", "m", "n"), + (False, False, False, False, False, False), + (edge, edge, edge, edge, edge, edge), + torch.randn(4, 4, 4, 4, 4, 4), + ).update_mask() + # a[i, j, k, l, m, n] -> b[l, j, i, n, k, m] + b = a.permute(("l", "j", "i", "n", "k", "m")) + for i in range(4): + for j in range(4): + for k in range(4): + for l in range(4): # noqa: E741 + for m in range(4): + for n in range(4): + p = [bool(x & 2) for x in (i, j, k, l, m, n)] + if sum(p) % 2 != 0: + continue + # i j k l m n + # (l) (i j k) m n + # l (j) (i) k m n + # l j i (n) (k m) + sign = ( + (p[3] & (p[0] ^ p[1] ^ p[2])) + ^ (p[1] & p[0]) + ^ (p[5] & (p[2] ^ p[4])) + ) + if sign: + assert b.tensor[l, j, i, n, k, m] == -a.tensor[i, j, k, l, m, n] + else: + assert b.tensor[l, j, i, n, k, m] == a.tensor[i, j, k, l, m, n] + + +def test_mask_cache_should_not_survive_permute() -> None: + a = NamedGrassmannTensor( + ("a", "b", "c"), + (False, False, False), + ((2, 2), (1, 0), (1, 0)), + torch.randn(4, 1, 1, dtype=torch.float64), + ) + + _ = a.mask + + b = a.permute(("c", "a", "b")) + + gt_fresh = GrassmannTensor( + _arrow=b.arrow, + _edges=b.edges, + _tensor=b.tensor, + _parity=None, + _mask=None, + ) + + assert torch.equal(b.mask, gt_fresh.mask) diff --git a/tests/reciprocal_test.py b/tests/reciprocal_test.py new file mode 100644 index 0000000..bda4b01 --- /dev/null +++ b/tests/reciprocal_test.py @@ -0,0 +1,45 @@ +import torch +import pytest + +from grassmann_tensor import NamedGrassmannTensor + + +@pytest.mark.parametrize( + "x", + [ + NamedGrassmannTensor( + ("a", "b"), (True, True), ((2, 2), (4, 4)), torch.randn(4, 8, dtype=torch.float64) + ).update_mask(), + NamedGrassmannTensor( + ("a", "b"), (False, False), ((2, 2), (4, 4)), torch.randn(4, 8, dtype=torch.float64) + ).update_mask(), + NamedGrassmannTensor( + ("a", "b", "c"), + (True, True, True), + ((2, 2), (4, 4), (8, 8)), + torch.randn(4, 8, 16, dtype=torch.float64), + ).update_mask(), + NamedGrassmannTensor( + ("a", "b", "c", "d"), + (True, True, True, True), + ((2, 2), (4, 4), (8, 8), (16, 16)), + torch.randn(4, 8, 16, 32, dtype=torch.float64), + ).update_mask(), + ], +) +def test_reciprocal(x: NamedGrassmannTensor) -> None: + tensor = x.reciprocal() + assert tensor.names == x.names + assert tensor.arrow == x.arrow + assert tensor.edges == x.edges + assert tensor.tensor.shape == x.tensor.shape + assert tensor.tensor.dtype == x.tensor.dtype + assert tensor.tensor.device == x.tensor.device + + assert not torch.isinf(tensor.tensor).any() + + zero = x.tensor == 0 + assert torch.equal(tensor.tensor[zero], x.tensor[zero]) + + non_zero = ~zero + assert torch.allclose(tensor.tensor[non_zero], (1 / x.tensor[non_zero])) diff --git a/tests/reshape_test.py b/tests/reshape_test.py index 77f1823..8a8ee34 100644 --- a/tests/reshape_test.py +++ b/tests/reshape_test.py @@ -1,7 +1,7 @@ import random import pytest import torch -from grassmann_tensor import GrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor @pytest.mark.parametrize( @@ -180,10 +180,8 @@ def test_reshape_equal_edges_nontrivial_merging_with_other_edge() -> None: def test_reshape_with_none() -> None: a = GrassmannTensor((), (), torch.tensor(2333)).reshape(((1, 0), (1, 0))).reshape(()) assert len(a.arrow) == 0 and len(a.edges) == 0 and a.tensor.dim() == 0 - b = GrassmannTensor((), (), torch.tensor(2333)).reshape(((1, 0), (1, 0))).reshape(()) - assert len(b.arrow) == 0 and len(b.edges) == 0 and b.tensor.dim() == 0 - c = GrassmannTensor((), (), torch.tensor(2333)).reshape((1, 1)) - assert len(c.arrow) == 2 and len(c.edges) == 2 and c.tensor.dim() == 2 + b = GrassmannTensor((), (), torch.tensor(2333)).reshape((1, 1)) + assert len(b.arrow) == 2 and len(b.edges) == 2 and b.tensor.dim() == 2 def test_reshape_with_none_edge_assertion() -> None: @@ -289,3 +287,183 @@ def test_reshape_plan_exhausted_then_skip_trivial_self_edges() -> None: assert out.edges == ((2, 2),) assert out.tensor.shape == (4,) assert out.arrow == (False,) + + +@pytest.mark.parametrize( + "arrow", + [ + (i, j, k, l, m) + for i in [False, True] + for j in [False, True] + for k in [False, True] + for l in [False, True] # noqa: E741 + for m in [False, True] + ], +) +@pytest.mark.parametrize("plan_range", [(i, j) for i in range(5) for j in range(5) if j > i]) +def test_named_tensor_reshape_consistency( + arrow: tuple[bool, ...], plan_range: tuple[int, int] +) -> None: + names = tuple(chr(ord("a") + i) for i in range(5)) + l, h = plan_range # noqa: E741 + if not all(arrow[l:h]) and any(arrow[l:h]): + pytest.skip("Invalid reshape plan for the given arrow configuration.") + edge = (2, 2) + a = NamedGrassmannTensor( + names, arrow, (edge, edge, edge, edge, edge), torch.randn([4, 4, 4, 4, 4]) + ) + merged_name = f"m{l}_{h}" + merge_map: dict[str, tuple[str, ...]] = {merged_name: names[l:h]} + + split_map: dict[str, tuple[tuple[str, tuple[int, int]], ...]] = { + merged_name: tuple((n, edge) for n in names[l:h]) + } + b = a.merge_edge(merge_map) + c = b.split_edge(split_map) + assert torch.allclose(a.tensor, c.tensor) + + +def test_named_tensor_merging_mixed_arrows() -> None: + names = ("a", "b", "c") + arrow = (True, False, True) + edges = ((2, 2), (2, 2), (2, 2)) + a = NamedGrassmannTensor(names, arrow, edges, torch.randn([4, 4, 4])) + with pytest.raises(AssertionError, match="Cannot merge edges with different arrows"): + _ = a.merge_edge({"a": ("a", "b", "c")}) + + +def test_named_tensor_splitting_dimension_mismatch_edges_because_of_unequal() -> None: + names = ("a",) + arrow = (True,) + edges = ((8, 8),) + a = NamedGrassmannTensor(names, arrow, edges, torch.randn([16])) + _ = a.split_edge({"a": (("a", (2, 2)), ("b", (2, 2)))}) + with pytest.raises(AssertionError, match="Dimension mismatch in splitting"): + _ = a.split_edge({"a": (("a", (4, 4)), ("b", (2, 2)))}) + + +def test_named_tensor_splitting_dimension_mismatch_edges_because_of_different_even_odd() -> None: + names = ("a", "b") + arrow = (True, True) + edges = ((3, 1), (2, 2)) + a = NamedGrassmannTensor(names, arrow, edges, torch.randn([4, 4])) + _ = a.split_edge({"a": (("a", (0, 1)), ("b", (3, 1)), ("c", (0, 1))), "b": (("d", (2, 2)),)}) + with pytest.raises(AssertionError, match="Dimension mismatch in splitting"): + _ = a.split_edge({"a": (("a", (0, 1)), ("b", (2, 2))), "b": (("c", (0, 1)), ("d", (2, 2)))}) + with pytest.raises(AssertionError, match="Dimension mismatch in splitting"): + _ = a.split_edge({"a": (("a", (0, 1)), ("b", (3, 1))), "b": (("c", (2, 2)),)}) + + +def test_named_tensor_splitting_shape_exceeds() -> None: + names = ("a",) + arrow = (False,) + edges = ((8, 8),) + a = NamedGrassmannTensor(names, arrow, edges, torch.randn([16])) + with pytest.raises(AssertionError, match="New shape exceeds in splitting"): + _ = a.split_edge({"a": (("a", (1, 1)), ("b", (1, 1)))}) + + +def test_named_tensor_equal_edges_trivial() -> None: + names = ("a",) + arrow = (True,) + edges = ((2, 2),) + a = NamedGrassmannTensor(names, arrow, edges, torch.randn([4])) + _ = a.split_edge({"a": (("a", (2, 2)),)}) + + +def test_named_tensor_equal_edges_nontrivial_splitting() -> None: + names = ("a",) + arrow = (True,) + edges = ((1, 3),) + a = NamedGrassmannTensor(names, arrow, edges, torch.randn([4])) + _ = a.split_edge({"a": (("a", (3, 1)), ("b", (1, 0)), ("c", (0, 1)))}) + + +def test_named_tensor_equal_edges_nontrivial_splitting_with_other_edge() -> None: + names = ("a", "b") + arrow = (True, True) + edges = ((1, 3), (2, 2)) + a = NamedGrassmannTensor(names, arrow, edges, torch.randn([4, 4])) + _ = a.split_edge({"a": (("a", (3, 1)), ("b", (1, 0)), ("c", (0, 1))), "b": (("d", (2, 2)),)}) + + +def test_named_tensor_equal_edges_nontrivial_merging() -> None: + names = ("a", "b", "c") + arrow = (True, True, True) + edges = ((1, 3), (1, 0), (0, 1)) + a = NamedGrassmannTensor(names, arrow, edges, torch.randn([4, 1, 1])) + _ = a.merge_edge({"a": ("a", "b", "c")}) + + +def test_named_tensor_equal_edges_nontrivial_merging_with_other_edge() -> None: + names = ("a", "b", "c", "d") + arrow = (True, True, True, True) + edges = ((1, 3), (1, 0), (0, 1), (2, 2)) + a = NamedGrassmannTensor(names, arrow, edges, torch.randn([4, 1, 1, 4])) + _ = a.merge_edge({"a": ("a", "b", "c"), "b": ("d",)}) + + +def test_named_tensor_with_none() -> None: + a = ( + NamedGrassmannTensor((), (), (), torch.tensor(2333)) + .split_edge({"": (("a", (1, 0)), ("b", (1, 0)))}) + .to_scalar() + ) + assert len(a.arrow) == 0 and len(a.edges) == 0 and a.tensor.dim() == 0 + b = NamedGrassmannTensor((), (), (), torch.tensor(2333)).split_edge( + {"": (("a", (1, 0)), ("b", (1, 0)))} + ) + assert len(b.arrow) == 2 and len(b.edges) == 2 and b.tensor.dim() == 2 + + +def test_named_tensor_with_none_edge_assertion() -> None: + with pytest.raises(AssertionError, match="Only pure even edges can be merged into none edges"): + _ = NamedGrassmannTensor( + ("a", "b"), (True, True), ((0, 1), (1, 0)), torch.tensor([[2333]]) + ).to_scalar() + with pytest.raises(AssertionError, match="Cannot split none edges into illegal edges"): + _ = NamedGrassmannTensor((), (), (), torch.tensor(2333)).split_edge({"": (("a", (0, 1)),)}) + with pytest.raises(AssertionError, match="Cannot split none edges into illegal edges"): + _ = NamedGrassmannTensor((), (), (), torch.tensor(2333)).split_edge( + {"": (("a", (0, 1)), ("b", (1, 0)))} + ) + + +def test_named_tensor_plan_exhausted_then_skip_trivial_self_edges() -> None: + a = NamedGrassmannTensor( + ("a", "b", "c"), + (False, False, False), + ((2, 2), (1, 0), (1, 0)), + torch.randn(4, 1, 1), + ) + out = a.merge_edge({"a": ("a", "b", "c")}) + assert out.names == ("a",) + assert out.edges == ((2, 2),) + assert out.tensor.shape == (4,) + assert out.arrow == (False,) + + +def test_named_tensor_permute() -> None: + a = NamedGrassmannTensor( + ("a", "b", "c"), + (False, False, False), + ((2, 2), (1, 0), (1, 0)), + torch.randn(4, 1, 1), + ) + out = a.merge_edge({"a": ("a", "c")}) + assert out.names == ("a", "b") + assert out.edges == ((2, 2), (1, 0)) + assert out.tensor.shape == (4, 1) + assert out.arrow == (False, False) + + out = a.merge_edge({"a": ("c", "a")}) + assert out.names == ("a", "b") + assert out.edges == ((2, 2), (1, 0)) + assert out.tensor.shape == (4, 1) + assert out.arrow == (False, False) + + out = a.merge_edge({"b": ("b", "c")}) + assert out.names == ("a", "b") + assert out.edges == ((2, 2), (1, 0)) + assert out.tensor.shape == (4, 1) + assert out.arrow == (False, False) diff --git a/tests/reverse_test.py b/tests/reverse_test.py index ead3806..419f570 100644 --- a/tests/reverse_test.py +++ b/tests/reverse_test.py @@ -1,6 +1,6 @@ import pytest import torch -from grassmann_tensor.tensor import GrassmannTensor +from grassmann_tensor.tensor import GrassmannTensor, NamedGrassmannTensor ReverseCase = tuple[ tuple[bool, ...], tuple[tuple[int, int], ...], torch.Tensor, tuple[int, ...], torch.Tensor @@ -127,3 +127,147 @@ def test_reverse_fail(x: ReverseFailCase) -> None: grassmann_tensor = GrassmannTensor(arrow, edges, tensor) with pytest.raises(AssertionError, match=message): grassmann_tensor.reverse(reverse_by) + + +NamedReverseCase = tuple[ + tuple[str, ...], + tuple[bool, ...], + tuple[tuple[int, int], ...], + torch.Tensor, + set[str], + torch.Tensor, +] + + +@pytest.mark.parametrize( + "x", + [ + ((), (), (), torch.tensor(6), (), torch.tensor(6)), + ( + ("a", "b"), + (False, False), + ((1, 1), (0, 0)), + torch.zeros([2, 0]), + ("a",), + torch.zeros([2, 0]), + ), + ( + ("a", "b"), + (False, False), + ((1, 1), (0, 1)), + torch.tensor([[0], [4]]), + ("a",), + torch.tensor([[0], [4]]), + ), + ( + ("a", "b"), + (False, False), + ((1, 1), (1, 1)), + torch.tensor([[1, 0], [0, 4]]), + (), + torch.tensor([[1, 0], [0, 4]]), + ), + ( + ("a", "b"), + (False, False), + ((1, 1), (1, 1)), + torch.tensor([[1, 0], [0, 4]]), + ("a",), + torch.tensor([[1, 0], [0, 4]]), + ), + ( + ("a", "b"), + (True, False), + ((1, 1), (1, 1)), + torch.tensor([[1, 0], [0, 4]]), + ("a",), + torch.tensor([[1, 0], [0, -4]]), + ), + ( + ("a", "b"), + (False, False), + ((1, 1), (1, 1)), + torch.tensor([[1, 0], [0, 4]]), + ("b",), + torch.tensor([[1, 0], [0, 4]]), + ), + ( + ("a", "b"), + (False, True), + ((1, 1), (1, 1)), + torch.tensor([[1, 0], [0, 4]]), + ("b",), + torch.tensor([[1, 0], [0, -4]]), + ), + ( + ("a", "b", "c"), + (False, False, False), + ((1, 1), (1, 1), (1, 1)), + torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), + ("a",), + torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), + ), + ( + ("a", "b", "c"), + (True, False, False), + ((1, 1), (1, 1), (1, 1)), + torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), + ("a",), + torch.tensor([[[1, 0], [0, 2]], [[0, -3], [-4, 0]]]), + ), + ( + ("a", "b", "c"), + (False, True, True), + ((1, 1), (1, 1), (1, 1)), + torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), + ("c",), + torch.tensor([[[1, 0], [0, -2]], [[0, -3], [4, 0]]]), + ), + ( + ("a", "b", "c"), + (True, False, True), + ((1, 1), (1, 1), (1, 1)), + torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), + ("a", "b"), + torch.tensor([[[1, 0], [0, 2]], [[0, -3], [-4, 0]]]), + ), + ( + ("a", "b", "c"), + (True, True, True), + ((1, 1), (1, 1), (1, 1)), + torch.tensor([[[1, 0], [0, 2]], [[0, 3], [4, 0]]]), + ("a", "b"), + torch.tensor([[[1, 0], [0, -2]], [[0, -3], [4, 0]]]), + ), + ], +) +def test_named_tensor_reverse(x: NamedReverseCase) -> None: + names, arrow, edges, tensor, reverse_by, expected = x + grassmann_tensor = NamedGrassmannTensor(names, arrow, edges, tensor) + result = grassmann_tensor.reverse(reverse_by) + assert torch.allclose(result.tensor, expected) + + +NamedReverseFailCase = tuple[ + tuple[str, ...], tuple[bool, ...], tuple[tuple[int, int], ...], torch.Tensor, set[str], str +] + + +@pytest.mark.parametrize( + "x", + [ + ( + ("a", "b"), + (False, False), + ((1, 1), (1, 1)), + torch.tensor([[1, 0], [0, 4]]), + ("a", "a"), + "Indices must be unique", + ), + ], +) +def test_named_tensor_reverse_fail(x: NamedReverseFailCase) -> None: + names, arrow, edges, tensor, reverse_by, message = x + grassmann_tensor = NamedGrassmannTensor(names, arrow, edges, tensor) + with pytest.raises(AssertionError, match=message): + grassmann_tensor.reverse(reverse_by) diff --git a/tests/svd_test.py b/tests/svd_test.py index 01b57ec..98ba506 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -4,14 +4,16 @@ import itertools from typing import TypeAlias, Iterable, Any -from grassmann_tensor import GrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor +Names: TypeAlias = tuple[str, ...] Arrow: TypeAlias = tuple[bool, ...] Edges: TypeAlias = tuple[tuple[int, int], ...] Tensor: TypeAlias = torch.Tensor Cutoff: TypeAlias = int | tuple[int, int] | None Tau: TypeAlias = float FreeNamesU: TypeAlias = tuple[int, ...] +NamedFreeNamesU: TypeAlias = set[str] SVDCases = Iterable[ParameterSet] @@ -169,8 +171,6 @@ def test_svd_with_incompatible_cutoff( @pytest.mark.parametrize("a,b", [(3, 5), (1, 1), (8, 2)]) def test_svd_both_blocks_empty_raises_with_int_cutoff(a: int, b: int) -> None: - # edges: left=(even_left=0, odd_left=a), right=(even_right=b, odd_right=0) - # tensor shape must be (a, b) arrow = (True, True) edges = ((0, a), (b, 0)) tensor = torch.randn(a, b, dtype=torch.float64) @@ -196,30 +196,30 @@ def test_svd_both_blocks_empty_raises_with_tuple_cutoff(a: int, b: int) -> None: @pytest.mark.parametrize( - "a,b,c,k", + "a, b, k", [ - (3, 5, 7, 2), - (4, 1, 2, 3), + (3, 5, 2), + (4, 1, 3), ], ) -def test_svd_int_cutoff_even_block_empty_select_from_odd_only( - a: int, b: int, c: int, k: int -) -> None: +def test_svd_int_cutoff_even_block_empty_select_from_odd_only(a: int, b: int, k: int) -> None: arrow = (True, True) - edges = ((0, a), (b, c)) - tensor = torch.randn(a, b + c, dtype=torch.float64) + edges = ((0, a), (0, b)) + tensor = torch.randn(a, b, dtype=torch.complex128) - gt = GrassmannTensor(arrow, edges, tensor) - U, S, Vh = gt.svd((0,), cutoff=k) + pure_odd_tensor = GrassmannTensor(arrow, edges, tensor).update_mask() + U, S, Vh = pure_odd_tensor.svd((0,), cutoff=k) - expected_k = min(k, min(a, c)) + expected_k = min(k, min(a, b)) + assert U.edges[0] == (0, a) assert U.edges[-1] == (0, expected_k) assert Vh.edges[0] == (0, expected_k) + assert Vh.edges[1] == (0, b) assert S.edges == ((0, expected_k), (0, expected_k)) @pytest.mark.parametrize( - "a,b,k", + "a, b, k", [ (5, 4, 2), (7, 3, 5), @@ -228,14 +228,16 @@ def test_svd_int_cutoff_even_block_empty_select_from_odd_only( def test_svd_int_cutoff_odd_block_empty_select_from_even_only(a: int, b: int, k: int) -> None: arrow = (True, True) edges = ((a, 0), (b, 0)) - tensor = torch.randn(a, b, dtype=torch.float64) + tensor = torch.randn(a, b, dtype=torch.complex128) gt = GrassmannTensor(arrow, edges, tensor) U, S, Vh = gt.svd((0,), cutoff=k) expected_k = min(k, min(a, b)) + assert U.edges[0] == (a, 0) assert U.edges[-1] == (expected_k, 0) assert Vh.edges[0] == (expected_k, 0) + assert Vh.edges[1] == (b, 0) assert S.edges == ((expected_k, 0), (expected_k, 0)) @@ -265,3 +267,254 @@ def test_svd_dtype_device_continuity(dtype: torch.dtype, device: torch.device) - assert u.tensor.device == device assert s.tensor.device == device assert vh.tensor.device == device + + +BASE_NAMED_GT_CASES: list[tuple[Names, Arrow, Edges, Tensor]] = [ + (("a", "b"), (True, True), ((2, 2), (4, 4)), torch.randn(4, 8, dtype=torch.float64)), + (("a", "b"), (False, False), ((2, 2), (4, 4)), torch.randn(4, 8, dtype=torch.float64)), + ( + ("a", "b", "c"), + (True, True, True), + ((2, 2), (4, 4), (8, 8)), + torch.randn(4, 8, 16, dtype=torch.float64), + ), + ( + ("a", "b", "c", "d"), + (True, True, True, True), + ((2, 2), (4, 4), (8, 8), (16, 16)), + torch.randn(4, 8, 16, 32, dtype=torch.float64), + ), +] + + +def named_svd_cases() -> SVDCases: + params = [] + for names, arrow, edges, tensor in BASE_NAMED_GT_CASES: + for fnu in choose_free_names(len(edges)): + even_singular, odd_singular = get_total_singular(edges, fnu) + max_singular = max(even_singular, odd_singular) + total = even_singular + odd_singular + cutoff_list = [ + None, + max_singular, + max_singular - 1, + (even_singular, odd_singular), + ] + for cutoff in cutoff_list: + if cutoff is None: + kept = total + elif isinstance(cutoff, int): + k = cutoff + kept = min(k, even_singular) + min(k, odd_singular) + else: + ke = min(int(cutoff[0]), even_singular) + ko = min(int(cutoff[1]), odd_singular) + kept = ke + ko + tau = tau_for_cutoff(kept, total) + named_fnu = tuple(names[i] for i in fnu) + params.append( + pytest.param( + names, + arrow, + edges, + tensor, + cutoff, + tau, + named_fnu, + id=f"edges={tuple(edges)}|fnu={named_fnu}|cut={cutoff}|tau={tau:.2e}", + ) + ) + return params + + +@pytest.mark.parametrize( + "names, arrow, edges, tensor, cutoff, tau, free_names_u", + named_svd_cases(), +) +@pytest.mark.repeat(20) +def test_named_svd( + names: Names, + arrow: Arrow, + edges: Edges, + tensor: Tensor, + cutoff: Cutoff, + tau: Tau, + free_names_u: NamedFreeNamesU, +) -> None: + gt = NamedGrassmannTensor(names, arrow, edges, tensor).update_mask() + U, S, Vh = gt.svd(free_names_u, "right", "left", "left", "right", cutoff=cutoff) + + US = U.contract(S, {("right", "left")}) + USV = US.contract(Vh, {("right", "left")}) + + USV = USV.permute(names) + + masked = gt.update_mask().tensor + den = masked.norm() + eps = torch.finfo(masked.dtype).eps + rel_err = (masked - USV.tensor).norm() / max(den, eps) + assert rel_err <= tau + + +@pytest.mark.parametrize( + "names, arrow, edges, tensor, cutoff , tau, free_names_u", + named_svd_cases(), +) +@pytest.mark.parametrize( + "incompatible_cutoff", + [ + -1, + 0, + ( + 1, + 2, + 3, + ), + "string", + {"key", "value"}, + [1, 2, 3], + {1, 2}, + object(), + ], +) +def test_named_svd_with_incompatible_cutoff( + names: Names, + arrow: Arrow, + edges: Edges, + tensor: Tensor, + cutoff: Cutoff, + tau: Tau, + free_names_u: NamedFreeNamesU, + incompatible_cutoff: Any, +) -> None: + gt = NamedGrassmannTensor(names, arrow, edges, tensor) + if isinstance(incompatible_cutoff, int): + with pytest.raises(AssertionError, match="Cutoff must be greater than 0"): + _, _, _ = gt.svd( + free_names_u, "right", "left", "left", "right", cutoff=incompatible_cutoff + ) + elif isinstance(incompatible_cutoff, tuple): + with pytest.raises( + AssertionError, match="The length of cutoff must be 2 if cutoff is a tuple" + ): + _, _, _ = gt.svd( + free_names_u, "right", "left", "left", "right", cutoff=incompatible_cutoff + ) + else: + with pytest.raises( + ValueError, match="Cutoff must be an integer or a tuple of two integers" + ): + _, _, _ = gt.svd( + free_names_u, "right", "left", "left", "right", cutoff=incompatible_cutoff + ) + + +@pytest.mark.parametrize("a,b", [(3, 5), (1, 1), (8, 2)]) +def test_named_svd_both_blocks_empty_raises_with_int_cutoff(a: int, b: int) -> None: + names = ("a", "b") + arrow = (True, True) + edges = ((0, a), (b, 0)) + tensor = torch.randn(a, b, dtype=torch.float64) + + gt = NamedGrassmannTensor(names, arrow, edges, tensor) + + free_names_u = {"a"} + with pytest.raises(RuntimeError, match="Both parity block are empty. Can not form SVD."): + _ = gt.svd(free_names_u, "right", "left", "left", "right", cutoff=1) + + +@pytest.mark.parametrize("a,b", [(3, 5), (2, 4), (7, 3)]) +def test_named_svd_both_blocks_empty_raises_with_tuple_cutoff(a: int, b: int) -> None: + names = ("a", "b") + arrow = (True, True) + edges = ((0, a), (b, 0)) + tensor = torch.randn(a, b, dtype=torch.float64) + + gt = NamedGrassmannTensor(names, arrow, edges, tensor) + + free_names_u = {"a"} + with pytest.raises(RuntimeError, match="Both parity block are empty. Can not form SVD."): + _ = gt.svd(free_names_u, "right", "left", "left", "right", cutoff=(1, 1)) + + +@pytest.mark.parametrize( + "a, b, k", + [ + (3, 5, 2), + (4, 1, 3), + ], +) +def test_named_svd_int_cutoff_even_block_empty_select_from_odd_only(a: int, b: int, k: int) -> None: + names = ("a", "b") + arrow = (True, True) + edges = ((0, a), (0, b)) + tensor = torch.randn(a, b, dtype=torch.complex128) + + gt = NamedGrassmannTensor(names, arrow, edges, tensor) + U, S, Vh = gt.svd({"a"}, "right", "left", "left", "right", cutoff=k) + + expected_k = min(k, min(a, b)) + assert U.names == ("a", "right") + assert U.edges[0] == (0, a) + assert U.edges[-1] == (0, expected_k) + assert Vh.names == ("left", "b") + assert Vh.edges[0] == (0, expected_k) + assert Vh.edges[1] == (0, b) + assert S.names == ("left", "right") + assert S.edges == ((0, expected_k), (0, expected_k)) + + +@pytest.mark.parametrize( + "a,b,k", + [ + (5, 4, 2), + (7, 3, 5), + ], +) +def test_named_svd_int_cutoff_odd_block_empty_select_from_even_only(a: int, b: int, k: int) -> None: + names = ("a", "b") + arrow = (True, True) + edges = ((a, 0), (b, 0)) + tensor = torch.randn(a, b, dtype=torch.complex128) + + gt = NamedGrassmannTensor(names, arrow, edges, tensor) + U, S, Vh = gt.svd({"a"}, "right", "left", "left", "right", cutoff=k) + + expected_k = min(k, min(a, b)) + assert U.names == ("a", "right") + assert U.edges[0] == (a, 0) + assert U.edges[-1] == (expected_k, 0) + assert Vh.names == ("left", "b") + assert Vh.edges[0] == (expected_k, 0) + assert Vh.edges[1] == (b, 0) + assert S.names == ("left", "right") + assert S.edges == ((expected_k, 0), (expected_k, 0)) + + +devices = [torch.device("cpu")] +if torch.cuda.is_available(): + devices.append(torch.device("cuda:0")) + + +@pytest.mark.parametrize( + "dtype", + [ + torch.float64, + torch.complex128, + ], +) +@pytest.mark.parametrize("device", devices) +def test_named_svd_dtype_device_continuity(dtype: torch.dtype, device: torch.device) -> None: + a = NamedGrassmannTensor( + ("a", "b", "c", "d"), + (True, True, True, True), + ((2, 2), (4, 4), (8, 8), (16, 16)), + torch.randn(4, 8, 16, 32, dtype=dtype, device=device), + ) + u, s, vh = a.svd({"a"}, "right", "left", "left", "right", cutoff=1) + assert u.tensor.dtype == dtype + assert s.tensor.dtype == dtype + assert vh.tensor.dtype == dtype + assert u.tensor.device == device + assert s.tensor.device == device + assert vh.tensor.device == device