Source code for deepml.accelerator_trainer

import os
from collections import OrderedDict, defaultdict
from typing import Callable, Dict, Optional, Tuple, Union

import torch
from accelerate import Accelerator
from tqdm import tqdm

from deepml.base import BaseLearner
from deepml.tasks import Task
from deepml.tracking import MLExperimentLogger, TensorboardLogger


[docs] class AcceleratorTrainer(BaseLearner): """Training class using HuggingFace Accelerate for distributed training. This trainer leverages the Accelerate library for seamless distributed training, mixed precision, and device management across CPUs, GPUs, and TPUs. It supports gradient accumulation, gradient clipping, and automatic model/optimizer preparation. """
[docs] def __init__( self, task: Task, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, lr_scheduler_step_policy: str = "epoch", accelerator_config: Optional[dict] = None, ): """Initializes the AcceleratorTrainer. Args: task: Task object defining the learning task (e.g., classification, segmentation). optimizer: PyTorch optimizer instance for parameter updates. criterion: Loss function module. lr_scheduler: Learning rate scheduler instance. Defaults to None. lr_scheduler_step_policy: When to call scheduler.step(). Valid options are ``"epoch"`` (step after each epoch) or ``"step"`` (step after each optimizer update). Defaults to ``"epoch"``. accelerator_config: Optional dictionary of keyword arguments passed to Accelerate.Accelerator() for configuration. Common options include: - ``gradient_accumulation_steps``: Number of steps to accumulate gradients - ``mixed_precision``: Mixed precision mode ("no", "fp16", "bf16") - ``device_placement``: Whether to automatically place tensors on device - ``split_batches``: Whether to split batches across devices Defaults to None (uses Accelerate defaults). Note: Unlike FabricTrainer, this class accepts an lr_scheduler instance directly rather than a factory function (lr_scheduler_fn). """ super().__init__( task=task, optimizer=optimizer, criterion=criterion, lr_scheduler=lr_scheduler, lr_scheduler_fn=None, lr_scheduler_step_policy=lr_scheduler_step_policy, ) if accelerator_config is None: accelerator_config = {} self.accelerator = Accelerator(**accelerator_config) self._task._device = self.accelerator.device self.epochs_completed = 0 self.best_val_loss = float("inf") self.history = defaultdict(list) os.makedirs(self._model_dir, exist_ok=True)
def __train( self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, train_loader: torch.utils.data.DataLoader, step_lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, metrics: Dict[str, torch.nn.Module] = None, non_blocking: bool = True, gradient_clip_value: Optional[float] = None, gradient_clip_max_norm: Optional[float] = None, ) -> OrderedDict[str, float]: """Runs a single training epoch with Accelerate-managed gradient accumulation. Executes one complete pass through the training data with distributed training support, gradient accumulation, and optional gradient clipping. Args: model: Model to train (prepared by Accelerate). optimizer: Optimizer for parameter updates (prepared by Accelerate). criterion: Loss function accepting (outputs, targets). train_loader: DataLoader yielding (inputs, targets) batches (prepared by Accelerate). step_lr_scheduler: Learning rate scheduler stepped after each optimizer update (for "step" policy). Defaults to None. metrics: Dictionary mapping metric names to metric modules. Each metric should accept (outputs, targets). Defaults to None. non_blocking: Whether to use asynchronous CUDA transfers. Defaults to True. gradient_clip_value: Maximum absolute value for gradient clipping. Gradients clipped to [-value, value]. Defaults to None (no clipping). gradient_clip_max_norm: Maximum L2 norm for gradient clipping. Defaults to None (no clipping). Returns: OrderedDict mapping metric names to aggregated values (simple moving average) across all processes. Only meaningful on the main process. Note: - Uses Accelerate's accumulate context manager for gradient accumulation - Gradients are clipped only when sync_gradients is True - Metrics are gathered from all processes and aggregated on main process - Progress bars and logging only occur on the main process """ model.train() # init all metrics with zeros each local process local_batch_metrics_dict = AcceleratorTrainer.init_metrics(metrics) training_progress_bar = None step = None global_metrics_dict = {} if self.accelerator.is_main_process: # Global metrics dict for tracking metrics from all processes, # separate history is used to track metrics across multiple calls to fit method global_metrics_dict = AcceleratorTrainer.init_metrics(metrics) training_progress_bar = tqdm( total=len(train_loader), desc="{:12s}".format("Training"), dynamic_ncols=True, ) # count number of steps step = 0 # Nullify the parameter gradients optimizer.zero_grad(set_to_none=True) for batch_index, (x, y) in enumerate(train_loader): with self.accelerator.accumulate(model): outputs, x, y = self._task.train_step( x, y, model=model, device=self.accelerator.device, non_blocking=non_blocking, ) if ( isinstance(outputs, torch.Tensor) and outputs.ndim == 2 and outputs.shape[1] == 1 ): y = y.view_as(outputs) loss = criterion(outputs, y) self.accelerator.backward( loss ) # no need to scale loss manually, accelerator takes care of it # Gradient clipping if self.accelerator.sync_gradients: if gradient_clip_value is not None: self.accelerator.clip_grad_value_( model.parameters(), gradient_clip_value ) elif gradient_clip_max_norm is not None: self.accelerator.clip_grad_norm_( model.parameters(), max_norm=gradient_clip_max_norm ) optimizer.step() if step_lr_scheduler is not None: step_lr_scheduler.step() # Nullify the parameter gradients optimizer.zero_grad(set_to_none=True) local_batch_metrics_dict["loss"] = loss AcceleratorTrainer.update_metrics( outputs, y, metrics, local_batch_metrics_dict ) # collect metric values from all processes using tensor type, avoid dict type values = torch.tensor( list(local_batch_metrics_dict.values()), device=self.accelerator.device, dtype=torch.float32, ) # all_gather is used to aggregate the value across processes all_batch_metrics = self.accelerator.gather( values ) # returns tensor of shape (world_size, num_metrics) # Aggregate metrics across all processes if self.accelerator.is_main_process: training_progress_bar.update(1) step = step + 1 all_batch_metrics = all_batch_metrics.view( self.accelerator.num_processes, len(local_batch_metrics_dict) ) # Convert all_batch_metrics to dict with metric names all_batch_metrics = { name: all_batch_metrics[ :, i ] # all_batch_metrics[:, 0] -> loss, all_batch_metrics[:, 1] -> acc, etc. for i, name in enumerate(local_batch_metrics_dict.keys()) } AcceleratorTrainer.update_metrics_with_simple_moving_average( all_batch_metrics, global_metrics_dict, step ) training_progress_bar.set_postfix( { name: f"{round(value, 4)}" for name, value in global_metrics_dict.items() } ) if self.accelerator.is_main_process: training_progress_bar.close() return global_metrics_dict
[docs] def fit( self, train_loader: torch.utils.data.DataLoader, val_loader: torch.utils.data.DataLoader = None, epochs: int = 10, save_model_after_every_epoch: int = 5, metrics: Dict[str, torch.nn.Module] = None, gradient_clip_value: Optional[float] = None, gradient_clip_max_norm: Optional[float] = None, resume_from_checkpoint: str = None, load_optimizer_state: bool = False, load_scheduler_state: bool = False, logger: MLExperimentLogger = None, non_blocking: bool = True, image_inverse_transform: Callable = None, logger_img_size: Union[int, Tuple[int, int]] = None, ): """Trains the model for the specified number of epochs using Accelerate. Handles the complete training workflow including model preparation, distributed training coordination, checkpointing, validation, and metric logging. Args: train_loader: DataLoader for training data. val_loader: DataLoader for validation data. Defaults to None. epochs: Total number of epochs to train. Defaults to 10. save_model_after_every_epoch: Frequency (in epochs) to save model checkpoints. Defaults to 5. metrics: Dictionary mapping metric names to metric instances. Each metric must be a torch.nn.Module with a forward() method. Defaults to None. gradient_clip_value: Maximum absolute value for gradient clipping. Gradients will be clipped to [-gradient_clip_value, gradient_clip_value]. Mutually exclusive with gradient_clip_max_norm. Defaults to None. gradient_clip_max_norm: Maximum L2 norm for gradient clipping. Mutually exclusive with gradient_clip_value. Defaults to None. resume_from_checkpoint: Path to checkpoint file to resume training from. Defaults to None. load_optimizer_state: Whether to load optimizer state from checkpoint. Defaults to False. load_scheduler_state: Whether to load learning rate scheduler state from checkpoint. Defaults to False. logger: Experiment logger for tracking metrics and artifacts. If None, uses TensorboardLogger. Defaults to None. non_blocking: Whether to use asynchronous CUDA tensor transfers. Defaults to True. image_inverse_transform: Transformation to reverse image normalization for visualization in TensorBoard. Defaults to None. logger_img_size: Image size (int or tuple) for TensorBoard logging. Defaults to None. Returns: Dictionary containing training history with metric names as keys and lists of values as entries. Raises: ValueError: If both gradient_clip_value and gradient_clip_max_norm are provided. TypeError: If any metric is not a torch.nn.Module with a forward() method. Note: - All model, optimizer, scheduler, and dataloaders are prepared by Accelerate - Only the main process saves checkpoints and manages logging - All processes synchronize at the end of each epoch using wait_for_everyone() - The model is automatically unwrapped when saving best validation checkpoint """ if gradient_clip_value is not None and gradient_clip_max_norm is not None: raise ValueError( "Only one of gradient_clip_value or gradient_clip_max_norm should be passed." ) # Check valid metrics types if metrics: for metric_name, metric_instance in metrics.items(): if not ( isinstance(metric_instance, torch.nn.Module) and hasattr(metric_instance, "forward") ): raise TypeError(f"{metric_instance.__class__} is not supported") # Resume from checkpoint if provided if resume_from_checkpoint is not None and os.path.exists( resume_from_checkpoint ): state_dict = torch.load( resume_from_checkpoint, map_location=self.accelerator.device ) self._model.load_state_dict(state_dict["model_state_dict"]) if load_optimizer_state: AcceleratorTrainer.load_optimizer_state(self._optimizer, state_dict) if ( self._lr_scheduler is not None and load_scheduler_state and "scheduler_state_dict" in state_dict ): AcceleratorTrainer.load_lr_schedular_state( self._lr_scheduler, state_dict ) self.epochs_completed = state_dict.get("epoch", 0) self.best_val_loss = state_dict.get("val_loss", float("inf")) if self.accelerator.is_main_process: print( f"Resuming training from epoch {self.epochs_completed} with best validation loss {self.best_val_loss}" ) # Prepare everything for the current device (CPU/GPU/TPU) model, optimizer, lr_scheduler, train_loader, val_loader = ( self.accelerator.prepare( self._model, self._optimizer, self._lr_scheduler, train_loader, val_loader, ) ) if self.accelerator.is_main_process: self.logger = ( logger if logger is not None else TensorboardLogger(self._model_dir) ) self.logger.log_params( task=self._task, model=self._model, optimizer=self._optimizer, lr_scheduler=self._lr_scheduler, criterion=self._criterion, loader=val_loader, epochs=epochs, gradient_clip_value=gradient_clip_value, gradient_clip_max_norm=gradient_clip_max_norm, resume_from_checkpoint=resume_from_checkpoint, ) criterion = self._criterion epochs_completed = self.epochs_completed best_val_loss = self.best_val_loss epochs = epochs_completed + epochs history = defaultdict(list) val_global_metrics_dict = {"loss": float("inf")} train_loss = float("inf") val_loss = float("inf") for epoch in range(epochs_completed, epochs): if self.accelerator.is_main_process: print("Epoch {}/{}:".format(epoch + 1, epochs)) AcceleratorTrainer.write_lr(optimizer, epoch + 1, self.logger, history) # training train_global_metrics_dict = self.__train( model, optimizer, criterion, train_loader, step_lr_scheduler=( lr_scheduler if self._lr_scheduler_step_policy == "step" else None ), metrics=metrics, non_blocking=non_blocking, gradient_clip_value=gradient_clip_value, gradient_clip_max_norm=gradient_clip_max_norm, ) # evaluation if val_loader is not None: val_global_metrics_dict = self.__validate( model, val_loader, criterion, metrics, non_blocking=non_blocking, ) # After each epoch completed, write metrics to logger if self.accelerator.is_main_process: epochs_completed = epochs_completed + 1 self.log_metrics( val_loader, train_global_metrics_dict, val_global_metrics_dict, history, epochs_completed, logger_img_size, image_inverse_transform, ) train_loss = train_global_metrics_dict["loss"] val_loss = val_global_metrics_dict["loss"] message = f"\nTrain Loss: {train_loss:.4f} Val Loss: {val_loss:.4f}" # Save the best validation model if val_loss < best_val_loss: message = message + " [Saving best validation model]" best_val_loss = val_loss self.save( "best_val_model", self.accelerator.unwrap_model(model), optimizer, criterion, lr_scheduler, epoch=epochs_completed, train_loss=train_loss, val_loss=val_loss, ) # Log info message to console only global zero process tqdm.write(message) if epochs_completed % save_model_after_every_epoch == 0: last_checkpoint = "epoch_{}_model".format(epochs_completed) self.save( last_checkpoint, model, optimizer, criterion, lr_scheduler, epoch=epochs_completed, train_loss=train_loss, val_loss=val_loss, ) # Ensure all processes are synchronized before proceeding next epoch self.accelerator.wait_for_everyone() # LR Scheduler step after each epoch if lr_scheduler is not None and self._lr_scheduler_step_policy == "epoch": if val_loader and isinstance( lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau ): lr_scheduler.step(val_global_metrics_dict["loss"]) else: lr_scheduler.step() # Save latest model at the end if self.accelerator.is_main_process: self.save( "latest_model", model, optimizer, criterion, lr_scheduler, epoch=epochs_completed, train_loss=train_loss, val_loss=val_loss, ) return history
@torch.no_grad() def __validate( self, model: torch.nn.Module, loader: torch.utils.data.DataLoader, criterion: torch.nn.Module, metrics: Dict[str, torch.nn.Module] = None, non_blocking: bool = True, ): """Runs a single validation epoch across all processes. Evaluates the model on validation data with distributed support and metric aggregation across all processes. Args: model: Model to evaluate (prepared by Accelerate); set to eval() mode. loader: DataLoader for validation data (prepared by Accelerate). criterion: Loss function accepting (outputs, targets). metrics: Dictionary mapping metric names to metric modules. Each metric should accept (outputs, targets). Defaults to None. non_blocking: Whether to use asynchronous CUDA transfers. Defaults to True. Returns: OrderedDict mapping metric names to aggregated values (simple moving average) across all processes. Only meaningful on the main process. Note: - Gradients are disabled via ``@torch.no_grad()`` decorator - Metrics are gathered from all processes using accelerator.gather() - Progress bars and returned metrics are managed only on the main process - All processes participate in metric computation but only main process logs """ model.eval() local_batch_metrics_dict = AcceleratorTrainer.init_metrics(metrics) validation_progress_bar = None global_metrics_dict = AcceleratorTrainer.init_metrics(metrics) step = 0 if self.accelerator.is_main_process: validation_progress_bar = tqdm( total=len(loader), desc="{:12s}".format("Validation"), dynamic_ncols=True, leave=True, ) for batch_index, (x, y) in enumerate(loader): outputs, x, y = self._task.eval_step( x, y, model=model, device=self.accelerator.device, non_blocking=non_blocking, ) if isinstance(y, torch.Tensor): y = y.to(self.accelerator.device) if ( isinstance(outputs, torch.Tensor) and outputs.ndim == 2 and outputs.shape[1] == 1 ): y = y.view_as(outputs) loss = criterion(outputs, y) local_batch_metrics_dict["loss"] = loss AcceleratorTrainer.update_metrics( outputs, y, metrics, local_batch_metrics_dict ) # collect metric values from all processes using tensor type, avoid dict type values = torch.tensor( list(local_batch_metrics_dict.values()), device=self.accelerator.device, dtype=torch.float32, ) # used to aggregate the value across processes all_batch_metrics = self.accelerator.gather( values ) # returns tensor of shape (world_size, num_metrics) # Aggregate metrics across all processes if self.accelerator.is_main_process: validation_progress_bar.update(1) step = step + 1 all_batch_metrics = all_batch_metrics.view( self.accelerator.num_processes, len(local_batch_metrics_dict) ) # Convert all_batch_metrics to dict with metric names # all_batch_metrics[:, 0] -> loss, all_batch_metrics[:, 1] -> acc, etc. all_batch_metrics = { name: all_batch_metrics[:, i] for i, name in enumerate(local_batch_metrics_dict.keys()) } AcceleratorTrainer.update_metrics_with_simple_moving_average( all_batch_metrics, global_metrics_dict, step ) validation_progress_bar.set_postfix( { name: f"{round(value, 4)}" for name, value in global_metrics_dict.items() } ) if self.accelerator.is_main_process: validation_progress_bar.close() return global_metrics_dict
[docs] def fit_temp(self, train_loader, val_loader, epochs=10, metrics: dict = {}): """Temporary/experimental training method with simplified Accelerate workflow. **Warning**: This method appears to be legacy/debug code and should not be used in production. Use the ``fit()`` method instead. Args: train_loader: DataLoader for training data. val_loader: DataLoader for validation data. epochs: Number of epochs to train. Defaults to 10. metrics: Dictionary mapping metric names to metric functions. Defaults to {}. Note: - This method has several issues compared to the main ``fit()`` method: - References ``self.model`` instead of ``self._model`` - Hardcoded checkpoint paths - Missing checkpoint management features - Uses deprecated ``gather_for_metrics()`` instead of ``gather()`` - This should likely be removed or refactored to align with ``fit()`` Deprecated: Use ``fit()`` method instead for production training. """ print("Number of processes:", self.accelerator.num_processes) print("Is main process:", self.accelerator.is_main_process) print("Device:", self.accelerator.device) # Prepare everything for the current device (CPU/GPU/TPU) model, optimizer, train_loader, val_loader = self.accelerator.prepare( self.model, self.optimizer, train_loader, val_loader ) self.model = model self.optimizer = optimizer for epoch in range(epochs): # Important for multi-process shuffling if hasattr(train_loader, "sampler") and hasattr( train_loader.sampler, "set_epoch" ): train_loader.sampler.set_epoch(epoch) model.train() step = 0 global_metrics = OrderedDict(loss=0.0, **{k: 0.0 for k in metrics}) if self.accelerator.is_main_process: print(f"Epoch {epoch + 1}") pbar = tqdm(train_loader, desc="Training", dynamic_ncols=True) for batch in train_loader: inputs, targets = batch outputs = model(inputs) loss = self.criterion(outputs, targets) self.accelerator.backward(loss) optimizer.step() optimizer.zero_grad() # Gather metrics across devices local_metrics = {"loss": loss.detach()} for name, metric_fn in metrics.items(): local_metrics[name] = metric_fn(outputs, targets) gathered = self.accelerator.gather_for_metrics(local_metrics) if self.accelerator.is_main_process: step += 1 for k, v in gathered.items(): global_metrics[k] += ( v.mean().item() - global_metrics[k] ) / step pbar.set_postfix( {f"train_{k}": round(v, 4) for k, v in global_metrics.items()} ) pbar.update() if self.accelerator.is_main_process: pbar.close() # Validation loop model.eval() step = 0 val_metrics = OrderedDict(loss=0.0, **{k: 0.0 for k in metrics}) if self.accelerator.is_main_process: pbar = tqdm(val_loader, desc="Validation", dynamic_ncols=True) with torch.no_grad(): for batch in val_loader: inputs, targets = batch outputs = model(inputs) loss = self.criterion(outputs, targets) local_metrics = {"loss": loss} for name, metric_fn in metrics.items(): local_metrics[name] = metric_fn(outputs, targets) gathered = self.accelerator.gather_for_metrics(local_metrics) if self.accelerator.is_main_process: step += 1 for k, v in gathered.items(): val_metrics[k] += (v.mean().item() - val_metrics[k]) / step pbar.set_postfix( {f"val_{k}": round(v, 4) for k, v in val_metrics.items()} ) pbar.update() if self.accelerator.is_main_process: pbar.close() print("-" * 40) # Save model if needed if self.accelerator.is_main_process: self.accelerator.save_model(self.model, "./checkpoint") torch.save( {"optimizer": self.optimizer.state_dict()}, "./checkpoint/optimizer.pt" )