base

BaseClassificationLitModule & its config.

class cneuromax.fitting.deeplearning.litmodule.classification.base.BaseClassificationLitModuleConfig(log_val_wandb=False, num_classes=2)[source]

Bases: BaseLitModuleConfig

Holds BaseClassificationLitModule config values.

Parameters:

num_classes (int, default: 2) – Number of classes to classify between.

class cneuromax.fitting.deeplearning.litmodule.classification.base.BaseClassificationLitModule(*args, **kwargs)[source]

Bases: BaseLitModule, ABC

Base Classification LightningModule.

Ref: lightning.pytorch.core.LightningModule

If logging validation data to W&B, make sure to define the wandb_columns attribute in the subclass.

config
Type:

BaseClassificationLitModuleConfig

accuracy
Type:

torchmetrics.classification.MulticlassAccuracy

wandb_table

A table to upload to W&B containing validation data.

Type:

wandb.Table

step(data, stage)[source]

Computes the model accuracy and cross entropy loss.

Parameters:
  • data (tuple[Float[Tensor, 'batch_size *x_dim'], Int[Tensor, 'batch_size']]) – A tuple (x, y) where x is the input data and y is the target data.

  • stage (str) – See stage.

Return type:

Float[Tensor, '']

Returns:

The cross entropy loss.

save_val_data(x, y, y_hat, logits)[source]

Saves data computed during validation for later use.

Parameters:
  • x (Float[Tensor, 'batch_size *x_dim']) – The input data.

  • y (Int[Tensor, 'batch_size']) – The target class.

  • y_hat (Int[Tensor, 'batch_size']) – The predicted class.

  • logits (Float[Tensor, 'batch_size num_classes']) – The raw num_classes network outputs.

Return type:

None