diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 5abb675..117c21e 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -901,7 +901,7 @@ def contract( # Merge tensor `a` common edges arrow_merge_common_edges = (False, True) shape_merge_common_edges = (a.tensor.shape[0], math.prod(a.tensor.shape[1:])) - even, odd, reorder, sign = self._reorder_indices(a.edges[1:]) + even, odd, _, sign = self._reorder_indices(a.edges[1:]) edges_merge_common_edges = typing.cast( tuple[tuple[int, int], ...], (a.edges[0], (even, odd)), @@ -915,7 +915,6 @@ def contract( tensor_merge_common_edges = torch.where( merging_parity, -tensor_merge_common_edges, +tensor_merge_common_edges ) - tensor_merge_common_edges = tensor_merge_common_edges.index_select(1, reorder) a = dataclasses.replace( a, diff --git a/tests/svd_test.py b/tests/svd_test.py index 84a2b8e..01b57ec 100644 --- a/tests/svd_test.py +++ b/tests/svd_test.py @@ -241,7 +241,7 @@ def test_svd_int_cutoff_odd_block_empty_select_from_even_only(a: int, b: int, k: devices = [torch.device("cpu")] if torch.cuda.is_available(): - devices.append(torch.device("cuda")) + devices.append(torch.device("cuda:0")) @pytest.mark.parametrize(