base

BaseClassificationLitModule & its config.

class cneuromax.fitting.deeplearning.litmodule.classification.base.BaseClassificationLitModuleConfig(wandb_column_names=<factory>, wandb_train_log_interval=50, wandb_num_samples=3, num_classes=2)[source]

Bases: BaseLitModuleConfig

Holds BaseClassificationLitModule config values.

Parameters:
class cneuromax.fitting.deeplearning.litmodule.classification.base.BaseClassificationLitModule(*args, **kwargs)[source]

Bases: BaseLitModule, ABC

Base Classification LightningModule.

Ref: lightning.pytorch.core.LightningModule

config
Type:

BaseClassificationLitModuleConfig

accuracy
Type:

torchmetrics.classification.MulticlassAccuracy

wandb_table

A table to upload to W&B containing validation data.

Type:

wandb.Table

abstract property wandb_media_x

Converts a tensor to a W&B media object.

final step(data, stage)[source]

Computes the model accuracy and cross entropy loss.

Parameters:
  • data (tuple[Float[Tensor, 'batch_size *x_dim'], Int[Tensor, 'batch_size']]) – A tuple (x, y) where x is the input data and y is the target data.

  • stage (str) – See stage.

Return type:

Float[Tensor, '']

Returns:

The cross entropy loss.

final save_wandb_data(stage, x, y, y_hat, logits)[source]

Saves rich data to be logged to W&B.

Parameters:
  • stage (str)

  • x (Float[Tensor, 'batch_size *x_dim'])

  • y (Int[Tensor, 'batch_size'])

  • y_hat (Int[Tensor, 'batch_size'])

  • logits (Float[Tensor, 'batch_size num_classes']) – The raw num_classes network outputs.

Return type:

None