base¶
BaseDataModule
+ its datasets/config classes.
- class cneuromax.fitting.deeplearning.datamodule.base.Datasets(train=None, val=None, test=None, predict=None)[source]¶
Bases:
object
Holds phase-specific
torch.utils.data.Dataset
objects.Using the word
phase
to not overload Lightningstage
terminology used forfit
,validate
andtest
.- Parameters:
train (
Union
[Dataset
[Tensor
|dict
[str
,Tensor
]],Dataset
,None
], default:None
)val (
Union
[Dataset
[Tensor
|dict
[str
,Tensor
]],Dataset
,None
], default:None
)test (
Union
[Dataset
[Tensor
|dict
[str
,Tensor
]],Dataset
,None
], default:None
)predict (
Union
[Dataset
[Tensor
|dict
[str
,Tensor
]],Dataset
,None
], default:None
)
- class cneuromax.fitting.deeplearning.datamodule.base.BaseDataModuleConfig(data_dir='${config.data_dir}', device='${config.device}', max_per_device_batch_size=None, fixed_per_device_batch_size=None, fixed_per_device_num_workers=None, shuffle_val_dataset=True)[source]¶
Bases:
object
Holds
BaseDataModule
config values.- Parameters:
data_dir (
str
, default:'${config.data_dir}'
) – Seedata_dir
.max_per_device_batch_size (
Optional
[int
], default:None
) – Seeper_device_batch_size
. Sets an upper bound on the aforementioned attribute.fixed_per_device_batch_size (
Optional
[int
], default:None
) – Seeper_device_batch_size
. Setting this value skips the batch size search infind_good_per_device_batch_size()
which is not recommended for resource efficiency.fixed_per_device_num_workers (
Optional
[int
], default:None
) – Seeper_device_num_workers
. Setting this value skips the num workers search infind_good_per_device_num_workers()
which is not recommended for resource efficiency.shuffle_val_dataset (
bool
, default:True
)
- class cneuromax.fitting.deeplearning.datamodule.base.BaseDataModule(config)[source]¶
Bases:
LightningDataModule
,ABC
Base
lightning.pytorch.core.LightningDataModule
.With
<phase>
being any oftrain
,val
,test
orpredict
, subclasses need to properly define thedatasets.<phase>
attribute(s) for each desired phase.- Parameters:
- config¶
- Type:
- collate_fn¶
See
collate_fn
argument intorch.utils.data.DataLoader
.- Type:
- pin_memory¶
Whether to copy tensors into device pinned memory before returning them (is set to
True
by default ifdevice
is"gpu"
).- Type:
- per_device_batch_size¶
Per-device number of samples to load per iteration. Temporary value (
1
) is overwritten inset_batch_size_and_num_workers()
.- Type:
- per_device_num_workers¶
Per-device number of CPU processes to use for data loading (
0
means that the data will be loaded by each device’s assigned CPU process). Temporary value (0
) is later overwritten inset_batch_size_and_num_workers()
.- Type:
- final load_state_dict(state_dict)[source]¶
Replace instance attrib vals w/
state_dict
vals.- Parameters:
state_dict (
dict
[str
,int
]) – Dictionary containing values to overrideper_device_batch_size
&per_device_num_workers
.- Return type:
- final state_dict()[source]¶
Returns instance attribute values.
- Return type:
- Returns:
- A new dictionary containing attribute values
- final x_dataloader(dataset, *, shuffle=True)[source]¶
Generic
torch.utils.data.DataLoader
factory method.- Parameters:
- Raises:
AttributeError – If
dataset
isNone
.- Return type:
- Returns:
- A new
torch.utils.data.DataLoader
instance wrapping thedataset
argument.
- final train_dataloader()[source]¶
Calls
x_dataloader()
w/datasets
.train
.- Return type:
- Returns:
- A new training
torch.utils.data.DataLoader
instance.
- final val_dataloader()[source]¶
Calls
x_dataloader()
w/datasets
.val
.- Return type:
- Returns:
- A new validation
torch.utils.data.DataLoader
instance.
- final test_dataloader()[source]¶
Calls
x_dataloader()
w/datasets
.test
.- Return type:
- Returns:
- A new testing
torch.utils.data.DataLoader
instance.
- final predict_dataloader()[source]¶
Calls
x_dataloader()
w/datasets
.predict
.- Return type:
- Returns:
- A new prediction
torch.utils.data.DataLoader
instance that does not shuffle the dataset.