fugw.solvers.FUGWSparseSolver#
- class fugw.solvers.FUGWSparseSolver(nits_bcd=10, nits_uot=1000, tol_bcd=None, tol_uot=None, tol_loss=None, eval_bcd=1, eval_uot=10, ibpp_eps_base=1, ibpp_nits_sinkhorn=1)#
Solver computing sparse solutions
Methods
fugw_loss
(pi, gamma, data_const, ...)Compute FUGW loss and each of its components.
get_parameters_uot_l2
(pi, tuple_weights, ...)Compute parameters of the L2 loss.
local_biconvex_cost
(pi, transpose, ...)Before each block coordinate descent (BCD) step, the local cost matrix is updated.
solve
([alpha, rho_s, rho_t, eps, reg_mode, ...])Run BCD iterations.
- __init__(nits_bcd=10, nits_uot=1000, tol_bcd=None, tol_uot=None, tol_loss=None, eval_bcd=1, eval_uot=10, ibpp_eps_base=1, ibpp_nits_sinkhorn=1)#
Init FUGW solver.
- Parameters:
- nits_bcd: int or None,
Number of block-coordinate-descent iterations to run. If None, run until tol_bcd or tol_loss is reached. Default: 10
- nits_uot: int or None,
Number of solver iteration to run at each BCD iteration If None, run until tol_uot is reached. Default: 1000
- tol_bcd: float or None,
Stop the BCD procedure early if the absolute difference between two consecutive transport plans under this threshold. If None, do not stop early. Default: None
- tol_uot: float or None,
Stop the BCD procedure early if the absolute difference between two consecutive transport plans under this threshold. If None, do not stop early. Default: None
- tol_loss: float or None,
Stop the BCD procedure early if the FUGW loss falls under this threshold. If None, do not stop early. Default: None
- eval_bcd: int,
During .fit(), at every eval_bcd step: 1. compute the FUGW loss and store it in an array 2. consider stopping early if tol_loss is not None 3. consider stopping early if tol_bcd is not None Default: 1
- eval_uot: int,
During .fit(), at every eval_uot step: 1. consider stopping early if tol_uot is not None Default: 10
- ibpp_eps_base: int,
Regularization parameter specific to the ibpp solver. Default: 1
- ibpp_nits_sinkhorn: int,
Number of sinkhorn iterations to run within each uot iteration of the ibpp solver. Default: 1
- Attributes:
- Same as parameters.
- fugw_loss(pi, gamma, data_const, tuple_weights, hyperparams)#
Compute FUGW loss and each of its components.
Computes scalar fugw loss, which is a combination of: - a Wasserstein loss on features - a Gromow-Wasserstein loss on geometries - marginal constraints on the computed OT plan - a regularization term (KL, ie entropic, or L2)
- Parameters:
- pi: torch.Tensor
- gamma: torch.Tensor
- data_const: tuple
- tuple_weights: tuple
- hyperparams: tuple
- Returns:
- l: dict
Dictionary containing the loss and its unweighted components. Keys are: “wasserstein”, “gromov_wasserstein”, “marginal_constraint_dim1”, “marginal_constraint_dim2”, “regularization”, “total”. Values are float or None.
- get_parameters_uot_l2(pi, tuple_weights, hyperparams)#
Compute parameters of the L2 loss.
- local_biconvex_cost(pi, transpose, data_const, tuple_weights, hyperparams)#
Before each block coordinate descent (BCD) step, the local cost matrix is updated. This local cost is a matrix of size (n, m) which evaluates the cost between every pair of points of the source and target distributions. Then, we run a BCD (sinkhorn, ibpp or mm) step which makes use of this cost to update the transport plans.
- solve(alpha=0.5, rho_s=1, rho_t=1, eps=0.01, reg_mode='joint', divergence='kl', F=(None, None), Ds=(None, None), Dt=(None, None), F_val=(None, None), Ds_val=(None, None), Dt_val=(None, None), ws=None, wt=None, init_plan=None, init_duals=None, solver='ibpp', callback_bcd=None, verbose=False)#
Run BCD iterations.
- Parameters:
- alpha: float, optional
- rho_s: float, optional
- rho_t: float, optional
- eps: float, optional
- reg_mode: string, optional
- divergence: string, optional
- F: (ndarray(n, d+2), ndarray(m, d+2)) or (None, None)
- Ds: (ndarray(n, k+2), ndarray(n, k+2)), or (None, None)
- Dt: (ndarray(m, k+2), ndarray(m, k+2)), or (None, None)
- F_val: (ndarray(n, d+2), ndarray(m, d+2)) or (None, None)
- Ds_val: (ndarray(n, k+2), ndarray(n, k+2)), or (None, None)
- Dt_val: (ndarray(m, k+2), ndarray(m, k+2)), or (None, None)
- ws: ndarray(n), None
Measures assigned to source points.
- wt: ndarray(m), None
Measures assigned to target points.
- init_plan: torch.tensor sparse, None
Initialisation matrix for sample coupling.
- init_duals: torch.tensor sparse, None
Initialisation matrix for sample coupling.
- solver: “sinkhorn”, “mm”, “ibpp”
Solver to use.
- callback_bcd: callable or None
Callback function called at the end of each BCD step. It will be called with the following arguments:
locals (dictionary containing all local variables)
- verbose: bool, optional, defaults to False
Log solving process.
- Returns:
- res: dict
- Dictionary containing the following keys:
- pi: sparse torch.Tensor of size n x m
Sample matrix.
- gamma: sparse torch.Tensor of size d1 x d2
Feature matrix.
- duals_pi: tuple of torch.Tensor of size
Duals of pi
- duals_gamma: tuple of torch.Tensor of size
Duals of gamma
- loss: dict of lists
Dictionary containing the loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated. Keys are: “wasserstein”, “gromov_wasserstein”, “marginal_constraint_dim1”, “marginal_constraint_dim2”, “regularization”, “total”. Values are float or None.
- loss_steps: list
BCD steps at the end of which the FUGW loss was evaluated
- loss_times: list
Elapsed time at the end of each BCD step for which the FUGW loss was evaluated.