Skip to content

model 🧠

Init model submodule of xai4mri.

Author: Simon M. Hofmann
Years: 2023-2024

OutOfTheBoxModels 🧠

Bases: Enum

Enum class for out-of-the-box models for MRI-based predictions.

Call the show method to get an overview of available models:

OutOfTheBoxModels.show()

Get a model by selecting it

mrinet_creator = OutOfTheBoxModels.MRINET.value
mrinet = mrinet_creator.create(...)

# Alternatively, get the model by string
sfcn = OutOfTheBoxModels("sfcn").value.create(...)

default classmethod 🧠

default()

Get the default model, which is MRInet.

Source code in src/xai4mri/model/mrinets.py
509
510
511
512
@classmethod
def default(cls):
    """Get the default model, which is `MRInet`."""
    return cls.MRINET

pretrained_models 🧠

pretrained_models()

Get the pretrained models for the model.

Get information about pretrained MRInet models

OutOfTheBoxModels.MRINET.pretrained_models().show()

Ultimately, this is just a small wrapper avoiding calling the value attribute.

Source code in src/xai4mri/model/mrinets.py
497
498
499
500
501
502
503
504
505
506
507
def pretrained_models(self):
    """
    Get the pretrained models for the model.

    !!! example "Get information about pretrained `MRInet` models"
        ```python
        OutOfTheBoxModels.MRINET.pretrained_models().show()
        ```
    Ultimately, this is just a small wrapper avoiding calling the `value` attribute.
    """
    return self.value.pretrained_models()

reference 🧠

reference()

Get the reference for the model.

Get the reference for MRInet

OutOfTheBoxModels.MRINET.reference()

Ultimately, this is just a small wrapper avoiding calling the value attribute.

Source code in src/xai4mri/model/mrinets.py
485
486
487
488
489
490
491
492
493
494
495
def reference(self):
    """
    Get the reference for the model.

    !!! example "Get the reference for `MRInet`"
        ```python
        OutOfTheBoxModels.MRINET.reference()
        ```
    Ultimately, this is just a small wrapper avoiding calling the `value` attribute.
    """
    return self.value.reference()

show classmethod 🧠

show()

Show available models.

Source code in src/xai4mri/model/mrinets.py
514
515
516
517
@classmethod
def show(cls):
    """Show available models."""
    return print(*[(model, model.reference()) for model in cls], sep="\n")

analyze_model 🧠

analyze_model(
    model: Model,
    ipt: ndarray,
    norm: bool,
    analyzer_type: str = "lrp.sequential_preset_a",
    neuron_selection: int | None = None,
    **kwargs
) -> ndarray

Analyze the prediction of a model with respect to a given input.

Produce an analyzer map ('heatmap') for a given model and input image. The heatmap indicates the relevance of each pixel w.r.t. the model's prediction.

Parameters:

Name Type Description Default
model Model

Deep learning model.

required
ipt ndarray

Input image to model, shape: [batch_size: = 1, x, y, z, channels: = 1].

required
norm bool

True: normalize the computed analyzer map to [-1, 1].

required
analyzer_type str

Type of model analyzers [default: "lrp.sequential_preset_a" for ConvNets]. Check documentation of iNNvestigate for different types of analyzers.

'lrp.sequential_preset_a'
neuron_selection int | None

Index of the model's output neuron [int], whose activity is to be analyzed; Or take the 'max_activation' neuron [if None]

None
kwargs

Additional keyword arguments for the innvestigate.create_analyzer() function.

{}

Returns:

Type Description
ndarray

The computed analyzer map.

Source code in src/xai4mri/model/interpreter.py
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def analyze_model(
    model: keras.Model,
    ipt: np.ndarray,
    norm: bool,
    analyzer_type: str = "lrp.sequential_preset_a",
    neuron_selection: int | None = None,
    **kwargs,
) -> np.ndarray:
    """
    Analyze the prediction of a model with respect to a given input.

    Produce an analyzer map ('heatmap') for a given model and input image.
    The heatmap indicates the relevance of each pixel w.r.t. the model's prediction.

    :param model: Deep learning model.
    :param ipt: Input image to model, shape: `[batch_size: = 1, x, y, z, channels: = 1]`.
    :param norm: True: normalize the computed analyzer map to [-1, 1].
    :param analyzer_type: Type of model analyzers [default: "lrp.sequential_preset_a" for ConvNets].
                          Check documentation of `iNNvestigate` for different types of analyzers.
    :param neuron_selection: Index of the model's output neuron [int], whose activity is to be analyzed;
                             Or take the 'max_activation' neuron [if `None`]
    :param kwargs: Additional keyword arguments for the `innvestigate.create_analyzer()` function.
    :return: The computed analyzer map.
    """
    # Create analyzer
    disable_model_checks = kwargs.pop("disable_model_checks", True)
    analyzer = innvestigate.create_analyzer(
        name=analyzer_type,
        model=model,
        disable_model_checks=disable_model_checks,
        neuron_selection_mode="index" if isinstance(neuron_selection, int) else "max_activation",
        **kwargs,
    )

    # Apply analyzer w.r.t. maximum activated output-neuron
    a = analyzer.analyze(ipt, neuron_selection=neuron_selection)

    if norm:
        # Normalize between [-1, 1]
        a /= np.max(np.abs(a))

    return a

get_model 🧠

get_model(
    model_type: OutOfTheBoxModels | str, **kwargs
) -> Sequential | None

Get a freshly initiated out-of-the-box model.

Example of how to get an MRInet model

mrinet = get_model(
             model_type=OutOfTheBoxModels.MRINET,
             name="MyMRInet",
             n_classes=False,
             input_shape=(91, 109, 91),
             )

