Go to the end to download the full example code.
Monitor memory usage at each BCD iteration with callbacks#
In this example, we use a callback function to monitor memory usage at each iteration of the block-coordinate descent (BCD) algorithm. This can be useful to detect memory leaks, or to check that the memory usage is not too high for a given device.
import re
from functools import partial
import matplotlib.pyplot as plt
import torch
from fugw.mappings import FUGW
from fugw.utils import _init_mock_distribution
from rich.console import Console
from rich.table import Table
n_points_source = 50
n_points_target = 40
n_features_train = 2
n_features_test = 2
Let us generate random training data for the source and target distributions
/usr/local/lib/python3.8/site-packages/torch/distributions/ UserWarning:
Singular sample detected.
Features and geometries should be normalized before calling the solver
source_features_train_normalized = source_features_train / torch.linalg.norm(
source_features_train, dim=1
).reshape(-1, 1)
target_features_train_normalized = target_features_train / torch.linalg.norm(
target_features_train, dim=1
).reshape(-1, 1)
source_geometry_normalized = source_geometry / source_geometry.max()
target_geometry_normalized = target_geometry / target_geometry.max()
We define a function to check memory usage at each iteration of the BCD algorithm, as well as util functions to print relevant information. In short, fugw callback functions receive locals(), which is a dictionary of all local variables in the current scope. This allows us to access the tensors that are used in the BCD algorithm. In particular, we filter our tensors that are on our device of interest, and we compute their respective memory usage.
def is_sparse(t):
return str(t.layout).find("sparse") >= 0
def str_size(s):
m = re.match(r"torch\.Size\(\[(.*)\]\)", str(s))
return f"{}"
def str_mem(mem, unit="KB"):
if unit == "KB":
return f"{mem / 1024:,.3f} KB"
elif unit == "MB":
return f"{mem / 1024 ** 2:,.3f} MB"
def check_memory_usage(locals, device=torch.device("cpu")):
console = Console()
variables = []
for name, value in locals.items():
if torch.is_tensor(value) and value.device == device:
variables.append([name, value])
variables = sorted(variables, key=lambda x: x[0].lower())
table = Table()
table.add_column("Size", justify="right")
table.add_column("Numel", justify="right")
table.add_column("Memory allocated", justify="right")
memory_allocated = 0
for name, value in variables:
if is_sparse(value):
s = value.size()
numel = value._nnz()
var_memory_allocated = numel * value.element_size()
name, str_size(s), f"{numel:,}", str_mem(var_memory_allocated)
s = value.size()
numel = value.numel()
var_memory_allocated = numel * value.element_size()
name, str_size(s), f"{numel:,}", str_mem(var_memory_allocated)
memory_allocated += var_memory_allocated
f"Total ({len(variables)})",
if device.type == "cuda":
memory_lines = [
("Memory allocated", str_mem(torch.cuda.memory_allocated(device))),
("Memory cached", str_mem(torch.cuda.memory_cached(device))),
("Memory reserved", str_mem(torch.cuda.memory_reserved(device))),
list(map(lambda x: f"{x[0]}\t{x[1]}", memory_lines)).join("\n")
Let us define the optimization problem to solve
Now, we fit a transport plan between source and target distributions using a sinkhorn solver. Our callback function will be called at each iteration of the BCD algorithm.
device = torch.device("cpu")
memory_at_bcd_step = []
_ =
"nits_bcd": 5,
"nits_uot": 100,
callback_bcd=partial(check_memory_usage, device=device),
[21:41:40] Validation data for feature maps is not provided. Using
training data instead.
Validation data for anatomical kernels is not provided.
Using training data instead.
BCD step 1/5 FUGW loss: 15.431012153625488
Validation loss: 15.431012153625488
┃ Variable ┃ Size ┃ Numel ┃ Memory allocated ┃
│ cost_gamma │ 50, 40 │ 2,000 │ 7.812 KB │
│ cost_pi │ 50, 40 │ 2,000 │ 7.812 KB │
│ Ds │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_sqr │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_sqr_val │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_val │ 50, 50 │ 2,500 │ 9.766 KB │
│ Dt │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_sqr │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_sqr_val │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_val │ 40, 40 │ 1,600 │ 6.250 KB │
│ F │ 50, 40 │ 2,000 │ 7.812 KB │
│ gamma │ 50, 40 │ 2,000 │ 7.812 KB │
│ init_plan │ 50, 40 │ 2,000 │ 7.812 KB │
│ mass_gamma │ │ 1 │ 0.004 KB │
│ mass_pi │ │ 1 │ 0.004 KB │
│ new_eps │ │ 1 │ 0.004 KB │
│ new_rho_s │ │ 1 │ 0.004 KB │
│ new_rho_t │ │ 1 │ 0.004 KB │
│ pi │ 50, 40 │ 2,000 │ 7.812 KB │
│ pi_prev │ 50, 40 │ 2,000 │ 7.812 KB │
│ ws │ 50 │ 50 │ 0.195 KB │
│ ws_dot_wt │ 50, 40 │ 2,000 │ 7.812 KB │
│ wt │ 40 │ 40 │ 0.156 KB │
│ Total (23) │ │ │ 126.934 KB │
BCD step 2/5 FUGW loss: 9.585478782653809
Validation loss: 9.585478782653809
┃ Variable ┃ Size ┃ Numel ┃ Memory allocated ┃
│ cost_gamma │ 50, 40 │ 2,000 │ 7.812 KB │
│ cost_pi │ 50, 40 │ 2,000 │ 7.812 KB │
│ Ds │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_sqr │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_sqr_val │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_val │ 50, 50 │ 2,500 │ 9.766 KB │
│ Dt │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_sqr │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_sqr_val │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_val │ 40, 40 │ 1,600 │ 6.250 KB │
│ F │ 50, 40 │ 2,000 │ 7.812 KB │
│ gamma │ 50, 40 │ 2,000 │ 7.812 KB │
│ init_plan │ 50, 40 │ 2,000 │ 7.812 KB │
│ mass_gamma │ │ 1 │ 0.004 KB │
│ mass_pi │ │ 1 │ 0.004 KB │
│ new_eps │ │ 1 │ 0.004 KB │
│ new_rho_s │ │ 1 │ 0.004 KB │
│ new_rho_t │ │ 1 │ 0.004 KB │
│ pi │ 50, 40 │ 2,000 │ 7.812 KB │
│ pi_prev │ 50, 40 │ 2,000 │ 7.812 KB │
│ ws │ 50 │ 50 │ 0.195 KB │
│ ws_dot_wt │ 50, 40 │ 2,000 │ 7.812 KB │
│ wt │ 40 │ 40 │ 0.156 KB │
│ Total (23) │ │ │ 126.934 KB │
BCD step 3/5 FUGW loss: 5.271676540374756
Validation loss: 5.271676540374756
┃ Variable ┃ Size ┃ Numel ┃ Memory allocated ┃
│ cost_gamma │ 50, 40 │ 2,000 │ 7.812 KB │
│ cost_pi │ 50, 40 │ 2,000 │ 7.812 KB │
│ Ds │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_sqr │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_sqr_val │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_val │ 50, 50 │ 2,500 │ 9.766 KB │
│ Dt │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_sqr │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_sqr_val │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_val │ 40, 40 │ 1,600 │ 6.250 KB │
│ F │ 50, 40 │ 2,000 │ 7.812 KB │
│ gamma │ 50, 40 │ 2,000 │ 7.812 KB │
│ init_plan │ 50, 40 │ 2,000 │ 7.812 KB │
│ mass_gamma │ │ 1 │ 0.004 KB │
│ mass_pi │ │ 1 │ 0.004 KB │
│ new_eps │ │ 1 │ 0.004 KB │
│ new_rho_s │ │ 1 │ 0.004 KB │
│ new_rho_t │ │ 1 │ 0.004 KB │
│ pi │ 50, 40 │ 2,000 │ 7.812 KB │
│ pi_prev │ 50, 40 │ 2,000 │ 7.812 KB │
│ ws │ 50 │ 50 │ 0.195 KB │
│ ws_dot_wt │ 50, 40 │ 2,000 │ 7.812 KB │
│ wt │ 40 │ 40 │ 0.156 KB │
│ Total (23) │ │ │ 126.934 KB │
BCD step 4/5 FUGW loss: 6.342617511749268
Validation loss: 6.342617511749268
┃ Variable ┃ Size ┃ Numel ┃ Memory allocated ┃
│ cost_gamma │ 50, 40 │ 2,000 │ 7.812 KB │
│ cost_pi │ 50, 40 │ 2,000 │ 7.812 KB │
│ Ds │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_sqr │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_sqr_val │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_val │ 50, 50 │ 2,500 │ 9.766 KB │
│ Dt │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_sqr │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_sqr_val │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_val │ 40, 40 │ 1,600 │ 6.250 KB │
│ F │ 50, 40 │ 2,000 │ 7.812 KB │
│ gamma │ 50, 40 │ 2,000 │ 7.812 KB │
│ init_plan │ 50, 40 │ 2,000 │ 7.812 KB │
│ mass_gamma │ │ 1 │ 0.004 KB │
│ mass_pi │ │ 1 │ 0.004 KB │
│ new_eps │ │ 1 │ 0.004 KB │
│ new_rho_s │ │ 1 │ 0.004 KB │
│ new_rho_t │ │ 1 │ 0.004 KB │
│ pi │ 50, 40 │ 2,000 │ 7.812 KB │
│ pi_prev │ 50, 40 │ 2,000 │ 7.812 KB │
│ ws │ 50 │ 50 │ 0.195 KB │
│ ws_dot_wt │ 50, 40 │ 2,000 │ 7.812 KB │
│ wt │ 40 │ 40 │ 0.156 KB │
│ Total (23) │ │ │ 126.934 KB │
BCD step 5/5 FUGW loss: 7.891664028167725
Validation loss: 7.891664028167725
┃ Variable ┃ Size ┃ Numel ┃ Memory allocated ┃
│ cost_gamma │ 50, 40 │ 2,000 │ 7.812 KB │
│ cost_pi │ 50, 40 │ 2,000 │ 7.812 KB │
│ Ds │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_sqr │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_sqr_val │ 50, 50 │ 2,500 │ 9.766 KB │
│ Ds_val │ 50, 50 │ 2,500 │ 9.766 KB │
│ Dt │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_sqr │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_sqr_val │ 40, 40 │ 1,600 │ 6.250 KB │
│ Dt_val │ 40, 40 │ 1,600 │ 6.250 KB │
│ F │ 50, 40 │ 2,000 │ 7.812 KB │
│ gamma │ 50, 40 │ 2,000 │ 7.812 KB │
│ init_plan │ 50, 40 │ 2,000 │ 7.812 KB │
│ mass_gamma │ │ 1 │ 0.004 KB │
│ mass_pi │ │ 1 │ 0.004 KB │
│ new_eps │ │ 1 │ 0.004 KB │
│ new_rho_s │ │ 1 │ 0.004 KB │
│ new_rho_t │ │ 1 │ 0.004 KB │
│ pi │ 50, 40 │ 2,000 │ 7.812 KB │
│ pi_prev │ 50, 40 │ 2,000 │ 7.812 KB │
│ ws │ 50 │ 50 │ 0.195 KB │
│ ws_dot_wt │ 50, 40 │ 2,000 │ 7.812 KB │
│ wt │ 40 │ 40 │ 0.156 KB │
│ Total (23) │ │ │ 126.934 KB │
In this example, we see that fugw’s memory usage is constant.
fig = plt.figure(figsize=(5, 5))
fig.suptitle("Memory usage at each BCD iteration")
ax = fig.add_subplot()
ax.set_xlabel("BCD iteration")
ax.set_ylabel("Memory allocated (KB)")

Total running time of the script: (0 minutes 1.630 seconds)
Estimated memory usage: 61 MB