base

BaseLitModule & its config.

class cneuromax.fitting.deeplearning.litmodule.base.BaseLitModuleConfig(log_val_wandb=False)[source]

Bases: object

Holds BaseLitModule config values.

Parameters:

log_val_wandb (bool, default: False) – Whether to log validation data to W&B.

class cneuromax.fitting.deeplearning.litmodule.base.BaseLitModule(config, nnmodule, optimizer, scheduler)[source]

Bases: WandbValLoggingLightningModule, ABC

Base lightning.pytorch.core.LightningModule.

Subclasses need to implement the step() method that inputs both data and stage arguments while returning the loss value(s) in the form of a torch.Tensor.

Example definition:

def step(
    self: "BaseClassificationLitModule",
    data: tuple[
        Float[Tensor, " batch_size *x_dim"],
        Int[Tensor, " batch_size"],
    ],
    stage: An[str, one_of("train", "val", "test")],
) -> Float[Tensor, " "]:
    ...

Note

data and loss value(s) type hints in this class are not rendered properly in the documentation due to an incompatibility between sphinx and jaxtyping. Refer to the source code available next to the method signatures to find the correct types.

Parameters:
config

See config.

Type:

BaseLitModuleConfig

nnmodule

See nnmodule.

Type:

torch.nn.Module

optimizer_partial

See optimizer.

Type:

partial[torch.optim.Optimizer]

scheduler_partial

See scheduler.

Type:

partial[ torch.optim.lr_scheduler.LRScheduler]

optimizer

optimizer instantiated.

Type:

torch.optim.Optimizer

scheduler

scheduler instantiated.

Type:

torch.optim.lr_scheduler.LRScheduler

Raises:

NotImplementedError – If the step() method is not defined or callable.

final stage_step(data, stage)[source]

Generic stage wrapper around the step() method.

Verifies that the step() method exists and is callable, calls it and logs the loss value(s).

Parameters:
  • data (Any) – The batched input data.

  • stage (str) – The current stage (train, val, test or predict).

Return type:

Num[Tensor, '*_']

Returns:

The loss value(s).

final training_step(data)[source]

Calls stage_step() with argument stage="train".

Parameters:

data (Any) – See data.

Return type:

Num[Tensor, '*_']

Returns:

The loss value(s).

final validation_step(data, *args, **kwargs)[source]

Calls stage_step() with argument stage="val".

Parameters:

data (Any) – See data.

Return type:

Num[Tensor, '*_']

Returns:

The loss value(s).

final test_step(data)[source]

Calls stage_step() with argument stage="test".

Parameters:

data (Any) – See data.

Return type:

Num[Tensor, '*_']

Returns:

The loss value(s).

configure_optimizers()[source]

Returns a dict with optimizer and scheduler.

Return type:

OptimizerLRSchedulerConfig

Returns:

This instance’s

Optimizer and torch.optim.lr_scheduler.LRScheduler.