diff --git a/docs/release-notes/4133.feat.md b/docs/release-notes/4133.feat.md new file mode 100644 index 0000000000..bf6035e58b --- /dev/null +++ b/docs/release-notes/4133.feat.md @@ -0,0 +1 @@ +Allow {func}`scanpy.pp.neighbors` to accept precomputed connectivities for downstream graph tools such as {func}`scanpy.tl.umap`. {smaller}`E Estaji` diff --git a/src/scanpy/neighbors/__init__.py b/src/scanpy/neighbors/__init__.py index 7bc2470df3..6472a45b49 100644 --- a/src/scanpy/neighbors/__init__.py +++ b/src/scanpy/neighbors/__init__.py @@ -90,6 +90,7 @@ def neighbors( # noqa: PLR0913 n_pcs: int | None = None, *, distances: np.ndarray | SpBase | None = None, + connectivities: np.ndarray | SpBase | None = None, use_rep: str | None = None, knn: bool = True, method: _Method = "umap", @@ -126,6 +127,15 @@ def neighbors( # noqa: PLR0913 *ignored if ``transformer`` is an instance.* {n_pcs} + distances + Precomputed distance matrix to use instead of computing distances from + `adata`. Connectivities are computed from this matrix using `method`. + connectivities + Precomputed connectivities matrix to use instead of deriving + connectivities from distances. This allows downstream tools such as + :func:`~scanpy.tl.umap` to consume externally computed neighborhood + graphs. If `distances` is also provided, it is stored alongside the + connectivities. {use_rep} knn If `True`, use a hard threshold to restrict the number of neighbors to @@ -153,8 +163,8 @@ def neighbors( # noqa: PLR0913 :class:`~pynndescent.pynndescent_.PyNNDescentTransformer` metric A known metric’s name or a callable that returns a distance. - If `distances` is given, this parameter is simply stored in `.uns` (see below), - otherwise defaults to `'euclidean'`. + If `distances` or `connectivities` is given, this parameter is simply + stored in `.uns` (see below), otherwise defaults to `'euclidean'`. *ignored if ``transformer`` is an instance.* metric_kwds @@ -180,6 +190,7 @@ def neighbors( # noqa: PLR0913 `adata.obsp['distances' | f'{{key_added}}_distances']` : :class:`scipy.sparse.csr_matrix` (dtype `float`) Distance matrix of the nearest neighbors search. Each row (cell) has `n_neighbors`-1 non-zero entries. These are the distances to their `n_neighbors`-1 nearest neighbors (excluding the cell itself). + Not set if only `connectivities` is provided. `adata.obsp['connectivities' | f'{{key_added}}_connectivities']` : :class:`scipy.sparse._csr.csr_matrix` (dtype `float`) Weighted adjacency matrix of the neighborhood graph of data points. Weights should be interpreted as connectivities. @@ -209,13 +220,14 @@ def neighbors( # noqa: PLR0913 dict(random_state=rng.arg) if isinstance(rng, _LegacyRng) else {} ) - if distances is None: + adata = adata.copy() if copy else adata + if adata.is_view: # we shouldn't need this here... + adata._init_as_actual(adata.copy()) + + if distances is None and connectivities is None: if metric is None: metric = "euclidean" start = logg.info("computing neighbors") - adata = adata.copy() if copy else adata - if adata.is_view: # we shouldn't need this here... - adata._init_as_actual(adata.copy()) neighbors_ = Neighbors(adata) neighbors_.compute_neighbors( n_neighbors, @@ -242,31 +254,33 @@ def neighbors( # noqa: PLR0913 meta_random_state.pop("random_state", None) if ignored: warn( - f"Parameter(s) ignored if `distances` is given: {ignored}", + f"Parameter(s) ignored if `distances` or `connectivities` is given: {ignored}", UserWarning, ) if callable(metric): - msg = "`metric` must be a string if `distances` is given." + msg = ( + "`metric` must be a string if `distances` or `connectivities` is given." + ) raise TypeError(msg) - start = logg.info("computing connectivities") - # if a precomputed distance matrix is provided, skip the PCA and distance computation - if isinstance(distances, SpBase): - if TYPE_CHECKING: - from scipy.sparse._base import _spbase - - assert isinstance(distances, _spbase) - distances = distances.tocsr(copy=True) - distances.setdiag(0) - distances.eliminate_zeros() - else: - distances = np.asarray(distances) - np.fill_diagonal(distances, 0) + start = logg.info( + "computing connectivities" + if connectivities is None + else "using precomputed connectivities" + ) neighbors_ = Neighbors(adata) neighbors_.n_neighbors = n_neighbors neighbors_.knn = True - neighbors_._distances = distances - neighbors_._connectivities = neighbors_._compute_connectivites(method) + if distances is not None: + neighbors_._distances = _prepare_precomputed_distances( + distances, n_obs=adata.n_obs + ) + if connectivities is None: + neighbors_._connectivities = neighbors_._compute_connectivites(method) + else: + neighbors_._connectivities = _prepare_precomputed_connectivities( + connectivities, n_obs=adata.n_obs + ) key_added, neighbors_dict = _get_metadata( key_added, @@ -283,7 +297,8 @@ def neighbors( # noqa: PLR0913 neighbors_dict["rp_forest"] = neighbors_.rp_forest adata.uns[key_added] = neighbors_dict - adata.obsp[neighbors_dict["distances_key"]] = neighbors_.distances + if neighbors_.distances is not None: + adata.obsp[neighbors_dict["distances_key"]] = neighbors_.distances adata.obsp[neighbors_dict["connectivities_key"]] = neighbors_.connectivities logg.info( @@ -291,13 +306,62 @@ def neighbors( # noqa: PLR0913 time=start, deep=( f"added to `.uns[{key_added!r}]`\n" - f" `.obsp[{neighbors_dict['distances_key']!r}]`, distances for each pair of neighbors\n" - f" `.obsp[{neighbors_dict['connectivities_key']!r}]`, weighted adjacency matrix" + + ( + f" `.obsp[{neighbors_dict['distances_key']!r}]`, distances for each pair of neighbors\n" + if neighbors_.distances is not None + else "" + ) + + f" `.obsp[{neighbors_dict['connectivities_key']!r}]`, weighted adjacency matrix" ), ) return adata if copy else None +def _validate_precomputed_matrix_shape( + matrix: np.ndarray | SpBase, *, name: str, n_obs: int +) -> None: + if matrix.shape != (n_obs, n_obs): + msg = f"`{name}` must have shape ({n_obs}, {n_obs}), found {matrix.shape}." + raise ValueError(msg) + + +def _prepare_precomputed_distances( + distances: np.ndarray | SpBase, *, n_obs: int +) -> np.ndarray | CSBase: + _validate_precomputed_matrix_shape(distances, name="distances", n_obs=n_obs) + if isinstance(distances, SpBase): + if TYPE_CHECKING: + from scipy.sparse._base import _spbase + + assert isinstance(distances, _spbase) + distances = distances.tocsr(copy=True) + distances.setdiag(0) + distances.eliminate_zeros() + else: + distances = np.asarray(distances) + np.fill_diagonal(distances, 0) + return distances + + +def _prepare_precomputed_connectivities( + connectivities: np.ndarray | SpBase, *, n_obs: int +) -> CSBase: + _validate_precomputed_matrix_shape( + connectivities, name="connectivities", n_obs=n_obs + ) + if isinstance(connectivities, SpBase): + if TYPE_CHECKING: + from scipy.sparse._base import _spbase + + assert isinstance(connectivities, _spbase) + connectivities = connectivities.tocsr(copy=True) + else: + connectivities = sparse.csr_matrix(connectivities) # noqa: TID251 + connectivities.setdiag(0) + connectivities.eliminate_zeros() + return connectivities + + def _get_metadata( key_added: str | None, **params: Unpack[NeighborsParams], diff --git a/tests/test_neighbors.py b/tests/test_neighbors.py index 158adfb65b..a9c6aec83a 100644 --- a/tests/test_neighbors.py +++ b/tests/test_neighbors.py @@ -294,3 +294,85 @@ def test_neighbors_distance_equivalence() -> None: assert p.pop("metric") == "euclidean" assert p_d.pop("metric") is None assert p == p_d + + +def test_neighbors_connectivities_support_umap() -> None: + adata = pbmc68k_reduced() + adata_c = adata.copy() + for key in list(adata_c.obsp): + del adata_c.obsp[key] + adata_c.uns.pop("neighbors", None) + + sc.pp.neighbors(adata) + sc.pp.neighbors(adata_c, connectivities=adata.obsp["connectivities"]) + + np.testing.assert_allclose( + adata.obsp["connectivities"].toarray(), + adata_c.obsp["connectivities"].toarray(), + ) + assert "distances" not in adata_c.obsp + + sc.tl.umap(adata, random_state=0) + sc.tl.umap(adata_c, random_state=0) + + np.testing.assert_allclose(adata.obsm["X_umap"], adata_c.obsm["X_umap"]) + + +def test_neighbors_dense_precomputed_inputs_are_prepared() -> None: + adata = AnnData(np.array(X)) + distances = np.array(distances_euclidean_all) + connectivities = np.array(connectivities_umap) + np.fill_diagonal(distances, 1) + np.fill_diagonal(connectivities, 1) + + sc.pp.neighbors(adata, distances=distances, connectivities=connectivities) + + np.testing.assert_allclose(np.diag(adata.obsp["distances"]), 0) + assert isinstance(adata.obsp["connectivities"], CSBase) + np.testing.assert_allclose(adata.obsp["connectivities"].diagonal(), 0) + + +@pytest.mark.parametrize( + ("argument", "name"), + [ + pytest.param("distances", "distances", id="distances"), + pytest.param("connectivities", "connectivities", id="connectivities"), + ], +) +def test_neighbors_precomputed_shape_validation(argument: str, name: str) -> None: + adata = AnnData(np.array(X)) + + with pytest.raises(ValueError, match=rf"`{name}` must have shape"): + sc.pp.neighbors( + adata, **{argument: np.ones((adata.n_obs - 1, adata.n_obs - 1))} + ) + + +def test_neighbors_precomputed_rejects_callable_metric() -> None: + adata = AnnData(np.array(X)) + + def metric(a, b): + return np.linalg.norm(a - b) + + with pytest.raises(TypeError, match="`metric` must be a string"): + sc.pp.neighbors( + adata, connectivities=np.array(connectivities_umap), metric=metric + ) + + +def test_neighbors_precomputed_warns_for_ignored_parameters() -> None: + adata = AnnData(np.array(X)) + + with pytest.warns( + UserWarning, + match=r"Parameter\(s\) ignored if `distances` or `connectivities` is given", + ): + sc.pp.neighbors( + adata, + connectivities=np.array(connectivities_umap), + n_pcs=1, + use_rep="X", + knn=False, + metric_kwds={"p": 1}, + rng=1, + )