Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions RELEASES.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,13 @@
This new release adds support for sparse cost matrices and a new lazy EMD solver that computes distances on-the-fly from coordinates, reducing memory usage from O(n×m) to O(n+m). Both implementations are backend-agnostic and preserve gradient computation for automatic differentiation.

#### New features
- Add `ot.utils.DataScaler` class for backend-aware joint normalization of input
distributions, with sklearn-compatible `fit`/`transform`/`fit_transform` API and
support for `'standard'`, `'minmax'`, and `'l2'` methods (PR #808)
- Add `ot.utils.apply_scaler` helper that dispatches preprocessing to a scaler object,
a callable, or a no-op (PR #808)
- Add optional `scaler` parameter to `sliced_wasserstein_distance` and
`max_sliced_wasserstein_distance` (PR #808)
- Add lazy EMD solver with on-the-fly distance computation from coordinates (PR #788)
- Add Warmstart feature to the EMD solver for existing potentials (PR #793)
- Add Warmstart potentials feature to the EMD solver for lazy and sparse solver (PR #795)
Expand Down
114 changes: 114 additions & 0 deletions examples/sliced-wasserstein/plot_sliced_wasserstein_scaler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# -*- coding: utf-8 -*-
"""
============================================================
Sliced Wasserstein Distance with input scaling (DataScaler)
============================================================

.. note::
Example added in release: 0.9.7.

This example illustrates why input scaling matters when computing the Sliced
Wasserstein Distance (SWD) between distributions whose features have very
different magnitudes. Without scaling, the SWD is dominated by high-magnitude
features and may miss meaningful differences in low-magnitude features.

The :class:`ot.utils.DataScaler` class fits normalization statistics once on a
representative sample and applies the same fixed transformation on every call.
This is preferred over re-normalizing inside each SWD call because the
transformation stays consistent across mini-batches during optimization.

"""

# Author: Harguna Sood <harguna.sood@gmail.com>
#
# License: MIT License

import matplotlib.pylab as pl
import numpy as np

import ot

##############################################################################
# Generate two 2D distributions with mismatched feature scales
# ------------------------------------------------------------
#
# Feature 1 is on the scale of 1000 with random noise.
# Feature 2 is on the scale of 1 with a meaningful 5-sigma shift between
# source and target distributions.

# %% parameters and data generation

rng = np.random.RandomState(0)
n = 500

X_s = np.column_stack(
[
rng.normal(1000, 100, n), # feature 1: large scale, no real signal
rng.normal(0, 1, n), # feature 2: small scale, no shift
]
)
X_t = np.column_stack(
[
rng.normal(1000, 100, n), # feature 1: same distribution as source
rng.normal(5, 1, n), # feature 2: shifted by 5 std
]
)

##############################################################################
# SWD without scaling
# -------------------
#
# Because feature 1 has values ~1000x larger than feature 2, the random
# projections used in SWD are dominated by feature 1. The meaningful shift
# in feature 2 is buried.

# %% SWD without scaling

swd_raw = ot.sliced_wasserstein_distance(X_s, X_t, n_projections=200, seed=0)
print("SWD without scaling: {:.4f}".format(swd_raw))

##############################################################################
# SWD with DataScaler
# -------------------
#
# Fit a standard scaler jointly on both distributions, then pass it to SWD.
# The same fixed statistics are reused on every call, giving a stable loss
# across mini-batches.

# %% SWD with DataScaler

scaler = ot.utils.DataScaler(norm="standard").fit([X_s, X_t])
swd_scaled = ot.sliced_wasserstein_distance(
X_s, X_t, n_projections=200, seed=0, scaler=scaler
)
print("SWD with DataScaler: {:.4f}".format(swd_scaled))

##############################################################################
# Visualize raw vs. scaled distributions
# ---------------------------------------

# %% plot distributions

X_s_n = scaler.transform(X_s)
X_t_n = scaler.transform(X_t)

pl.figure(1, figsize=(12, 5))

pl.subplot(1, 2, 1)
pl.scatter(X_s[:, 0], X_s[:, 1], alpha=0.5, label="$X_s$", s=10)
pl.scatter(X_t[:, 0], X_t[:, 1], alpha=0.5, label="$X_t$", s=10)
pl.title("Raw distributions\n(feature 2 signal hidden by feature 1 scale)")
pl.xlabel("Feature 1 (large scale)")
pl.ylabel("Feature 2 (small scale)")
pl.legend()

pl.subplot(1, 2, 2)
pl.scatter(X_s_n[:, 0], X_s_n[:, 1], alpha=0.5, label="$X_s$ normalized", s=10)
pl.scatter(X_t_n[:, 0], X_t_n[:, 1], alpha=0.5, label="$X_t$ normalized", s=10)
pl.title("Normalized distributions\n(feature 2 shift clearly visible)")
pl.xlabel("Feature 1 (normalized)")
pl.ylabel("Feature 2 (normalized)")
pl.legend()

pl.tight_layout()
pl.show()
36 changes: 35 additions & 1 deletion ot/sliced.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

import numpy as np
from .backend import get_backend, NumpyBackend
from .utils import list_to_array, get_coordinate_circle
from .utils import list_to_array, get_coordinate_circle, apply_scaler
from .lp import (
wasserstein_circle,
semidiscrete_wasserstein2_unif_circle,
Expand Down Expand Up @@ -76,6 +76,7 @@ def sliced_wasserstein_distance(
projections=None,
seed=None,
log=False,
scaler=None,
):
r"""
Computes a Monte-Carlo approximation of the p-Sliced Wasserstein distance
Expand Down Expand Up @@ -109,6 +110,20 @@ def sliced_wasserstein_distance(
Seed used for random number generator
log: bool, optional
if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
scaler: None, object with .transform(), or callable, optional
Preprocessing applied to X_s and X_t before computing the distance.
Useful for normalizing inputs when features have very different scales.

- ``None`` : no preprocessing (default)
- Object with ``.transform()`` method : e.g. an :class:`ot.utils.DataScaler`
fitted on a representative sample. This is the recommended way to get
stable, consistent normalization across multiple calls (e.g. when
using SWD as a loss in mini-batch training).
- Callable : any function, lambda, or PyTorch transform applied
directly as ``scaler(X_s)`` and ``scaler(X_t)``.

See :class:`ot.utils.DataScaler` for a backend-aware scaler that supports
joint fitting on multiple distributions.

Returns
-------
Expand Down Expand Up @@ -136,6 +151,8 @@ def sliced_wasserstein_distance(

nx = get_backend(X_s, X_t, a, b, projections)

X_s, X_t = apply_scaler(X_s, X_t, scaler)

n = X_s.shape[0]
m = X_t.shape[0]

Expand Down Expand Up @@ -181,6 +198,7 @@ def max_sliced_wasserstein_distance(
projections=None,
seed=None,
log=False,
scaler=None,
):
r"""
Computes a Monte-Carlo approximation of the max p-Sliced Wasserstein distance
Expand Down Expand Up @@ -215,6 +233,20 @@ def max_sliced_wasserstein_distance(
Seed used for random number generator
log: bool, optional
if True, sliced_wasserstein_distance returns the projections used and their associated EMD.
scaler : None, object with .transform(), or callable, optional
Preprocessing applied to X_s and X_t before computing the distance.
Useful for normalizing inputs when features have very different scales.

- ``None`` : no preprocessing (default)
- Object with ``.transform()`` method : e.g. an :class:`ot.utils.DataScaler`
fitted on a representative sample. This is the recommended way to get
stable, consistent normalization across multiple calls (e.g. when
using SWD as a loss in mini-batch training).
- Callable : any function, lambda, or PyTorch transform applied
directly as ``scaler(X_s)`` and ``scaler(X_t)``.

See :class:`ot.utils.DataScaler` for a backend-aware scaler that supports
joint fitting on multiple distributions.

Returns
-------
Expand Down Expand Up @@ -242,6 +274,8 @@ def max_sliced_wasserstein_distance(

nx = get_backend(X_s, X_t, a, b, projections)

X_s, X_t = apply_scaler(X_s, X_t, scaler)

n = X_s.shape[0]
m = X_t.shape[0]

Expand Down
Loading
Loading