Skip to content

prune_image 🧠

Pruning MRIs has the goal to remove background around brains / heads, i.e., reduce the size of a 3D image.

One key objective is to find the smallest box in the whole dataset, which can surround each brain / head in it. That is, the space of the 'biggest' brain.

Author: Simon M. Hofmann
Years: 2023-2024

find_brain_edges 🧠

find_brain_edges(
    x3d: ndarray, sl: bool = False
) -> tuple[slice, slice, slice] | tuple[int, ...]

Find the on- & the offset of brain (or head) voxels for each plane.

This will find the tangential edges of the brain in the given 3D volume.

                    /      3D-volume    /
                   +-------------------+
                   |   +_____edge____+ |
                   |   |    *****    | |
                   |   |  **     **  | |
                   |   | *  ** **  * | |
                   |   |*    ***    *| |
                   | Y |*    ***    *| |
                   |   |*    ***    *| |  Z
                   |   |*   ** **   *| | /
                   |   | * **   ** * | |/  /
                   |   |  **     **  | |  /
                   |   |    *****    |/| /
                   |   +––––– X –––––+ |/
                   +-------------------+

Parameters:

Name Type Description Default
x3d ndarray

3D data.

required
sl bool

Whether to return slice's. Instead, provide coordinates (if set to False, default).

False

Returns:

Type Description
tuple[slice, slice, slice] | tuple[int, ...]

Tuple with six values of slices or coordinates, two values (lower, upper) per dimension / axis.

Source code in src/xai4mri/dataloader/prune_image.py
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
def find_brain_edges(x3d: np.ndarray, sl: bool = False) -> tuple[slice, slice, slice] | tuple[int, ...]:
    """
    Find the on- & the offset of brain (or head) voxels for each plane.

    This will find the tangential edges of the brain in the given 3D volume.

    ```text
                        /      3D-volume    /
                       +-------------------+
                       |   +_____edge____+ |
                       |   |    *****    | |
                       |   |  **     **  | |
                       |   | *  ** **  * | |
                       |   |*    ***    *| |
                       | Y |*    ***    *| |
                       |   |*    ***    *| |  Z
                       |   |*   ** **   *| | /
                       |   | * **   ** * | |/  /
                       |   |  **     **  | |  /
                       |   |    *****    |/| /
                       |   +––––– X –––––+ |/
                       +-------------------+
    ```

    :param x3d: 3D data.
    :param sl: Whether to return `slice`'s.
               Instead, provide coordinates (if set to `False`, default).
    :return: Tuple with six values of slices or coordinates, two values (lower, upper) per dimension / axis.
    """
    # Slice through image until first appearing brain-voxels are detected (i.e., no background)
    # Find 'lower' planes (i.e., low, left, back, respectively)
    il, jl, kl = 0, 0, 0  # initialize
    while np.all(x3d[il, :, :] == BG_VALUE):  # sagittal slide
        il += 1
    while np.all(x3d[:, jl, :] == BG_VALUE):  # transverse slide
        jl += 1
    while np.all(x3d[:, :, kl] == BG_VALUE):  # coronal/posterior/frontal
        kl += 1

    # Now, find 'upper' planes (i.e., upper, right, front, respectively)
    iu, ju, ku = np.array(x3d.shape) - 1
    while np.all(x3d[iu, :, :] == BG_VALUE):  # sagittal/longitudinal
        iu -= 1
    while np.all(x3d[:, ju, :] == BG_VALUE):  # transverse/inferior/horizontal
        ju -= 1
    while np.all(x3d[:, :, ku] == BG_VALUE):  # coronal/posterior/frontal
        ku -= 1

    if sl:  # return slices
        return slice(il, iu + 1), slice(jl, ju + 1), slice(kl, ku + 1)
    # else return coordinates
    return il, iu, jl, ju, kl, ku

get_brain_axes_length 🧠

get_brain_axes_length(x3d: ndarray) -> Sequence[int]

Get the length of each brain axis (x,y,z) in voxels.

This will find the tangential edges of the brain in the given 3D volume and measure their lengths.

                    /      3D-volume    /
                   +-------------------+
                   |   +_____edge____+ |
                   |   |    *****    | |
                   |   |  **     **  | |
                   |   | *  ** **  * | |
                   |   |*    ***    *| |
                   | Y |*    ***    *| |
                   |   |*    ***    *| |  Z
                   |   |*   ** **   *| | /
                   |   | * **   ** * | |/  /
                   |   |  **     **  | |  /
                   |   |    *****    |/| /
                   |   +––––– X –––––+ |/
                   +-------------------+

