Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 22 additions & 29 deletions grassmann_tensor/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -841,35 +841,8 @@ def contract(
order_b = left_leg_b + right_leg_b

# 1. Permutation
arrow_a = tuple(a.arrow[i] for i in order_a)
edges_a = tuple(a.edges[i] for i in order_a)
tensor_a = a.tensor.permute(order_a)
parity_a = tuple(a.parity[i] for i in order_a)
mask_a = a.mask.permute(order_a)

a = dataclasses.replace(
a,
_arrow=arrow_a,
_edges=edges_a,
_tensor=tensor_a,
_parity=parity_a,
_mask=mask_a,
)

arrow_b = tuple(b.arrow[i] for i in order_b)
edges_b = tuple(b.edges[i] for i in order_b)
tensor_b = b.tensor.permute(order_b)
parity_b = tuple(b.parity[i] for i in order_b)
mask_b = b.mask.permute(order_b)

b = dataclasses.replace(
b,
_arrow=arrow_b,
_edges=edges_b,
_tensor=tensor_b,
_parity=parity_b,
_mask=mask_b,
)
a = a.permute(order_a)
b = b.permute(order_b)

arrow = a.arrow[:-contract_length_a] + b.arrow[contract_length_b:]
edges = a.edges[:-contract_length_a] + b.edges[contract_length_b:]
Expand Down Expand Up @@ -1125,6 +1098,17 @@ def _tensor_mask(self) -> torch.Tensor:
torch.zeros_like(self._tensor, dtype=torch.bool),
)

def reciprocal(self) -> GrassmannTensor:
return dataclasses.replace(
self, _tensor=torch.where(self.tensor == 0, self.tensor, 1 / self.tensor)
)

def norm(self, p: typing.Any) -> float:
return float(torch.linalg.vector_norm(self.tensor.masked_select(~self.mask), ord=p))

def sqrt(self) -> GrassmannTensor:
return dataclasses.replace(self, _tensor=torch.sqrt(torch.abs(self.tensor)))

