diff --git a/ot/datasets.py b/ot/datasets.py index 6e3be518a..161fce670 100644 --- a/ot/datasets.py +++ b/ot/datasets.py @@ -8,6 +8,7 @@ import numpy as np import scipy as sp +from scipy.stats import ortho_group, multivariate_normal from .utils import check_random_state, deprecated @@ -180,3 +181,97 @@ def make_data_classif(dataset, n, nz=0.5, theta=0, p=0.5, random_state=None, **k def get_data_classif(dataset, n, nz=0.5, theta=0, random_state=None, **kwargs): """Deprecated see make_data_classif""" return make_data_classif(dataset, n, nz=0.5, theta=0, random_state=None, **kwargs) + + +def make_gauss_hd( + ns, nt, p=100, dim=5, m_diff=3, a=(10, 15), b=(3, 3), sub_the_same=False +): + """Generation of source and target domains from Gaussian HD distributions + + Parameters + ---------- + ns : int + number of samples (source) + nt : int + number of samples (target) + p : int + dimension of the ambient space the data live in + dim : (int,int) or int + the intrinsic dimensions of the source and target Gaussian HD distriutions. If a single int the intrinsic dimension is assumed to be the same + m_diff : float + the shift in the first coordinate of the means of the Gaussian HD distributions, i.e. ms_0 and mt_0, respectively (see code) + a : (float, float) + positive floating numbers corresponding to the isotropic variances in the principal subspace, for the source and target distributions, respectively. The same as \delta in :ref:`[1] `, Proposition 2.2 + b : (float, float) + positive floating numbers corresponding to the isotropic variance outside the principal subspace for the source and target distributions, respectively. + sub_the_same : bool + should the source/target Gaussian HD distributions live in the same principal subspace? + + Returns + ------- + Xs : ndarray, shape (ns, p) + `ns` observations of size `p` (source) + Xt : ndarray, shape (nt, p) + `nt` observations of size `p` (destination) + pmts : list + a list containing the parameters of the Gaussian HD distributions + + .. _references-make_gauss_hd: + References + ---------- + + .. [1] Bouveyron, C. & Corneli, M. ("Scaling Optimal Transport to High-Dimensional Gaussian Distributions") + + """ + d = (dim, dim) if isinstance(dim, int) else dim + mu = np.zeros((2, p)) + S = [] + mu[1, 0] = m_diff + Q = [ortho_group.rvs(p) for _ in range(2)] + + if sub_the_same: + Q[1] = Q[0] + + S.append( + Q[0] + @ np.diag(np.hstack((np.full(d[0], a[0]), np.full(p - d[0], b[0])))) + @ Q[0].T + ) + S.append( + Q[1] + @ np.diag(np.hstack((np.full(d[1], a[1]), np.full(p - d[1], b[1])))) + @ Q[1].T + ) + + Xs = multivariate_normal.rvs(mean=mu[0], cov=S[0], size=ns) + Xt = multivariate_normal.rvs(mean=mu[1], cov=S[1], size=ns) + + ms = mu[0] + mt = mu[1] + ds = d[0] + dt = d[1] + sigma2_s = np.array(b[0]) + sigma2_t = np.array(b[1]) + ls = np.repeat(a[0], ds) - sigma2_s + lt = np.repeat(a[1], dt) - sigma2_t + Us = Q[0][:, :ds] + Ut = Q[1][:, :dt] + ds = np.array([ds]) + dt = np.array([dt]) + + prmts = { + "ms": ms, + "mt": mt, + "sigma2_s": sigma2_s, + "sigma2_t": sigma2_t, + "ls": ls, + "lt": lt, + "Us": Us, + "Ut": Ut, + "ds": ds, + "dt": dt, + "Cs": S[0], + "Ct": S[1], + } + + return Xs, Xt, prmts diff --git a/ot/gaussian.py b/ot/gaussian.py index 7c25cd660..771de8285 100644 --- a/ot/gaussian.py +++ b/ot/gaussian.py @@ -93,6 +93,119 @@ def bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False): return A, b +def bures_wasserstein_mapping_hd( + ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt, log=False +): + r"""Return OT linear operator between HD Gaussian distritutions. + + The function estimates the optimal linear operator that aligns the two + HD Gaussian distributions :math:`\mathcal{N}(\mu_s, U_s, l_s, \sigma_s^2, d_s)` + and :math:`\mathcal{N}(\mu_t, U_t, l_t, \sigma_t^2, d_t)` as proposed in + :ref:`[3] `, Th. 2.9 + . + + The linear operator from source to target :math:`M` + + .. math:: + M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} + + where : + + .. math:: + \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} + \Sigma_s^{-1/2} \\ + + \Sigma_s^{1/2} &=\sigma_s I_p + U_s C_s U_s^T \\ + + C_s &=\diag(\sqrt{l_{s1} + \sigma_s^2} - \sigma_s, \dots, \sqrt{l_{sd_s} + \sigma_s^2} - \sigma_s) \\ + + \Sigma_s^{-1/2} &= \frac{1}{\sigma_s} (I_p - U_s D_s U_s^T ) \\ + + D_s &= \diag((\sqrt{l_{s1} + \sigma_s^2} - \sigma_s)/\sqrt{l_{s1} + \sigma_s^2}, \dots, (\sqrt{l_{sd_s} + \sigma_s^2} - \sigma_s)/\sqrt{l_{sd_s} + \sigma_s^2}) \\ + + \Sigma_t &= U_t \diag(l_t) U_t^T + \sigma_t^2 I_p \\ + + \mathbf{b} &= \mu_t - \mathbf{A} \mu_s + + Parameters + ---------- + ms : array-like (p,) + mean of the source distribution + mt : array-like (p,) + mean of the target distribution + Us : array-like (p,ds) + orthogonal matrix spanning the principal subspace of the source distribution + Ut : array-like (p,dt) + orthogonal matrix spanning the principal subspace of the target distribution + ls : array-like (ds,) + the variances associated with the principal sub-axes for the source distribution + lt : array-like (dt,) + the variances associated with the principal sub-axes for the target distribution + sigma_s^2 : array-like (1,) + the residual variance of the source distribution + sigma_t^2 : array-like (1,) + the residual variance of the target distribution + ds : array-like (1,) + the intrinsic dimension of the source distribution + dt : array-like (1,) + the intrinsic dimension of the target distribution + log : bool, optional + record log if True + + + Returns + ------- + A : (d, d) array-like + Linear operator + b : (1, d) array-like + bias + log : dict + log dictionary return only if log==True in parameters + + + .. _references-OT-mapping-linear: + References + ---------- + .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + .. [3] Bouveyron, C. & Corneli, M. ("Scaling Optimal Transport to High-Dimensional Gaussian Distributions") + """ + + ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt = list_to_array( + ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt + ) + nx = get_backend(ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt) + + p = Us.shape[0] + + # source + Cs = nx.diag(nx.sqrt(ls + sigma2_s) - nx.sqrt(sigma2_s)) + Ss_sq = dots(Us, Cs, Us.T) + nx.sqrt(sigma2_s) * nx.eye(p) + Ds = nx.diag((nx.sqrt(ls + sigma2_s) - nx.sqrt(sigma2_s)) / nx.sqrt(ls + sigma2_s)) + Ss_sqinv = (1 / nx.sqrt(sigma2_s)) * (nx.eye(p) - dots(Us, Ds, Us.T)) + + # destination + St = dots(Ut, nx.diag(lt), Ut.T) + sigma2_t * nx.eye(p) + + M0 = nx.sqrtm(dots(Ss_sq, St, Ss_sq)) + + A = dots(Ss_sqinv, M0, Ss_sqinv) + b = mt - nx.dot(ms, A) + + if log: + log = {} + log["Ss_sq"] = Ss_sq + log["Ss_sqinv"] = Ss_sqinv + return A, b, log + else: + return A, b + + def empirical_bures_wasserstein_mapping( xs, xt, reg=1e-6, ws=None, wt=None, bias=True, log=False ): @@ -128,7 +241,7 @@ def empirical_bures_wasserstein_mapping( regularization added to the diagonals of covariances (>0) ws : array-like (ns,1), optional weights for the source samples - wt : array-like (ns,1), optional + wt : array-like (nt,1), optional weights for the target samples bias: boolean, optional estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) @@ -200,6 +313,153 @@ def empirical_bures_wasserstein_mapping( return A, b +def empirical_bures_wasserstein_mapping_hd( + xs, xt, ds, dt, reg=0.0, ws=None, wt=None, bias=True, log=False +): + r"""Return OT HD linear operator between samples. + + The function estimates the optimal linear HD operator that aligns the two + empirical distributions. This is equivalent to estimating the closed + form mapping between two HD Gaussian distributions :math:`\mathcal{N}(\mu_s, U_s, l_s, \sigma_s^2, d_s)` + and :math:`\mathcal{N}(\mu_t, U_t, l_t, \sigma_t^2, d_t)` as proposed in + :ref:`[3] `, Th. 2.9. + + The linear operator from source to target :math:`M` + + .. math:: + M(\mathbf{x})= \mathbf{A} \mathbf{x} + \mathbf{b} + + where : + + .. math:: + \mathbf{A} &= \Sigma_s^{-1/2} \left(\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2} \right)^{1/2} + \Sigma_s^{-1/2} \\ + + \Sigma_s^{1/2} &=\sigma_s I_p + U_s C_s U_s^T \\ + + C_s &=\diag(\sqrt{l_{s1} + \sigma_s^2} - \sigma_s, \dots, \sqrt{l_{sd_s} + \sigma_s^2} - \sigma_s) \\ + + \Sigma_s^{-1/2} &= \frac{1}{\sigma_s} (I_p - U_s D_s U_s^T ) \\ + + D_s &= \diag((\sqrt{l_{s1} + \sigma_s^2} - \sigma_s)/\sqrt{l_{s1} + \sigma_s^2}, \dots, (\sqrt{l_{sd_s} + \sigma_s^2} - \sigma_s)/\sqrt{l_{sd_s} + \sigma_s^2}) \\ + + \Sigma_t &= U_t \diag(l_t) U_t^T + \sigma_t^2 I_p \\ + + \mathbf{b} &= \mu_t - \mathbf{A} \mu_s + + + Parameters + ---------- + xs : array-like (ns,p) + samples in the source domain + xt : array-like (nt,p) + samples in the target domain + ds : array-like (1,) + the intrinsic dimension of the source distribution + dt : array-like(1,) + the intrinsic dimension of the target distribution + reg : float,optional + regularization added to the diagonals of covariances (null by default) + ws : array-like (ns,1), optional + weights for the source samples + wt : array-like (nt,1), optional + weights for the target samples + bias: boolean, optional + estimate bias :math:`\mathbf{b}` else :math:`\mathbf{b} = 0` (default:True) + log : bool, optional + record log if True + + + Returns + ------- + A : (p, p) array-like + Linear operator + b : (1, p) array-like + bias + log : dict + log dictionary return only if log==True in parameters + + + .. _references-OT-mapping-linear: + References + ---------- + .. [1] Knott, M. and Smith, C. S. "On the optimal mapping of + distributions", Journal of Optimization Theory and Applications + Vol 43, 1984 + + .. [2] Peyré, G., & Cuturi, M. (2017). "Computational Optimal + Transport", 2018. + + .. [3] Bouveyron, C. & Corneli, M. ("Scaling Optimal Transport to High-Dimensional Gaussian Distributions") + + """ + + xs, xt, ds, dt = list_to_array(xs, xt, ds, dt) + nx = get_backend(xs, xt, ds, dt) + is_input_finite = is_all_finite(xs, xt, ds, dt) + + p = xs.shape[1] + + if ws is None: + ws = nx.ones((xs.shape[0], 1), type_as=xs) / xs.shape[0] + + if wt is None: + wt = nx.ones((xt.shape[0], 1), type_as=xt) / xt.shape[0] + + if bias: + mxs = nx.dot(ws.T, xs) / nx.sum(ws) + mxt = nx.dot(wt.T, xt) / nx.sum(wt) + + xs = xs - mxs + xt = xt - mxt + else: + mxs = nx.zeros((1, p), type_as=xs) + mxt = nx.zeros((1, p), type_as=xs) + + Cs = nx.dot((xs * ws).T, xs) / nx.sum(ws) + reg * nx.eye(p, type_as=xs) + Ct = nx.dot((xt * wt).T, xt) / nx.sum(wt) + reg * nx.eye(p, type_as=xt) + + eigs = nx.eigh(Cs) + a_s = eigs[0][-ds[0] :] + sgm2_s = (nx.trace(Cs) - nx.sum(a_s)) / (p - ds) + Qs = eigs[1] + Us = Qs[:, -ds[0] :] + ls = a_s - sgm2_s + + eigt = nx.eigh(Ct) + a_t = eigt[0][-dt[0] :] + sgm2_t = (nx.trace(Ct) - nx.sum(a_t)) / (p - dt) + Qt = eigt[1] + Ut = Qt[:, -dt[0] :] + lt = a_t - sgm2_t + + if log: + A, b, log = bures_wasserstein_mapping_hd( + mxs, mxt, Us, Ut, ls, lt, sgm2_s, sgm2_t, ds, dt, log=log + ) + else: + A, b = bures_wasserstein_mapping_hd( + mxs, mxt, Us, Ut, ls, lt, sgm2_s, sgm2_t, ds, dt + ) + + if is_input_finite and not is_all_finite(A, b): + warnings.warn( + "Numerical errors were encountered in ot.gaussian.empirical_bures_wasserstein_mapping_hd. " + "Consider increasing the regularization parameter `reg` or reducing the intrinsic dimensions ds/dt." + ) + + if log: + log["Us"] = Us + log["Ut"] = Ut + log["ls"] = ls + log["lt"] = lt + log["sigma2_s"] = sgm2_s + log["sigma2_t"] = sgm2_t + return A, b, log + else: + return A, b + + def bures_distance(Cs, Ct, paired=False, log=False, nx=None): r"""Return Bures distance. diff --git a/test/test_gaussian.py b/test/test_gaussian.py index 733fcfab9..fdb174caf 100644 --- a/test/test_gaussian.py +++ b/test/test_gaussian.py @@ -10,7 +10,7 @@ import pytest import ot -from ot.datasets import make_data_classif +from ot.datasets import make_data_classif, make_gauss_hd from ot.utils import is_all_finite @@ -42,6 +42,34 @@ def test_bures_wasserstein_mapping(nx): np.testing.assert_allclose(Ct, Cst, rtol=1e-2, atol=1e-2) +def test_bures_wasserstein_mapping_hd(nx): + ns = 100 + nt = 100 + + Xs, Xt, ll = make_gauss_hd(ns, nt, p=50, dim=10, m_diff=5, a=(7, 7), b=(1, 1)) + + ms = ll["ms"] + mt = ll["mt"] + sigma2_s = ll["sigma2_s"] + sigma2_t = ll["sigma2_t"] + ls = ll["ls"] + lt = ll["lt"] + Us = ll["Us"] + Ut = ll["Ut"] + ds = ll["ds"] + dt = ll["dt"] + Cs = ll["Cs"] + Ct = ll["Ct"] + + A_hd, b_hd = ot.gaussian.bures_wasserstein_mapping_hd( + ms, mt, Us, Ut, ls, lt, sigma2_s, sigma2_t, ds, dt, log=False + ) + A, b = ot.gaussian.bures_wasserstein_mapping(ms, mt, Cs, Ct, log=False) + + np.testing.assert_allclose(A_hd, A, rtol=1e-2, atol=1e-2) + np.testing.assert_allclose(b_hd, b, rtol=1e-2, atol=1e-2) + + @pytest.mark.parametrize("bias", [True, False]) def test_empirical_bures_wasserstein_mapping(nx, bias): ns = 50