Skip to content

visualize_heatmap 🧠

Functions to visualize analyzer (XAI) heatmaps.

Quick Start

Particularly, the plot_heatmap function is used to visualize the relevance maps of the LRP analyzer from xai4mri.model.interpreter.

from xai4mri.model.interpreter import analyze_model
from xai4mri.visualizer import plot_heatmap

# Analyze model
analyzer_obj = analyze_model(model=model, ipt=mri_image, ...)

# Visualize heatmap / relevance map
analyzer_fig = plot_heatmap(ipt=mri_image, analyser_obj=analyzer_obj, ...)

Author: Simon M. Hofmann
Years: 2023-2024

gregoire_black_fire_red 🧠

gregoire_black_fire_red(analyser_obj: ndarray) -> ndarray

Apply a color scheme to the analyzer object.

Parameters:

Name Type Description Default
analyser_obj ndarray

XAI analyzer object (e.g., LRP relevance map).

required

Returns:

Type Description
ndarray

Colorized relevance map.

Source code in src/xai4mri/visualizer/visualize_heatmap.py
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
def gregoire_black_fire_red(analyser_obj: np.ndarray) -> np.ndarray:
    """
    Apply a color scheme to the analyzer object.

    :param analyser_obj: XAI analyzer object (e.g., `LRP` relevance map).
    :return: Colorized relevance map.
    """
    x = analyser_obj.copy()
    x /= np.max(np.abs(x))

    hrp = np.clip(x - 0.00, a_min=0, a_max=0.25) / 0.25  # all pos. values(+) above 0 get red, above .25 full red(=1.)
    hgp = np.clip(x - 0.25, a_min=0, a_max=0.25) / 0.25  # all above .25 get green, above .50 full green
    hbp = np.clip(x - 0.50, a_min=0, a_max=0.50) / 0.50  # all above .50 get blue until full blue at 1. (mix 2 white)

    hbn = np.clip(-x - 0.00, a_min=0, a_max=0.25) / 0.25  # all neg. values(-) below 0 get blue ...
    hgn = np.clip(-x - 0.25, a_min=0, a_max=0.25) / 0.25  # ... green ....
    hrn = np.clip(-x - 0.50, a_min=0, a_max=0.50) / 0.50  # ... red ... mixes to white (1.,1.,1.)

    return np.concatenate(
        [(hrp + hrn)[..., None], (hgp + hgn)[..., None], (hbp + hbn)[..., None]],
        axis=x.ndim,
    )

list_supported_cmaps 🧠

list_supported_cmaps()

Return a list of supported color maps for heatmap plotting.

Source code in src/xai4mri/visualizer/visualize_heatmap.py
155
156
157
158
def list_supported_cmaps():
    """Return a list of supported color maps for heatmap plotting."""
    print(*(list(custom_maps.keys()) + CMAPS), sep="\n")
    return list(custom_maps.keys()) + CMAPS

plot_heatmap 🧠

plot_heatmap(
    ipt: ndarray,
    analyser_obj: ndarray,
    cmap_name: str = "black-fire-red",
    mode: str = "triplet",
    fig_name: str = "Heatmap",
    **kwargs
) -> Figure

Plot an XAI-based analyzer object over the model input in the form of a heatmap.

How to use

from xai4mri.model.interpreter import analyze_model
from xai4mri.visualizer import plot_heatmap

# Analyze model
analyzer_obj = analyze_model(model=model, ipt=mri_image, ...)

# Visualize heatmap / relevance map
analyzer_fig = plot_heatmap(ipt=mri_image, analyser_obj=analyzer_obj, ...)

Parameters:

Name Type Description Default
ipt ndarray

Model input image.

required
analyser_obj ndarray

Analyzer object (relevance map) that is computed by the model interpreter (e.g., LRP). Both the input image and the analyzer object must have the same shape.

required
cmap_name str

Name of color-map (cmap) to be applied.

'black-fire-red'
mode str

"triplet": Plot three slices of different axes. "all": Plot all slices (w/ brain OR w/o brain → set: plot_empty=True in kwargs)

'triplet'
fig_name str

name of figure

'Heatmap'
kwargs

