Skip to content

mrinets 🧠

Train & load deep learning models for MRI-based predictions.

This module provides a set of out-of-the-box models for MRI-based predictions. The models are designed for 3D MRI data and can be used for regression or classification tasks. Moreover, there are pretrained models available for some model architectures.

Models are built on top of the TensorFlow Keras API. This is necessary to ensure compatibility with iNNvestigate for model interpretability.

Author: Simon M. Hofmann
Years: 2022-2024

MRInet 🧠

Bases: _ModelCreator

MRInet model creator.

MRInet is the basemodel architecture (3D-ConvNet) used within the multi-level ensembles in Hofmann et al. (2022, NeuroImage). The models were trained to predict age from MR images of different sequences (T1, FLAIR, SWI) in the LIFE Adult study.

This model creator class provides a hard-coded Keras implementation of the 3D-CNN model. Creating a model with this class is as simple as calling the create method. This will return a fresh (i.e., untrained) and compiled Keras model ready to be trained:

Create a new instance of MRInet

# Create a new instance of MRInet
mrinet = MRInet.create(name="MyMRInet", n_classes=False, input_shape=(91, 109, 91))

# Train on MRI dataset
mrinet.fit(X_train, y_train, ...)

Pretrained models are available for T1, FLAIR, and SWI images.

Get an overview of pretrained models by calling:

MRInet.pretrained_models().show()

create staticmethod 🧠

create(
    name: str,
    n_classes: bool | None | int,
    input_shape: tuple[int, int, int],
    learning_rate: float = 0.0005,
    target_bias: float | None = None,
    leaky_relu: bool = False,
    batch_norm: bool = False,
) -> Sequential

Create a new instance of MRInet, a 3D-convolutional neural network (CNN) for predictions on MRIs.

This is a hard-coded Keras implementation of 3D-CNN model as reported in Hofmann et al. (2022, NeuroImage).

The model can be trained for regression (n_classes=False) or classification (n_classes: int >= 2) tasks.

Parameters:

Name Type Description Default
name str

Model name, which, for instance, could refer to the project it is applied for.

required
target_bias float | None

Model output bias. For classification tasks with this can be left blank [None]. For regression tasks, it is recommended to set this bias to the average of the prediction target distribution in the dataset.

None
input_shape tuple[int, int, int]

