Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 51 additions & 15 deletions grassmann_tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def permute(self, before_by_after: tuple[int, ...]) -> GrassmannTensor:
_mask=mask,
)

def reverse(self, indices: tuple[int, ...]) -> GrassmannTensor:
def reverse(self, indices: tuple[int, ...], apply_parity: bool = True) -> GrassmannTensor:
"""
Reverse the specified indices of the Grassmann tensor.

Expand All @@ -184,7 +184,7 @@ def reverse(self, indices: tuple[int, ...]) -> GrassmannTensor:
(
self._unsqueeze(parity, index, self.tensor.dim())
for index, parity in enumerate(self.parity)
if index in indices and self.arrow[index]
if index in indices and self.arrow[index] is apply_parity
),
torch.zeros([], dtype=torch.bool, device=self.tensor.device),
)
Expand Down Expand Up @@ -665,7 +665,7 @@ def svd(

arrow_reverse = tuple(i for i, current in enumerate(tensor.arrow) if current)
if arrow_reverse:
tensor = tensor.reverse(arrow_reverse).reverse(arrow_reverse).reverse(arrow_reverse)
tensor = tensor.reverse(arrow_reverse, apply_parity=False)

left_dim = math.prod(tensor.tensor.shape[: len(left_legs)])
right_dim = math.prod(tensor.tensor.shape[len(left_legs) :])
Expand Down Expand Up @@ -947,7 +947,9 @@ def contract(
c = dataclasses.replace(c, _arrow=arrow, _edges=edges, _tensor=c.tensor.reshape(shape))
return c

def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannTensor:
def exponential(
self, pairs: tuple[tuple[int, ...], tuple[int, ...]], *, permute_back: bool = True
) -> GrassmannTensor:
tensor, left_legs, right_legs = self._group_edges(pairs)

assert tensor.arrow in ((False, True), (True, False)), (
Expand All @@ -956,7 +958,7 @@ def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> Grassma

tensor_reverse_flag = tensor.arrow != (False, True)
if tensor_reverse_flag:
tensor = tensor.reverse((0, 1))
tensor = tensor.reverse((0, 1), False)

left_dim, right_dim = tensor.tensor.shape

Expand Down Expand Up @@ -988,13 +990,16 @@ def exponential(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> Grassma
edges_after_permute = tuple(self.edges[i] for i in order)
tensor_exp = tensor_exp.reshape(edges_after_permute)

inv_order = self.get_inv_order(order)
if permute_back:
inv_order = self.get_inv_order(order)

tensor_exp = tensor_exp.permute(inv_order)
tensor_exp = tensor_exp.permute(inv_order)

return tensor_exp

def identity(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannTensor:
def identity(
self, pairs: tuple[tuple[int, ...], tuple[int, ...]], *, permute_back: bool = True
) -> GrassmannTensor:
tensor, left_legs, right_legs = self._group_edges(pairs)

assert tensor.arrow in ((False, True), (True, False)), (
Expand All @@ -1003,7 +1008,7 @@ def identity(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannT

tensor_reverse_flag = tensor.arrow != (False, True)
if tensor_reverse_flag:
tensor = tensor.reverse((0, 1))
tensor = tensor.reverse((0, 1), False)

left_dim, right_dim = tensor.tensor.shape

Expand All @@ -1029,9 +1034,10 @@ def identity(self, pairs: tuple[tuple[int, ...], tuple[int, ...]]) -> GrassmannT
edges_after_permute = tuple(self.edges[i] for i in order)
tensor_identity = tensor_identity.reshape(edges_after_permute)

inv_order = self.get_inv_order(order)
if permute_back:
inv_order = self.get_inv_order(order)

tensor_identity = tensor_identity.permute(inv_order)
tensor_identity = tensor_identity.permute(inv_order)

return tensor_identity

Expand Down Expand Up @@ -1453,12 +1459,12 @@ def permute(self, before_by_after: tuple[str, ...]) -> NamedGrassmannTensor:
_mask=None,
)

def reverse(self, reversed_names: set[str]) -> NamedGrassmannTensor:
def reverse(self, reversed_names: set[str], apply_parity: bool = True) -> NamedGrassmannTensor:
assert len(reversed_names) == len(set(reversed_names)), (
f"Indices must be unique, but got {reversed_names}"
)
indices = tuple(self.get_name_index(name) for name in reversed_names)
tensor = self.gt.reverse(indices)
tensor = self.gt.reverse(indices, apply_parity=apply_parity)
return dataclasses.replace(
self,
_arrow=tensor.arrow,
Expand Down Expand Up @@ -1713,7 +1719,7 @@ def _get_left_right_indices(
def exponential(self, pairs: set[tuple[str, str]]) -> NamedGrassmannTensor:
names, left_idx, right_idx = self._get_left_right_indices(pairs)

exp = self.gt.exponential((left_idx, right_idx))
exp = self.gt.exponential((left_idx, right_idx), permute_back=False)

return dataclasses.replace(
self,
Expand All @@ -1726,7 +1732,7 @@ def exponential(self, pairs: set[tuple[str, str]]) -> NamedGrassmannTensor:
def identity(self, pairs: set[tuple[str, str]]) -> NamedGrassmannTensor:
names, left_idx, right_idx = self._get_left_right_indices(pairs)

identity = self.gt.identity((left_idx, right_idx))
identity = self.gt.identity((left_idx, right_idx), permute_back=False)

return dataclasses.replace(
self,
Expand Down Expand Up @@ -1850,6 +1856,36 @@ def sqrt(self) -> NamedGrassmannTensor:
def rank(self) -> int:
return len(self.names)

def allclose(
self, other: NamedGrassmannTensor, rtol: float = 1e-05, atol: float = 1e-8
) -> bool:
if not isinstance(other, NamedGrassmannTensor):
raise TypeError(f"Expected NamedGrassmannTensor, got {type(other)}")
if set(self._names) != set(other._names):
raise TypeError(
f"Expected same name, but got self: {set(self._names)}, other: {set(other._names)}"
)

if tuple(self._names) != tuple(other._names):
other = other.permute(self._names)

if self._arrow != other._arrow:
raise TypeError(
f"Expected same arrow, but got self: {self._arrow}, other: {other._arrow}"
)

if self._edges != other._edges:
raise TypeError(
f"Expected same edges, but got self: {self._edges}, other: {other._edges}"
)

tensor_a = self.update_mask()._tensor
tensor_b = other.update_mask()._tensor

tensor_b = tensor_b.to(tensor_a.device)

return tensor_a.allclose(tensor_b, rtol=rtol, atol=atol)

def _validate_edge_compatibility(self, other: NamedGrassmannTensor) -> None:
assert self._names == other.names, (
f"Names must match for arithmetic operations. Got {self._names} and {other.names}."
Expand Down
9 changes: 9 additions & 0 deletions tests/exponential_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,15 @@ def test_named_tensor_exponential_assertation() -> None:
),
{("a", "c"), ("b", "d")},
),
(
NamedGrassmannTensor(
("a", "b", "c", "d"),
(False, True, False, True),
((4, 4), (4, 4), (8, 8), (8, 8)),
torch.randn(8, 8, 16, 16, dtype=torch.float64),
),
{("a", "b"), ("c", "d")},
),
],
)
def test_named_tensor_exponential_via_taylor_expansion(
Expand Down
15 changes: 12 additions & 3 deletions tests/identity_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,15 @@ def test_named_tensor_identity_assertation() -> None:
),
{("a", "c"), ("b", "d")},
),
(
NamedGrassmannTensor(
("a", "b", "c", "d"),
(False, True, False, True),
((4, 4), (4, 4), (8, 8), (8, 8)),
torch.randn(8, 8, 16, 16, dtype=torch.float64),
),
{("a", "b"), ("c", "d")},
),
],
)
def test_named_tensor_identity_via_self_multiplication(
Expand All @@ -147,6 +156,6 @@ def test_named_tensor_identity_via_self_multiplication(
tensor = tensor.update_mask()
identity = tensor.identity(pairs)
contract_pairs = typing.cast(set[tuple[str, str]], {item[::-1] for item in pairs})
assert torch.allclose((identity.contract(identity, contract_pairs)).tensor, identity.tensor)
assert torch.allclose((identity.contract(tensor, contract_pairs)).tensor, tensor.tensor)
assert torch.allclose((tensor.contract(identity, contract_pairs)).tensor, tensor.tensor)
assert identity.contract(identity, contract_pairs).allclose(identity)
assert (identity.contract(tensor, contract_pairs)).allclose(tensor)
assert (tensor.contract(identity, contract_pairs)).allclose(tensor)
51 changes: 51 additions & 0 deletions tests/utility_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,3 +115,54 @@ def test_rank() -> None:
("a", "b"), (False, False), ((1, 1), (1, 1)), torch.Tensor([[-4, 9], [0, -1]])
)
assert tensor.rank() == 2


def test_allclose() -> None:
data1 = torch.Tensor([[0, 1], [2, 3]]).to(dtype=torch.float64, device="cpu")
device = "cuda" if torch.cuda.is_available() else "cpu"
data2 = torch.Tensor([[0, -1], [-2, 3]]).to(dtype=torch.float64, device=device)
tensor1 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data1)
tensor2 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data2)
assert tensor1.allclose(tensor2)

tensor3 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), torch.randn(2, 2))
tensor4 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), torch.randn(2, 2))
assert not tensor3.allclose(tensor4)

data = generate_filled_data(((1, 1), (1, 1)))
tensor5 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data)
tensor6 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data)
tensor6 = tensor6.permute(("b", "a"))
assert tensor5.allclose(tensor6)


def test_allclose_other_type() -> None:
data = generate_filled_data(((1, 1), (1, 1)))
tensor1 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data)
tensor2 = data
with pytest.raises(TypeError, match="Expected NamedGrassmannTensor"):
tensor1.allclose(tensor2) # type: ignore[arg-type]


def test_allclose_mismatch_names() -> None:
data = generate_filled_data(((1, 1), (1, 1)))
tensor1 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data)
tensor2 = NamedGrassmannTensor(("c", "d"), (False, True), ((1, 1), (1, 1)), data)
with pytest.raises(TypeError, match="Expected same name"):
tensor1.allclose(tensor2)


def test_allclose_mismatch_arrows() -> None:
data = generate_filled_data(((1, 1), (1, 1)))
tensor1 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data)
tensor2 = NamedGrassmannTensor(("a", "b"), (False, False), ((1, 1), (1, 1)), data)
with pytest.raises(TypeError, match="Expected same arrow"):
tensor1.allclose(tensor2)


def test_allclose_mismatch_edges() -> None:
data = generate_filled_data(((1, 1), (1, 1)))
tensor1 = NamedGrassmannTensor(("a", "b"), (False, True), ((1, 1), (1, 1)), data)
tensor2 = NamedGrassmannTensor(("a", "b"), (False, True), ((2, 0), (0, 2)), data)
with pytest.raises(TypeError, match="Expected same edges"):
tensor1.allclose(tensor2)