fugw.mappings.FUGWSparse#
- class fugw.mappings.FUGWSparse(alpha=0.5, rho=1, eps=0.01, reg_mode='joint', divergence='kl')#
Class computing sparse transport plans
Methods
fit
([source_features, target_features, ...])Compute transport plan between source and target distributions using feature maps and geometries.
inverse_transform
(target_features[, device])Transport target feature maps using fitted OT plan.
transform
(source_features[, id_reg, device])Transport source feature maps using fitted OT plan.
- __init__(alpha=0.5, rho=1, eps=0.01, reg_mode='joint', divergence='kl')#
Init FUGW problem.
- Parameters:
- alpha: float, optional, defaults to 0.5
Value in ]0, 1[, interpolates the relative importance of the Wasserstein and the Gromov-Wasserstein losses in the FUGW loss (see equation)
- rho: float or tuple of 2 floats, optional, defaults to 1
Value in ]0, +inf[, controls the relative importance of the marginal constraints. High values force the mass of each point to be transported ; low values allow for some mass loss
- eps: float, optional, defaults to 1e-2
Value in ]0, +inf[, controls the relative importance of the entropy loss
- reg_mode: “joint” or “independent”, optional, defaults to “joint”
“joint”: use unbalanced-GW-like regularisation term “independent”: use unbalanced-W-like regularisation term
- divergence: string, optional
What divergence to use for the marginal contraints and regularization. Can be “kl” or “l2”. Defaults to “kl”.
- Attributes:
- alpha: float
- rho: float
- eps: float
- reg_mode: “joint” or “independent”
- pi: numpy.ndarray or None
Transport plan computed with
.fit()
- loss: dict of lists
Dictionary containing the training loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated. Keys are: “wasserstein”, “gromov_wasserstein”, “marginal_constraint_dim1”, “marginal_constraint_dim2”, “regularization”, “total”. Values are float or None.
- loss_val: dict of lists
Dictionary containing the validation loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated. Values are float or None.
- loss_steps: list
BCD steps at the end of which the FUGW loss was evaluated
- loss_times: list
Elapsed time at the end of each BCD step for which the FUGW loss was evaluated.
- fit(source_features=None, target_features=None, source_geometry_embedding=None, target_geometry_embedding=None, source_features_val=None, target_features_val=None, source_geometry_embedding_val=None, target_geometry_embedding_val=None, source_weights=None, target_weights=None, init_plan=None, init_duals=None, solver='mm', solver_params={}, callback_bcd=None, device='auto', verbose=False)#
Compute transport plan between source and target distributions using feature maps and geometries. In our case, feature maps are fMRI contrast maps, and geometries are that of the cortical meshes of individuals under study.
- Parameters:
- source_features: ndarray(n_features, n), optional
Feature maps for source subject. n_features is the number of contrast maps, it should be the same for source and target data. n is the number of nodes on the source graph, it can be different from m, the number of nodes on the target graph. This array should be normalized, otherwise you will run into computational errors.
- target_features: ndarray(n_features, m), optional
Feature maps for target subject. This array should be normalized, otherwise you will run into computational errors.
- source_geometry_embedding: ndarray(n, k), optional
Embedding X such that norm(X_i - X_j) approximates the anatomical distance between vertices i and j of the source mesh This array should be normalized, otherwise you will run into computational errors.
- target_geometry_embedding: ndarray(m, k), optional
Embedding X such that norm(X_i - X_j) approximates the anatomical distance between vertices i and j of the target mesh This array should be normalized, otherwise you will run into computational errors.
- source_features_val: ndarray(n_features, n) or None
Feature maps for source subject used for validation. If None, source_features will be used instead.
- target_features_val: ndarray(n_features, m) or None
Feature maps for target subject used for validation. If None, target_features will be used instead.
- source_geometry_embedding_val: ndarray(n, n) or None
Kernel matrix of anatomical distances between nodes of source mesh used for validation. If None, source_geometry will be used instead.
- target_geometry_embedding_val: ndarray(m, m) or None
Kernel matrix of anatomical distances between nodes of target mesh used for validation. If None, target_geometry will be used instead.
- source_weights: ndarray(n) or None
Distribution weights of source nodes. Should sum to 1. If None, each node’s weight will be set to 1 / n.
- target_weights: ndarray(n) or None
Distribution weights of target nodes. Should sum to 1. If None, each node’s weight will be set to 1 / m.
- init_plan: torch.sparse COO or CSR matrix or None
Torch sparse matrix whose sparsity mask will be that of the transport plan computed by this solver.
- solver: “mm” or “ibpp”
Solver to use.
- solver_params: fugw.solvers.utils.BaseSolver params
Parameters given to the solver.
- callback_bcd: callable or None
Callback function called at the end of each BCD step. It will be called with the following arguments:
locals (dictionary containing all local variables)
- device: “auto” or torch.device
if “auto”: use first available gpu if it’s available, cpu otherwise.
- verbose: bool, optional, defaults to False
Log solving process.
- Returns:
- self: FUGWSparse class object
- inverse_transform(target_features, device='auto')#
Transport target feature maps using fitted OT plan. Use GPUs if available.
- Parameters:
- target_features: ndarray(n_samples, m) or ndarray(m)
Contrast map for target subject
- device: “auto” or torch.device
If “auto”: use first available GPU if it’s available, CPU otherwise.
- Returns:
- transported_data: ndarray(n_samples, n) or ndarray(n)
Contrast map transported in target subject’s space
- transform(source_features, id_reg=0, device='auto')#
Transport source feature maps using fitted OT plan. Use GPUs if available.
- Parameters:
- source_features: ndarray(n_samples, n) or ndarray(n)
Contrast map for source subject
- id_reg: float, in the [0, 1] interval, defaults to 0
If source/target share the same geometry, interpolate the transport plan with the identity using the provided coefficient. A value of 1 (resp. 0) will rely solely on the identity (resp. the transport plan).
- device: “auto” or torch.device
If “auto”: use first available GPU if it’s available, CPU otherwise.
- Returns:
- transported_data: ndarray(n_samples, m) or ndarray(m)
Contrast map transported in target subject’s space