Additional kwargs: "c_intensifier", "clip_q", "min_sym_clip", "true_scale", "plot_empty", "axis", "every", "crosshair", "gamma". And, kwargs for plot_mid_slice() and slice_through() from xai4mri.visualizer.visualize_mri.

{}

Returns:

Type Description
Figure

plt.Figure object of the heatmap plot.

Source code in src/xai4mri/visualizer/visualize_heatmap.py
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
def plot_heatmap(
    ipt: np.ndarray,
    analyser_obj: np.ndarray,
    cmap_name: str = "black-fire-red",
    mode: str = "triplet",
    fig_name: str = "Heatmap",
    **kwargs,
) -> plt.Figure:
    """
    Plot an XAI-based analyzer object over the model input in the form of a heatmap.

    !!! example "How to use"
        ```python
        from xai4mri.model.interpreter import analyze_model
        from xai4mri.visualizer import plot_heatmap

        # Analyze model
        analyzer_obj = analyze_model(model=model, ipt=mri_image, ...)

        # Visualize heatmap / relevance map
        analyzer_fig = plot_heatmap(ipt=mri_image, analyser_obj=analyzer_obj, ...)
        ```

    :param ipt: Model input image.
    :param analyser_obj: Analyzer object (relevance map) that is computed by the model interpreter (e.g., `LRP`).
                         Both the input image and the analyzer object must have the same shape.
    :param cmap_name: Name of color-map (`cmap`) to be applied.
    :param mode: "triplet": Plot three slices of different axes.
                 "all": Plot all slices (w/ brain OR w/o brain → set: `plot_empty=True` in `kwargs`)
    :param fig_name: name of figure
    :param kwargs: Additional kwargs:
                   "c_intensifier", "clip_q", "min_sym_clip", "true_scale", "plot_empty", "axis", "every", "crosshair",
                    "gamma".
                    And, `kwargs` for `plot_mid_slice()` and `slice_through()` from `xai4mri.visualizer.visualize_mri`.
    :return: `plt.Figure` object of the heatmap plot.
    """
    a = analyser_obj.copy().squeeze().astype(np.float32)
    mode = mode.lower()
    if mode not in {"triplet", "all"}:
        msg = "mode must be 'triplet', or 'all'!"
        raise ValueError(msg)

    # Extract kwargs
    cintensifier = kwargs.pop("c_intensifier", 1.0)
    clipq = kwargs.pop("clip_q", 1e-2)
    min_sym_clip = kwargs.pop("min_sym_clip", True)
    true_scale = kwargs.pop("true_scale", False)
    plot_empty = kwargs.pop("plot_empty", False)
    axis = kwargs.pop("axis", 0)
    every = kwargs.pop("every", 2)
    crosshairs = kwargs.pop("crosshair", False)
    gamma = kwargs.pop("gamma", 0.2)

    # Render image
    colored_a = _apply_colormap(
        analyser_obj=a,
        input_image=ipt.squeeze().astype(np.float32),
        cmap_name=cmap_name,
        c_intensifier=cintensifier,
        clip_q=clipq,
        min_sym_clip=min_sym_clip,
        gamma=gamma,
        true_scale=true_scale,
    )
    heatmap = colored_a[0]

    cbar_range = (-1, 1) if not true_scale else (-colored_a[2], colored_a[2])
    if mode == "triplet":
        fig = plot_mid_slice(
            volume=heatmap,
            fig_name=fig_name,
            cmap=_create_cmap(gregoire_black_fire_red),
            c_range="full",
            cbar=True,
            cbar_range=cbar_range,
            edges=False,
            crosshairs=crosshairs,
            **kwargs,
        )

    else:  # mode == "all"
        if not plot_empty:
            # Remove planes with no information
            heatmap = heatmap.compress(
                ~np.all(
                    heatmap == 0,
                    axis=tuple(ax for ax in range(heatmap.ndim) if ax != axis),
                ),
                axis=axis,
            )

        fig = slice_through(
            volume=heatmap,
            every=every,
            axis=axis,
            fig_name=fig_name,
            edges=False,
            cmap=_create_cmap(gregoire_black_fire_red),
            c_range="full",
            crosshairs=crosshairs,
            **kwargs,
        )
    return fig