base

BaseLitModule & its config.

class cneuromax.fitting.deeplearning.litmodule.base.BaseLitModuleConfig(wandb_column_names, wandb_train_log_interval=50, wandb_num_samples=3)[source]

Bases: object

Holds BaseLitModule config values.

Parameters:
class cneuromax.fitting.deeplearning.litmodule.base.BaseLitModule(config, nnmodule, optimizer, scheduler)[source]

Bases: LightningModule, ABC

Base lightning.pytorch.core.LightningModule.

Subclasses need to implement the step() method that inputs both data and stage arguments while returning the loss value(s) in the form of a torch.Tensor.

Example definition:

def step(
    self: "BaseClassificationLitModule",
    data: tuple[
        Float[Tensor, " batch_size *x_dim"],
        Int[Tensor, " batch_size"],
    ],
    stage: An[str, one_of("train", "val", "test")],
) -> Float[Tensor, " "]:
    ...

Note

In cneuromax, we propose to split the PyTorch module definition from the Lightning module definition for (arguably) better code organization, reuse & readability. As a result, each Lightning module receives a PyTorch module as an argument which it turns into an instance attribute. This is in contrast with the suggested Lightning best practices where Lightning modules subclass PyTorch modules, and thus allow PyTorch module method definitions in the Lightning module.

Parameters:
config

See config.

nnmodule

See nnmodule.

optimizer_partial

See optimizer.

scheduler_partial

See scheduler.

optimizer

optimizer instantiated.

scheduler

scheduler instantiated.

curr_train_step
curr_val_epoch
wandb_train_table

Table containing the rich training data that gets logged to W&B.

Type:

wandb.Table

wandb_train_data

A list of dictionaries containing validation data relating to one specific example (ex: input_data, logits, …).

Type:

list[dict[str, Any]]

wandb_val_table

See wandb_train_table.

Type:

wandb.Table

wandb_val_data

See wandb_train_data.

Type:

list[dict[str, Any]]

on_save_checkpoint(checkpoint)[source]

Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.

Parameters:

checkpoint (dict[str, Any]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.

Return type:

None

Example:

def on_save_checkpoint(self, checkpoint):
    # 99% of use cases you don't need to implement this method
    checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object

Note

Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.

on_load_checkpoint(checkpoint)[source]

Called by Lightning to restore your model. If you saved something with on_save_checkpoint() this is your chance to restore this.

Parameters:

checkpoint (dict[str, Any]) – Loaded checkpoint

Return type:

None

Example:

def on_load_checkpoint(self, checkpoint):
    # 99% of the time you don't need to implement this method
    self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']

Note

Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.

on_train_batch_start(*args, **kwargs)[source]

Called in the training loop before anything happens for that batch.

If you return -1 here, you will skip training for the rest of the current epoch.

Parameters:
  • batch – The batched data as it is returned by the training DataLoader.

  • batch_idx – the index of the batch

Return type:

None

on_validation_start()[source]

Called at the beginning of validation.

Return type:

None

optimizer_step(*args, **kwargs)[source]

Override this method to adjust the default way the Trainer calls the optimizer.

By default, Lightning calls step() and zero_grad() as shown in the example. This method (and zero_grad()) won’t be called during the accumulation phase when Trainer(accumulate_grad_batches != 1). Overriding this hook has no benefit with manual optimization.

Parameters:
  • epoch – Current epoch

  • batch_idx – Index of current batch

  • optimizer – A PyTorch optimizer

  • optimizer_closure – The optimizer closure. This closure must be executed as it includes the calls to training_step(), optimizer.zero_grad(), and backward().

Return type:

None

Examples:

def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure):
    # Add your custom logic to run directly before `optimizer.step()`

    optimizer.step(closure=optimizer_closure)

    # Add your custom logic to run directly after `optimizer.step()`
on_validation_epoch_end()[source]

Called in the validation loop at the very end of the epoch.

Return type:

None

update_wandb_data_before_log(data)[source]

Hook for subclasses to run a final update to W&B data.

Return type:

None

abstract step(data, stage)[source]

Method to be implemented by subclasses.

Return type:

Num[Tensor, '*_']

final stage_step(data, stage)[source]

Generic stage wrapper around the step() method.

Parameters:
Return type:

Num[Tensor, '*_']

Returns:

The loss value(s).

final training_step(data)[source]

Calls stage_step() with argument stage="train".

Parameters:

data (Any)

Return type:

Num[Tensor, '*_']

Returns:

The loss value(s).

final validation_step(data, *args, **kwargs)[source]

Calls stage_step() with argument stage="val".

Parameters:

data (Any)

Return type:

Num[Tensor, '*_']

Returns:

The loss value(s).

final test_step(data)[source]

Calls stage_step() with argument stage="test".

Parameters:

data (Any)

Return type:

Num[Tensor, '*_']

Returns:

The loss value(s).

final configure_optimizers()[source]

Returns a dict with optimizer and scheduler.

Return type:

tuple[list[Optimizer], list[LRScheduler]]

Returns:

This instance’s

Optimizer and torch.optim.lr_scheduler.LRScheduler.