Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/4133.feat.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Allow {func}`scanpy.pp.neighbors` to accept precomputed connectivities for downstream graph tools such as {func}`scanpy.tl.umap`. {smaller}`E Estaji`
116 changes: 90 additions & 26 deletions src/scanpy/neighbors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def neighbors( # noqa: PLR0913
n_pcs: int | None = None,
*,
distances: np.ndarray | SpBase | None = None,
connectivities: np.ndarray | SpBase | None = None,
use_rep: str | None = None,
knn: bool = True,
method: _Method = "umap",
Expand Down Expand Up @@ -126,6 +127,15 @@ def neighbors( # noqa: PLR0913

*ignored if ``transformer`` is an instance.*
{n_pcs}
distances
Precomputed distance matrix to use instead of computing distances from
`adata`. Connectivities are computed from this matrix using `method`.
connectivities
Precomputed connectivities matrix to use instead of deriving
connectivities from distances. This allows downstream tools such as
:func:`~scanpy.tl.umap` to consume externally computed neighborhood
graphs. If `distances` is also provided, it is stored alongside the
connectivities.
{use_rep}
knn
If `True`, use a hard threshold to restrict the number of neighbors to
Expand Down Expand Up @@ -153,8 +163,8 @@ def neighbors( # noqa: PLR0913
:class:`~pynndescent.pynndescent_.PyNNDescentTransformer`
metric
A known metric’s name or a callable that returns a distance.
If `distances` is given, this parameter is simply stored in `.uns` (see below),
otherwise defaults to `'euclidean'`.
If `distances` or `connectivities` is given, this parameter is simply
stored in `.uns` (see below), otherwise defaults to `'euclidean'`.

*ignored if ``transformer`` is an instance.*
metric_kwds
Expand All @@ -180,6 +190,7 @@ def neighbors( # noqa: PLR0913

`adata.obsp['distances' | f'{{key_added}}_distances']` : :class:`scipy.sparse.csr_matrix` (dtype `float`)
Distance matrix of the nearest neighbors search. Each row (cell) has `n_neighbors`-1 non-zero entries. These are the distances to their `n_neighbors`-1 nearest neighbors (excluding the cell itself).
Not set if only `connectivities` is provided.
`adata.obsp['connectivities' | f'{{key_added}}_connectivities']` : :class:`scipy.sparse._csr.csr_matrix` (dtype `float`)
Weighted adjacency matrix of the neighborhood graph of data
points. Weights should be interpreted as connectivities.
Expand Down Expand Up @@ -209,13 +220,14 @@ def neighbors( # noqa: PLR0913
dict(random_state=rng.arg) if isinstance(rng, _LegacyRng) else {}
)

if distances is None:
adata = adata.copy() if copy else adata
if adata.is_view: # we shouldn't need this here...
adata._init_as_actual(adata.copy())

if distances is None and connectivities is None:
if metric is None:
metric = "euclidean"
start = logg.info("computing neighbors")
adata = adata.copy() if copy else adata
if adata.is_view: # we shouldn't need this here...
adata._init_as_actual(adata.copy())
neighbors_ = Neighbors(adata)
neighbors_.compute_neighbors(
n_neighbors,
Expand All @@ -242,31 +254,33 @@ def neighbors( # noqa: PLR0913
meta_random_state.pop("random_state", None)
if ignored:
warn(
f"Parameter(s) ignored if `distances` is given: {ignored}",
f"Parameter(s) ignored if `distances` or `connectivities` is given: {ignored}",
UserWarning,
)
if callable(metric):
msg = "`metric` must be a string if `distances` is given."
msg = (
"`metric` must be a string if `distances` or `connectivities` is given."
)
raise TypeError(msg)
start = logg.info("computing connectivities")
# if a precomputed distance matrix is provided, skip the PCA and distance computation
if isinstance(distances, SpBase):
if TYPE_CHECKING:
from scipy.sparse._base import _spbase

assert isinstance(distances, _spbase)
distances = distances.tocsr(copy=True)
distances.setdiag(0)
distances.eliminate_zeros()
else:
distances = np.asarray(distances)
np.fill_diagonal(distances, 0)
start = logg.info(
"computing connectivities"
if connectivities is None
else "using precomputed connectivities"
)

neighbors_ = Neighbors(adata)
neighbors_.n_neighbors = n_neighbors
neighbors_.knn = True
neighbors_._distances = distances
neighbors_._connectivities = neighbors_._compute_connectivites(method)
if distances is not None:
neighbors_._distances = _prepare_precomputed_distances(
distances, n_obs=adata.n_obs
)
if connectivities is None:
neighbors_._connectivities = neighbors_._compute_connectivites(method)
else:
neighbors_._connectivities = _prepare_precomputed_connectivities(
connectivities, n_obs=adata.n_obs
)

key_added, neighbors_dict = _get_metadata(
key_added,
Expand All @@ -283,21 +297,71 @@ def neighbors( # noqa: PLR0913
neighbors_dict["rp_forest"] = neighbors_.rp_forest

adata.uns[key_added] = neighbors_dict
adata.obsp[neighbors_dict["distances_key"]] = neighbors_.distances
if neighbors_.distances is not None:
adata.obsp[neighbors_dict["distances_key"]] = neighbors_.distances
adata.obsp[neighbors_dict["connectivities_key"]] = neighbors_.connectivities

logg.info(
" finished",
time=start,
deep=(
f"added to `.uns[{key_added!r}]`\n"
f" `.obsp[{neighbors_dict['distances_key']!r}]`, distances for each pair of neighbors\n"
f" `.obsp[{neighbors_dict['connectivities_key']!r}]`, weighted adjacency matrix"
+ (
f" `.obsp[{neighbors_dict['distances_key']!r}]`, distances for each pair of neighbors\n"
if neighbors_.distances is not None
else ""
)
+ f" `.obsp[{neighbors_dict['connectivities_key']!r}]`, weighted adjacency matrix"
),
)
return adata if copy else None


def _validate_precomputed_matrix_shape(
matrix: np.ndarray | SpBase, *, name: str, n_obs: int
) -> None:
if matrix.shape != (n_obs, n_obs):
msg = f"`{name}` must have shape ({n_obs}, {n_obs}), found {matrix.shape}."
raise ValueError(msg)


def _prepare_precomputed_distances(
distances: np.ndarray | SpBase, *, n_obs: int
) -> np.ndarray | CSBase:
_validate_precomputed_matrix_shape(distances, name="distances", n_obs=n_obs)
if isinstance(distances, SpBase):
if TYPE_CHECKING:
from scipy.sparse._base import _spbase

assert isinstance(distances, _spbase)
distances = distances.tocsr(copy=True)
distances.setdiag(0)
distances.eliminate_zeros()
else:
distances = np.asarray(distances)
np.fill_diagonal(distances, 0)
return distances


def _prepare_precomputed_connectivities(
connectivities: np.ndarray | SpBase, *, n_obs: int
) -> CSBase:
_validate_precomputed_matrix_shape(
connectivities, name="connectivities", n_obs=n_obs
)
if isinstance(connectivities, SpBase):
if TYPE_CHECKING:
from scipy.sparse._base import _spbase

assert isinstance(connectivities, _spbase)
connectivities = connectivities.tocsr(copy=True)
else:
connectivities = sparse.csr_matrix(connectivities) # noqa: TID251
connectivities.setdiag(0)
connectivities.eliminate_zeros()
return connectivities


def _get_metadata(
key_added: str | None,
**params: Unpack[NeighborsParams],
Expand Down
82 changes: 82 additions & 0 deletions tests/test_neighbors.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,3 +294,85 @@ def test_neighbors_distance_equivalence() -> None:
assert p.pop("metric") == "euclidean"
assert p_d.pop("metric") is None
assert p == p_d


def test_neighbors_connectivities_support_umap() -> None:
adata = pbmc68k_reduced()
adata_c = adata.copy()
for key in list(adata_c.obsp):
del adata_c.obsp[key]
adata_c.uns.pop("neighbors", None)

sc.pp.neighbors(adata)
sc.pp.neighbors(adata_c, connectivities=adata.obsp["connectivities"])

np.testing.assert_allclose(
adata.obsp["connectivities"].toarray(),
adata_c.obsp["connectivities"].toarray(),
)
assert "distances" not in adata_c.obsp

sc.tl.umap(adata, random_state=0)
sc.tl.umap(adata_c, random_state=0)

np.testing.assert_allclose(adata.obsm["X_umap"], adata_c.obsm["X_umap"])


def test_neighbors_dense_precomputed_inputs_are_prepared() -> None:
adata = AnnData(np.array(X))
distances = np.array(distances_euclidean_all)
connectivities = np.array(connectivities_umap)
np.fill_diagonal(distances, 1)
np.fill_diagonal(connectivities, 1)

sc.pp.neighbors(adata, distances=distances, connectivities=connectivities)

np.testing.assert_allclose(np.diag(adata.obsp["distances"]), 0)
assert isinstance(adata.obsp["connectivities"], CSBase)
np.testing.assert_allclose(adata.obsp["connectivities"].diagonal(), 0)


@pytest.mark.parametrize(
("argument", "name"),
[
pytest.param("distances", "distances", id="distances"),
pytest.param("connectivities", "connectivities", id="connectivities"),
],
)
def test_neighbors_precomputed_shape_validation(argument: str, name: str) -> None:
adata = AnnData(np.array(X))

with pytest.raises(ValueError, match=rf"`{name}` must have shape"):
sc.pp.neighbors(
adata, **{argument: np.ones((adata.n_obs - 1, adata.n_obs - 1))}
)


def test_neighbors_precomputed_rejects_callable_metric() -> None:
adata = AnnData(np.array(X))

def metric(a, b):
return np.linalg.norm(a - b)

with pytest.raises(TypeError, match="`metric` must be a string"):
sc.pp.neighbors(
adata, connectivities=np.array(connectivities_umap), metric=metric
)


def test_neighbors_precomputed_warns_for_ignored_parameters() -> None:
adata = AnnData(np.array(X))

with pytest.warns(
UserWarning,
match=r"Parameter\(s\) ignored if `distances` or `connectivities` is given",
):
sc.pp.neighbors(
adata,
connectivities=np.array(connectivities_umap),
n_pcs=1,
use_rep="X",
knn=False,
metric_kwds={"p": 1},
rng=1,
)
Loading