base

BaseDataModule + its datasets/config classes.

class cneuromax.fitting.deeplearning.datamodule.base.Datasets(train=None, val=None, test=None, predict=None)[source]

Bases: object

Holds phase-specific torch.utils.data.Dataset objects.

Using the word phase to not overload Lightning stage terminology used for fit, validate and test.

Parameters:
class cneuromax.fitting.deeplearning.datamodule.base.BaseDataModuleConfig(data_dir='${config.data_dir}', device='${config.device}', max_per_device_batch_size=None, fixed_per_device_batch_size=None, fixed_per_device_num_workers=None, shuffle_val_dataset=True)[source]

Bases: object

Holds BaseDataModule config values.

Parameters:
class cneuromax.fitting.deeplearning.datamodule.base.BaseDataModule(config)[source]

Bases: LightningDataModule, ABC

Base lightning.pytorch.core.LightningDataModule.

With <phase> being any of train, val, test or predict, subclasses need to properly define the datasets.<phase> attribute(s) for each desired phase.

Parameters:

config (BaseDataModuleConfig)

config
Type:

BaseDataModuleConfig

datasets
Type:

Datasets

collate_fn

See collate_fn argument in torch.utils.data.DataLoader.

Type:

Callable

pin_memory

Whether to copy tensors into device pinned memory before returning them (is set to True by default if device is "gpu").

Type:

bool

per_device_batch_size

Per-device number of samples to load per iteration. Temporary value (1) is overwritten in set_batch_size_and_num_workers().

Type:

int

per_device_num_workers

Per-device number of CPU processes to use for data loading (0 means that the data will be loaded by each device’s assigned CPU process). Temporary value (0) is later overwritten in set_batch_size_and_num_workers().

Type:

int

final load_state_dict(state_dict)[source]

Replace instance attrib vals w/ state_dict vals.

Parameters:

state_dict (dict[str, int]) – Dictionary containing values to override per_device_batch_size & per_device_num_workers.

Return type:

None

final state_dict()[source]

Returns instance attribute values.

Return type:

dict[str, int]

Returns:

A new dictionary containing attribute values

per_device_batch_size & per_device_num_workers.

final x_dataloader(dataset, *, shuffle=True)[source]

Generic torch.utils.data.DataLoader factory method.

Parameters:
Raises:

AttributeError – If dataset is None.

Return type:

DataLoader[Tensor]

Returns:

A new

torch.utils.data.DataLoader instance wrapping the dataset argument.

final train_dataloader()[source]

Calls x_dataloader() w/ datasets .train.

Return type:

DataLoader[Tensor]

Returns:

A new training

torch.utils.data.DataLoader instance.

final val_dataloader()[source]

Calls x_dataloader() w/ datasets .val.

Return type:

DataLoader[Tensor]

Returns:

A new validation

torch.utils.data.DataLoader instance.

final test_dataloader()[source]

Calls x_dataloader() w/ datasets .test.

Return type:

DataLoader[Tensor]

Returns:

A new testing

torch.utils.data.DataLoader instance.

final predict_dataloader()[source]

Calls x_dataloader() w/ datasets .predict.

Return type:

DataLoader[Tensor]

Returns:

A new prediction

torch.utils.data.DataLoader instance that does not shuffle the dataset.