Parameters:

Name Type Description Default
x3d ndarray

3D volume holding a brain / mask.

required

Returns:

Type Description
Sequence[int]

The brain axes lengths.

Source code in src/xai4mri/dataloader/prune_image.py
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
def get_brain_axes_length(x3d: np.ndarray) -> Sequence[int]:
    """
    Get the length of each brain axis (x,y,z) in voxels.

    This will find the tangential edges of the brain in the given 3D volume and measure their lengths.

    ```text
                        /      3D-volume    /
                       +-------------------+
                       |   +_____edge____+ |
                       |   |    *****    | |
                       |   |  **     **  | |
                       |   | *  ** **  * | |
                       |   |*    ***    *| |
                       | Y |*    ***    *| |
                       |   |*    ***    *| |  Z
                       |   |*   ** **   *| | /
                       |   | * **   ** * | |/  /
                       |   |  **     **  | |  /
                       |   |    *****    |/| /
                       |   +––––– X –––––+ |/
                       +-------------------+
    ```

    :param x3d: 3D volume holding a brain / mask.
    :return: The brain axes lengths.
    """
    il, iu, jl, ju, kl, ku = find_brain_edges(x3d)
    return [iu + 1 - il, ju + 1 - jl, ku + 1 - kl]

get_global_max_axes 🧠

get_global_max_axes(
    nifti_img: Nifti1Image, per_axis: bool
) -> int | Sequence[int]

Get the global max axis-length(s) for the given brain.

The global lengths are the maximum axis-length for all brain axes. It is globally defined for all brains in the dataset. The value can be set in the PruneConfig class (PruneConfig.largest_brain_max_axes). These values are used for pruning brain images.

Parameters:

Name Type Description Default
nifti_img Nifti1Image

NIfTI image.

required
per_axis bool

True: return max axis-length for each axis (for prune_mode='max'); False: return max axis-length for all axes (for prune_mode='cube').

required

Returns:

Type Description
int | Sequence[int]

Global max axis-length(s).

