fugw.solvers.FUGWSparseSolver#

class fugw.solvers.FUGWSparseSolver(nits_bcd=10, nits_uot=1000, tol_bcd=None, tol_uot=None, tol_loss=None, eval_bcd=1, eval_uot=10, ibpp_eps_base=1, ibpp_nits_sinkhorn=1)#

Solver computing sparse solutions

Methods

fugw_loss(pi, gamma, data_const, ...)

Compute FUGW loss and each of its components.

get_parameters_uot_l2(pi, tuple_weights, ...)

Compute parameters of the L2 loss.

local_biconvex_cost(pi, transpose, ...)

Before each block coordinate descent (BCD) step, the local cost matrix is updated.

solve([alpha, rho_s, rho_t, eps, reg_mode, ...])

Run BCD iterations.

__init__(nits_bcd=10, nits_uot=1000, tol_bcd=None, tol_uot=None, tol_loss=None, eval_bcd=1, eval_uot=10, ibpp_eps_base=1, ibpp_nits_sinkhorn=1)#

Init FUGW solver.

Parameters:
nits_bcd: int or None,

Number of block-coordinate-descent iterations to run. If None, run until tol_bcd or tol_loss is reached. Default: 10

nits_uot: int or None,

Number of solver iteration to run at each BCD iteration If None, run until tol_uot is reached. Default: 1000

tol_bcd: float or None,

Stop the BCD procedure early if the absolute difference between two consecutive transport plans under this threshold. If None, do not stop early. Default: None

tol_uot: float or None,

Stop the BCD procedure early if the absolute difference between two consecutive transport plans under this threshold. If None, do not stop early. Default: None

tol_loss: float or None,

Stop the BCD procedure early if the FUGW loss falls under this threshold. If None, do not stop early. Default: None

eval_bcd: int,

During .fit(), at every eval_bcd step: 1. compute the FUGW loss and store it in an array 2. consider stopping early if tol_loss is not None 3. consider stopping early if tol_bcd is not None Default: 1

eval_uot: int,

During .fit(), at every eval_uot step: 1. consider stopping early if tol_uot is not None Default: 10

ibpp_eps_base: int,

Regularization parameter specific to the ibpp solver. Default: 1

ibpp_nits_sinkhorn: int,

Number of sinkhorn iterations to run within each uot iteration of the ibpp solver. Default: 1

Attributes:
Same as parameters.
fugw_loss(pi, gamma, data_const, tuple_weights, hyperparams)#

Compute FUGW loss and each of its components.

Computes scalar fugw loss, which is a combination of: - a Wasserstein loss on features - a Gromow-Wasserstein loss on geometries - marginal constraints on the computed OT plan - a regularization term (KL, ie entropic, or L2)

Parameters:
pi: torch.Tensor
gamma: torch.Tensor
data_const: tuple
tuple_weights: tuple
hyperparams: tuple
Returns:
l: dict

Dictionary containing the loss and its unweighted components. Keys are: “wasserstein”, “gromov_wasserstein”, “marginal_constraint_dim1”, “marginal_constraint_dim2”, “regularization”, “total”. Values are float or None.

get_parameters_uot_l2(pi, tuple_weights, hyperparams)#

Compute parameters of the L2 loss.

local_biconvex_cost(pi, transpose, data_const, tuple_weights, hyperparams)#

Before each block coordinate descent (BCD) step, the local cost matrix is updated. This local cost is a matrix of size (n, m) which evaluates the cost between every pair of points of the source and target distributions. Then, we run a BCD (sinkhorn, ibpp or mm) step which makes use of this cost to update the transport plans.

solve(alpha=0.5, rho_s=1, rho_t=1, eps=0.01, reg_mode='joint', divergence='kl', F=(None, None), Ds=(None, None), Dt=(None, None), F_val=(None, None), Ds_val=(None, None), Dt_val=(None, None), ws=None, wt=None, init_plan=None, init_duals=None, solver='ibpp', callback_bcd=None, verbose=False)#

Run BCD iterations.

Parameters:
alpha: float, optional
rho_s: float, optional
rho_t: float, optional
eps: float, optional
reg_mode: string, optional
divergence: string, optional
F: (ndarray(n, d+2), ndarray(m, d+2)) or (None, None)
Ds: (ndarray(n, k+2), ndarray(n, k+2)), or (None, None)
Dt: (ndarray(m, k+2), ndarray(m, k+2)), or (None, None)
F_val: (ndarray(n, d+2), ndarray(m, d+2)) or (None, None)
Ds_val: (ndarray(n, k+2), ndarray(n, k+2)), or (None, None)
Dt_val: (ndarray(m, k+2), ndarray(m, k+2)), or (None, None)
ws: ndarray(n), None

Measures assigned to source points.

wt: ndarray(m), None

Measures assigned to target points.

init_plan: torch.tensor sparse, None

Initialisation matrix for sample coupling.

init_duals: torch.tensor sparse, None

Initialisation matrix for sample coupling.

solver: “sinkhorn”, “mm”, “ibpp”

Solver to use.

callback_bcd: callable or None

Callback function called at the end of each BCD step. It will be called with the following arguments:

  • locals (dictionary containing all local variables)

verbose: bool, optional, defaults to False

Log solving process.

Returns:
res: dict
Dictionary containing the following keys:
pi: sparse torch.Tensor of size n x m

Sample matrix.

gamma: sparse torch.Tensor of size d1 x d2

Feature matrix.

duals_pi: tuple of torch.Tensor of size

Duals of pi

duals_gamma: tuple of torch.Tensor of size

Duals of gamma

loss: dict of lists

Dictionary containing the loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated. Keys are: “wasserstein”, “gromov_wasserstein”, “marginal_constraint_dim1”, “marginal_constraint_dim2”, “regularization”, “total”. Values are float or None.

loss_steps: list

BCD steps at the end of which the FUGW loss was evaluated

loss_times: list

Elapsed time at the end of each BCD step for which the FUGW loss was evaluated.