Shape of the input to the model. This should be the shape of a single MRI (e.g., (91, 91, 109).

required
learning_rate float

Learning rate which is used for the model's optimizer (here, Adam).

0.0005
batch_norm bool

Use batch normalization or not. Batch normalization should only be used if the model is fed with larger batches. For model interpretability provided by xai4mri, it is recommended to not use batch normalization.

False
leaky_relu bool

Using leaky or vanilla ReLU activation functions. Leaky ReLU is recommended for better performance. However, iNNvestigate, which is used for model interpretability, does not support leaky ReLU currently.

False
n_classes bool | None | int

Number of classes. For regression tasks set to False or 0; For classification tasks, provide integer >= 2.

required

Returns:

Type Description
Sequential

Compiled MRInet model (based on Keras), ready to be trained.

Source code in src/xai4mri/model/mrinets.py
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
@staticmethod
def create(
    name: str,
    n_classes: bool | None | int,
    input_shape: tuple[int, int, int],
    learning_rate: float = 5e-4,
    target_bias: float | None = None,
    leaky_relu: bool = False,
    batch_norm: bool = False,
) -> keras.Sequential:
    """
    Create a new instance of `MRInet`, a 3D-convolutional neural network (CNN) for predictions on MRIs.

    This is a hard-coded Keras implementation of 3D-CNN model as reported in
    [Hofmann et al. (2022, *NeuroImage*)](https://doi.org/10.1016/j.neuroimage.2022.119504).

    The model can be trained for regression (`n_classes=False`) or classification (`n_classes: int >= 2`) tasks.

    :param name: Model name, which, for instance, could refer to the project it is applied for.
    :param target_bias: Model output bias.
                        For classification tasks with this can be left blank [`None`].
                        For regression tasks, it is recommended to set this bias to the average of the
                        prediction target distribution in the dataset.
    :param input_shape: Shape of the input to the model.
                        This should be the shape of a single MRI (e.g., (91, 91, 109).
    :param learning_rate: Learning rate which is used for the model's optimizer (here, `Adam`).
    :param batch_norm: Use batch normalization or not.
                       Batch normalization should only be used if the model is fed with larger batches.
                       For model interpretability provided by `xai4mri`,
                       it is recommended to not use batch normalization.
    :param leaky_relu: Using leaky or vanilla ReLU activation functions.
                       Leaky ReLU is recommended for better performance.
                       However, `iNNvestigate`, which is used for model interpretability,
                       does not support leaky ReLU currently.
    :param n_classes: Number of classes.
                      For regression tasks set to `False` or `0`;
                      For classification tasks, provide integer >= 2.
    :return: Compiled `MRInet` model (based on `Keras`), ready to be trained.
    """
    if target_bias is not None:
        cprint(string=f"\nGiven target bias is {target_bias:.3f}\n", col="y")

    actfct = None if leaky_relu and not batch_norm else "relu"

    _check_n_classes(n_classes=n_classes)

    k_model = keras.Sequential(name=name)  # OR: Sequential([keras.layer.Conv3d(....), layer...])

    # 3D-Conv
    if batch_norm:
        k_model.add(keras.layers.BatchNormalization(input_shape=(*input_shape, 1)))
        k_model.add(keras.layers.Conv3D(filters=16, kernel_size=(3, 3, 3), padding="SAME", activation=actfct))
    else:
        k_model.add(
            keras.layers.Conv3D(
                filters=16,
                kernel_size=(3, 3, 3),
                padding="SAME",
                activation=actfct,
                input_shape=(*input_shape, 1),
            )
        )
        # auto-add batch:None, last: channels
    if leaky_relu:
        k_model.add(keras.layers.LeakyReLU(alpha=0.2))  # lrelu
    k_model.add(keras.layers.MaxPool3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding="SAME"))

    if batch_norm:
        k_model.add(keras.layers.BatchNormalization())
    k_model.add(keras.layers.Conv3D(filters=16, kernel_size=(3, 3, 3), padding="SAME", activation=actfct))
    if leaky_relu:
        k_model.add(keras.layers.LeakyReLU(alpha=0.2))
    k_model.add(keras.layers.MaxPool3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding="SAME"))

    if batch_norm:
        k_model.add(keras.layers.BatchNormalization())
    k_model.add(keras.layers.Conv3D(filters=32, kernel_size=(3, 3, 3), padding="SAME", activation=actfct))
    if leaky_relu:
        k_model.add(keras.layers.LeakyReLU(alpha=0.2))
    k_model.add(keras.layers.MaxPool3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding="SAME"))

    if batch_norm:
        k_model.add(keras.layers.BatchNormalization())
    k_model.add(keras.layers.Conv3D(filters=64, kernel_size=(3, 3, 3), padding="SAME", activation=actfct))
    if leaky_relu:
        k_model.add(keras.layers.LeakyReLU(alpha=0.2))
    k_model.add(keras.layers.MaxPool3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding="SAME"))

    # 3D-Conv (1x1x1)
    if batch_norm:
        k_model.add(keras.layers.BatchNormalization())
    k_model.add(keras.layers.Conv3D(filters=32, kernel_size=(1, 1, 1), padding="SAME", activation=actfct))
    if leaky_relu:
        k_model.add(keras.layers.LeakyReLU(alpha=0.2))

    k_model.add(keras.layers.MaxPool3D(pool_size=(3, 3, 3), strides=(2, 2, 2), padding="SAME"))

    if batch_norm:
        k_model.add(keras.layers.BatchNormalization())

    # FC
    k_model.add(keras.layers.Flatten())
    k_model.add(keras.layers.Dropout(rate=0.5))
    k_model.add(keras.layers.Dense(units=64, activation=actfct))
    if leaky_relu:
        k_model.add(keras.layers.LeakyReLU(alpha=0.2))

    # Output
    if n_classes:
        k_model.add(
            keras.layers.Dense(
                units=n_classes,
                activation="softmax",  # in binary case. also: 'sigmoid'
                use_bias=False,
            )
        )  # default: True

    else:
        k_model.add(
            keras.layers.Dense(
                units=1,
                activation="linear",
                # add target bias (recommended: mean of target distribution)
                use_bias=True,
                bias_initializer=keras.initializers.Constant(value=target_bias) if target_bias else "zeros",
            )
        )

    # Compile
    k_model.compile(
        optimizer=keras.optimizers.Adam(learning_rate),  # ="adam",
        loss="mse",
        metrics=["accuracy"] if n_classes else ["mae"],
    )

    # Summary
    k_model.summary()

    return k_model

