fugw.solvers.FUGWSolver#

class fugw.solvers.FUGWSolver(nits_bcd=10, nits_uot=1000, tol_bcd=None, tol_uot=None, tol_loss=None, eval_bcd=1, eval_uot=10, ibpp_eps_base=1, ibpp_nits_sinkhorn=1)#

Solver computing dense solutions.

Methods

fugw_loss(pi, gamma, data_const, ...)

Compute the FUGW loss and each of its components.

get_parameters_uot_l2(pi, tuple_weights, ...)

Compute parameters of the L2 loss.

local_biconvex_cost(pi, transpose, ...)

Compute a matrix representing the local biconvex cost.

solve([alpha, rho_s, rho_t, eps, reg_mode, ...])

Run BCD iterations.

__init__(nits_bcd=10, nits_uot=1000, tol_bcd=None, tol_uot=None, tol_loss=None, eval_bcd=1, eval_uot=10, ibpp_eps_base=1, ibpp_nits_sinkhorn=1)#

Init FUGW solver.

Parameters:
nits_bcd: int or None,

Number of block-coordinate-descent iterations to run. If None, run until tol_bcd or tol_loss is reached. Default: 10

nits_uot: int or None,

Number of solver iteration to run at each BCD iteration If None, run until tol_uot is reached. Default: 1000

tol_bcd: float or None,

Stop the BCD procedure early if the absolute difference between two consecutive transport plans under this threshold. If None, do not stop early. Default: None

tol_uot: float or None,

Stop the BCD procedure early if the absolute difference between two consecutive transport plans under this threshold. If None, do not stop early. Default: None

tol_loss: float or None,

Stop the BCD procedure early if the FUGW loss falls under this threshold. If None, do not stop early. Default: None

eval_bcd: int,

During .fit(), at every eval_bcd step: 1. compute the FUGW loss and store it in an array 2. consider stopping early if tol_loss is not None 3. consider stopping early if tol_bcd is not None Default: 1

eval_uot: int,

During .fit(), at every eval_uot step: 1. consider stopping early if tol_uot is not None Default: 10

ibpp_eps_base: int,

Regularization parameter specific to the ibpp solver. Default: 1

ibpp_nits_sinkhorn: int,

Number of sinkhorn iterations to run within each uot iteration of the ibpp solver. Default: 1

Attributes:
Same as parameters.
fugw_loss(pi, gamma, data_const, tuple_weights, hyperparams)#

Compute the FUGW loss and each of its components.

Computes scalar fugw loss, which is a combination of: - a Wasserstein loss on features - a Gromow-Wasserstein loss on geometries - marginal constraints on the computed OT plan - a regularization term (KL, ie entropic, or L2)

Parameters:
pi: torch.Tensor
gamma: torch.Tensor
data_const: tuple
tuple_weights: tuple
hyperparams: tuple
Returns:
l: dict

Dictionary containing the loss and its unweighted components. Keys are: “wasserstein”, “gromov_wasserstein”, “marginal_constraint_dim1”, “marginal_constraint_dim2”, “regularization”, “total”. Values are float or None.

get_parameters_uot_l2(pi, tuple_weights, hyperparams)#

Compute parameters of the L2 loss.

local_biconvex_cost(pi, transpose, data_const, tuple_weights, hyperparams)#

Compute a matrix representing the local biconvex cost.

Before each block coordinate descent (BCD) step, the local cost matrix is updated. This local cost is a matrix of size (n, m) which evaluates the cost between every pair of points of the source and target distributions. Then, we run a BCD (sinkhorn, ibpp or mm) step which makes use of this cost to update the transport plans.

Parameters:
pi: torch.Tensor

Transport plan.

transpose: bool

Whether to transpose the transport plan.

data_const: tuple
Tuple containing the following elements:
  • X_sqr: torch.Tensor of size n x n

    Squared distances between source points.

  • Y_sqr: torch.Tensor of size m x m

    Squared distances between target points.

  • X: torch.Tensor of size n x d1

    Source points.

  • Y: torch.Tensor of size m x d2

    Target points.

  • D: torch.Tensor of size n x m

    Kernel matrix between the source and target training features.

tuple_weights: tuple
Tuple containing the following elements:
  • ws: torch.Tensor of size n

    Measures assigned to source points.

  • wt: torch.Tensor of size m

    Measures assigned to target points.

  • ws_dot_wt: torch.Tensor of size n x m

    Outer product of ws and wt.

hyperparams: tuple
Tuple containing the following elements:
  • rho_s: float

    Regularization parameter for source marginal constraints.

  • rho_t: float

    Regularization parameter for target marginal constraints.

  • eps: float

    Regularization parameter for joint regularization.

  • alpha: float

    Weight of the Gromov-Wasserstein loss.

  • reg_mode: string

    Regularization mode.

  • divergence: string

    Divergence to use.

Returns:
cost: torch.Tensor

Local biconvex cost matrix of size (n, m).

solve(alpha=0.5, rho_s=1, rho_t=1, eps=0.01, reg_mode='joint', divergence='kl', F=None, Ds=None, Dt=None, F_val=None, Ds_val=None, Dt_val=None, ws=None, wt=None, init_plan=None, init_duals=None, solver='sinkhorn', callback_bcd=None, verbose=False)#

Run BCD iterations.

Parameters:
alpha: float, optional
rho_s: float, optional
rho_t: float, optional
eps: float, optional
reg_mode: string, optional
divergence: string, optional
F: matrix of size n x m.

Kernel matrix between the source and target training features.

Ds: matrix of size n x n
Dt: matrix of size m x m
F_val: matrix of size n x m, None

Kernel matrix between the source and target validation features.

Ds_val: matrix of size n x n, None
Dt_val: matrix of size m x m, None
ws: ndarray(n), None

Measures assigned to source points.

wt: ndarray(m), None

Measures assigned to target points.

init_plan: matrix of size n x m if not None.

Initialization matrix for coupling.

init_duals: tuple or None

Initialization duals for coupling.

solver: “sinkhorn”, “mm”, “ibpp”

Solver to use.

callback_bcd: callable or None

Callback function called at the end of each BCD step. It will be called with the following arguments:

  • locals (dictionary containing all local variables)

verbose: bool, optional, defaults to False

Log solving process.

Returns:
res: dict
Dictionary containing the following keys:
pi: torch.Tensor of size n x m

Sample matrix.

gamma: torch.Tensor of size d1 x d2

Feature matrix.

duals_pi: tuple of torch.Tensor of size

Duals of pi

duals_gamma: tuple of torch.Tensor of size

Duals of gamma

loss: dict of lists

Dictionary containing the loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated. Keys are: “wasserstein”, “gromov_wasserstein”, “marginal_constraint_dim1”, “marginal_constraint_dim2”, “regularization”, “total”. Values are float or None.

loss_val: dict of lists

Dictionary containing the loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated on the validation set.

loss_steps: list

BCD steps at the end of which the FUGW loss was evaluated

loss_times: list

Elapsed time at the end of each BCD step for which the FUGW loss was evaluated.