base¶
BaseLitModule
& its config.
- class cneuromax.fitting.deeplearning.litmodule.base.BaseLitModuleConfig(log_val_wandb=False)[source]¶
Bases:
object
Holds
BaseLitModule
config values.- Parameters:
log_val_wandb (
bool
, default:False
) – Whether to log validation data to W&B.
- class cneuromax.fitting.deeplearning.litmodule.base.BaseLitModule(config, nnmodule, optimizer, scheduler)[source]¶
Bases:
WandbValLoggingLightningModule
,ABC
Base
lightning.pytorch.core.LightningModule
.Subclasses need to implement the
step()
method that inputs bothdata
andstage
arguments while returning the loss value(s) in the form of atorch.Tensor
.Example definition:
def step( self: "BaseClassificationLitModule", data: tuple[ Float[Tensor, " batch_size *x_dim"], Int[Tensor, " batch_size"], ], stage: An[str, one_of("train", "val", "test")], ) -> Float[Tensor, " "]: ...
Note
data
and loss value(s) type hints in this class are not rendered properly in the documentation due to an incompatibility between sphinx and jaxtyping. Refer to the source code available next to the method signatures to find the correct types.- Parameters:
nnmodule (
Module
) – Atorch
nn.Module
to be used by this instance.optimizer (
partial
[Optimizer
]) – Atorch
Optimizer
to be used by this instance. It is partial as an argument as thennmodule
parameters are required for its initialization.scheduler (
partial
[LRScheduler
]) – Atorch
Scheduler
to be used by this instance. It is partial as an argument as theoptimizer
is required for its initialization.
- optimizer_partial¶
See
optimizer
.- Type:
partial[torch.optim.Optimizer]
- scheduler_partial¶
See
scheduler
.- Type:
partial[ torch.optim.lr_scheduler.LRScheduler]
- Raises:
NotImplementedError – If the
step()
method is not defined or callable.
- final stage_step(data, stage)[source]¶
Generic stage wrapper around the
step()
method.Verifies that the
step()
method exists and is callable, calls it and logs the loss value(s).- Parameters:
- Return type:
Num[Tensor, '*_']
- Returns:
The loss value(s).
- final training_step(data)[source]¶
Calls
stage_step()
with argumentstage="train"
.
- final validation_step(data, *args, **kwargs)[source]¶
Calls
stage_step()
with argumentstage="val"
.
- final test_step(data)[source]¶
Calls
stage_step()
with argumentstage="test"
.
- configure_optimizers()[source]¶
Returns a dict with
optimizer
andscheduler
.- Return type:
OptimizerLRSchedulerConfig
- Returns:
- This instance’s
Optimizer
andtorch.optim.lr_scheduler.LRScheduler
.