# Alternatively, get the model by string
sfcn = get_model(
           model_type="sfcn",
           name="MySFCN",
           n_classes=40,
           input_shape=(160, 192, 160),
           )

This is a wrapper for the `create` method of the model creator classes.

Parameters:

Name Type Description Default
model_type OutOfTheBoxModels | str

model of type OutOfTheBoxModels or string (e.g., 'mrinet')

required
kwargs

keyword arguments for model creation

{}
Source code in src/xai4mri/model/mrinets.py
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
def get_model(model_type: OutOfTheBoxModels | str, **kwargs) -> keras.Sequential | None:
    """
    Get a freshly initiated out-of-the-box model.

    !!! example "Example of how to get an `MRInet` model"
        ```python
        mrinet = get_model(
                     model_type=OutOfTheBoxModels.MRINET,
                     name="MyMRInet",
                     n_classes=False,
                     input_shape=(91, 109, 91),
                     )

        # Alternatively, get the model by string
        sfcn = get_model(
                   model_type="sfcn",
                   name="MySFCN",
                   n_classes=40,
                   input_shape=(160, 192, 160),
                   )

        This is a wrapper for the `create` method of the model creator classes.
        ```
    :param model_type: model of type `OutOfTheBoxModels` or `string` (e.g., 'mrinet')
    :param kwargs: keyword arguments for model creation
    """
    if OutOfTheBoxModels(model_type) == OutOfTheBoxModels.MRINET:
        return MRInet.create(**kwargs)
    if OutOfTheBoxModels(model_type) == OutOfTheBoxModels.SFCN:
        return SFCN.create(**kwargs)
    return None

mono_phase_model_training 🧠

mono_phase_model_training(
    model: Sequential,
    epochs: int,
    data: BaseDataSet,
    target: str,
    model_parent_path: str | Path,
    split_dict: dict | None = None,
    callbacks: list[Callback] | None = None,
    **kwargs
) -> Sequential | None

Train / finetune all model weights at once.

This simply trains the model on the provided dataset for the given number of epochs.

When used for transfer learning

This is a naive approach to transfer learning, and can lead to issues such as catastrophic forgetting.

Parameters:

Name Type Description Default
model Sequential

Compiled Keras model to be trained on the provided dataset.

required
epochs int

Number of training epochs.

required
data BaseDataSet

Dataset for training and evaluation. This must be a subclass of BaseDataSet.

required
target str

Variable to be predicted. This must be in the 'study_tableof the dataset (data').

required
model_parent_path str | Path

The path to the parent folder of the given model, where the model will be saved.

required
split_dict dict | None

Data split dictionary for training, validation, and test data.

None
callbacks list[Callback] | None

A list of Keras's callbacks (except of ModelCheckpoint).

None
kwargs

Additional keyword arguments for data.create_data_split().

{}

Returns:

Type Description
Sequential | None

Trained model.

Source code in src/xai4mri/model/transfer.py
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
def mono_phase_model_training(
    model: keras.Sequential,
    epochs: int,
    data: BaseDataSet,
    target: str,
    model_parent_path: str | Path,
    split_dict: dict | None = None,
    callbacks: list[keras.callbacks.Callback] | None = None,
    **kwargs,
) -> keras.Sequential | None:
    """
    Train / finetune all model weights at once.

    This simply trains the model on the provided dataset for the given number of epochs.

    !!! note "When used for transfer learning"
        This is a naive approach to transfer learning, and can lead to issues such as catastrophic forgetting.

    :param model: Compiled `Keras` model to be trained on the provided dataset.
    :param epochs: Number of training epochs.
    :param data: Dataset for training and evaluation. This must be a subclass of `BaseDataSet`.
    :param target: Variable to be predicted. This must be in the 'study_table` of the dataset (`data').
    :param model_parent_path: The path to the parent folder of the given model, where the model will be saved.
    :param split_dict: Data split dictionary for training, validation, and test data.
    :param callbacks: A list of `Keras`'s `callbacks` (except of `ModelCheckpoint`).
    :param kwargs: Additional keyword arguments for `data.create_data_split()`.
    :return: Trained model.
    """
    # Set path to model and split dictionary
    path_to_model = Path(model_parent_path) / model.name
    path_to_checkpoints = path_to_model / "checkpoints"
    path_to_split_dict = path_to_model / f"{model.name}_split_dict"

    # Check if model has been trained already
    if list(path_to_checkpoints.glob("*")):
        cprint(string="Model has been trained already. Skipping training.", col="y")
        return None

    # Create data splits
    split_dict, train_data_gen, val_data_gen, test_data_gen = data.create_data_split(
        target=target,
        batch_size=kwargs.pop("batch_size", 1),
        split_ratio=kwargs.pop("split_ratio", (0.8, 0.1, 0.1)),
        split_dict=split_dict,
    )

    # Define callbacks
    callbacks = [] if callbacks is None else callbacks
    for c in callbacks:
        if isinstance(c, keras.callbacks.ModelCheckpoint):
            break
    else:
        callbacks.append(
            keras.callbacks.ModelCheckpoint(
                filepath=path_to_checkpoints / "cp-{epoch:04d}.ckpt",
                save_weights_only=True,  # TODO: revisit  # noqa: FIX002
                monitor="val_loss",
                mode="auto",
                save_best_only=True,  # TODO: revisit  # noqa: FIX002
                save_freq="epoch",
            )
        )

    # Train model
    model.fit(
        x=train_data_gen,
        epochs=epochs,
        validation_data=val_data_gen,
        callbacks=callbacks,
    )

    # Save split dictionary and model
    data.save_split_dict(save_path=path_to_split_dict)
    model.save(path_to_model / f"{model.name}.h5")

    # Model evaluation
    model.evaluate(test_data_gen)

    return model