datamodule

MNISTDataModule & its config.

class cneuromax.projects.classify_mnist.datamodule.MNISTDataModuleConfig(data_dir='${config.data_dir}', device='${config.device}', max_per_device_batch_size=None, fixed_per_device_batch_size=None, fixed_per_device_num_workers=None, val_percentage=0.005)[source]

Bases: BaseDataModuleConfig

Holds MNISTDataModule config values.

Parameters:

val_percentage (float, default: 0.005) – Percentage of the training dataset to use for validation.

class cneuromax.projects.classify_mnist.datamodule.MNISTDataModule(config)[source]

Bases: BaseDataModule

project BaseDataModule.

Parameters:

config (MNISTDataModuleConfig)

train_val_split

The train/validation split (sums to 1).

Type:

tuple[float, float]

transform

The torchvision dataset transformations.

Type:

torchvision.transforms.Compose

prepare_data()[source]

Downloads the MNIST dataset.

Return type:

None

setup(stage)[source]

Creates the train/val/test datasets.

Parameters:

stage (str) – Current stage type.

Return type:

None