diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 9fdb6bb..03d0572 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -164,7 +164,7 @@ def permute(self, before_by_after: tuple[int, ...]) -> GrassmannTensor: _mask=mask, ) - def reverse(self, indices: tuple[int, ...]) -> GrassmannTensor: + def reverse(self, indices: tuple[int, ...], apply_parity: bool = True) -> GrassmannTensor: """ Reverse the specified indices of the Grassmann tensor. @@ -184,7 +184,7 @@ def reverse(self, indices: tuple[int, ...]) -> GrassmannTensor: ( self._unsqueeze(parity, index, self.tensor.dim()) for index, parity in enumerate(self.parity) - if index in indices and self.arrow[index] + if index in indices and self.arrow[index] is apply_parity ), torch.zeros([], dtype=torch.bool, device=self.tensor.device), ) @@ -665,7 +665,7 @@ def svd( arrow_reverse = tuple(i for i, current in enumerate(tensor.arrow) if current) if arrow_reverse: - tensor = tensor.reverse(arrow_reverse).reverse(arrow_reverse).reverse(arrow_reverse) + tensor = tensor.reverse(arrow_reverse, apply_parity=False) left_dim = math.prod(tensor.tensor.shape[: len(left_legs)]) right_dim = math.prod(tensor.tensor.shape[len(left_legs) :]) @@ -947,7 +947,9 @@ def contract( c = dataclasses.replace(c, _arrow=arrow, _edges=edges, _tensor=c.tensor.reshape(shape)) return c - def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannTensor: + def exponential( + self, pairs: tuple[tuple[int, ...], tuple[int, ...]], *, permute_back: bool = True + ) -> GrassmannTensor: tensor, left_legs, right_legs = self._group_edges(pairs) assert tensor.arrow in ((False, True), (True, False)), ( @@ -956,7 +958,7 @@ def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> Grassma tensor_reverse_flag = tensor.arrow != (False, True) if tensor_reverse_flag: - tensor = tensor.reverse((0, 1)) + tensor = tensor.reverse((0, 1), False) left_dim, right_dim = tensor.tensor.shape @@ -988,13 +990,16 @@ def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> Grassma edges_after_permute = tuple(self.edges[i] for i in order) tensor_exp = tensor_exp.reshape(edges_after_permute) - inv_order = self.get_inv_order(order) + if permute_back: + inv_order = self.get_inv_order(order) - tensor_exp = tensor_exp.permute(inv_order) + tensor_exp = tensor_exp.permute(inv_order) return tensor_exp - def identity(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannTensor: + def identity( + self, pairs: tuple[tuple[int, ...], tuple[int, ...]], *, permute_back: bool = True + ) -> GrassmannTensor: tensor, left_legs, right_legs = self._group_edges(pairs) assert tensor.arrow in ((False, True), (True, False)), ( @@ -1003,7 +1008,7 @@ def identity(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannT tensor_reverse_flag = tensor.arrow != (False, True) if tensor_reverse_flag: - tensor = tensor.reverse((0, 1)) + tensor = tensor.reverse((0, 1), False) left_dim, right_dim = tensor.tensor.shape @@ -1029,9 +1034,10 @@ def identity(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannT edges_after_permute = tuple(self.edges[i] for i in order) tensor_identity = tensor_identity.reshape(edges_after_permute) - inv_order = self.get_inv_order(order) + if permute_back: + inv_order = self.get_inv_order(order) - tensor_identity = tensor_identity.permute(inv_order) + tensor_identity = tensor_identity.permute(inv_order) return tensor_identity @@ -1453,12 +1459,12 @@ def permute(self, before_by_after: tuple[str, ...]) -> NamedGrassmannTensor: _mask=None, ) - def reverse(self, reversed_names: set[str]) -> NamedGrassmannTensor: + def reverse(self, reversed_names: set[str], apply_parity: bool = True) -> NamedGrassmannTensor: assert len(reversed_names) == len(set(reversed_names)), ( f"Indices must be unique, but got {reversed_names}" ) indices = tuple(self.get_name_index(name) for name in reversed_names) - tensor = self.gt.reverse(indices) + tensor = self.gt.reverse(indices, apply_parity=apply_parity) return dataclasses.replace( self, _arrow=tensor.arrow, @@ -1713,7 +1719,7 @@ def _get_left_right_indices( def exponential(self, pairs: set[tuple[str, str]]) -> NamedGrassmannTensor: names, left_idx, right_idx = self._get_left_right_indices(pairs) - exp = self.gt.exponential((left_idx, right_idx)) + exp = self.gt.exponential((left_idx, right_idx), permute_back=False) return dataclasses.replace( self, @@ -1726,7 +1732,7 @@ def exponential(self, pairs: set[tuple[str, str]]) -> NamedGrassmannTensor: def identity(self, pairs: set[tuple[str, str]]) -> NamedGrassmannTensor: names, left_idx, right_idx = self._get_left_right_indices(pairs) - identity = self.gt.identity((left_idx, right_idx)) + identity = self.gt.identity((left_idx, right_idx), permute_back=False) return dataclasses.replace( self, @@ -1850,6 +1856,36 @@ def sqrt(self) -> NamedGrassmannTensor: def rank(self) -> int: return len(self.names) + def allclose( + self, other: NamedGrassmannTensor, rtol: float = 1e-05, atol: float = 1e-8 + ) -> bool: + if not isinstance(other, NamedGrassmannTensor): + raise TypeError(f"Expected NamedGrassmannTensor, got {type(other)}") + if set(self._names) != set(other._names): + raise TypeError( + f"Expected same name, but got self: {set(self._names)}, other: {set(other._names)}" + ) + + if tuple(self._names) != tuple(other._names): + other = other.permute(self._names) + + if self._arrow != other._arrow: + raise TypeError( + f"Expected same arrow, but got self: {self._arrow}, other: {other._arrow}" + ) + + if self._edges != other._edges: + raise TypeError( + f"Expected same edges, but got self: {self._edges}, other: {other._edges}" + ) + + tensor_a = self.update_mask()._tensor + tensor_b = other.update_mask()._tensor + + tensor_b = tensor_b.to(tensor_a.device) + + return tensor_a.allclose(tensor_b, rtol=rtol, atol=atol) + def _validate_edge_compatibility(self, other: NamedGrassmannTensor) -> None: assert self._names == other.names, ( f"Names must match for arithmetic operations. Got {self._names} and {other.names}." diff --git a/tests/exponential_test.py b/tests/exponential_test.py index 22ad95e..9915343 100644 --- a/tests/exponential_test.py +++ b/tests/exponential_test.py @@ -174,6 +174,15 @@ def test_named_tensor_exponential_assertation() -> None: ), {("a", "c"), ("b", "d")}, ), + ( + NamedGrassmannTensor( + ("a", "b", "c", "d"), + (False, True, False, True), + ((4, 4), (4, 4), (8, 8), (8, 8)), + torch.randn(8, 8, 16, 16, dtype=torch.float64), + ), + {("a", "b"), ("c", "d")}, + ), ], ) def test_named_tensor_exponential_via_taylor_expansion( diff --git a/tests/identity_test.py b/tests/identity_test.py index f3296e3..1c6d55b 100644 --- a/tests/identity_test.py +++ b/tests/identity_test.py @@ -138,6 +138,15 @@ def test_named_tensor_identity_assertation() -> None: ), {("a", "c"), ("b", "d")}, ), + ( + NamedGrassmannTensor( + ("a", "b", "c", "d"), + (False, True, False, True), + ((4, 4), (4, 4), (8, 8), (8, 8)), + torch.randn(8, 8, 16, 16, dtype=torch.float64), + ), + {("a", "b"), ("c", "d")}, + ), ], ) def test_named_tensor_identity_via_self_multiplication( @@ -147,6 +156,6 @@ def test_named_tensor_identity_via_self_multiplication( tensor = tensor.update_mask() identity = tensor.identity(pairs) contract_pairs = typing.cast(set[tuple[str, str]], {item[::-1] for item in pairs}) - assert torch.allclose((identity.contract(identity, contract_pairs)).tensor, identity.tensor) - assert torch.allclose((identity.contract(tensor, contract_pairs)).tensor, tensor.tensor) - assert torch.allclose((tensor.contract(identity, contract_pairs)).tensor, tensor.tensor) + assert identity.contract(identity, contract_pairs).allclose(identity) + assert (identity.contract(tensor, contract_pairs)).allclose(tensor) + assert (tensor.contract(identity, contract_pairs)).allclose(tensor) diff --git a/tests/utility_test.py b/tests/utility_test.py index cfa32ad..d3b034d 100644 --- a/tests/utility_test.py +++ b/tests/utility_test.py @@ -115,3 +115,54 @@ def test_rank() -> None: ("a", "b"), (False, False), ((1, 1), (1, 1)), torch.Tensor([[-4, 9], [0, -1]]) ) assert tensor.rank() == 2 + + +def test_allclose() -> None: + data1 = torch.Tensor([[0, 1], [2, 3]]).to(dtype=torch.float64, device="cpu") + device = "cuda" if torch.cuda.is_available() else "cpu" + data2 = torch.Tensor([[0, -1], [-2, 3]]).to(dtype=torch.float64, device=device) + tensor1 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data1) + tensor2 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data2) + assert tensor1.allclose(tensor2) + + tensor3 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), torch.randn(2, 2)) + tensor4 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), torch.randn(2, 2)) + assert not tensor3.allclose(tensor4) + + data = generate_filled_data(((1, 1), (1, 1))) + tensor5 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data) + tensor6 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data) + tensor6 = tensor6.permute(("b", "a")) + assert tensor5.allclose(tensor6) + + +def test_allclose_other_type() -> None: + data = generate_filled_data(((1, 1), (1, 1))) + tensor1 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data) + tensor2 = data + with pytest.raises(TypeError, match="Expected NamedGrassmannTensor"): + tensor1.allclose(tensor2) # type: ignore[arg-type] + + +def test_allclose_mismatch_names() -> None: + data = generate_filled_data(((1, 1), (1, 1))) + tensor1 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data) + tensor2 = NamedGrassmannTensor(("c", "d"), (False, True), ((1, 1), (1, 1)), data) + with pytest.raises(TypeError, match="Expected same name"): + tensor1.allclose(tensor2) + + +def test_allclose_mismatch_arrows() -> None: + data = generate_filled_data(((1, 1), (1, 1))) + tensor1 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data) + tensor2 = NamedGrassmannTensor(("a", "b"), (False, False), ((1, 1), (1, 1)), data) + with pytest.raises(TypeError, match="Expected same arrow"): + tensor1.allclose(tensor2) + + +def test_allclose_mismatch_edges() -> None: + data = generate_filled_data(((1, 1), (1, 1))) + tensor1 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data) + tensor2 = NamedGrassmannTensor(("a", "b"), (False, True), ((2, 0), (0, 2)), data) + with pytest.raises(TypeError, match="Expected same edges"): + tensor1.allclose(tensor2)