Source code in src/xai4mri/dataloader/prune_image.py
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
def get_global_max_axes(nifti_img: nib.Nifti1Image, per_axis: bool) -> int | Sequence[int]:
    """
    Get the global max axis-length(s) for the given brain.

    The global lengths are the maximum axis-length for all brain axes.
    It is globally defined for all brains in the dataset.
    The value can be set in the `PruneConfig` class (`PruneConfig.largest_brain_max_axes`).
    These values are used for pruning brain images.

    :param nifti_img: NIfTI image.
    :param per_axis: True: return max axis-length for each axis (for `prune_mode='max'`);
                     False: return max axis-length for all axes (for `prune_mode='cube'`).
    :return: Global max axis-length(s).
    """
    if nib.orientations.aff2axcodes(nifti_img.affine) != tuple(GLOBAL_ORIENTATION_SPACE):
        msg = (
            f"The orientation of the given NIfTI must match the GLOBAL_ORIENTATION_SPACE: "
            f"'{GLOBAL_ORIENTATION_SPACE}'!"
        )
        raise ValueError(msg)

    # PruneConfig.largest_brain_max_axes is defined for 1 mm isotropic resolution
    resolution = np.round(nifti_img.header["pixdim"][1:4], decimals=3)  # image resolution per axis
    # Adapt the global max axes to the resolution of the given image
    global_max_axis = np.round(PruneConfig.largest_brain_max_axes // resolution).astype(int)
    if not per_axis:
        global_max_axis = int(global_max_axis.max())
    return global_max_axis

permute_array 🧠

permute_array(xd: ndarray) -> ndarray

Swap all entries (e.g., voxels) in the given x-dimensional array (e.g., 3D-MRI).

Parameters:

Name Type Description Default
xd ndarray

x-dimensional array

required

Returns:

Type Description
ndarray

permuted array

Source code in src/xai4mri/dataloader/prune_image.py
357
358
359
360
361
362
363
364
365
366
def permute_array(xd: np.ndarray) -> np.ndarray:
    """
    Swap all entries (e.g., voxels) in the given x-dimensional array (e.g., 3D-MRI).

    :param xd: x-dimensional array
    :return: permuted array
    """
    flat_xd = xd.flatten()
    np.random.shuffle(flat_xd)  # noqa: NPY002
    return flat_xd.reshape(xd.shape)

permute_nifti 🧠

permute_nifti(nifti_img: Nifti1Image) -> Nifti1Image

Swap all entries (e.g., voxels) in the given NIfTI image.

Parameters:

Name Type Description Default
nifti_img Nifti1Image

NIfTI image

required

Returns:

Type Description
Nifti1Image

permuted NIfTI image (i.e., a noise image)

Source code in src/xai4mri/dataloader/prune_image.py
369
370
371
372
373
374
375
376
377
378
379
380
def permute_nifti(nifti_img: nib.Nifti1Image) -> nib.Nifti1Image:
    """
    Swap all entries (e.g., voxels) in the given NIfTI image.

    :param nifti_img: NIfTI image
    :return: permuted NIfTI image (i.e., a noise image)
    """
    xd = nifti_img.get_fdata()
    flat_xd = xd.flatten()
    np.random.shuffle(flat_xd)  # noqa: NPY002
    xd = flat_xd.reshape(xd.shape)
    return nib.Nifti1Image(dataobj=xd, affine=nifti_img.affine, header=nifti_img.header)

prune_mri 🧠

prune_mri(
    x3d: ndarray,
    make_cube: bool = False,
    max_axis: (
        int | Sequence[int] | ndarray[int] | None
    ) = None,
    padding: int = 0,
) -> ndarray | None

Prune given 3D MRI to (smaller) volume with side-length(s) == max_axis [int OR 3D tuple].

If max_axis is None, find the smallest volume, which covers the brain (i.e., remove zero-padding). Works very fast. [Implementation with np.pad is possible, too].

Compare to: nilearn.image.crop_img() for NIfTI's: * This crops exactly along the brain only * which is the same as: mri[find_brain_edges(mri, sl=True)] * but it is slower

Parameters:

Name Type Description Default
x3d ndarray

3D MRI

required
max_axis int | Sequence[int] | ndarray[int] | None

Either side-length [int] of a pruned cube; Or pruned side-length for each axis [3D-sequence: [int, int, int]].

None
make_cube bool

True: pruned MRI will be a cube; False: each axis will be pruned to max_axis

False
padding int

Number of zero-padding layers that should remain around the brain [int >= 0]

0

Returns:

Type Description
ndarray | None

pruned brain image or None if x3d is not a numpy array

Source code in src/xai4mri/dataloader/prune_image.py
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
def prune_mri(
    x3d: np.ndarray,
    make_cube: bool = False,
    max_axis: int | Sequence[int] | np.ndarray[int] | None = None,
    padding: int = 0,
) -> np.ndarray | None:
    """
    Prune given 3D MRI to (smaller) volume with side-length(s) == `max_axis` [`int` OR 3D tuple].

    If `max_axis` is `None`, find the smallest volume, which covers the brain (i.e., remove zero-padding).
    Works very fast.
    [Implementation with `np.pad` is possible, too].

    Compare to: `nilearn.image.crop_img()` for NIfTI's:
        * This crops exactly along the brain only
        * which is the same as: `mri[find_brain_edges(mri, sl=True)]`
        * but it is slower

    :param x3d: 3D MRI
    :param max_axis: Either side-length [int] of a pruned cube; Or
                     pruned side-length for each axis [3D-sequence: [int, int, int]].
    :param make_cube: True: pruned MRI will be a cube; False: each axis will be pruned to `max_axis`
    :param padding: Number of zero-padding layers that should remain around the brain [int >= 0]
    :return: pruned brain image or `None` if `x3d` is not a numpy array
    """
    # Check argument:
    if max_axis is not None:
        if make_cube:
            if not isinstance(max_axis, (int, np.int_)):
                msg = "If the target volume suppose to be a cube, 'max_axis' must of type int!"
                raise TypeError(msg)
        else:
            msg = "If the target volume suppose to be no cube, 'max_axis' must be a 3D-shaped tuple of integers!"
            if not isinstance(max_axis, (Sequence, np.ndarray)) or not all(
                isinstance(e, (int, np.int_)) for e in max_axis
            ):
                raise TypeError(msg)
            n_dims = 3
            if len(max_axis) != n_dims:
                raise ValueError(msg)

    if not isinstance(padding, (int, np.int_)) or padding < 0:
        msg = "'padding' must be an integer >= 0!"
        raise ValueError(msg)

    if isinstance(x3d, np.ndarray):
        # Cut out
        x3d_minimal = x3d[find_brain_edges(x3d, sl=True)]

        # Prune to smaller volume
        if max_axis is None:
            # find the longest axis for cubing [int] OR take the shape of the minimal volume [3D-tuple]
            max_axis = np.max(x3d_minimal.shape) if make_cube else np.array(x3d_minimal.shape)

        # Add padding at the borders (if requested) & make max_axis a 3D shape-tuple/list
        max_axis = [max_axis + padding] * 3 if make_cube else np.array(max_axis) + padding

        # Initialize an empty 3D target volume
        x3d_small_vol = np.zeros(max_axis, dtype=x3d.dtype)
        if x3d.min() != 0.0:
            x3d_small_vol[x3d_small_vol == 0] = x3d.min()  # in case background is e.g. -1

        x3d_small_vol, _ = _place_small_in_middle_of_big(big=x3d_small_vol, small=x3d_minimal)  # _ = cut

    else:
        cprint(string="'x3d' is not a numpy array!", col="r")
        x3d_small_vol = None

    return x3d_small_vol

reverse_pruning 🧠

reverse_pruning(
    original_mri: ndarray | Nifti1Image,
    pruned_mri: ndarray,
    pruned_stats_map: ndarray | None = None,
) -> ndarray | Nifti1Image

Reverse the pruning of an MRI or its corresponding statistical map.

If a statistical map is given, both the original MRI and the pruned MRI are necessary to find the edges of the cut-off during pruning. If no statistical map is given, only the original MRI and the pruned MRI are required. Note, in this case reverse_pruning() is applied to a processed and pruned version of the original MRI.

Make sure that the original MRI and the pruned MRI have the same orientation.

Parameters:

Name Type Description Default
original_mri ndarray | Nifti1Image

Original (i.e., non-pruned) MRI.

required
pruned_mri ndarray

Pruned MRI.

required
pruned_stats_map ndarray | None

[Optional] pruned statistical map.

None

Returns:

Type Description
ndarray | Nifti1Image

MRI with original size (if original_mri is given as Nifti1Image, returns Nifti1Image).

Source code in src/xai4mri/dataloader/prune_image.py
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
def reverse_pruning(
    original_mri: np.ndarray | nib.Nifti1Image,
    pruned_mri: np.ndarray,
    pruned_stats_map: np.ndarray | None = None,
) -> np.ndarray | nib.Nifti1Image:
    """
    Reverse the pruning of an MRI or its corresponding statistical map.

    If a statistical map is given, both the original MRI and the pruned MRI are necessary to find the edges of
    the cut-off during pruning.
    If no statistical map is given, only the original MRI and the pruned MRI are required.
    Note, in this case `reverse_pruning()` is applied to a processed and pruned version of the original MRI.

    Make sure that the original MRI and the pruned MRI have the same orientation.

    :param original_mri: Original (i.e., non-pruned) MRI.
    :param pruned_mri: Pruned MRI.
    :param pruned_stats_map: [Optional] pruned statistical map.
    :return: MRI with original size (if original_mri is given as Nifti1Image, returns Nifti1Image).
    """
    # Check whether original_mri is Nifti1Image
    is_nifti = isinstance(original_mri, nib.Nifti1Image)

    # Define which volume to use for reverse pruning
    volume_to_reverse_pruning = pruned_mri if pruned_stats_map is None else pruned_stats_map

    # Initialize the MRI to fill
    volume_to_fill = np.zeros(shape=original_mri.shape)
    volume_to_fill[...] = BG_VALUE  # set background

    # Find the edges of the brain (slice format)
    original_mri_edge_slices = find_brain_edges(x3d=original_mri.get_fdata() if is_nifti else original_mri, sl=True)
    pruned_mri_edge_slices = find_brain_edges(x3d=pruned_mri, sl=True)

    # Use the edges to place the brain data at the right spot
    volume_to_fill[original_mri_edge_slices] = volume_to_reverse_pruning[pruned_mri_edge_slices]

    if is_nifti:
        return nib.Nifti1Image(
            dataobj=volume_to_fill,
            affine=original_mri.affine,
            header=original_mri.header if pruned_stats_map is None else None,
            extra=original_mri.extra if pruned_stats_map is None else None,
            dtype=original_mri.get_data_dtype() if pruned_stats_map is None else None,
        )

    return volume_to_fill