pretrained_models staticmethod 🧠

pretrained_models() -> type[PretrainedModelsMRInet]

Return enum of pretrained MRInet models.

Call the show method to get an overview of available pretrained models:

MRInet.pretrained_models().show()
Source code in src/xai4mri/model/mrinets.py
277
278
279
280
281
282
283
284
285
286
287
@staticmethod
def pretrained_models() -> type[PretrainedModelsMRInet]:
    """
    Return enum of pretrained `MRInet` models.

    !!! tip "Call the `show` method to get an overview of available pretrained models:"
        ```python
        MRInet.pretrained_models().show()
        ```
    """
    return PretrainedModelsMRInet

reference staticmethod 🧠

reference() -> str

Get the reference for the MRInet model.

Source code in src/xai4mri/model/mrinets.py
272
273
274
275
@staticmethod
def reference() -> str:
    """Get the reference for the `MRInet` model."""
    return "Hofmann et al. (2022). NeuroImage. https://doi.org/10.1016/j.neuroimage.2022.119504"

OutOfTheBoxModels 🧠

Bases: Enum

Enum class for out-of-the-box models for MRI-based predictions.

Call the show method to get an overview of available models:

OutOfTheBoxModels.show()

Get a model by selecting it

mrinet_creator = OutOfTheBoxModels.MRINET.value
mrinet = mrinet_creator.create(...)

# Alternatively, get the model by string
sfcn = OutOfTheBoxModels("sfcn").value.create(...)

default classmethod 🧠

default()

Get the default model, which is MRInet.

Source code in src/xai4mri/model/mrinets.py
509
510
511
512
@classmethod
def default(cls):
    """Get the default model, which is `MRInet`."""
    return cls.MRINET

pretrained_models 🧠

pretrained_models()

Get the pretrained models for the model.

Get information about pretrained MRInet models

OutOfTheBoxModels.MRINET.pretrained_models().show()

Ultimately, this is just a small wrapper avoiding calling the value attribute.

Source code in src/xai4mri/model/mrinets.py
497
498
499
500
501
502
503
504
505
506
507
def pretrained_models(self):
    """
    Get the pretrained models for the model.

    !!! example "Get information about pretrained `MRInet` models"
        ```python
        OutOfTheBoxModels.MRINET.pretrained_models().show()
        ```
    Ultimately, this is just a small wrapper avoiding calling the `value` attribute.
    """
    return self.value.pretrained_models()

reference 🧠

reference()

Get the reference for the model.

Get the reference for MRInet

OutOfTheBoxModels.MRINET.reference()

Ultimately, this is just a small wrapper avoiding calling the value attribute.

Source code in src/xai4mri/model/mrinets.py
485
486
487
488
489
490
491
492
493
494
495
def reference(self):
    """
    Get the reference for the model.

    !!! example "Get the reference for `MRInet`"
        ```python
        OutOfTheBoxModels.MRINET.reference()
        ```
    Ultimately, this is just a small wrapper avoiding calling the `value` attribute.
    """
    return self.value.reference()

show classmethod 🧠

show()

Show available models.

Source code in src/xai4mri/model/mrinets.py
514
515
516
517
@classmethod
def show(cls):
    """Show available models."""
    return print(*[(model, model.reference()) for model in cls], sep="\n")

PretrainedMRInetFLAIR dataclass 🧠

PretrainedMRInetFLAIR(
    name: str = "flair_model",
    url: str = "https://keeper.mpdl.mpg.de/f/8481f5906f3d4192ab12/?dl=1",
)

Bases: _PretrainedMRInet

Pretrained MRInet model for FLAIR images.

