base¶
BaseDataModule + its datasets/config classes.
- class cneuromax.fitting.deeplearning.datamodule.base.Datasets(train=None, val=None, test=None, predict=None)[source]¶
Bases:
objectHolds phase-specific
torch.utils.data.Datasetobjects.Using the word
phaseto not overload Lightningstageterminology used forfit,validateandtest.- Parameters:
train (
Union[Dataset[Tensor|dict[str,Tensor]],Dataset,DatasetDict,None], default:None)val (
Union[Dataset[Tensor|dict[str,Tensor]],Dataset,DatasetDict,None], default:None)test (
Union[Dataset[Tensor|dict[str,Tensor]],Dataset,DatasetDict,None], default:None)predict (
Union[Dataset[Tensor|dict[str,Tensor]],Dataset,DatasetDict,None], default:None)
- class cneuromax.fitting.deeplearning.datamodule.base.BaseDataModuleConfig(data_dir='${config.data_dir}', device='${config.device}', max_per_device_batch_size=None, fixed_per_device_batch_size=None, fixed_per_device_num_workers=None, shuffle_train_dataset=True, shuffle_val_dataset=True, drop_last=False)[source]¶
Bases:
objectHolds
BaseDataModuleconfig values.- Parameters:
data_dir (
str, default:'${config.data_dir}') – Seedata_dir.max_per_device_batch_size (
Optional[int], default:None) – Seeper_device_batch_size. Sets an upper bound on the aforementioned attribute.fixed_per_device_batch_size (
Optional[int], default:None) – Seeper_device_batch_size. Setting this value skips the batch size search infind_good_per_device_batch_size()which is not recommended for resource efficiency.fixed_per_device_num_workers (
Optional[int], default:None) – Seeper_device_num_workers. Setting this value skips the num workers search infind_good_per_device_num_workers()which is not recommended for resource efficiency.shuffle_train_dataset (
bool, default:True)shuffle_val_dataset (
bool, default:True)
- class cneuromax.fitting.deeplearning.datamodule.base.BaseDataModule(config)[source]¶
Bases:
LightningDataModule,ABCBase
lightning.pytorch.core.LightningDataModule.With
<phase>being any oftrain,val,testorpredict, subclasses need to properly define thedatasets.<phase>attribute(s) for each desired phase.- Parameters:
- config¶
- Type:
- collate_fn¶
See
collate_fnargument intorch.utils.data.DataLoader.- Type:
- pin_memory¶
Whether to copy tensors into device pinned memory before returning them (is set to
Trueby default ifdeviceis"gpu").- Type:
- per_device_batch_size¶
Per-device number of samples to load per iteration. Temporary value (
1) is overwritten inset_batch_size_and_num_workers().- Type:
- per_device_num_workers¶
Per-device number of CPU processes to use for data loading (
0means that the data will be loaded by each device’s assigned CPU process). Temporary value (0) is later overwritten inset_batch_size_and_num_workers().- Type:
- final load_state_dict(state_dict)[source]¶
Replace instance attrib vals w/
state_dictvals.- Parameters:
state_dict (
dict[str,int]) – Dictionary containing values to overrideper_device_batch_size&per_device_num_workers.- Return type:
- final state_dict()[source]¶
Returns instance attribute values.
- Return type:
- Returns:
- A new dictionary containing attribute values
- final x_dataloader(dataset, *, shuffle=True)[source]¶
Generic
torch.utils.data.DataLoaderfactory method.- Parameters:
- Raises:
AttributeError – If
datasetisNone.- Return type:
- Returns:
- A new
torch.utils.data.DataLoaderinstance wrapping thedatasetargument.
- final train_dataloader()[source]¶
Calls
x_dataloader()w/datasets.train.- Return type:
- Returns:
- A new training
torch.utils.data.DataLoaderinstance.
- val_dataloader()[source]¶
Calls
x_dataloader()w/datasets.val.- Return type:
- Returns:
- A new validation
torch.utils.data.DataLoaderinstance.
- final test_dataloader()[source]¶
Calls
x_dataloader()w/datasets.test.- Return type:
- Returns:
- A new testing
torch.utils.data.DataLoaderinstance.
- final predict_dataloader()[source]¶
Calls
x_dataloader()w/datasets.predict.- Return type:
- Returns:
- A new prediction
torch.utils.data.DataLoaderinstance that does not shuffle the dataset.