From e3969a07c9902a895fe1d8ef83f8137c24843689 Mon Sep 17 00:00:00 2001 From: Gausshj Date: Wed, 10 Dec 2025 15:43:39 +0800 Subject: [PATCH] fix(contract): fix contract numerical issue - Remove redundant reorder code in contract method - Fix svd device assertation issue --- grassmann_tensor/tensor.py | 3 +-- tests/svd_test.py | 2 +- 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 5abb675..117c21e 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -901,7 +901,7 @@ def contract( # Merge tensor `a` common edges arrow_merge_common_edges = (False, True) shape_merge_common_edges = (a.tensor.shape[0], math.prod(a.tensor.shape[1:])) - even, odd, reorder, sign = self._reorder_indices(a.edges[1:]) + even, odd, _, sign = self._reorder_indices(a.edges[1:]) edges_merge_common_edges = typing.cast( tuple[tuple[int, int], ...], (a.edges[0], (even, odd)), @@ -915,7 +915,6 @@ def contract( tensor_merge_common_edges = torch.where( merging_parity, -tensor_merge_common_edges, +tensor_merge_common_edges ) - tensor_merge_common_edges = tensor_merge_common_edges.index_select(1, reorder) a = dataclasses.replace( a, diff --git a/tests/svd_test.py b/tests/svd_test.py index 84a2b8e..01b57ec 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -241,7 +241,7 @@ def test_svd_int_cutoff_odd_block_empty_select_from_even_only(a: int, b: int, k: devices = [torch.device("cpu")] if torch.cuda.is_available(): - devices.append(torch.device("cuda")) + devices.append(torch.device("cuda:0")) @pytest.mark.parametrize(