"""`Lightning <https://lightning.ai/>`_ utilities."""
import contextlib
import copy
import logging
import math
import os
import sys
import time
from functools import partial
from typing import Annotated as An
from typing import Any
import numpy as np
import torch
from lightning.pytorch import Trainer
from lightning.pytorch.callbacks import BatchSizeFinder, ModelCheckpoint
from lightning.pytorch.loggers.wandb import WandbLogger
from lightning.pytorch.trainer.connectors.checkpoint_connector import (
_CheckpointConnector,
)
from omegaconf import OmegaConf
from torch.distributed import ReduceOp
from cneuromax.fitting.deeplearning.datamodule import BaseDataModule
from cneuromax.fitting.deeplearning.litmodule import BaseLitModule
from cneuromax.fitting.utils.hydra import get_launcher_config
from cneuromax.utils.beartype import one_of
from cneuromax.utils.misc import can_connect_to_internet
log = logging.getLogger(__name__)
[docs]
def instantiate_trainer(
trainer_partial: partial[Trainer],
logger_partial: partial[WandbLogger],
device: An[str, one_of("cpu", "gpu")],
output_dir: str,
save_every_n_train_steps: int | None,
) -> Trainer:
"""Instantiates :paramref:`trainer_partial`.
Args:
trainer_partial
logger_partial
device: See :paramref:`~.FittingSubtaskConfig.device`.
output_dir: See :paramref:`~.BaseSubtaskConfig.output_dir`.
save_every_n_train_steps: See
:paramref:`~.DeepLearningSubtaskConfig.save_every_n_train_steps`.
Returns:
A :class:`lightning.pytorch.Trainer` instance.
"""
launcher_config = get_launcher_config()
# Retrieve the `Trainer` callbacks specified in the config
callbacks: list[Any] = trainer_partial.keywords["callbacks"] or []
# Check internet connection
offline = (
logger_partial.keywords["offline"] or not can_connect_to_internet()
)
if offline:
os.environ["WANDB_DISABLE_SERVICE"] = "True"
# Adds a callback to save the state of training (not just the model
# despite the name) at the end of every validation epoch.
callbacks.append(
ModelCheckpoint(
dirpath=trainer_partial.keywords["default_root_dir"],
monitor="val/loss",
save_last=True,
save_top_k=1,
every_n_train_steps=save_every_n_train_steps,
),
)
# Instantiate the :class:`WandbLogger`.`
logger = logger_partial(offline=offline)
# Feed the Hydra (https://hydra.cc) config to W&B (https://wandb.ai).
logger.experiment.config.update(
OmegaConf.to_container(
OmegaConf.load(f"{output_dir}/.hydra/config.yaml"),
resolve=True,
throw_on_missing=True,
),
)
# Instantiate the trainer.
return trainer_partial(
devices=(
launcher_config.gpus_per_node or 1
if device == "gpu"
else launcher_config.tasks_per_node
),
logger=logger,
callbacks=callbacks,
)
[docs]
def set_batch_size_and_num_workers(
trainer: Trainer,
datamodule: BaseDataModule,
litmodule: BaseLitModule,
device: An[str, one_of("cpu", "gpu")],
output_dir: str,
) -> None:
"""Sets attribute values for a :class:`~.BaseDataModule`.
See :func:`find_good_per_device_batch_size` and
:func:`find_good_per_device_num_workers` for more details on how
these variables' values are determined.
Args:
trainer
datamodule
litmodule
device: See :paramref:`~.FittingSubtaskConfig.device`.
output_dir: See :paramref:`~.BaseSubtaskConfig.output_dir`.
"""
if not datamodule.config.fixed_per_device_batch_size:
proposed_per_device_batch_size = find_good_per_device_batch_size(
litmodule=litmodule,
datamodule=datamodule,
device=device,
device_ids=trainer.device_ids,
output_dir=output_dir,
)
per_device_batch_size = int(
trainer.strategy.reduce(
torch.tensor(proposed_per_device_batch_size),
reduce_op=ReduceOp.MIN, # type: ignore [arg-type]
),
)
else:
per_device_batch_size = datamodule.config.fixed_per_device_batch_size
if datamodule.config.fixed_per_device_num_workers is None:
proposed_per_device_num_workers = find_good_per_device_num_workers(
datamodule=datamodule,
per_device_batch_size=per_device_batch_size,
)
per_device_num_workers = int(
trainer.strategy.reduce(
torch.tensor(proposed_per_device_num_workers),
reduce_op=ReduceOp.MAX, # type: ignore [arg-type]
),
)
else:
per_device_num_workers = datamodule.config.fixed_per_device_num_workers
datamodule.per_device_batch_size = per_device_batch_size
datamodule.per_device_num_workers = per_device_num_workers
[docs]
def find_good_per_device_batch_size(
litmodule: BaseLitModule,
datamodule: BaseDataModule,
device: str,
device_ids: list[int],
output_dir: str,
) -> int:
"""Probes a :attr:`~.BaseDataModule.per_device_batch_size` value.
This functionality makes the following, not always correct, but
generally reasonable assumptions:
- As long as the ``total_batch_size / dataset_size`` ratio remains
small (e.g. ``< 0.01`` so as to benefit from the stochasticity of
gradient updates), running the same number of gradient updates with
a larger batch size will yield better training performance than
running the same number of gradient updates with a smaller batch
size.
- Loading data from disk to RAM is a larger bottleneck than loading
data from RAM to GPU VRAM.
- If you are training on multiple GPUs, each GPU has roughly the
same amount of VRAM.
Args:
litmodule
datamodule
device: See :paramref:`~.FittingSubtaskConfig.device`.
device_ids: See :class:`lightning.pytorch.Trainer.device_ids`.
output_dir: See :paramref:`~.BaseSubtaskConfig.output_dir`.
Returns:
A roughly optimal ``per_device_batch_size`` value.
"""
launcher_config = get_launcher_config()
datamodule_copy = copy.deepcopy(datamodule)
datamodule_copy.prepare_data()
datamodule_copy.setup("fit")
# (since the default `datamodule` `batch_size` is 1)
dataset_len = len(datamodule_copy.train_dataloader())
num_computing_devices = launcher_config.nodes * (
launcher_config.gpus_per_node or 1
if device == "gpu"
else launcher_config.tasks_per_node
)
per_device_batch_size: int | None
# Ensures total batch_size is < 1% of the train dataloader size.
max_per_device_batch_size = dataset_len // (100 * num_computing_devices)
# If a maximum batch size was specified, use the smaller of the two.
if datamodule.config.max_per_device_batch_size:
max_per_device_batch_size = min(
max_per_device_batch_size,
datamodule.config.max_per_device_batch_size,
)
if device == "cpu":
# Only running batch size finding on GPU: reaching out of GPU
# memory errors does not freeze the system wheras reaching out
# of CPU memory errors does.
per_device_batch_size = max_per_device_batch_size
else:
litmodule_copy = copy.deepcopy(litmodule)
# Speeds up the batch size search by removing the validation
# epoch end method, which is independent of the batch size.
litmodule_copy.on_validation_epoch_end = None # type: ignore[assignment,method-assign]
# Speeds up the batch size search by using a reasonable number
# of workers for the search.
if launcher_config.cpus_per_task:
datamodule_copy.per_device_num_workers = (
launcher_config.cpus_per_task
)
batch_size_finder = BatchSizeFinder(
mode="power",
batch_arg_name="per_device_batch_size",
max_trials=int(math.log2(max_per_device_batch_size)),
)
# Stops the `fit` method after the batch size has been found.
batch_size_finder._early_exit = True # noqa: SLF001
trainer = Trainer(
accelerator=device,
devices=[device_ids[0]], # The first available device.
default_root_dir=output_dir + "/lightning/tuner/",
callbacks=[batch_size_finder],
)
log.info("Finding good `batch_size` parameter...")
per_device_batch_size = None
# Prevents the `fit` method from raising a `KeyError`, see:
# https://github.com/Lightning-AI/pytorch-lightning/issues/18114
with contextlib.suppress(KeyError):
trainer.fit(model=litmodule_copy, datamodule=datamodule_copy)
# If any OOM occurs the value is stored in
# `per_device_batch_size` else in `optimal_batch_size`.
optimal_per_device_batch_size = max(
datamodule_copy.per_device_batch_size,
batch_size_finder.optimal_batch_size or 1,
)
per_device_batch_size = min(
optimal_per_device_batch_size,
max_per_device_batch_size,
)
if per_device_batch_size == 0:
per_device_batch_size = 1
log.info(f"Best `batch_size` parameter: {per_device_batch_size}.")
return per_device_batch_size
[docs]
def find_good_per_device_num_workers(
datamodule: BaseDataModule,
per_device_batch_size: int,
max_num_data_passes: int = 100,
) -> int:
"""Probes a :attr:`~.BaseDataModule.per_device_num_workers` value.
Iterates through a range of ``num_workers`` values and measures the
time it takes to iterate through a fixed number of data passes;
returning the value that yields the shortest time.
Args:
datamodule
per_device_batch_size: The return value of
:func:`find_good_per_device_batch_size`.
max_num_data_passes: Maximum number of data passes to iterate
through.
Returns:
A roughly optimal ``per_device_num_workers`` value.
"""
launcher_config = get_launcher_config()
log.info("Finding good `num_workers` parameter...")
if launcher_config.cpus_per_task in [None, 1]:
log.info("Only 1 worker available/provided. Returning 0.")
return 0
# Static type checking purposes, already narrowed down to `int`
# through the `if` statement above.
assert launcher_config.cpus_per_task # noqa: S101
times: list[float] = [
sys.float_info.max for _ in range(launcher_config.cpus_per_task + 1)
]
datamodule_copy = copy.deepcopy(datamodule)
datamodule_copy.per_device_batch_size = per_device_batch_size
datamodule_copy.prepare_data()
datamodule_copy.setup("fit")
for num_workers in range(launcher_config.cpus_per_task, -1, -1):
datamodule_copy.per_device_num_workers = num_workers
start_time = time.time()
num_data_passes = 0
while num_data_passes < max_num_data_passes:
for _ in datamodule_copy.train_dataloader():
num_data_passes += 1
if num_data_passes == max_num_data_passes:
break
times[num_workers] = time.time() - start_time
log.info(
f"num_workers: {num_workers}, time taken: {times[num_workers]}",
)
# If the time taken is not decreasing, stop the search.
if (
# Not after the first iteration.
num_workers != launcher_config.cpus_per_task # noqa: PLR1714
# Still want to attempt `num_workers` = 0.
and num_workers != 1
and times[num_workers + 1] <= times[num_workers]
):
break
best_time = int(np.argmin(times))
log.info(f"Best `num_workers` parameter: {best_time}.")
return best_time
[docs]
class InitOptimParamsCheckpointConnector(_CheckpointConnector):
"""Tweaked `Lightning <https://lightning.ai/>`_ ckpt connector.
Allows to make use of the instantiated optimizers'
hyper-parameters rather than the checkpointed hyper-parameters.
For use when resuming training with different optimizer
hyper-parameters (e.g. with a PBT `Hydra <https://hydra.cc>`_
Sweeper).
"""
[docs]
def restore_optimizers(self: "InitOptimParamsCheckpointConnector") -> None:
"""Tweaked method to preserve newly instantiated parameters."""
new_optims = copy.deepcopy(self.trainer.strategy.optimizers)
super().restore_optimizers()
for ckpt_optim, new_optim in zip(
self.trainer.strategy.optimizers,
new_optims,
strict=True,
):
for ckpt_optim_param_group, new_optim_param_group in zip(
ckpt_optim.param_groups,
new_optim.param_groups,
strict=True,
):
for ckpt_optim_param_group_key in ckpt_optim_param_group:
# Skip the `params` key as it is not a HP.
if ckpt_optim_param_group_key != "params":
# Place the new Hydra instantiated optimizers'
# HPs back into the restored optimizers.
ckpt_optim_param_group[ckpt_optim_param_group_key] = (
new_optim_param_group[ckpt_optim_param_group_key]
)