fugw.mappings.FUGWSparse#
- class fugw.mappings.FUGWSparse(alpha=0.5, rho=1, eps=0.01, reg_mode='joint', divergence='kl')#
- Class computing sparse transport plans - Methods - fit([source_features, target_features, ...])- Compute transport plan between source and target distributions using feature maps and geometries. - inverse_transform(target_features[, device])- Transport target feature maps using fitted OT plan. - transform(source_features[, id_reg, device])- Transport source feature maps using fitted OT plan. - __init__(alpha=0.5, rho=1, eps=0.01, reg_mode='joint', divergence='kl')#
- Init FUGW problem. - Parameters:
- alpha: float, optional, defaults to 0.5
- Value in ]0, 1[, interpolates the relative importance of the Wasserstein and the Gromov-Wasserstein losses in the FUGW loss (see equation) 
- rho: float or tuple of 2 floats, optional, defaults to 1
- Value in ]0, +inf[, controls the relative importance of the marginal constraints. High values force the mass of each point to be transported ; low values allow for some mass loss 
- eps: float, optional, defaults to 1e-2
- Value in ]0, +inf[, controls the relative importance of the entropy loss 
- reg_mode: “joint” or “independent”, optional, defaults to “joint”
- “joint”: use unbalanced-GW-like regularisation term “independent”: use unbalanced-W-like regularisation term 
- divergence: string, optional
- What divergence to use for the marginal contraints and regularization. Can be “kl” or “l2”. Defaults to “kl”. 
 
- Attributes:
- alpha: float
- rho: float
- eps: float
- reg_mode: “joint” or “independent”
- pi: numpy.ndarray or None
- Transport plan computed with - .fit()
- loss: dict of lists
- Dictionary containing the training loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated. Keys are: “wasserstein”, “gromov_wasserstein”, “marginal_constraint_dim1”, “marginal_constraint_dim2”, “regularization”, “total”. Values are float or None. 
- loss_val: dict of lists
- Dictionary containing the validation loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated. Values are float or None. 
- loss_steps: list
- BCD steps at the end of which the FUGW loss was evaluated 
- loss_times: list
- Elapsed time at the end of each BCD step for which the FUGW loss was evaluated. 
 
 
 - fit(source_features=None, target_features=None, source_geometry_embedding=None, target_geometry_embedding=None, source_features_val=None, target_features_val=None, source_geometry_embedding_val=None, target_geometry_embedding_val=None, source_weights=None, target_weights=None, init_plan=None, init_duals=None, solver='mm', solver_params={}, callback_bcd=None, device='auto', verbose=False)#
- Compute transport plan between source and target distributions using feature maps and geometries. In our case, feature maps are fMRI contrast maps, and geometries are that of the cortical meshes of individuals under study. - Parameters:
- source_features: ndarray(n_features, n), optional
- Feature maps for source subject. n_features is the number of contrast maps, it should be the same for source and target data. n is the number of nodes on the source graph, it can be different from m, the number of nodes on the target graph. This array should be normalized, otherwise you will run into computational errors. 
- target_features: ndarray(n_features, m), optional
- Feature maps for target subject. This array should be normalized, otherwise you will run into computational errors. 
- source_geometry_embedding: ndarray(n, k), optional
- Embedding X such that norm(X_i - X_j) approximates the anatomical distance between vertices i and j of the source mesh This array should be normalized, otherwise you will run into computational errors. 
- target_geometry_embedding: ndarray(m, k), optional
- Embedding X such that norm(X_i - X_j) approximates the anatomical distance between vertices i and j of the target mesh This array should be normalized, otherwise you will run into computational errors. 
- source_features_val: ndarray(n_features, n) or None
- Feature maps for source subject used for validation. If None, source_features will be used instead. 
- target_features_val: ndarray(n_features, m) or None
- Feature maps for target subject used for validation. If None, target_features will be used instead. 
- source_geometry_embedding_val: ndarray(n, n) or None
- Kernel matrix of anatomical distances between nodes of source mesh used for validation. If None, source_geometry will be used instead. 
- target_geometry_embedding_val: ndarray(m, m) or None
- Kernel matrix of anatomical distances between nodes of target mesh used for validation. If None, target_geometry will be used instead. 
- source_weights: ndarray(n) or None
- Distribution weights of source nodes. Should sum to 1. If None, each node’s weight will be set to 1 / n. 
- target_weights: ndarray(n) or None
- Distribution weights of target nodes. Should sum to 1. If None, each node’s weight will be set to 1 / m. 
- init_plan: torch.sparse COO or CSR matrix or None
- Torch sparse matrix whose sparsity mask will be that of the transport plan computed by this solver. 
- solver: “mm” or “ibpp”
- Solver to use. 
- solver_params: fugw.solvers.utils.BaseSolver params
- Parameters given to the solver. 
- callback_bcd: callable or None
- Callback function called at the end of each BCD step. It will be called with the following arguments: - locals (dictionary containing all local variables) 
 
- device: “auto” or torch.device
- if “auto”: use first available gpu if it’s available, cpu otherwise. 
- verbose: bool, optional, defaults to False
- Log solving process. 
 
- Returns:
- self: FUGWSparse class object
 
 
 - inverse_transform(target_features, device='auto')#
- Transport target feature maps using fitted OT plan. Use GPUs if available. - Parameters:
- target_features: ndarray(n_samples, m) or ndarray(m)
- Contrast map for target subject 
- device: “auto” or torch.device
- If “auto”: use first available GPU if it’s available, CPU otherwise. 
 
- Returns:
- transported_data: ndarray(n_samples, n) or ndarray(n)
- Contrast map transported in target subject’s space 
 
 
 - transform(source_features, id_reg=0, device='auto')#
- Transport source feature maps using fitted OT plan. Use GPUs if available. - Parameters:
- source_features: ndarray(n_samples, n) or ndarray(n)
- Contrast map for source subject 
- id_reg: float, in the [0, 1] interval, defaults to 0
- If source/target share the same geometry, interpolate the transport plan with the identity using the provided coefficient. A value of 1 (resp. 0) will rely solely on the identity (resp. the transport plan). 
- device: “auto” or torch.device
- If “auto”: use first available GPU if it’s available, CPU otherwise. 
 
- Returns:
- transported_data: ndarray(n_samples, m) or ndarray(m)
- Contrast map transported in target subject’s space