import csv
import os
from collections import OrderedDict, defaultdict
from typing import Callable, Dict, Tuple, Union
import numpy as np
import torch
from tqdm import tqdm
import deepml.tasks
from deepml.tasks import Task
from deepml.tracking import MLExperimentLogger, TensorboardLogger
MAC_TORCH_2_2_2 = "2.2.2"
[docs]
class Learner:
"""Training class for learning model weights using PyTorch.
This trainer provides straightforward training functionality with support for learning
rate scheduling, automatic mixed precision (AMP), gradient accumulation, and
gradient clipping. It's designed for single-device training and works well in
interactive environments like Jupyter notebooks.
For multi-GPU or distributed training, consider using :class:`FabricTrainer` or
:class:`AcceleratorTrainer`.
Attributes:
epochs_completed: Number of epochs completed in training.
best_val_loss: Best validation loss achieved during training.
history: Dictionary storing training history metrics across epochs.
logger: Experiment logger for tracking metrics and artifacts.
Note:
This trainer is ideal for:
- Single GPU/CPU training
- Jupyter notebook environments
- Simple training workflows without distributed requirements
- Debugging and prototyping
"""
[docs]
def __init__(
self,
task: Task,
optimizer: torch.optim.Optimizer,
criterion: torch.nn.Module,
lr_scheduler=None,
lr_scheduler_step_policy: str = "epoch",
load_state: bool = False,
use_amp: bool = False,
):
"""Initializes the Learner.
Args:
task: Task object defining the learning task (e.g., classification, segmentation).
optimizer: PyTorch optimizer instance for parameter updates.
criterion: Loss function module.
lr_scheduler: Learning rate scheduler instance. Defaults to None.
lr_scheduler_step_policy: When to call scheduler.step(). Valid options are
``"epoch"`` (step after each epoch) or ``"step"`` (step after each
optimizer update). Defaults to ``"epoch"``.
load_state: Whether to resume model training. If True, loads optimizer state,
scheduler state (if any), and training history from checkpoint.
Defaults to False.
use_amp: Whether to use automatic mixed precision (AMP) for training.
Defaults to False.
"""
assert isinstance(task, Task)
self.__predictor = task
self.__model = self.__predictor.model
self.__model_dir = self.__predictor.model_dir
self.__model_file_name = self.__predictor.model_file_name
self.__optimizer = None
self.__criterion = None
self.__lr_scheduler = None
self.__lr_scheduler_step_policy = None
self.__use_amp = use_amp
self.set_optimizer(optimizer)
self.set_criterion(criterion)
self.set_lr_scheduler(lr_scheduler, lr_scheduler_step_policy)
self.epochs_completed = 0
self.best_val_loss = np.inf
self.history = defaultdict(list)
self.logger = None
os.makedirs(self.__model_dir, exist_ok=True)
self.__metrics_dict = OrderedDict({"loss": 0})
self.__device = self.__predictor.device
# gradient scaler for mixed precision training
# The same GradScaler instance should be used for the entire convergence run, if multiple calls to fit are used.
# torch.amp.GradScaler is not defined for MAC torch 2.2.2, so we use torch.cuda.amp.GradScaler
if torch.__version__ != MAC_TORCH_2_2_2:
self.__scaler = torch.amp.GradScaler(self.__device, enabled=self.__use_amp)
else:
self.__scaler = torch.cuda.amp.GradScaler(
self.__device, enabled=self.__use_amp
)
if load_state:
self.__load_state()
[docs]
def set_optimizer(self, optimizer: torch.optim.Optimizer):
"""Sets the optimizer for training.
Args:
optimizer: PyTorch optimizer instance.
Raises:
AssertionError: If optimizer is not a torch.optim.Optimizer instance.
"""
assert isinstance(optimizer, torch.optim.Optimizer)
self.__optimizer = optimizer
[docs]
def set_criterion(self, criterion: torch.nn.Module):
"""Sets the loss function for training.
Args:
criterion: Loss function module.
Raises:
AssertionError: If criterion is not a torch.nn.Module instance.
"""
assert isinstance(criterion, torch.nn.Module)
self.__criterion = criterion
[docs]
def set_lr_scheduler(self, lr_scheduler, lr_scheduler_step_policy: str = "epoch"):
"""Sets the learning rate scheduler.
Args:
lr_scheduler: Learning rate scheduler instance. If None, no scheduler is used.
lr_scheduler_step_policy: When to call scheduler.step(). Valid options are
``"epoch"`` or ``"step"``. Defaults to ``"epoch"``.
Raises:
AssertionError: If lr_scheduler_step_policy is not ``"epoch"`` or ``"step"``.
"""
if lr_scheduler:
lr_scheduler_step_policy = lr_scheduler_step_policy.lower()
assert isinstance(
lr_scheduler_step_policy, str
) and lr_scheduler_step_policy in ["epoch", "step"]
self.__lr_scheduler = lr_scheduler
self.__lr_scheduler_step_policy = lr_scheduler_step_policy
def __load_state(self):
model_path = os.path.join(self.__model_dir, self.__model_file_name)
if os.path.exists(model_path):
state_dict = (
torch.load(model_path)
if torch.cuda.is_available()
else torch.load(model_path, map_location=torch.device("cpu"))
)
else:
print(f"{model_path} does not exist.")
return
self.__load_optimizer_state(state_dict)
if self.__lr_scheduler:
self.__load_lr_schedular_state(state_dict)
if self.__use_amp and "scaler" in state_dict:
self.__scaler.load_state_dict(state_dict["scaler"])
if "epoch" in state_dict:
self.epochs_completed = state_dict["epoch"]
if "val_loss" in state_dict:
self.best_val_loss = state_dict["val_loss"]
def __load_optimizer_state(self, state_dict):
if "optimizer" in state_dict and "optimizer_state_dict" in state_dict:
if state_dict["optimizer"] == self.__optimizer.__class__.__name__:
self.__optimizer.load_state_dict(state_dict["optimizer_state_dict"])
else:
print(
f"Skipping load optimizer state because {self.__optimizer.__class__.__name__}"
f" != {state_dict['optimizer']}"
)
def __load_lr_schedular_state(self, state_dict):
if "scheduler" in state_dict and "scheduler_state_dict" in state_dict:
if state_dict["scheduler"] == self.__lr_scheduler.__class__.__name__:
self.__lr_scheduler.load_state_dict(state_dict["scheduler_state_dict"])
else:
print(
f"Skipping load lr scheduler state because {self.__lr_scheduler.__class__.__name__}"
f" != {state_dict['scheduler']}"
)
[docs]
def save(
self,
tag: str,
save_optimizer_state: bool = False,
epoch: int = -1,
train_loss: float = None,
val_loss: float = None,
):
"""Saves model checkpoint and training state.
Args:
tag: Name tag for the checkpoint file (without extension).
save_optimizer_state: Whether to include optimizer state in the checkpoint.
Defaults to False.
epoch: Current epoch number. Defaults to -1.
train_loss: Training loss value for this checkpoint. Defaults to None.
val_loss: Validation loss value for this checkpoint. Defaults to None.
Returns:
str: Full path to the saved checkpoint file.
Note:
- Automatically handles DataParallel models
- Saves scheduler state if scheduler is configured
- Saves AMP scaler state if AMP is enabled
- Logs the model to the experiment logger
"""
state_dict = {
"model_state_dict": (
self.__model.module.state_dict()
if isinstance(self.__model, torch.nn.DataParallel)
else self.__model.state_dict()
),
"criterion": self.__criterion.__class__.__name__,
"epoch": epoch,
"train_loss": train_loss,
"val_loss": val_loss,
}
if save_optimizer_state:
state_dict["optimizer"] = self.__optimizer.__class__.__name__
state_dict["optimizer_state_dict"] = self.__optimizer.state_dict()
if self.__lr_scheduler:
state_dict["scheduler"] = self.__lr_scheduler.__class__.__name__
state_dict["scheduler_state_dict"] = self.__lr_scheduler.state_dict()
if self.__use_amp:
state_dict["scaler"] = self.__scaler.state_dict()
filepath = f"{os.path.join(self.__model_dir, tag)}.pt"
torch.save(state_dict, filepath)
self.logger.log_model(tag, self.__model, epoch, artifact_path=filepath)
return filepath
@torch.no_grad()
def validate(
self,
loader: torch.utils.data.DataLoader,
criterion: torch.nn.Module,
metrics: Dict[str, torch.nn.Module] = None,
non_blocking=False,
):
"""Evaluates the model on the validation data.
Args:
loader: DataLoader for validation data.
criterion: Loss function module.
metrics: Dictionary mapping metric names to metric modules. Defaults to None.
non_blocking: Whether to use asynchronous CUDA transfers. Defaults to False.
Returns:
OrderedDict mapping metric names to their average values across all batches.
Raises:
Exception: If loader is None.
Note:
- Model is set to eval() mode
- Gradients are disabled via @torch.no_grad() decorator
- Metrics are computed as running averages
"""
if loader is None:
raise Exception("Loader cannot be None.")
self.__model.eval()
self.__metrics_dict["loss"] = 0
self.__init_metrics(metrics)
with tqdm(
total=len(loader), desc="{:12s}".format("Validation"), dynamic_ncols=True
) as bar:
for batch_index, (x, y) in enumerate(loader):
outputs, x, y = self.__predictor.eval_step(x, y, non_blocking)
if isinstance(y, torch.Tensor):
y = y.to(self.__device)
if (
isinstance(outputs, torch.Tensor)
and outputs.ndim == 2
and outputs.shape[1] == 1
):
y = y.view_as(outputs)
loss = criterion(outputs, y)
self.__metrics_dict["loss"] = self.__metrics_dict["loss"] + (
(loss.item() - self.__metrics_dict["loss"]) / (batch_index + 1)
)
self.__update_metrics(outputs, y, metrics, batch_index + 1)
bar.update(1)
bar.set_postfix(
{
name: f"{round(value, 4)}"
for name, value in self.__metrics_dict.items()
}
)
return self.__metrics_dict
[docs]
def set_predictor(self, predictor: deepml.tasks.Task):
assert isinstance(predictor, Task)
self.__predictor = predictor
def __init_metrics(self, metrics: Dict[str, torch.nn.Module]):
if metrics is None:
return
for metric_name, _ in metrics.items():
if metric_name == "loss":
raise ValueError("Metric name 'loss' is reserved of criterion")
self.__metrics_dict[metric_name] = 0
def __update_metrics(
self,
outputs: torch.Tensor,
targets: torch.Tensor,
metrics: Dict[str, torch.nn.Module],
step: int,
):
if metrics is None:
return
with torch.no_grad():
for metric_name, metric_instance in metrics.items():
self.__metrics_dict[metric_name] = self.__metrics_dict[metric_name] + (
(
metric_instance(outputs, targets).item()
- self.__metrics_dict[metric_name]
)
/ step
)
def __write_metrics_to_logger(self, tag: str, global_step: int):
for name, value in self.__metrics_dict.items():
self.logger.log_metric(f"{name}/{tag}", value, global_step)
def __write_history(self, stage: str):
for name, value in self.__metrics_dict.items():
self.history[f"{stage}_{name}"].append(value)
def __write_lr(self, global_step: int):
# Write lr to tensor-board and history dict
if len(self.__optimizer.param_groups) == 1:
param_group = self.__optimizer.param_groups[0]
self.logger.log_metric("learning_rate", param_group["lr"], global_step)
self.history["learning_rate"].append(param_group["lr"])
else:
for index, param_group in enumerate(self.__optimizer.param_groups):
self.logger.log_metric(
f"learning_rate/param_group_{index}", param_group["lr"], global_step
)
self.history[f"learning_rate/param_group_{index}"].append(
param_group["lr"]
)
[docs]
def fit(
self,
train_loader: torch.utils.data.DataLoader,
val_loader: torch.utils.data.DataLoader = None,
epochs: int = 10,
steps_per_epoch: int = None,
save_model_after_every_epoch: int = 5,
metrics: Dict[str, torch.nn.Module] = None,
gradient_accumulation_steps: int = 1,
gradient_clip_value: float = 0,
gradient_clip_algorithm: str = "norm",
logger: MLExperimentLogger = None,
non_blocking: bool = True,
image_inverse_transform: Callable = None,
logger_img_size: Union[int, Tuple[int, int]] = None,
):
"""Trains the model for the specified number of epochs.
Args:
train_loader: DataLoader for training data.
val_loader: DataLoader for validation data. Defaults to None.
epochs: Total number of epochs to train. Defaults to 10.
steps_per_epoch: Number of steps per epoch. Should be around len(train_loader)
to ensure full dataset coverage. If None, defaults to len(train_loader).
Defaults to None.
save_model_after_every_epoch: Frequency (in epochs) to save model checkpoints.
Defaults to 5.
metrics: Dictionary mapping metric names to metric instances. Each metric
must be a torch.nn.Module with a forward() method. Defaults to None.
gradient_accumulation_steps: Number of steps to accumulate gradients before
performing an optimizer step. Simulates larger batch sizes. Must be > 0.
Defaults to 1.
gradient_clip_value: Maximum value for gradient clipping. If 0, no clipping
is applied. Defaults to 0.
gradient_clip_algorithm: Gradient clipping algorithm. Options:
- ``"norm"``: Clip by gradient norm (recommended)
- ``"value"``: Clip by gradient value
Defaults to ``"norm"``.
logger: Experiment logger for tracking metrics and artifacts. If None, uses
TensorboardLogger. Defaults to None.
non_blocking: Whether to use asynchronous CUDA tensor transfers.
Defaults to True.
image_inverse_transform: Transformation to reverse image normalization for
visualization in TensorBoard. Defaults to None.
logger_img_size: Image size (int or tuple) for TensorBoard logging.
Defaults to None.
Raises:
AssertionError: If steps_per_epoch > len(train_loader).
AssertionError: If gradient_accumulation_steps <= 0.
AssertionError: If gradient_clip_algorithm not in ["norm", "value"].
TypeError: If any metric is not a torch.nn.Module with a forward() method.
Note:
- Supports automatic mixed precision (AMP) if enabled in __init__
- Automatically saves best validation model when validation improves
- Handles DataParallel models automatically
- Learning rate scheduler can step per epoch or per gradient update
- For multi-GPU/distributed training, use FabricTrainer or AcceleratorTrainer
"""
if steps_per_epoch is None:
steps_per_epoch = len(train_loader)
assert steps_per_epoch <= len(
train_loader
), "Steps per epoch should not be greater than len(train_loader)"
assert (
gradient_accumulation_steps > 0
), "Accumulation steps should be greater than 0"
assert gradient_clip_algorithm in ["norm", "value"]
self.__model.to(self.__device)
self.__criterion = self.__criterion.to(self.__device)
# initialize the logger if not provided
if self.logger is None:
self.logger = (
logger if logger is not None else TensorboardLogger(self.__model_dir)
)
# Log params
self.logger.log_params(
task=self.__predictor,
loader=val_loader,
epochs=epochs,
criterion=self.__criterion,
lr_scheduler=self.__lr_scheduler,
)
# Check valid metrics types
if metrics:
for metric_name, metric_instance in metrics.items():
if not (
isinstance(metric_instance, torch.nn.Module)
and hasattr(metric_instance, "forward")
):
raise TypeError(f"{metric_instance.__class__} is not supported")
# Replace all metrics during call to learner fit
self.__metrics_dict = OrderedDict({"loss": 0})
train_loss = 0
epochs = self.epochs_completed + epochs
# Nullify the parameter gradients
self.__optimizer.zero_grad(set_to_none=True)
for epoch in range(self.epochs_completed, epochs):
tqdm.write("Epoch {}/{}:".format(epoch + 1, epochs))
# Training mode
self.__model.train()
# Iterate over batches
step = 0
# init all metrics with zeros
self.__metrics_dict["loss"] = 0
self.__init_metrics(metrics)
# Write current lr to logger
self.__write_lr(epoch + 1)
with tqdm(
total=steps_per_epoch,
desc="{:12s}".format("Training"),
dynamic_ncols=True,
) as bar:
for batch_index, (x, y) in enumerate(train_loader):
if self.__use_amp:
# Enable autocast for mixed precision training
with torch.autocast(
dtype=torch.float16,
device_type=str(self.__device),
enabled=self.__use_amp,
):
outputs, x, y = self.__predictor.train_step(
x, y, non_blocking
)
if (
isinstance(outputs, torch.Tensor)
and outputs.ndim == 2
and outputs.shape[1] == 1
):
y = y.view_as(outputs)
loss = self.__criterion(outputs, y)
loss = (
loss / gradient_accumulation_steps
) # Normalize loss by accumulation steps
# Accumulates scaled gradients
self.__scaler.scale(loss).backward()
else:
outputs, x, y = self.__predictor.train_step(x, y, non_blocking)
if (
isinstance(outputs, torch.Tensor)
and outputs.ndim == 2
and outputs.shape[1] == 1
):
y = y.view_as(outputs)
loss = self.__criterion(outputs, y)
loss = loss / gradient_accumulation_steps
loss.backward()
if (batch_index + 1) % gradient_accumulation_steps == 0 or (
batch_index + 1
) == len(train_loader):
# Apply Gradient clipping
if gradient_clip_value is not None and gradient_clip_value > 0:
self.__scaler.unscale_(self.__optimizer)
if gradient_clip_algorithm == "norm":
torch.nn.utils.clip_grad_norm_(
self.__model.parameters(), gradient_clip_value
)
elif gradient_clip_algorithm == "value":
torch.nn.utils.clip_grad_value_(
self.__model.parameters(), gradient_clip_value
)
if self.__use_amp:
# Update model parameters
self.__scaler.step(self.__optimizer)
# Updates the scale for next iteration.
self.__scaler.update()
else:
# Update model parameters
self.__optimizer.step()
if (
self.__lr_scheduler
and self.__lr_scheduler_step_policy == "step"
):
self.__lr_scheduler.step()
# Nullify the parameter gradients
self.__optimizer.zero_grad(set_to_none=True)
step = step + 1
self.__metrics_dict["loss"] = self.__metrics_dict["loss"] + (
(loss.item() - self.__metrics_dict["loss"]) / step
)
# Update metrics
self.__update_metrics(outputs, y, metrics, step)
bar.set_postfix(
{
name: f"{round(value, 4)}"
for name, value in self.__metrics_dict.items()
}
)
bar.update(1)
if (batch_index + 1) % steps_per_epoch == 0:
break
self.epochs_completed = self.epochs_completed + 1
train_loss = self.__metrics_dict["loss"]
self.__write_metrics_to_logger("train", self.epochs_completed)
self.__write_history("train")
message = f"Training Loss: {train_loss:.4f} "
val_loss = np.inf
if val_loader is not None:
self.validate(val_loader, self.__criterion, metrics)
val_loss = self.__metrics_dict["loss"]
self.__write_metrics_to_logger("val", self.epochs_completed)
self.__write_history("val")
message = message + f"Validation Loss: {val_loss:.4f} "
# write random val images to tensorboard
if logger_img_size is not None:
self.__predictor.write_prediction_to_logger(
"val",
val_loader,
self.logger,
image_inverse_transform,
self.epochs_completed,
img_size=logger_img_size,
)
# Save best validation model
if val_loss < self.best_val_loss:
message = message + "[Saving best validation model]"
self.best_val_loss = val_loss
self.save(
"best_val_model",
save_optimizer_state=True,
epoch=self.epochs_completed,
train_loss=train_loss,
val_loss=val_loss,
)
if self.__lr_scheduler and self.__lr_scheduler_step_policy == "epoch":
if val_loader and isinstance(
self.__lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau
):
self.__lr_scheduler.step(val_loss)
else:
self.__lr_scheduler.step()
tqdm.write(message)
if self.epochs_completed % save_model_after_every_epoch == 0:
model_tag_name = "epoch_{}_model".format(self.epochs_completed)
self.save(
model_tag_name,
save_optimizer_state=True,
epoch=self.epochs_completed,
train_loss=train_loss,
val_loss=val_loss,
)
# Save latest model at the end
self.save(
"latest_model",
save_optimizer_state=True,
epoch=self.epochs_completed,
train_loss=train_loss,
val_loss=self.best_val_loss,
)
[docs]
def predict(self, loader):
"""Generates predictions for all data in the loader.
Args:
loader: DataLoader containing data for prediction.
Returns:
Tuple of (predictions, targets) where predictions are the model outputs
and targets are the ground truth labels.
"""
predictions, targets = self.__predictor.predict(loader)
return predictions, targets
[docs]
def predict_class(self, loader):
"""Generates class predictions with probabilities for all data.
Args:
loader: DataLoader containing data for prediction.
Returns:
Tuple of (predicted_class, probability, targets) where:
- predicted_class: Predicted class labels
- probability: Class probabilities or confidence scores
- targets: Ground truth labels
"""
predicted_class, probability, targets = self.__predictor.predict_class(loader)
return predicted_class, probability, targets
[docs]
def show_predictions(
self,
loader,
image_inverse_transform=None,
samples=9,
cols=3,
figsize=(10, 10),
target_known=True,
):
"""Visualizes model predictions on sample images.
Args:
loader: DataLoader containing data for visualization.
image_inverse_transform: Transformation to reverse image normalization for
display. Defaults to None.
samples: Number of samples to display. Defaults to 9.
cols: Number of columns in the visualization grid. Defaults to 3.
figsize: Figure size as (width, height) tuple. Defaults to (10, 10).
target_known: Whether ground truth targets are available for comparison.
Defaults to True.
"""
self.__predictor.show_predictions(
loader,
image_inverse_transform=image_inverse_transform,
samples=samples,
cols=cols,
figsize=figsize,
target_known=target_known,
)