fugw.solvers.FUGWSolver#
- class fugw.solvers.FUGWSolver(nits_bcd=10, nits_uot=1000, tol_bcd=None, tol_uot=None, tol_loss=None, eval_bcd=1, eval_uot=10, ibpp_eps_base=1, ibpp_nits_sinkhorn=1)#
Solver computing dense solutions.
Methods
fugw_loss
(pi, gamma, data_const, ...)Compute the FUGW loss and each of its components.
get_parameters_uot_l2
(pi, tuple_weights, ...)Compute parameters of the L2 loss.
local_biconvex_cost
(pi, transpose, ...)Compute a matrix representing the local biconvex cost.
solve
([alpha, rho_s, rho_t, eps, reg_mode, ...])Run BCD iterations.
- __init__(nits_bcd=10, nits_uot=1000, tol_bcd=None, tol_uot=None, tol_loss=None, eval_bcd=1, eval_uot=10, ibpp_eps_base=1, ibpp_nits_sinkhorn=1)#
Init FUGW solver.
- Parameters:
- nits_bcd: int or None,
Number of block-coordinate-descent iterations to run. If None, run until tol_bcd or tol_loss is reached. Default: 10
- nits_uot: int or None,
Number of solver iteration to run at each BCD iteration If None, run until tol_uot is reached. Default: 1000
- tol_bcd: float or None,
Stop the BCD procedure early if the absolute difference between two consecutive transport plans under this threshold. If None, do not stop early. Default: None
- tol_uot: float or None,
Stop the BCD procedure early if the absolute difference between two consecutive transport plans under this threshold. If None, do not stop early. Default: None
- tol_loss: float or None,
Stop the BCD procedure early if the FUGW loss falls under this threshold. If None, do not stop early. Default: None
- eval_bcd: int,
During .fit(), at every eval_bcd step: 1. compute the FUGW loss and store it in an array 2. consider stopping early if tol_loss is not None 3. consider stopping early if tol_bcd is not None Default: 1
- eval_uot: int,
During .fit(), at every eval_uot step: 1. consider stopping early if tol_uot is not None Default: 10
- ibpp_eps_base: int,
Regularization parameter specific to the ibpp solver. Default: 1
- ibpp_nits_sinkhorn: int,
Number of sinkhorn iterations to run within each uot iteration of the ibpp solver. Default: 1
- Attributes:
- Same as parameters.
- fugw_loss(pi, gamma, data_const, tuple_weights, hyperparams)#
Compute the FUGW loss and each of its components.
Computes scalar fugw loss, which is a combination of: - a Wasserstein loss on features - a Gromow-Wasserstein loss on geometries - marginal constraints on the computed OT plan - a regularization term (KL, ie entropic, or L2)
- Parameters:
- pi: torch.Tensor
- gamma: torch.Tensor
- data_const: tuple
- tuple_weights: tuple
- hyperparams: tuple
- Returns:
- l: dict
Dictionary containing the loss and its unweighted components. Keys are: “wasserstein”, “gromov_wasserstein”, “marginal_constraint_dim1”, “marginal_constraint_dim2”, “regularization”, “total”. Values are float or None.
- get_parameters_uot_l2(pi, tuple_weights, hyperparams)#
Compute parameters of the L2 loss.
- local_biconvex_cost(pi, transpose, data_const, tuple_weights, hyperparams)#
Compute a matrix representing the local biconvex cost.
Before each block coordinate descent (BCD) step, the local cost matrix is updated. This local cost is a matrix of size (n, m) which evaluates the cost between every pair of points of the source and target distributions. Then, we run a BCD (sinkhorn, ibpp or mm) step which makes use of this cost to update the transport plans.
- Parameters:
- pi: torch.Tensor
Transport plan.
- transpose: bool
Whether to transpose the transport plan.
- data_const: tuple
- Tuple containing the following elements:
- X_sqr: torch.Tensor of size n x n
Squared distances between source points.
- Y_sqr: torch.Tensor of size m x m
Squared distances between target points.
- X: torch.Tensor of size n x d1
Source points.
- Y: torch.Tensor of size m x d2
Target points.
- D: torch.Tensor of size n x m
Kernel matrix between the source and target training features.
- tuple_weights: tuple
- Tuple containing the following elements:
- ws: torch.Tensor of size n
Measures assigned to source points.
- wt: torch.Tensor of size m
Measures assigned to target points.
- ws_dot_wt: torch.Tensor of size n x m
Outer product of ws and wt.
- hyperparams: tuple
- Tuple containing the following elements:
- rho_s: float
Regularization parameter for source marginal constraints.
- rho_t: float
Regularization parameter for target marginal constraints.
- eps: float
Regularization parameter for joint regularization.
- alpha: float
Weight of the Gromov-Wasserstein loss.
- reg_mode: string
Regularization mode.
- divergence: string
Divergence to use.
- Returns:
- cost: torch.Tensor
Local biconvex cost matrix of size (n, m).
- solve(alpha=0.5, rho_s=1, rho_t=1, eps=0.01, reg_mode='joint', divergence='kl', F=None, Ds=None, Dt=None, F_val=None, Ds_val=None, Dt_val=None, ws=None, wt=None, init_plan=None, init_duals=None, solver='sinkhorn', callback_bcd=None, verbose=False)#
Run BCD iterations.
- Parameters:
- alpha: float, optional
- rho_s: float, optional
- rho_t: float, optional
- eps: float, optional
- reg_mode: string, optional
- divergence: string, optional
- F: matrix of size n x m.
Kernel matrix between the source and target training features.
- Ds: matrix of size n x n
- Dt: matrix of size m x m
- F_val: matrix of size n x m, None
Kernel matrix between the source and target validation features.
- Ds_val: matrix of size n x n, None
- Dt_val: matrix of size m x m, None
- ws: ndarray(n), None
Measures assigned to source points.
- wt: ndarray(m), None
Measures assigned to target points.
- init_plan: matrix of size n x m if not None.
Initialization matrix for coupling.
- init_duals: tuple or None
Initialization duals for coupling.
- solver: “sinkhorn”, “mm”, “ibpp”
Solver to use.
- callback_bcd: callable or None
Callback function called at the end of each BCD step. It will be called with the following arguments:
locals (dictionary containing all local variables)
- verbose: bool, optional, defaults to False
Log solving process.
- Returns:
- res: dict
- Dictionary containing the following keys:
- pi: torch.Tensor of size n x m
Sample matrix.
- gamma: torch.Tensor of size d1 x d2
Feature matrix.
- duals_pi: tuple of torch.Tensor of size
Duals of pi
- duals_gamma: tuple of torch.Tensor of size
Duals of gamma
- loss: dict of lists
Dictionary containing the loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated. Keys are: “wasserstein”, “gromov_wasserstein”, “marginal_constraint_dim1”, “marginal_constraint_dim2”, “regularization”, “total”. Values are float or None.
- loss_val: dict of lists
Dictionary containing the loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated on the validation set.
- loss_steps: list
BCD steps at the end of which the FUGW loss was evaluated
- loss_times: list
Elapsed time at the end of each BCD step for which the FUGW loss was evaluated.