From fcbab58061da0718ee089708f576d375c54e3d33 Mon Sep 17 00:00:00 2001 From: Gausshj Date: Tue, 6 Jan 2026 19:28:40 +0800 Subject: [PATCH] feat(tensor): add support for sqrt and norm - Add support for sqrt and norm - Add some test cases for sqrt and norm --- grassmann_tensor/tensor.py | 51 +++++++--------- tests/reciprocal_test.py | 42 ++++++++++++- tests/utility_test.py | 117 +++++++++++++++++++++++++++++++++++++ 3 files changed, 179 insertions(+), 31 deletions(-) create mode 100644 tests/utility_test.py diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 944a31a..9fdb6bb 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -841,35 +841,8 @@ def contract( order_b = left_leg_b + right_leg_b # 1. Permutation - arrow_a = tuple(a.arrow[i] for i in order_a) - edges_a = tuple(a.edges[i] for i in order_a) - tensor_a = a.tensor.permute(order_a) - parity_a = tuple(a.parity[i] for i in order_a) - mask_a = a.mask.permute(order_a) - - a = dataclasses.replace( - a, - _arrow=arrow_a, - _edges=edges_a, - _tensor=tensor_a, - _parity=parity_a, - _mask=mask_a, - ) - - arrow_b = tuple(b.arrow[i] for i in order_b) - edges_b = tuple(b.edges[i] for i in order_b) - tensor_b = b.tensor.permute(order_b) - parity_b = tuple(b.parity[i] for i in order_b) - mask_b = b.mask.permute(order_b) - - b = dataclasses.replace( - b, - _arrow=arrow_b, - _edges=edges_b, - _tensor=tensor_b, - _parity=parity_b, - _mask=mask_b, - ) + a = a.permute(order_a) + b = b.permute(order_b) arrow = a.arrow[:-contract_length_a] + b.arrow[contract_length_b:] edges = a.edges[:-contract_length_a] + b.edges[contract_length_b:] @@ -1125,6 +1098,17 @@ def _tensor_mask(self) -> torch.Tensor: torch.zeros_like(self._tensor, dtype=torch.bool), ) + def reciprocal(self) -> GrassmannTensor: + return dataclasses.replace( + self, _tensor=torch.where(self.tensor == 0, self.tensor, 1 / self.tensor) + ) + + def norm(self, p: typing.Any) -> float: + return float(torch.linalg.vector_norm(self.tensor.masked_select(~self.mask), ord=p)) + + def sqrt(self) -> GrassmannTensor: + return dataclasses.replace(self, _tensor=torch.sqrt(torch.abs(self.tensor))) + def _validate_edge_compatibility(self, other: GrassmannTensor) -> None: """ Validate that the edges of two ParityTensor instances are compatible for arithmetic operations. @@ -1857,6 +1841,15 @@ def reciprocal(self) -> NamedGrassmannTensor: self, _tensor=torch.where(self.tensor == 0, self.tensor, 1 / self.tensor) ) + def norm(self, p: typing.Any) -> float: + return float(torch.linalg.vector_norm(self.tensor.reshape(-1), ord=p)) + + def sqrt(self) -> NamedGrassmannTensor: + return dataclasses.replace(self, _tensor=torch.sqrt(torch.abs(self.tensor))) + + def rank(self) -> int: + return len(self.names) + def _validate_edge_compatibility(self, other: NamedGrassmannTensor) -> None: assert self._names == other.names, ( f"Names must match for arithmetic operations. Got {self._names} and {other.names}." diff --git a/tests/reciprocal_test.py b/tests/reciprocal_test.py index bda4b01..eb3e188 100644 --- a/tests/reciprocal_test.py +++ b/tests/reciprocal_test.py @@ -1,7 +1,45 @@ import torch import pytest -from grassmann_tensor import NamedGrassmannTensor +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor + + +@pytest.mark.parametrize( + "x", + [ + GrassmannTensor( + (True, True), ((2, 2), (4, 4)), torch.randn(4, 8, dtype=torch.float64) + ).update_mask(), + GrassmannTensor( + (False, False), ((2, 2), (4, 4)), torch.randn(4, 8, dtype=torch.float64) + ).update_mask(), + GrassmannTensor( + (True, True, True), + ((2, 2), (4, 4), (8, 8)), + torch.randn(4, 8, 16, dtype=torch.float64), + ).update_mask(), + GrassmannTensor( + (True, True, True, True), + ((2, 2), (4, 4), (8, 8), (16, 16)), + torch.randn(4, 8, 16, 32, dtype=torch.float64), + ).update_mask(), + ], +) +def test_reciprocal(x: NamedGrassmannTensor) -> None: + tensor = x.reciprocal() + assert tensor.arrow == x.arrow + assert tensor.edges == x.edges + assert tensor.tensor.shape == x.tensor.shape + assert tensor.tensor.dtype == x.tensor.dtype + assert tensor.tensor.device == x.tensor.device + + assert not torch.isinf(tensor.tensor).any() + + zero = x.tensor == 0 + assert torch.equal(tensor.tensor[zero], x.tensor[zero]) + + non_zero = ~zero + assert torch.allclose(tensor.tensor[non_zero], (1 / x.tensor[non_zero])) @pytest.mark.parametrize( @@ -27,7 +65,7 @@ ).update_mask(), ], ) -def test_reciprocal(x: NamedGrassmannTensor) -> None: +def test_named_reciprocal(x: NamedGrassmannTensor) -> None: tensor = x.reciprocal() assert tensor.names == x.names assert tensor.arrow == x.arrow diff --git a/tests/utility_test.py b/tests/utility_test.py new file mode 100644 index 0000000..cfa32ad --- /dev/null +++ b/tests/utility_test.py @@ -0,0 +1,117 @@ +import torch +import pytest +import itertools + +from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor + + +def generate_filled_data( + edges: tuple[tuple[int, int], ...], data: torch.Tensor | None = None +) -> torch.Tensor: + shape = tuple(even + odd for even, odd in edges) + + if data is None: + tensor = torch.zeros(shape) + filled = torch.arange(filled_count(edges)) + else: + assert data is not None + filled = data.reshape(-1) + tensor = torch.zeros(shape, dtype=data.dtype, device=data.device) + i = 0 + ranges = [range(s) for s in shape] + + for idx in itertools.product(*ranges): + total_parity = 0 + for k, (even, _) in enumerate(edges): + total_parity ^= 1 if idx[k] >= even else 0 + if total_parity == 0: + tensor[idx] = filled[i] + i += 1 + return tensor + + +def filled_count(edges: tuple[tuple[int, int], ...]) -> int: + total = 1 + diff = 1 + for even, odd in edges: + total *= even + odd + diff *= even - odd + return (total + diff) // 2 + + +@pytest.mark.parametrize( + "edges", + [ + ((1, 1),), + ((2, 2),), + ((2, 2), (2, 2)), + ((2, 2), (2, 2), (2, 2)), + ((2, 2), (2, 2), (2, 2), (2, 2)), + ], +) +def test_norm(edges: tuple[tuple[int, int], ...]) -> None: + arrow = tuple([False] * len(edges)) + filled = filled_count(edges) + half = filled // 2 + if filled % 2 == 0: + data = torch.arange(-half, half, dtype=torch.float64) + else: + data = torch.arange(-half, half + 1, dtype=torch.float64) + tensor_data = generate_filled_data(edges, data=data) + max_val = tensor_data.abs().max().item() + min_val = tensor_data.abs().min().item() + + tensor = GrassmannTensor(arrow, edges, tensor_data) + assert tensor.norm(p=torch.inf) == max_val + assert tensor.norm(p=-torch.inf) == min_val + assert tensor.norm(p=0) == filled - 1 + assert tensor.norm(p=2) == torch.linalg.vector_norm(data, ord=2) + + +@pytest.mark.parametrize( + "edges", + [ + ((1, 1),), + ((2, 2),), + ((2, 2), (2, 2)), + ((2, 2), (2, 2), (2, 2)), + ((2, 2), (2, 2), (2, 2), (2, 2)), + ], +) +def test_named_norm(edges: tuple[tuple[int, int], ...]) -> None: + names = tuple(chr(96 + i) for i in range(len(edges))) + arrow = tuple([False] * len(edges)) + filled = filled_count(edges) + half = filled // 2 + if filled % 2 == 0: + data = torch.arange(-half, half, dtype=torch.float64) + else: + data = torch.arange(-half, half + 1, dtype=torch.float64) + tensor_data = generate_filled_data(edges, data=data) + max_val = tensor_data.abs().max().item() + min_val = tensor_data.abs().min().item() + + tensor = NamedGrassmannTensor(names, arrow, edges, tensor_data) + assert tensor.norm(p=torch.inf) == max_val + assert tensor.norm(p=-torch.inf) == min_val + assert tensor.norm(p=0) == filled - 1 + assert tensor.norm(p=2) == torch.linalg.vector_norm(data, ord=2) + + +def test_sqrt() -> None: + tensor = GrassmannTensor((False, False), ((1, 1), (1, 1)), torch.Tensor([[-4, 9], [0, -1]])) + assert torch.allclose(tensor.sqrt().tensor, torch.Tensor(([[2, 3], [0, 1]]))) + + +def test_named_sqrt() -> None: + tensor = NamedGrassmannTensor( + ("a", "b"), (False, False), ((1, 1), (1, 1)), torch.Tensor([[-4, 9], [0, -1]]) + ) + assert torch.allclose(tensor.sqrt().tensor, torch.Tensor(([[2, 3], [0, 1]]))) + + +def test_rank() -> None: + tensor = NamedGrassmannTensor( + ("a", "b"), (False, False), ((1, 1), (1, 1)), torch.Tensor([[-4, 9], [0, -1]]) + ) + assert tensor.rank() == 2