diff --git a/RELEASES.md b/RELEASES.md index e55ba6138..1d6d72238 100644 --- a/RELEASES.md +++ b/RELEASES.md @@ -6,6 +6,13 @@ This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation. #### New features +- Add `ot.utils.DataScaler` class for backend-aware joint normalization of input + distributions, with sklearn-compatible `fit`/`transform`/`fit_transform` API and + support for `'standard'`, `'minmax'`, and `'l2'` methods (PR #808) +- Add `ot.utils.apply_scaler` helper that dispatches preprocessing to a scaler object, + a callable, or a no-op (PR #808) +- Add optional `scaler` parameter to `sliced_wasserstein_distance` and + `max_sliced_wasserstein_distance` (PR #808) - Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788) - Add Warmstart feature to the EMD solver for existing potentials (PR #793) - Add Warmstart potentials feature to the EMD solver for lazy and sparse solver (PR #795) diff --git a/examples/sliced-wasserstein/plot_sliced_wasserstein_scaler.py b/examples/sliced-wasserstein/plot_sliced_wasserstein_scaler.py new file mode 100644 index 000000000..c5aebfdcc --- /dev/null +++ b/examples/sliced-wasserstein/plot_sliced_wasserstein_scaler.py @@ -0,0 +1,114 @@ +# -*- coding: utf-8 -*- +""" +============================================================ +Sliced Wasserstein Distance with input scaling (DataScaler) +============================================================ + +.. note:: + Example added in release: 0.9.7. + +This example illustrates why input scaling matters when computing the Sliced +Wasserstein Distance (SWD) between distributions whose features have very +different magnitudes. Without scaling, the SWD is dominated by high-magnitude +features and may miss meaningful differences in low-magnitude features. + +The :class:`ot.utils.DataScaler` class fits normalization statistics once on a +representative sample and applies the same fixed transformation on every call. +This is preferred over re-normalizing inside each SWD call because the +transformation stays consistent across mini-batches during optimization. + +""" + +# Author: Harguna Sood +# +# License: MIT License + +import matplotlib.pylab as pl +import numpy as np + +import ot + +############################################################################## +# Generate two 2D distributions with mismatched feature scales +# ------------------------------------------------------------ +# +# Feature 1 is on the scale of 1000 with random noise. +# Feature 2 is on the scale of 1 with a meaningful 5-sigma shift between +# source and target distributions. + +# %% parameters and data generation + +rng = np.random.RandomState(0) +n = 500 + +X_s = np.column_stack( + [ + rng.normal(1000, 100, n), # feature 1: large scale, no real signal + rng.normal(0, 1, n), # feature 2: small scale, no shift + ] +) +X_t = np.column_stack( + [ + rng.normal(1000, 100, n), # feature 1: same distribution as source + rng.normal(5, 1, n), # feature 2: shifted by 5 std + ] +) + +############################################################################## +# SWD without scaling +# ------------------- +# +# Because feature 1 has values ~1000x larger than feature 2, the random +# projections used in SWD are dominated by feature 1. The meaningful shift +# in feature 2 is buried. + +# %% SWD without scaling + +swd_raw = ot.sliced_wasserstein_distance(X_s, X_t, n_projections=200, seed=0) +print("SWD without scaling: {:.4f}".format(swd_raw)) + +############################################################################## +# SWD with DataScaler +# ------------------- +# +# Fit a standard scaler jointly on both distributions, then pass it to SWD. +# The same fixed statistics are reused on every call, giving a stable loss +# across mini-batches. + +# %% SWD with DataScaler + +scaler = ot.utils.DataScaler(norm="standard").fit([X_s, X_t]) +swd_scaled = ot.sliced_wasserstein_distance( + X_s, X_t, n_projections=200, seed=0, scaler=scaler +) +print("SWD with DataScaler: {:.4f}".format(swd_scaled)) + +############################################################################## +# Visualize raw vs. scaled distributions +# --------------------------------------- + +# %% plot distributions + +X_s_n = scaler.transform(X_s) +X_t_n = scaler.transform(X_t) + +pl.figure(1, figsize=(12, 5)) + +pl.subplot(1, 2, 1) +pl.scatter(X_s[:, 0], X_s[:, 1], alpha=0.5, label="$X_s$", s=10) +pl.scatter(X_t[:, 0], X_t[:, 1], alpha=0.5, label="$X_t$", s=10) +pl.title("Raw distributions\n(feature 2 signal hidden by feature 1 scale)") +pl.xlabel("Feature 1 (large scale)") +pl.ylabel("Feature 2 (small scale)") +pl.legend() + +pl.subplot(1, 2, 2) +pl.scatter(X_s_n[:, 0], X_s_n[:, 1], alpha=0.5, label="$X_s$ normalized", s=10) +pl.scatter(X_t_n[:, 0], X_t_n[:, 1], alpha=0.5, label="$X_t$ normalized", s=10) +pl.title("Normalized distributions\n(feature 2 shift clearly visible)") +pl.xlabel("Feature 1 (normalized)") +pl.ylabel("Feature 2 (normalized)") +pl.legend() + +pl.tight_layout() +pl.show() diff --git a/ot/sliced.py b/ot/sliced.py index 4a0c8417b..6e55ab9a9 100644 --- a/ot/sliced.py +++ b/ot/sliced.py @@ -12,7 +12,7 @@ import numpy as np from .backend import get_backend, NumpyBackend -from .utils import list_to_array, get_coordinate_circle +from .utils import list_to_array, get_coordinate_circle, apply_scaler from .lp import ( wasserstein_circle, semidiscrete_wasserstein2_unif_circle, @@ -76,6 +76,7 @@ def sliced_wasserstein_distance( projections=None, seed=None, log=False, + scaler=None, ): r""" Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance @@ -109,6 +110,20 @@ def sliced_wasserstein_distance( Seed used for random number generator log: bool, optional if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + scaler: None, object with .transform(), or callable, optional + Preprocessing applied to X_s and X_t before computing the distance. + Useful for normalizing inputs when features have very different scales. + + - ``None`` : no preprocessing (default) + - Object with ``.transform()`` method : e.g. an :class:`ot.utils.DataScaler` + fitted on a representative sample. This is the recommended way to get + stable, consistent normalization across multiple calls (e.g. when + using SWD as a loss in mini-batch training). + - Callable : any function, lambda, or PyTorch transform applied + directly as ``scaler(X_s)`` and ``scaler(X_t)``. + + See :class:`ot.utils.DataScaler` for a backend-aware scaler that supports + joint fitting on multiple distributions. Returns ------- @@ -136,6 +151,8 @@ def sliced_wasserstein_distance( nx = get_backend(X_s, X_t, a, b, projections) + X_s, X_t = apply_scaler(X_s, X_t, scaler) + n = X_s.shape[0] m = X_t.shape[0] @@ -181,6 +198,7 @@ def max_sliced_wasserstein_distance( projections=None, seed=None, log=False, + scaler=None, ): r""" Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance @@ -215,6 +233,20 @@ def max_sliced_wasserstein_distance( Seed used for random number generator log: bool, optional if True, sliced_wasserstein_distance returns the projections used and their associated EMD. + scaler : None, object with .transform(), or callable, optional + Preprocessing applied to X_s and X_t before computing the distance. + Useful for normalizing inputs when features have very different scales. + + - ``None`` : no preprocessing (default) + - Object with ``.transform()`` method : e.g. an :class:`ot.utils.DataScaler` + fitted on a representative sample. This is the recommended way to get + stable, consistent normalization across multiple calls (e.g. when + using SWD as a loss in mini-batch training). + - Callable : any function, lambda, or PyTorch transform applied + directly as ``scaler(X_s)`` and ``scaler(X_t)``. + + See :class:`ot.utils.DataScaler` for a backend-aware scaler that supports + joint fitting on multiple distributions. Returns ------- @@ -242,6 +274,8 @@ def max_sliced_wasserstein_distance( nx = get_backend(X_s, X_t, a, b, projections) + X_s, X_t = apply_scaler(X_s, X_t, scaler) + n = X_s.shape[0] m = X_t.shape[0] diff --git a/ot/utils.py b/ot/utils.py index 64bf1ace9..a98d95dcf 100644 --- a/ot/utils.py +++ b/ot/utils.py @@ -1493,6 +1493,222 @@ def check_number_threads(numThreads): return numThreads +class DataScaler: + r"""Backend-aware data scaler with sklearn-compatible API. + + Fit normalization statistics on a single array or on the concatenation + of multiple arrays (joint fitting), then apply the same fixed transform + to any array. Supports NumPy, PyTorch, JAX, and TensorFlow backends via + POT's backend abstraction. + + Parameters + ---------- + norm : str, optional + Normalization method. One of: + + - ``'standard'`` (default) : zero mean, unit variance per feature + - ``'minmax'`` : scale each feature to [0, 1] + - ``'l2'`` : unit L2-norm per sample (row-wise, stateless) + + Attributes + ---------- + norm : str + The normalization method. + mean_ : array-like + Per-feature means (only for ``norm='standard'``). + std_ : array-like + Per-feature standard deviations (only for ``norm='standard'``). + min_ : array-like + Per-feature minimums (only for ``norm='minmax'``). + max_ : array-like + Per-feature maximums (only for ``norm='minmax'``). + + Examples + -------- + >>> import numpy as np + >>> from ot.utils import DataScaler + >>> X_s = np.array([[1.0, 100.0], [2.0, 200.0]]) + >>> X_t = np.array([[3.0, 300.0], [4.0, 400.0]]) + >>> scaler = DataScaler(norm='standard').fit([X_s, X_t]) + >>> X_s_scaled = scaler.transform(X_s) + """ + + _VALID_NORMS = ("standard", "minmax", "l2") + + def __init__(self, norm="standard"): + if norm not in self._VALID_NORMS: + raise ValueError( + "Invalid norm '{}'. Expected one of: {}".format(norm, self._VALID_NORMS) + ) + self.norm = norm + self.mean_ = None + self.std_ = None + self.min_ = None + self.max_ = None + self._fitted = False + + def fit(self, X): + r"""Compute normalization statistics from one array or a list of arrays. + + When given a list, arrays are concatenated along axis 0 before + computing statistics (joint fitting). + + Parameters + ---------- + X : array-like or list of array-like + Data to fit on. If a list, arrays must have the same number of + features (columns). + + Returns + ------- + self : DataScaler + """ + if isinstance(X, (list, tuple)): + if len(X) == 0: + raise ValueError("Cannot fit on empty list.") + nx = get_backend(*X) + X_concat = nx.concatenate(list(X), axis=0) + else: + nx = get_backend(X) + X_concat = X + + self._nx = nx + + if self.norm == "l2": + self._fitted = True + return self + + if self.norm == "standard": + self.mean_ = nx.mean(X_concat, axis=0) + self.std_ = nx.std(X_concat, axis=0) + zero_var = self.std_ == 0 + if nx.any(zero_var): + warnings.warn( + "Zero variance detected in one or more feature(s). " + "Those columns will not be scaled.", + RuntimeWarning, + ) + self.std_ = nx.where( + zero_var, + nx.ones(self.std_.shape, type_as=self.std_), + self.std_, + ) + + elif self.norm == "minmax": + self.min_ = nx.min(X_concat, axis=0) + self.max_ = nx.max(X_concat, axis=0) + zero_range = self.max_ == self.min_ + if nx.any(zero_range): + warnings.warn( + "Zero range detected in one or more feature(s). " + "Those columns will not be scaled.", + RuntimeWarning, + ) + self.max_ = nx.where( + zero_range, + self.min_ + 1.0, + self.max_, + ) + + self._fitted = True + return self + + def transform(self, X): + r"""Apply the fitted transformation to X. + + Parameters + ---------- + X : array-like + Data to transform. + + Returns + ------- + X_scaled : array-like + Transformed data, same shape and backend as X. + """ + if self.norm != "l2" and not self._fitted: + raise RuntimeError( + "DataScaler must be fitted before calling transform() " + "for norm='{}'.".format(self.norm) + ) + + nx = get_backend(X) + + if self.norm == "standard": + return (X - self.mean_) / self.std_ + elif self.norm == "minmax": + return (X - self.min_) / (self.max_ - self.min_) + elif self.norm == "l2": + norms = nx.sqrt(nx.sum(X**2, axis=1, keepdims=True)) + zero_norm = norms == 0 + if nx.any(zero_norm): + warnings.warn( + "Zero-norm row(s) detected. These will be left unchanged.", + RuntimeWarning, + ) + norms = nx.where( + zero_norm, + nx.ones(norms.shape, type_as=norms), + norms, + ) + return X / norms + + def fit_transform(self, X): + r"""Fit then transform. + + Parameters + ---------- + X : array-like or list of array-like + + Returns + ------- + X_scaled : array-like or list of array-like + If X was a list, returns a list of transformed arrays. + """ + self.fit(X) + if isinstance(X, (list, tuple)): + return [self.transform(x) for x in X] + return self.transform(X) + + +def apply_scaler(X_s, X_t, scaler=None): + r"""Apply a scaler to two arrays. + + Dispatches based on the type of ``scaler``: + + - ``None`` : returns inputs unchanged. + - Object with a ``.transform()`` method : calls ``scaler.transform()`` on each. + - Callable : calls ``scaler()`` on each (covers functions, lambdas, + PyTorch transforms, neural network encoders, etc.). + + Parameters + ---------- + X_s : array-like + Source samples. + X_t : array-like + Target samples. + scaler : None, object with .transform(), or callable, optional + Preprocessing to apply. + + Returns + ------- + X_s_out : array-like + Possibly transformed source samples. + X_t_out : array-like + Possibly transformed target samples. + """ + if scaler is None: + return X_s, X_t + if hasattr(scaler, "transform") and callable(scaler.transform): + return scaler.transform(X_s), scaler.transform(X_t) + if callable(scaler): + return scaler(X_s), scaler(X_t) + raise ValueError( + "scaler must be None, an object with a .transform() method, " + "or a callable. Got type: {}".format(type(scaler).__name__) + ) + + def fun_to_numpy(fun, arr, nx, warn=True): """Convert a function to a numpy function. diff --git a/test/test_sliced.py b/test/test_sliced.py index 05de13755..1216a38b3 100644 --- a/test/test_sliced.py +++ b/test/test_sliced.py @@ -697,3 +697,68 @@ def test_linear_sliced_sphere_backend_type_devices(nx): nx.assert_same_dtype_device(xb, valb) np.testing.assert_almost_equal(sw_np, nx.to_numpy(valb)) + + +class TestSlicedWassersteinScaler: + """Integration tests for the scaler parameter in sliced_wasserstein_distance.""" + + def test_scaler_none_matches_no_scaler(self): + rng = np.random.RandomState(0) + X_s = rng.normal(0, 1, (50, 3)) + X_t = rng.normal(1, 1, (50, 3)) + result_default = ot.sliced_wasserstein_distance(X_s, X_t, seed=0) + result_none = ot.sliced_wasserstein_distance(X_s, X_t, seed=0, scaler=None) + np.testing.assert_allclose(result_default, result_none) + + def test_scaler_with_datascaler_runs(self): + rng = np.random.RandomState(0) + X_s = rng.normal(0, 1, (50, 3)) + X_t = rng.normal(1, 1, (50, 3)) + scaler = ot.utils.DataScaler(norm="standard").fit([X_s, X_t]) + result = ot.sliced_wasserstein_distance(X_s, X_t, seed=0, scaler=scaler) + assert np.isfinite(result) + assert result >= 0 + + def test_scaler_surfaces_small_scale_signal(self): + """Scaled SWD detects a shift in a small-magnitude feature that unscaled SWD misses.""" + rng = np.random.RandomState(0) + n = 500 + X_s = np.column_stack( + [ + rng.normal(1000, 100, n), + rng.normal(0, 1, n), + ] + ) + X_t = np.column_stack( + [ + rng.normal(1000, 100, n), + rng.normal(5, 1, n), + ] + ) + scaler = ot.utils.DataScaler(norm="standard").fit([X_s, X_t]) + swd_scaled = ot.sliced_wasserstein_distance( + X_s, X_t, seed=0, n_projections=200, scaler=scaler + ) + assert swd_scaled > 1.0 + + def test_scaler_with_lambda(self): + X_s = np.array([[1.0, 2.0], [3.0, 4.0]]) + X_t = np.array([[5.0, 6.0], [7.0, 8.0]]) + result = ot.sliced_wasserstein_distance( + X_s, X_t, seed=0, scaler=lambda x: x / 10 + ) + assert np.isfinite(result) + + def test_invalid_scaler_raises(self): + X_s = np.array([[1.0, 2.0]]) + X_t = np.array([[2.0, 3.0]]) + with pytest.raises(ValueError, match="scaler must be"): + ot.sliced_wasserstein_distance(X_s, X_t, scaler=42) + + def test_max_sliced_scaler_integration(self): + rng = np.random.RandomState(0) + X_s = rng.normal(0, 1, (50, 3)) + X_t = rng.normal(1, 1, (50, 3)) + scaler = ot.utils.DataScaler(norm="standard").fit([X_s, X_t]) + result = ot.max_sliced_wasserstein_distance(X_s, X_t, seed=0, scaler=scaler) + assert np.isfinite(result) diff --git a/test/test_utils.py b/test/test_utils.py index 8c5e65b93..0d2207097 100644 --- a/test/test_utils.py +++ b/test/test_utils.py @@ -772,3 +772,120 @@ def fun(x): # backend function with pytest.raises(ValueError): ot.utils.fun_to_numpy(fun, None, nx, warn=True) + + +class TestDataScaler: + """Tests for the DataScaler class.""" + + def test_invalid_norm_raises(self): + with pytest.raises(ValueError, match="Invalid norm"): + ot.utils.DataScaler(norm="bogus") + + def test_standard_joint_fit_mean_std(self): + X_s = np.array([[0.0, 0.0], [2.0, 200.0]]) + X_t = np.array([[4.0, 400.0], [6.0, 600.0]]) + scaler = ot.utils.DataScaler(norm="standard").fit([X_s, X_t]) + expected_mean = np.array([3.0, 300.0]) + np.testing.assert_allclose(scaler.mean_, expected_mean) + + def test_standard_transform_zero_mean_unit_var(self): + rng = np.random.RandomState(0) + X_s = rng.normal(5, 2, (100, 3)) + X_t = rng.normal(10, 3, (100, 3)) + scaler = ot.utils.DataScaler(norm="standard").fit([X_s, X_t]) + X_s_n = scaler.transform(X_s) + X_t_n = scaler.transform(X_t) + joint = np.concatenate([X_s_n, X_t_n], axis=0) + np.testing.assert_allclose(joint.mean(axis=0), 0, atol=1e-10) + np.testing.assert_allclose(joint.std(axis=0), 1, atol=1e-10) + + def test_minmax_transform_in_unit_range(self): + X_s = np.array([[0.0], [5.0]]) + X_t = np.array([[10.0]]) + scaler = ot.utils.DataScaler(norm="minmax").fit([X_s, X_t]) + joint = np.concatenate([scaler.transform(X_s), scaler.transform(X_t)], axis=0) + assert joint.min() == 0.0 + assert joint.max() == 1.0 + + def test_l2_unit_norm_rows(self): + X = np.array([[3.0, 4.0], [1.0, 0.0]]) + scaler = ot.utils.DataScaler(norm="l2") + X_n = scaler.fit_transform(X) + norms = np.linalg.norm(X_n, axis=1) + np.testing.assert_allclose(norms, 1.0) + + def test_l2_no_explicit_fit_via_fit_transform(self): + scaler = ot.utils.DataScaler(norm="l2") + X = np.array([[3.0, 4.0]]) + result = scaler.fit_transform(X) + np.testing.assert_allclose(np.linalg.norm(result, axis=1), 1.0) + + def test_zero_variance_column_warns(self): + X = np.array([[1.0, 5.0], [2.0, 5.0], [3.0, 5.0]]) + with pytest.warns(RuntimeWarning, match="Zero variance"): + ot.utils.DataScaler(norm="standard").fit(X) + + def test_zero_range_column_warns(self): + X = np.array([[1.0, 5.0], [2.0, 5.0], [3.0, 5.0]]) + with pytest.warns(RuntimeWarning, match="Zero range"): + ot.utils.DataScaler(norm="minmax").fit(X) + + def test_zero_norm_row_warns(self): + X = np.array([[0.0, 0.0], [3.0, 4.0]]) + scaler = ot.utils.DataScaler(norm="l2") + with pytest.warns(RuntimeWarning, match="Zero-norm"): + scaler.fit_transform(X) + + def test_transform_before_fit_raises(self): + scaler = ot.utils.DataScaler(norm="standard") + with pytest.raises(RuntimeError, match="must be fitted"): + scaler.transform(np.array([[1.0, 2.0]])) + + def test_fit_empty_list_raises(self): + with pytest.raises(ValueError, match="empty list"): + ot.utils.DataScaler(norm="standard").fit([]) + + def test_fit_single_array(self): + X = np.array([[1.0, 2.0], [3.0, 4.0]]) + scaler = ot.utils.DataScaler(norm="standard").fit(X) + assert scaler.mean_ is not None + + def test_fit_transform_list_returns_list(self): + X_s = np.array([[1.0, 2.0], [3.0, 4.0]]) + X_t = np.array([[5.0, 6.0], [7.0, 8.0]]) + results = ot.utils.DataScaler(norm="standard").fit_transform([X_s, X_t]) + assert isinstance(results, list) + assert len(results) == 2 + assert results[0].shape == X_s.shape + assert results[1].shape == X_t.shape + + +class TestApplyScaler: + """Tests for the apply_scaler helper.""" + + def test_none_returns_inputs_unchanged(self): + X_s = np.array([[1.0, 2.0]]) + X_t = np.array([[3.0, 4.0]]) + out_s, out_t = ot.utils.apply_scaler(X_s, X_t, None) + np.testing.assert_array_equal(out_s, X_s) + np.testing.assert_array_equal(out_t, X_t) + + def test_with_datascaler(self): + rng = np.random.RandomState(0) + X_s = rng.normal(0, 1, (50, 4)) + X_t = rng.normal(5, 1, (50, 4)) + scaler = ot.utils.DataScaler(norm="standard").fit([X_s, X_t]) + out_s, out_t = ot.utils.apply_scaler(X_s, X_t, scaler) + np.testing.assert_allclose(out_s, scaler.transform(X_s)) + np.testing.assert_allclose(out_t, scaler.transform(X_t)) + + def test_with_lambda(self): + X_s = np.array([[1.0]]) + X_t = np.array([[2.0]]) + out_s, out_t = ot.utils.apply_scaler(X_s, X_t, lambda x: x * 10) + np.testing.assert_array_equal(out_s, X_s * 10) + np.testing.assert_array_equal(out_t, X_t * 10) + + def test_invalid_scaler_raises(self): + with pytest.raises(ValueError, match="scaler must be"): + ot.utils.apply_scaler(np.zeros((2, 2)), np.zeros((2, 2)), scaler=42)