mrinets
🧠
Train & load deep learning models for MRI-based predictions.
This module provides a set of out-of-the-box models for MRI-based predictions. The models are designed for 3D MRI data and can be used for regression or classification tasks. Moreover, there are pretrained models available for some model architectures.
Models are built on top of the TensorFlow
Keras
API.
This is necessary to ensure compatibility with iNNvestigate
for model interpretability.
Author: Simon M. Hofmann
Years: 2022-2024
MRInet
🧠
Bases: _ModelCreator
MRInet
model creator.
MRInet
is the basemodel architecture (3D-ConvNet) used within the multi-level ensembles in
Hofmann et al. (2022, NeuroImage).
The models were trained to predict age from MR images of different sequences (T1
, FLAIR
, SWI
)
in the LIFE Adult
study.
This model creator class provides a hard-coded Keras implementation of the 3D-CNN model.
Creating a model with this class is as simple as calling the create
method.
This will return a fresh (i.e., untrained) and compiled Keras
model ready to be trained:
Create a new instance of MRInet
# Create a new instance of MRInet
mrinet = MRInet.create(name="MyMRInet", n_classes=False, input_shape=(91, 109, 91))
# Train on MRI dataset
mrinet.fit(X_train, y_train, ...)
Pretrained models are available for T1
, FLAIR
, and SWI
images.
Get an overview of pretrained models by calling:
MRInet.pretrained_models().show()
create
staticmethod
🧠
create(
name: str,
n_classes: bool | None | int,
input_shape: tuple[int, int, int],
learning_rate: float = 0.0005,
target_bias: float | None = None,
leaky_relu: bool = False,
batch_norm: bool = False,
) -> Sequential
Create a new instance of MRInet
, a 3D-convolutional neural network (CNN) for predictions on MRIs.
This is a hard-coded Keras implementation of 3D-CNN model as reported in Hofmann et al. (2022, NeuroImage).
The model can be trained for regression (n_classes=False
) or classification (n_classes: int >= 2
) tasks.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str
|
Model name, which, for instance, could refer to the project it is applied for. |
required |
target_bias |
float | None
|
Model output bias. For classification tasks with this can be left blank [ |
None
|
input_shape |
tuple[int, int, int]
|
Shape of the input to the model. This should be the shape of a single MRI (e.g., (91, 91, 109). |
required |
learning_rate |
float
|
Learning rate which is used for the model's optimizer (here, |
0.0005
|
batch_norm |
bool
|
Use batch normalization or not. Batch normalization should only be used if the model is fed with larger batches. For model interpretability provided by |
False
|
leaky_relu |
bool
|
Using leaky or vanilla ReLU activation functions. Leaky ReLU is recommended for better performance. However, |
False
|
n_classes |
bool | None | int
|
Number of classes. For regression tasks set to |
required |
Returns:
Type | Description |
---|---|
Sequential
|
Compiled |
Source code in src/xai4mri/model/mrinets.py
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 |
|
pretrained_models
staticmethod
🧠
pretrained_models() -> type[PretrainedModelsMRInet]
Return enum of pretrained MRInet
models.
Call the show
method to get an overview of available pretrained models:
MRInet.pretrained_models().show()
Source code in src/xai4mri/model/mrinets.py
277 278 279 280 281 282 283 284 285 286 287 |
|
reference
staticmethod
🧠
reference() -> str
Get the reference for the MRInet
model.
Source code in src/xai4mri/model/mrinets.py
272 273 274 275 |
|
OutOfTheBoxModels
🧠
Bases: Enum
Enum class for out-of-the-box models for MRI-based predictions.
Call the show
method to get an overview of available models:
OutOfTheBoxModels.show()
Get a model by selecting it
mrinet_creator = OutOfTheBoxModels.MRINET.value
mrinet = mrinet_creator.create(...)
# Alternatively, get the model by string
sfcn = OutOfTheBoxModels("sfcn").value.create(...)
default
classmethod
🧠
default()
Get the default model, which is MRInet
.
Source code in src/xai4mri/model/mrinets.py
509 510 511 512 |
|
pretrained_models
🧠
pretrained_models()
Get the pretrained models for the model.
Get information about pretrained MRInet
models
OutOfTheBoxModels.MRINET.pretrained_models().show()
Ultimately, this is just a small wrapper avoiding calling the value
attribute.
Source code in src/xai4mri/model/mrinets.py
497 498 499 500 501 502 503 504 505 506 507 |
|
reference
🧠
reference()
Get the reference for the model.
Get the reference for MRInet
OutOfTheBoxModels.MRINET.reference()
Ultimately, this is just a small wrapper avoiding calling the value
attribute.
Source code in src/xai4mri/model/mrinets.py
485 486 487 488 489 490 491 492 493 494 495 |
|
show
classmethod
🧠
show()
Show available models.
Source code in src/xai4mri/model/mrinets.py
514 515 516 517 |
|
PretrainedMRInetFLAIR
dataclass
🧠
PretrainedMRInetFLAIR(
name: str = "flair_model",
url: str = "https://keeper.mpdl.mpg.de/f/8481f5906f3d4192ab12/?dl=1",
)
Bases: _PretrainedMRInet
Pretrained MRInet
model for FLAIR
images.
This is a basemodel in the FLAIR
sub-ensemble as reported in
Hofmann et al. (2022, NeuroImage).
The training data stem from the LIFE Adult
study,
and were normalized and registered to the T1w
-FreeSurfer
file (brain.finalsurfs.mgz
), and
subsequently pruned to (198, 198, 198)
voxels (for more details,
see Hofmann et al. (2022, NeuroImage).
load_model
classmethod
🧠
load_model(
parent_folder: str | Path,
custom_objects: T | None = None,
compile_model: bool = False,
) -> Sequential
Load a pretrained MRInet
model.
If the model is not present on the local machine, it will be downloaded from the server and saved in the provided parent folder.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parent_folder |
str | Path
|
Path to parent folder of the model. |
required |
custom_objects |
T | None
|
Custom objects to load (e.g., loss functions). |
None
|
compile_model |
bool
|
Compile the model or not. |
False
|
Returns:
Type | Description |
---|---|
Sequential
|
Pretrained |
Source code in src/xai4mri/model/mrinets.py
620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 |
|
PretrainedMRInetSWI
dataclass
🧠
PretrainedMRInetSWI(
name: str = "swi_model",
url: str = "https://keeper.mpdl.mpg.de/f/f53d43b723274687b6e2/?dl=1",
)
Bases: _PretrainedMRInet
Pretrained MRInet
model for SWI
.
This is a basemodel in the SWI
sub-ensemble as reported in
Hofmann et al. (2022, NeuroImage).
The training data stem from the LIFE Adult
study,
and were normalized and registered to the T1w
-FreeSurfer
file (brain.finalsurfs.mgz
), and
subsequently pruned to (198, 198, 198)
voxels (for more details,
see Hofmann et al. (2022, NeuroImage).
load_model
classmethod
🧠
load_model(
parent_folder: str | Path,
custom_objects: T | None = None,
compile_model: bool = False,
) -> Sequential
Load a pretrained MRInet
model.
If the model is not present on the local machine, it will be downloaded from the server and saved in the provided parent folder.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parent_folder |
str | Path
|
Path to parent folder of the model. |
required |
custom_objects |
T | None
|
Custom objects to load (e.g., loss functions). |
None
|
compile_model |
bool
|
Compile the model or not. |
False
|
Returns:
Type | Description |
---|---|
Sequential
|
Pretrained |
Source code in src/xai4mri/model/mrinets.py
620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 |
|
PretrainedMRInetT1
dataclass
🧠
PretrainedMRInetT1(
name: str = "t1_model",
url: str = "https://keeper.mpdl.mpg.de/f/3be6fed59b4948aca699/?dl=1",
)
Bases: _PretrainedMRInet
Pretrained MRInet
model for T1
-weighted images.
This is a basemodel in the T1
sub-ensemble as reported in
Hofmann et al. (2022, NeuroImage).
The training data stem from the LIFE Adult
study,
and were preprocessed using FreeSurfer (brain.finalsurfs.mgz
), and
subsequently pruned to (198, 198, 198)
voxels (for more details,
see Hofmann et al. (2022, NeuroImage).
load_model
classmethod
🧠
load_model(
parent_folder: str | Path,
custom_objects: T | None = None,
compile_model: bool = False,
) -> Sequential
Load a pretrained MRInet
model.
If the model is not present on the local machine, it will be downloaded from the server and saved in the provided parent folder.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
parent_folder |
str | Path
|
Path to parent folder of the model. |
required |
custom_objects |
T | None
|
Custom objects to load (e.g., loss functions). |
None
|
compile_model |
bool
|
Compile the model or not. |
False
|
Returns:
Type | Description |
---|---|
Sequential
|
Pretrained |
Source code in src/xai4mri/model/mrinets.py
620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 |
|
PretrainedModelsMRInet
🧠
Bases: Enum
Enum class for pretrained MRInet
models (T1
, FLAIR
, SWI
).
Call the show
method to get an overview of available models
PretrainedModelsMRInet.show()
Get a model by selecting it
t1_mrinet = PretrainedModelsMRInet.T1_MODEL.value.load_model(...)
# Alternatively, get the model by string
swi_mrinet = PretrainedModelsMRInet("swi_model").value.load_model(...)
show
classmethod
🧠
show()
Show available pretrained MRInet
models.
Source code in src/xai4mri/model/mrinets.py
729 730 731 732 |
|
SFCN
🧠
Bases: _ModelCreator
SFCN
model creator.
SFCN
is the fully convolutional neural network (SFCN
) model by
Peng et al. (2021, Medical Image Analysis)
for age prediction from MRI data of the ukbiobank
.
The architecture won the first place in the brain-age competition PAC 2019
.
create
staticmethod
🧠
create(
name: str,
n_classes: int,
input_shape: tuple[int, int, int] = (160, 192, 160),
learning_rate: float = 0.01,
dropout: bool = True,
) -> Sequential
Create the fully convolutional neural network (SFCN
) model.
The SFCN
was introduced in
Peng et al. (2021, Medical Image Analysis).
The original open-source implementation is done in PyTorch
and can be found at:
https://github.com/ha-ha-ha-han/UKBiobank_deep_pretrain
For training:
from Peng et al. (2021, p.4):
"The L2 weight decay coefficient was 0.001. The batch size was 8. The learning rate for the SGD optimiser was initialized as 0.01, then multiplied by 0.3 every 30 epochs unless otherwise specified. The total epoch number is 130 for the 12,949 training subjects. The epoch number is adjusted accordingly for the experiments with smaller training sets so that the training steps are roughly the same."
Use this version of the SFCN
model with caution
- The implementation is still experimental, is not tested properly yet, and might be updated in the future.
- The authors used "Gaussian soft labels" (
sigma=1, mean=*true age*
) for their loss function. This is not implemented here yet, and might require additional adjustments of thexai4mri.dataloader
module.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str
|
Model name, which, for instance, could refer to the project it is applied for. |
required |
input_shape |
tuple[int, int, int]
|
Shape of the input to the model. This should be the shape of a single MRI. |
(160, 192, 160)
|
n_classes |
int
|
Number of classes. In Peng et al. (2021) there were 40 age classes, representing 40 age strata. |
required |
learning_rate |
float
|
Learning rate which is used for the optimizer of the model. |
0.01
|
dropout |
bool
|
Use dropout or not. |
True
|
Returns:
Type | Description |
---|---|
Sequential
|
Compiled |
Source code in src/xai4mri/model/mrinets.py
301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 |
|
pretrained_models
staticmethod
🧠
pretrained_models() -> None
Return enum of pretrained SFCN
models.
Note
There are no pretrained models available for SFCN
yet.
Source code in src/xai4mri/model/mrinets.py
436 437 438 439 440 441 442 443 444 |
|
reference
staticmethod
🧠
reference() -> str
Get the reference for the SFCN
model.
Source code in src/xai4mri/model/mrinets.py
431 432 433 434 |
|
get_model
🧠
get_model(
model_type: OutOfTheBoxModels | str, **kwargs
) -> Sequential | None
Get a freshly initiated out-of-the-box model.
Example of how to get an MRInet
model
mrinet = get_model(
model_type=OutOfTheBoxModels.MRINET,
name="MyMRInet",
n_classes=False,
input_shape=(91, 109, 91),
)
# Alternatively, get the model by string
sfcn = get_model(
model_type="sfcn",
name="MySFCN",
n_classes=40,
input_shape=(160, 192, 160),
)
This is a wrapper for the `create` method of the model creator classes.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model_type |
OutOfTheBoxModels | str
|
model of type |
required |
kwargs |
keyword arguments for model creation |
{}
|
Source code in src/xai4mri/model/mrinets.py
520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 |
|