fugw.mappings.FUGW#

class fugw.mappings.FUGW(alpha=0.5, rho=1, eps=0.01, reg_mode='joint', divergence='kl')#

Class computing dense transport plans.

Methods

fit([source_features, target_features, ...])

Compute transport plan between source and target distributions using feature maps and geometries.

inverse_transform(target_features[, device])

Transport target feature maps using fitted OT plan.

transform(source_features[, id_reg, device])

Transport source feature maps using fitted OT plan.

__init__(alpha=0.5, rho=1, eps=0.01, reg_mode='joint', divergence='kl')#

Init FUGW problem.

Parameters:
alpha: float, optional, defaults to 0.5

Value in ]0, 1[, interpolates the relative importance of the Wasserstein and the Gromov-Wasserstein losses in the FUGW loss (see equation)

rho: float or tuple of 2 floats, optional, defaults to 1

Value in ]0, +inf[, controls the relative importance of the marginal constraints. High values force the mass of each point to be transported ; low values allow for some mass loss

eps: float, optional, defaults to 1e-2

Value in ]0, +inf[, controls the relative importance of the entropy loss

reg_mode: “joint” or “independent”, optional, defaults to “joint”

“joint”: use unbalanced-GW-like regularisation term “independent”: use unbalanced-W-like regularisation term

divergence: string, optional

What divergence to use for the marginal contraints and regularization. Can be “kl” or “l2”. Defaults to “kl”.

Attributes:
alpha: float
rho: float
eps: float
reg_mode: “joint” or “independent”
pi: numpy.ndarray or None

Transport plan computed with .fit()

loss: dict of lists

Dictionary containing the training loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated. Keys are: “wasserstein”, “gromov_wasserstein”, “marginal_constraint_dim1”, “marginal_constraint_dim2”, “regularization”, “total”. Values are float or None.

loss_val: dict of lists

Dictionary containing the validation loss and its unweighted components for each step of the block-coordinate-descent for which the FUGW loss was evaluated. Values are float or None.

loss_steps: list

BCD steps at the end of which the FUGW loss was evaluated

loss_times: list

Elapsed time at the end of each BCD step for which the FUGW loss was evaluated.

fit(source_features=None, target_features=None, source_geometry=None, target_geometry=None, source_features_val=None, target_features_val=None, source_geometry_val=None, target_geometry_val=None, source_weights=None, target_weights=None, init_plan=None, init_duals=None, solver='mm', solver_params={}, callback_bcd=None, device='auto', verbose=False)#

Compute transport plan between source and target distributions using feature maps and geometries. In our case, feature maps are fMRI contrast maps, and geometries are that of the cortical meshes of individuals under study.

Parameters:
source_features: ndarray(n_features, n), optional

Feature maps for source subject. n_features is the number of contrast maps, it should be the same for source and target data. n is the number of nodes on the source graph, it can be different from m, the number of nodes on the target graph. This array should be normalized, otherwise you will run into computational errors.

target_features: ndarray(n_features, m), optional

Feature maps for target subject. This array should be normalized, otherwise you will run into computational errors.

source_geometry: ndarray(n, n)

Kernel matrix of anatomical distances between nodes of source mesh This array should be normalized, otherwise you will run into computational errors.

target_geometry: ndarray(m, m)

Kernel matrix of anatomical distances between nodes of target mesh This array should be normalized, otherwise you will run into computational errors.

source_features_val: ndarray(n_features, n) or None

Feature maps for source subject used for validation. If None, source_features will be used instead.

target_features_val: ndarray(n_features, m) or None

Feature maps for target subject used for validation. If None, target_features will be used instead.

source_geometry_val: ndarray(n, n) or None

Kernel matrix of anatomical distances between nodes of source mesh used for validation. If None, source_geometry will be used instead.

target_geometry_val: ndarray(m, m) or None

Kernel matrix of anatomical distances between nodes of target mesh used for validation. If None, target_geometry will be used instead.

source_weights: ndarray(n) or None

Distribution weights of source nodes. Should sum to 1. If None, eahc node’s weight will be set to 1 / n.

target_weights: ndarray(n) or None

Distribution weights of target nodes. Should sum to 1. If None, eahc node’s weight will be set to 1 / m.

init_plan: ndarray(n, m) or None

Transport plan to use at initialisation. If None, an entropic initialization will be used.

init_duals: tuple of [ndarray(n), ndarray(m)] or None

Dual potentials to use at initialisation.

solver: “sinkhorn” or “mm” or “ibpp”

Solver to use.

solver_params: fugw.solvers.utils.BaseSolver params

Parameters given to the solver.

callback_bcd: callable or None

Callback function called at the end of each BCD step. It will be called with the following arguments:

  • locals (dictionary containing all local variables)

device: “auto” or torch.device

if “auto”: use first available gpu if it’s available, cpu otherwise.

verbose: bool, optional, defaults to False

Log solving process.

Returns:
self: FUGW class object
inverse_transform(target_features, device='auto')#

Transport target feature maps using fitted OT plan. Use GPUs if available.

Parameters:
target_features: ndarray(n_samples, m) or ndarray(m)

Contrast map for target subject

device: “auto” or torch.device

If “auto”: use first available GPU if it’s available, CPU otherwise.

Returns:
transported_data: ndarray(n_samples, n) or ndarray(n)

Contrast map transported in source subject’s space

transform(source_features, id_reg=0, device='auto')#

Transport source feature maps using fitted OT plan. Use GPUs if available.

Parameters:
source_features: ndarray(n_samples, n) or ndarray(n)

Contrast map for source subject

id_reg: float, in the [0, 1] interval, defaults to 0

If source/target share the same geometry, interpolate the transport plan with the identity using the provided coefficient. A value of 1 (resp. 0) will rely solely on the identity (resp. the transport plan).

device: “auto” or torch.device

If “auto”: use first available GPU if it’s available, CPU otherwise.

Returns:
transported_data: ndarray(n_samples, m) or ndarray(m)

Contrast map transported in target subject’s space