This is a basemodel in the FLAIR sub-ensemble as reported in Hofmann et al. (2022, NeuroImage).

The training data stem from the LIFE Adult study, and were normalized and registered to the T1w-FreeSurfer file (brain.finalsurfs.mgz), and subsequently pruned to (198, 198, 198) voxels (for more details, see Hofmann et al. (2022, NeuroImage).

load_model classmethod 🧠

load_model(
    parent_folder: str | Path,
    custom_objects: T | None = None,
    compile_model: bool = False,
) -> Sequential

Load a pretrained MRInet model.

If the model is not present on the local machine, it will be downloaded from the server and saved in the provided parent folder.

Parameters:

Name Type Description Default
parent_folder str | Path

Path to parent folder of the model.

required
custom_objects T | None

Custom objects to load (e.g., loss functions).

None
compile_model bool

Compile the model or not.

False

Returns:

Type Description
Sequential

Pretrained MRInet model.

Source code in src/xai4mri/model/mrinets.py
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
@classmethod
def load_model(
    cls, parent_folder: str | Path, custom_objects: T | None = None, compile_model: bool = False
) -> keras.Sequential:
    """
    Load a pretrained `MRInet` model.

    If the model is not present on the local machine, it will be downloaded from the server
    and saved in the provided parent folder.

    :param parent_folder: Path to parent folder of the model.
    :param custom_objects: Custom objects to load (e.g., loss functions).
    :param compile_model: Compile the model or not.
    :return: Pretrained `MRInet` model.
    """
    return _load_pretrained_mrinet_model(
        pretrained_model_type=cls,
        parent_folder=parent_folder,
        custom_objects=custom_objects,
        compile_model=compile_model,
    )

PretrainedMRInetSWI dataclass 🧠

PretrainedMRInetSWI(
    name: str = "swi_model",
    url: str = "https://keeper.mpdl.mpg.de/f/f53d43b723274687b6e2/?dl=1",
)

Bases: _PretrainedMRInet

Pretrained MRInet model for SWI.

This is a basemodel in the SWI sub-ensemble as reported in Hofmann et al. (2022, NeuroImage).

The training data stem from the LIFE Adult study, and were normalized and registered to the T1w-FreeSurfer file (brain.finalsurfs.mgz), and subsequently pruned to (198, 198, 198) voxels (for more details, see Hofmann et al. (2022, NeuroImage).

load_model classmethod 🧠

load_model(
    parent_folder: str | Path,
    custom_objects: T | None = None,
    compile_model: bool = False,
) -> Sequential

Load a pretrained MRInet model.

If the model is not present on the local machine, it will be downloaded from the server and saved in the provided parent folder.

Parameters:

Name Type Description Default
parent_folder str | Path

Path to parent folder of the model.

required
custom_objects T | None

Custom objects to load (e.g., loss functions).

None
compile_model bool

Compile the model or not.

False

Returns:

Type Description
Sequential

Pretrained MRInet model.

Source code in src/xai4mri/model/mrinets.py
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
@classmethod
def load_model(
    cls, parent_folder: str | Path, custom_objects: T | None = None, compile_model: bool = False
) -> keras.Sequential:
    """
    Load a pretrained `MRInet` model.

    If the model is not present on the local machine, it will be downloaded from the server
    and saved in the provided parent folder.

    :param parent_folder: Path to parent folder of the model.
    :param custom_objects: Custom objects to load (e.g., loss functions).
    :param compile_model: Compile the model or not.
    :return: Pretrained `MRInet` model.
    """
    return _load_pretrained_mrinet_model(
        pretrained_model_type=cls,
        parent_folder=parent_folder,
        custom_objects=custom_objects,
        compile_model=compile_model,
    )

PretrainedMRInetT1 dataclass 🧠

PretrainedMRInetT1(
    name: str = "t1_model",
    url: str = "https://keeper.mpdl.mpg.de/f/3be6fed59b4948aca699/?dl=1",
)

Bases: _PretrainedMRInet

Pretrained MRInet model for T1-weighted images.

This is a basemodel in the T1 sub-ensemble as reported in Hofmann et al. (2022, NeuroImage).

The training data stem from the LIFE Adult study, and were preprocessed using FreeSurfer (brain.finalsurfs.mgz), and subsequently pruned to (198, 198, 198) voxels (for more details, see Hofmann et al. (2022, NeuroImage).

load_model classmethod 🧠

load_model(
    parent_folder: str | Path,
    custom_objects: T | None = None,
    compile_model: bool = False,
) -> Sequential

Load a pretrained MRInet model.

If the model is not present on the local machine, it will be downloaded from the server and saved in the provided parent folder.

Parameters:

Name Type Description Default
parent_folder str | Path

Path to parent folder of the model.

required
custom_objects T | None

Custom objects to load (e.g., loss functions).

None
compile_model bool

Compile the model or not.

False

Returns:

Type Description
Sequential

Pretrained MRInet model.

Source code in src/xai4mri/model/mrinets.py
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
@classmethod
def load_model(
    cls, parent_folder: str | Path, custom_objects: T | None = None, compile_model: bool = False
) -> keras.Sequential:
    """
    Load a pretrained `MRInet` model.

    If the model is not present on the local machine, it will be downloaded from the server
    and saved in the provided parent folder.

    :param parent_folder: Path to parent folder of the model.
    :param custom_objects: Custom objects to load (e.g., loss functions).
    :param compile_model: Compile the model or not.
    :return: Pretrained `MRInet` model.
    """
    return _load_pretrained_mrinet_model(
        pretrained_model_type=cls,
        parent_folder=parent_folder,
        custom_objects=custom_objects,
        compile_model=compile_model,
    )

PretrainedModelsMRInet 🧠

Bases: Enum

Enum class for pretrained MRInet models (T1, FLAIR, SWI).

Call the show method to get an overview of available models

PretrainedModelsMRInet.show()

Get a model by selecting it

t1_mrinet = PretrainedModelsMRInet.T1_MODEL.value.load_model(...)

# Alternatively, get the model by string
swi_mrinet = PretrainedModelsMRInet("swi_model").value.load_model(...)

show classmethod 🧠

show()

Show available pretrained MRInet models.

Source code in src/xai4mri/model/mrinets.py
729
730
731
732
@classmethod
def show(cls):
    """Show available pretrained `MRInet` models."""
    return print(*[model.value for model in cls], sep="\n")

SFCN 🧠

Bases: _ModelCreator

SFCN model creator.

SFCN is the fully convolutional neural network (SFCN) model by Peng et al. (2021, Medical Image Analysis) for age prediction from MRI data of the ukbiobank.

The architecture won the first place in the brain-age competition PAC 2019.

create staticmethod 🧠

create(
    name: str,
    n_classes: int,
    input_shape: tuple[int, int, int] = (160, 192, 160),
    learning_rate: float = 0.01,
    dropout: bool = True,
) -> Sequential

Create the fully convolutional neural network (SFCN) model.

The SFCN was introduced in Peng et al. (2021, Medical Image Analysis).

The original open-source implementation is done in PyTorch and can be found at: https://github.com/ha-ha-ha-han/UKBiobank_deep_pretrain

For training:

from Peng et al. (2021, p.4):

"The L2 weight decay coefficient was 0.001. The batch size was 8. The learning rate for the SGD optimiser was initialized as 0.01, then multiplied by 0.3 every 30 epochs unless otherwise specified. The total epoch number is 130 for the 12,949 training subjects. The epoch number is adjusted accordingly for the experiments with smaller training sets so that the training steps are roughly the same."

Use this version of the SFCN model with caution
  • The implementation is still experimental, is not tested properly yet, and might be updated in the future.
  • The authors used "Gaussian soft labels" (sigma=1, mean=*true age*) for their loss function. This is not implemented here yet, and might require additional adjustments of the xai4mri.dataloader module.

Parameters:

Name Type Description Default
name str

Model name, which, for instance, could refer to the project it is applied for.

required
input_shape tuple[int, int, int]

Shape of the input to the model. This should be the shape of a single MRI.

(160, 192, 160)
n_classes int

Number of classes. In Peng et al. (2021) there were 40 age classes, representing 40 age strata.

required
learning_rate float

Learning rate which is used for the optimizer of the model.

0.01
dropout bool

Use dropout or not.

True

Returns:

Type Description
Sequential

Compiled SFCN model (based on Keras), ready to be trained.

Source code in src/xai4mri/model/mrinets.py
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
@staticmethod
@_experimental
def create(
    name: str,
    n_classes: int,
    input_shape: tuple[int, int, int] = (160, 192, 160),
    learning_rate: float = 0.01,
    dropout: bool = True,
) -> keras.Sequential:
    """
    Create the fully convolutional neural network (`SFCN`) model.

    The `SFCN` was introduced in
    [Peng et al. (2021, *Medical Image Analysis*)](https://doi.org/10.1016/j.media.2020.101871).

    The original open-source implementation is done in `PyTorch` and can be found at:
    https://github.com/ha-ha-ha-han/UKBiobank_deep_pretrain

    For training:

    !!! quote "from [Peng et al. (2021, p.4)](https://doi.org/10.1016/j.media.2020.101871):"
        > "The L2 weight decay coefficient was 0.001.
        The batch size was 8.
        The learning rate for the SGD optimiser was initialized as 0.01,
        then multiplied by 0.3 every 30 epochs unless otherwise specified.
        The total epoch number is 130 for the 12,949 training subjects.
        The epoch number is adjusted accordingly for the experiments with
        smaller training sets so that the training steps are roughly the
        same."

    ???+ warning "Use this version of the `SFCN` model with caution"
        - The **implementation is still experimental**,
        is not tested properly yet, and might be updated in the future.
        - The authors used "Gaussian soft labels" (`sigma=1, mean=*true age*`) for their loss function.
        This is not implemented here yet,
        and might require additional adjustments of the `xai4mri.dataloader` module.

    :param name: Model name, which, for instance, could refer to the project it is applied for.
    :param input_shape: Shape of the input to the model.
                        This should be the shape of a single MRI.
    :param n_classes: Number of classes.
                      In [Peng et al. (2021)](https://doi.org/10.1016/j.media.2020.101871)
                      there were 40 age classes, representing 40 age strata.
    :param learning_rate: Learning rate which is used for the optimizer of the model.
    :param dropout: Use dropout or not.
    :return: Compiled `SFCN` model (based on `Keras`), ready to be trained.
    """
    conv_ctn = 0  # conv block counter
    _check_n_classes(n_classes=n_classes)

    def conv_block(
        _out_channel: int,
        max_pool: bool = True,
        kernel_size: int = 3,
        padding: str = "same",
        max_pool_stride: int = 2,
        in_shape: tuple[int, int, int] | None = None,
    ) -> keras.Sequential:
        """Define a convolutional block for SFCN."""
        c_block = keras.Sequential(name=f"conv3D_block_{conv_ctn}")

        conv_kwargs = {} if in_shape is None else {"input_shape": (*in_shape, 1)}
        c_block.add(keras.layers.Conv3D(_out_channel, kernel_size=kernel_size, padding=padding, **conv_kwargs))
        c_block.add(keras.layers.BatchNormalization())
        if max_pool:
            c_block.add(keras.layers.MaxPooling3D(pool_size=2, strides=max_pool_stride))
        c_block.add(keras.layers.ReLU())
        return c_block

    # Build the model
    k_model = keras.Sequential(name=name)

    # Feature extractor
    channel_number = (32, 64, 128, 256, 256, 64)
    n_layer = len(channel_number)

    for i in range(n_layer):
        out_channel = channel_number[i]
        if i < n_layer - 1:
            k_model.add(
                conv_block(
                    out_channel,
                    max_pool=True,
                    kernel_size=3,
                    padding="same",
                    in_shape=input_shape if i == 0 else None,
                )
            )
        else:
            k_model.add(conv_block(out_channel, max_pool=False, kernel_size=1, padding="valid"))
        conv_ctn += 1

    # Classifier [in: (bs, 5, 6, 5, 64)]
    avg_shape = [5, 6, 5]
    k_model.add(keras.layers.AveragePooling3D(pool_size=avg_shape))
    if dropout:
        k_model.add(keras.layers.Dropout(rate=0.5))

    k_model.add(keras.layers.Conv3D(filters=n_classes, kernel_size=1, padding="valid"))

    # Output
    k_model.add(keras.layers.Activation(activation="log_softmax", name="log_softmax"))

    # Compile
    # Define the learning rate schedule
    lr_schedule = keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate=learning_rate,
        decay_steps=30,  # for the total number of 130 epochs in Peng et al. (2021)
        decay_rate=0.3,
        staircase=True,
    )
    k_model.compile(
        optimizer=keras.optimizers.SGD(  # SGD best in Peng et al. (2021)
            learning_rate=lr_schedule,
            weight_decay=keras.regularizers.l2(0.001),
        ),
        loss="kl_divergence",  # Peng et al. (2021) use K-L Divergence loss:
        # no one-hot encoding required: loss=keras.losses.SparseCategoricalCrossentropy(from_logits=True),
        # no log_softmax when using from_logits=True, however, KL-Div. loss expects log_softmax
        metrics=["accuracy"],  # ... if n_classes else ["mae"],
    )
    # ... to minimize a Kullback-Leibler divergence loss function between the predicted probability and a
    # Gaussian distribution (the mean is the true age,
    # and the distribution sigma is 1 year for UKB) for each training subject.
    # This soft-classification loss encourages the model to predict age as accurately as possible.

    # Summary
    k_model.summary()
    return k_model

pretrained_models staticmethod 🧠

pretrained_models() -> None

Return enum of pretrained SFCN models.

Note

There are no pretrained models available for SFCN yet.

Source code in src/xai4mri/model/mrinets.py
436
437
438
439
440
441
442
443
444
@staticmethod
def pretrained_models() -> None:
    """
    Return enum of pretrained `SFCN` models.

    !!! note
        There are no pretrained models available for `SFCN` yet.
    """
    cprint(string="There are no pretrained models available for SFCN (yet).", col="y")

reference staticmethod 🧠

reference() -> str

Get the reference for the SFCN model.

Source code in src/xai4mri/model/mrinets.py
431
432
433
434
@staticmethod
def reference() -> str:
    """Get the reference for the `SFCN` model."""
    return "Peng et al. (2021). Medical Image Analysis. https://doi.org/10.1016/j.media.2020.101871"

get_model 🧠

get_model(
    model_type: OutOfTheBoxModels | str, **kwargs
) -> Sequential | None

Get a freshly initiated out-of-the-box model.

Example of how to get an MRInet model

mrinet = get_model(
             model_type=OutOfTheBoxModels.MRINET,
             name="MyMRInet",
             n_classes=False,
             input_shape=(91, 109, 91),
             )

# Alternatively, get the model by string
sfcn = get_model(
           model_type="sfcn",
           name="MySFCN",
           n_classes=40,
           input_shape=(160, 192, 160),
           )

This is a wrapper for the `create` method of the model creator classes.

Parameters:

Name Type Description Default
model_type OutOfTheBoxModels | str

model of type OutOfTheBoxModels or string (e.g., 'mrinet')

required
kwargs

keyword arguments for model creation

{}
Source code in src/xai4mri/model/mrinets.py
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
def get_model(model_type: OutOfTheBoxModels | str, **kwargs) -> keras.Sequential | None:
    """
    Get a freshly initiated out-of-the-box model.

    !!! example "Example of how to get an `MRInet` model"
        ```python
        mrinet = get_model(
                     model_type=OutOfTheBoxModels.MRINET,
                     name="MyMRInet",
                     n_classes=False,
                     input_shape=(91, 109, 91),
                     )

        # Alternatively, get the model by string
        sfcn = get_model(
                   model_type="sfcn",
                   name="MySFCN",
                   n_classes=40,
                   input_shape=(160, 192, 160),
                   )

        This is a wrapper for the `create` method of the model creator classes.
        ```
    :param model_type: model of type `OutOfTheBoxModels` or `string` (e.g., 'mrinet')
    :param kwargs: keyword arguments for model creation
    """
    if OutOfTheBoxModels(model_type) == OutOfTheBoxModels.MRINET:
        return MRInet.create(**kwargs)
    if OutOfTheBoxModels(model_type) == OutOfTheBoxModels.SFCN:
        return SFCN.create(**kwargs)
    return None