fugw.solvers.FUGWSolver#
- class fugw.solvers.FUGWSolver(nits_bcd=10, nits_uot=1000, tol_bcd=None, tol_uot=None, tol_loss=None, eval_bcd=1, eval_uot=10, ibpp_eps_base=1, ibpp_nits_sinkhorn=1)#
- Solver computing dense solutions. - Methods - fugw_loss(pi, gamma, data_const, ...)- Compute the FUGW loss and each of its components. - get_parameters_uot_l2(pi, tuple_weights, ...)- Compute parameters of the L2 loss. - local_biconvex_cost(pi, transpose, ...)- Compute a matrix representing the local biconvex cost. - solve([alpha, rho_s, rho_t, eps, reg_mode, ...])- Run BCD iterations. - __init__(nits_bcd=10, nits_uot=1000, tol_bcd=None, tol_uot=None, tol_loss=None, eval_bcd=1, eval_uot=10, ibpp_eps_base=1, ibpp_nits_sinkhorn=1)#
- Init FUGW solver. - Parameters:
- nits_bcd: int or None,
- Number of block-coordinate-descent iterations to run. If None, run until tol_bcd or tol_loss is reached. Default: 10 
- nits_uot: int or None,
- Number of solver iteration to run at each BCD iteration If None, run until tol_uot is reached. Default: 1000 
- tol_bcd: float or None,
- Stop the BCD procedure early if the absolute difference between two consecutive transport plans under this threshold. If None, do not stop early. Default: None 
- tol_uot: float or None,
- Stop the BCD procedure early if the absolute difference between two consecutive transport plans under this threshold. If None, do not stop early. Default: None 
- tol_loss: float or None,
- Stop the BCD procedure early if the FUGW loss falls under this threshold. If None, do not stop early. Default: None 
- eval_bcd: int,
- During .fit(), at every eval_bcd step: 1. compute the FUGW loss and store it in an array 2. consider stopping early if tol_loss is not None 3. consider stopping early if tol_bcd is not None Default: 1 
- eval_uot: int,
- During .fit(), at every eval_uot step: 1. consider stopping early if tol_uot is not None Default: 10 
- ibpp_eps_base: int,
- Regularization parameter specific to the ibpp solver. Default: 1 
- ibpp_nits_sinkhorn: int,
- Number of sinkhorn iterations to run within each uot iteration of the ibpp solver. Default: 1 
 
- Attributes:
- Same as parameters.
 
 
 - fugw_loss(pi, gamma, data_const, tuple_weights, hyperparams)#
- Compute the FUGW loss and each of its components. - Computes scalar fugw loss, which is a combination of: - a Wasserstein loss on features - a Gromow-Wasserstein loss on geometries - marginal constraints on the computed OT plan - a regularization term (KL, ie entropic, or L2) - Parameters:
- pi: torch.Tensor
- gamma: torch.Tensor
- data_const: tuple
- tuple_weights: tuple
- hyperparams: tuple
 
- Returns:
- l: dict
- Dictionary containing the loss and its unweighted components. Keys are: “wasserstein”, “gromov_wasserstein”, “marginal_constraint_dim1”, “marginal_constraint_dim2”, “regularization”, “total”. Values are float or None. 
 
 
 - get_parameters_uot_l2(pi, tuple_weights, hyperparams)#
- Compute parameters of the L2 loss. 
 - local_biconvex_cost(pi, transpose, data_const, tuple_weights, hyperparams)#
- Compute a matrix representing the local biconvex cost. - Before each block coordinate descent (BCD) step, the local cost matrix is updated. This local cost is a matrix of size (n, m) which evaluates the cost between every pair of points of the source and target distributions. Then, we run a BCD (sinkhorn, ibpp or mm) step which makes use of this cost to update the transport plans. - Parameters:
- pi: torch.Tensor
- Transport plan. 
- transpose: bool
- Whether to transpose the transport plan. 
- data_const: tuple
- Tuple containing the following elements:
- X_sqr: torch.Tensor of size n x n
- Squared distances between source points. 
 
- Y_sqr: torch.Tensor of size m x m
- Squared distances between target points. 
 
- X: torch.Tensor of size n x d1
- Source points. 
 
- Y: torch.Tensor of size m x d2
- Target points. 
 
- D: torch.Tensor of size n x m
- Kernel matrix between the source and target training features. 
 
 
 
- tuple_weights: tuple
- Tuple containing the following elements:
- ws: torch.Tensor of size n
- Measures assigned to source points. 
 
- wt: torch.Tensor of size m
- Measures assigned to target points. 
 
- ws_dot_wt: torch.Tensor of size n x m
- Outer product of ws and wt. 
 
 
 
- hyperparams: tuple
- Tuple containing the following elements:
- rho_s: float
- Regularization parameter for source marginal constraints. 
 
- rho_t: float
- Regularization parameter for target marginal constraints. 
 
- eps: float
- Regularization parameter for joint regularization. 
 
- alpha: float
- Weight of the Gromov-Wasserstein loss. 
 
- reg_mode: string
- Regularization mode. 
 
- divergence: string
- Divergence to use. 
 
 
 
 
- Returns:
- cost: torch.Tensor
- Local biconvex cost matrix of size (n, m). 
 
 
 - solve(alpha=0.5, rho_s=1, rho_t=1, eps=0.01, reg_mode='joint', divergence='kl', F=None, Ds=None, Dt=None, F_val=None, Ds_val=None, Dt_val=None, ws=None, wt=None, init_plan=None, init_duals=None, solver='sinkhorn', callback_bcd=None, verbose=False)#
- Run BCD iterations. - Parameters:
- alpha: float, optional
- rho_s: float, optional
- rho_t: float, optional
- eps: float, optional
- reg_mode: string, optional
- divergence: string, optional
- F: matrix of size n x m.
- Kernel matrix between the source and target training features. 
- Ds: matrix of size n x n
- Dt: matrix of size m x m
- F_val: matrix of size n x m, None
- Kernel matrix between the source and target validation features. 
- Ds_val: matrix of size n x n, None
- Dt_val: matrix of size m x m, None
- ws: ndarray(n), None
- Measures assigned to source points. 
- wt: ndarray(m), None
- Measures assigned to target points. 
- init_plan: matrix of size n x m if not None.
- Initialization matrix for coupling. 
- init_duals: tuple or None
- Initialization duals for coupling. 
- solver: “sinkhorn”, “mm”, “ibpp”
- Solver to use. 
- callback_bcd: callable or None
- Callback function called at the end of each BCD step. It will be called with the following arguments: - locals (dictionary containing all local variables) 
 
- verbose: bool, optional, defaults to False
- Log solving process. 
 
- Returns:
- res: dict
- Dictionary containing the following keys:
- pi: torch.Tensor of size n x m
- Sample matrix. 
- gamma: torch.Tensor of size d1 x d2
- Feature matrix. 
- duals_pi: tuple of torch.Tensor of size
- Duals of pi 
- duals_gamma: tuple of torch.Tensor of size
- Duals of gamma 
- loss: dict of lists
- Dictionary containing the loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated. Keys are: “wasserstein”, “gromov_wasserstein”, “marginal_constraint_dim1”, “marginal_constraint_dim2”, “regularization”, “total”. Values are float or None. 
- loss_val: dict of lists
- Dictionary containing the loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated on the validation set. 
- loss_steps: list
- BCD steps at the end of which the FUGW loss was evaluated 
- loss_times: list
- Elapsed time at the end of each BCD step for which the FUGW loss was evaluated.