def _validate_edge_compatibility(self, other: GrassmannTensor) -> None:
"""
Validate that the edges of two ParityTensor instances are compatible for arithmetic operations.
Expand Down Expand Up @@ -1857,6 +1841,15 @@ def reciprocal(self) -> NamedGrassmannTensor:
self, _tensor=torch.where(self.tensor == 0, self.tensor, 1 / self.tensor)
)

def norm(self, p: typing.Any) -> float:
return float(torch.linalg.vector_norm(self.tensor.reshape(-1), ord=p))

def sqrt(self) -> NamedGrassmannTensor:
return dataclasses.replace(self, _tensor=torch.sqrt(torch.abs(self.tensor)))

def rank(self) -> int:
return len(self.names)

def _validate_edge_compatibility(self, other: NamedGrassmannTensor) -> None:
assert self._names == other.names, (
f"Names must match for arithmetic operations. Got {self._names} and {other.names}."
Expand Down
42 changes: 40 additions & 2 deletions tests/reciprocal_test.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,45 @@
import torch
import pytest

from grassmann_tensor import NamedGrassmannTensor
from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor


@pytest.mark.parametrize(
"x",
[
GrassmannTensor(
(True, True), ((2, 2), (4, 4)), torch.randn(4, 8, dtype=torch.float64)
).update_mask(),
GrassmannTensor(
(False, False), ((2, 2), (4, 4)), torch.randn(4, 8, dtype=torch.float64)
).update_mask(),
GrassmannTensor(
(True, True, True),
((2, 2), (4, 4), (8, 8)),
torch.randn(4, 8, 16, dtype=torch.float64),
).update_mask(),
GrassmannTensor(
(True, True, True, True),
((2, 2), (4, 4), (8, 8), (16, 16)),
torch.randn(4, 8, 16, 32, dtype=torch.float64),
).update_mask(),
],
)
def test_reciprocal(x: NamedGrassmannTensor) -> None:
tensor = x.reciprocal()
assert tensor.arrow == x.arrow
assert tensor.edges == x.edges
assert tensor.tensor.shape == x.tensor.shape
assert tensor.tensor.dtype == x.tensor.dtype
assert tensor.tensor.device == x.tensor.device

assert not torch.isinf(tensor.tensor).any()

zero = x.tensor == 0
assert torch.equal(tensor.tensor[zero], x.tensor[zero])

non_zero = ~zero
assert torch.allclose(tensor.tensor[non_zero], (1 / x.tensor[non_zero]))


@pytest.mark.parametrize(
Expand All @@ -27,7 +65,7 @@
).update_mask(),
],
)
def test_reciprocal(x: NamedGrassmannTensor) -> None:
def test_named_reciprocal(x: NamedGrassmannTensor) -> None:
tensor = x.reciprocal()
assert tensor.names == x.names
assert tensor.arrow == x.arrow
Expand Down
117 changes: 117 additions & 0 deletions tests/utility_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import torch
import pytest
import itertools

from grassmann_tensor import GrassmannTensor, NamedGrassmannTensor


def generate_filled_data(
edges: tuple[tuple[int, int], ...], data: torch.Tensor | None = None
) -> torch.Tensor:
shape = tuple(even + odd for even, odd in edges)

if data is None:
tensor = torch.zeros(shape)
filled = torch.arange(filled_count(edges))
else:
assert data is not None
filled = data.reshape(-1)
tensor = torch.zeros(shape, dtype=data.dtype, device=data.device)
i = 0
ranges = [range(s) for s in shape]

for idx in itertools.product(*ranges):
total_parity = 0
for k, (even, _) in enumerate(edges):
total_parity ^= 1 if idx[k] >= even else 0
if total_parity == 0:
tensor[idx] = filled[i]
i += 1
return tensor


def filled_count(edges: tuple[tuple[int, int], ...]) -> int:
total = 1
diff = 1
for even, odd in edges:
total *= even + odd
diff *= even - odd
return (total + diff) // 2


@pytest.mark.parametrize(
"edges",
[
((1, 1),),
((2, 2),),
((2, 2), (2, 2)),
((2, 2), (2, 2), (2, 2)),
((2, 2), (2, 2), (2, 2), (2, 2)),
],
)
def test_norm(edges: tuple[tuple[int, int], ...]) -> None:
arrow = tuple([False] * len(edges))
filled = filled_count(edges)
half = filled // 2
if filled % 2 == 0:
data = torch.arange(-half, half, dtype=torch.float64)
else:
data = torch.arange(-half, half + 1, dtype=torch.float64)
tensor_data = generate_filled_data(edges, data=data)
max_val = tensor_data.abs().max().item()
min_val = tensor_data.abs().min().item()

tensor = GrassmannTensor(arrow, edges, tensor_data)
assert tensor.norm(p=torch.inf) == max_val
assert tensor.norm(p=-torch.inf) == min_val
assert tensor.norm(p=0) == filled - 1
assert tensor.norm(p=2) == torch.linalg.vector_norm(data, ord=2)


@pytest.mark.parametrize(
"edges",
[
((1, 1),),
((2, 2),),
((2, 2), (2, 2)),
((2, 2), (2, 2), (2, 2)),
((2, 2), (2, 2), (2, 2), (2, 2)),
],
)
def test_named_norm(edges: tuple[tuple[int, int], ...]) -> None:
names = tuple(chr(96 + i) for i in range(len(edges)))
arrow = tuple([False] * len(edges))
filled = filled_count(edges)
half = filled // 2
if filled % 2 == 0:
data = torch.arange(-half, half, dtype=torch.float64)
else:
data = torch.arange(-half, half + 1, dtype=torch.float64)
tensor_data = generate_filled_data(edges, data=data)
max_val = tensor_data.abs().max().item()
min_val = tensor_data.abs().min().item()

tensor = NamedGrassmannTensor(names, arrow, edges, tensor_data)
assert tensor.norm(p=torch.inf) == max_val
assert tensor.norm(p=-torch.inf) == min_val
assert tensor.norm(p=0) == filled - 1
assert tensor.norm(p=2) == torch.linalg.vector_norm(data, ord=2)


def test_sqrt() -> None:
tensor = GrassmannTensor((False, False), ((1, 1), (1, 1)), torch.Tensor([[-4, 9], [0, -1]]))
assert torch.allclose(tensor.sqrt().tensor, torch.Tensor(([[2, 3], [0, 1]])))


def test_named_sqrt() -> None:
tensor = NamedGrassmannTensor(
("a", "b"), (False, False), ((1, 1), (1, 1)), torch.Tensor([[-4, 9], [0, -1]])
)
assert torch.allclose(tensor.sqrt().tensor, torch.Tensor(([[2, 3], [0, 1]])))


def test_rank() -> None:
tensor = NamedGrassmannTensor(
("a", "b"), (False, False), ((1, 1), (1, 1)), torch.Tensor([[-4, 9], [0, -1]])
)
assert tensor.rank() == 2