base¶
BaseLitModule
& its config.
- class cneuromax.fitting.deeplearning.litmodule.base.BaseLitModuleConfig(wandb_column_names, wandb_train_log_interval=50, wandb_num_samples=3)[source]¶
Bases:
object
Holds
BaseLitModule
config values.- Parameters:
wandb_train_log_interval (
int
, default:50
)wandb_num_samples (
int
, default:3
)
- class cneuromax.fitting.deeplearning.litmodule.base.BaseLitModule(config, nnmodule, optimizer, scheduler)[source]¶
Bases:
LightningModule
,ABC
Base
lightning.pytorch.core.LightningModule
.Subclasses need to implement the
step()
method that inputs bothdata
andstage
arguments while returning the loss value(s) in the form of atorch.Tensor
.Example definition:
def step( self: "BaseClassificationLitModule", data: tuple[ Float[Tensor, " batch_size *x_dim"], Int[Tensor, " batch_size"], ], stage: An[str, one_of("train", "val", "test")], ) -> Float[Tensor, " "]: ...
Note
In
cneuromax
, we propose to split the PyTorch module definition from the Lightning module definition for (arguably) better code organization, reuse & readability. As a result, each Lightning module receives a PyTorch module as an argument which it turns into an instance attribute. This is in contrast with the suggested Lightning best practices where Lightning modules subclass PyTorch modules, and thus allow PyTorch module method definitions in the Lightning module.- Parameters:
scheduler (
partial
[LRScheduler
])
- curr_train_step¶
- curr_val_epoch¶
- wandb_train_table¶
Table containing the rich training data that gets logged to W&B.
- Type:
wandb.Table
- wandb_train_data¶
A list of dictionaries containing validation data relating to one specific example (ex: input_data, logits, …).
- wandb_val_table¶
See
wandb_train_table
.- Type:
wandb.Table
- wandb_val_data¶
See
wandb_train_data
.
- on_save_checkpoint(checkpoint)[source]¶
Called by Lightning when saving a checkpoint to give you a chance to store anything else you might want to save.
- Parameters:
checkpoint (
dict
[str
,Any
]) – The full checkpoint dictionary before it gets dumped to a file. Implementations of this hook can insert additional data into this dictionary.- Return type:
Example:
def on_save_checkpoint(self, checkpoint): # 99% of use cases you don't need to implement this method checkpoint['something_cool_i_want_to_save'] = my_cool_pickable_object
Note
Lightning saves all aspects of training (epoch, global step, etc…) including amp scaling. There is no need for you to store anything about training.
- on_load_checkpoint(checkpoint)[source]¶
Called by Lightning to restore your model. If you saved something with
on_save_checkpoint()
this is your chance to restore this.- Parameters:
checkpoint (
dict
[str
,Any
]) – Loaded checkpoint- Return type:
Example:
def on_load_checkpoint(self, checkpoint): # 99% of the time you don't need to implement this method self.something_cool_i_want_to_save = checkpoint['something_cool_i_want_to_save']
Note
Lightning auto-restores global step, epoch, and train state including amp scaling. There is no need for you to restore anything regarding training.
- on_train_batch_start(*args, **kwargs)[source]¶
Called in the training loop before anything happens for that batch.
If you return -1 here, you will skip training for the rest of the current epoch.
- optimizer_step(*args, **kwargs)[source]¶
Override this method to adjust the default way the
Trainer
calls the optimizer.By default, Lightning calls
step()
andzero_grad()
as shown in the example. This method (andzero_grad()
) won’t be called during the accumulation phase whenTrainer(accumulate_grad_batches != 1)
. Overriding this hook has no benefit with manual optimization.- Parameters:
epoch – Current epoch
batch_idx – Index of current batch
optimizer – A PyTorch optimizer
optimizer_closure – The optimizer closure. This closure must be executed as it includes the calls to
training_step()
,optimizer.zero_grad()
, andbackward()
.
- Return type:
Examples:
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_closure): # Add your custom logic to run directly before `optimizer.step()` optimizer.step(closure=optimizer_closure) # Add your custom logic to run directly after `optimizer.step()`
- on_validation_epoch_end()[source]¶
Called in the validation loop at the very end of the epoch.
- Return type:
- update_wandb_data_before_log(data)[source]¶
Hook for subclasses to run a final update to W&B data.
- Return type:
- abstract step(data, stage)[source]¶
Method to be implemented by subclasses.
- Return type:
Num[Tensor, '*_']
- final training_step(data)[source]¶
Calls
stage_step()
with argumentstage="train"
.
- final validation_step(data, *args, **kwargs)[source]¶
Calls
stage_step()
with argumentstage="val"
.
- final test_step(data)[source]¶
Calls
stage_step()
with argumentstage="test"
.
- final configure_optimizers()[source]¶
Returns a dict with
optimizer
andscheduler
.- Return type:
- Returns:
- This instance’s
Optimizer
andtorch.optim.lr_scheduler.LRScheduler
.