diff --git a/grassmann_tensor/tensor.py b/grassmann_tensor/tensor.py index 117c21e..30269c1 100644 --- a/grassmann_tensor/tensor.py +++ b/grassmann_tensor/tensor.py @@ -844,8 +844,35 @@ def contract( order_b = left_leg_b + right_leg_b # 1. Permutation - a = a.permute(order_a) - b = b.permute(order_b) + arrow_a = tuple(a.arrow[i] for i in order_a) + edges_a = tuple(a.edges[i] for i in order_a) + tensor_a = a.tensor.permute(order_a) + parity_a = tuple(a.parity[i] for i in order_a) + mask_a = a.mask.permute(order_a) + + a = dataclasses.replace( + a, + _arrow=arrow_a, + _edges=edges_a, + _tensor=tensor_a, + _parity=parity_a, + _mask=mask_a, + ) + + arrow_b = tuple(b.arrow[i] for i in order_b) + edges_b = tuple(b.edges[i] for i in order_b) + tensor_b = b.tensor.permute(order_b) + parity_b = tuple(b.parity[i] for i in order_b) + mask_b = b.mask.permute(order_b) + + b = dataclasses.replace( + b, + _arrow=arrow_b, + _edges=edges_b, + _tensor=tensor_b, + _parity=parity_b, + _mask=mask_b, + ) arrow = a.arrow[:-contract_length_a] + b.arrow[contract_length_b:] edges = a.edges[:-contract_length_a] + b.edges[contract_length_b:]