import os
from collections import OrderedDict, defaultdict
from typing import Callable, Dict, Optional, Tuple, Union
import torch
from lightning_fabric import Fabric
from tqdm import tqdm
from deepml.base import BaseLearner
from deepml.tasks import Task
from deepml.tracking import MLExperimentLogger, TensorboardLogger
[docs]
class FabricTrainer(BaseLearner):
"""Training class for learning model weights using Lightning Fabric.
This trainer leverages Lightning Fabric for distributed training, mixed precision,
and hardware acceleration while maintaining a simple PyTorch-like interface.
It supports features like gradient accumulation, gradient clipping, learning rate
scheduling, checkpointing, and logging with experiment tracking integration. The trainer
is designed to be flexible and extensible for various types of learning tasks defined by the Task abstraction.
"""
[docs]
def __init__(
self,
task: Task,
optimizer: torch.optim.Optimizer,
criterion: torch.nn.Module,
lr_scheduler_fn: Optional[
Callable[[torch.optim.Optimizer], torch.optim.lr_scheduler._LRScheduler]
] = None,
lr_scheduler_step_policy: str = "epoch",
accelerator: Union[str, int] = "auto",
strategy: Union[str, int] = "auto",
devices: Union[str, int] = "auto",
precision: str = "32-true",
num_nodes: int = 1,
fabric_plugins: Optional = None,
):
"""Initializes the FabricTrainer.
Args:
task: Task object defining the learning task (e.g., classification, segmentation).
optimizer: PyTorch optimizer instance for parameter updates.
criterion: Loss function module.
lr_scheduler_fn: Factory function that creates a learning rate scheduler.
Should accept an optimizer and return a scheduler instance.
Example: ``lambda optimizer: StepLR(optimizer, step_size=5, gamma=0.5)``.
Defaults to None.
lr_scheduler_step_policy: When to call scheduler.step(). Valid options are
``"epoch"`` (step after each epoch) or ``"step"`` (step after each
gradient update). Defaults to ``"epoch"``.
accelerator: Hardware accelerator to use. Options: ``"cpu"``, ``"cuda"``,
``"mps"``, ``"gpu"``, ``"tpu"``, or ``"auto"``. Defaults to ``"auto"``.
strategy: Distributed training strategy. Options: ``"dp"``, ``"ddp"``,
``"fsdp"``, ``"deepspeed"``, ``"ddp_spawn"``, or ``"auto"``.
Defaults to ``"auto"``.
devices: Number or list of devices to use. Can be int, str, or ``"auto"``.
Defaults to ``"auto"``.
precision: Training precision. Options: ``"16-mixed"``, ``"32-true"``,
``"64-true"``, ``"bf16-mixed"``, ``"bf16-true"``, or ``"auto"``.
Defaults to ``"32-true"``.
num_nodes: Number of nodes for multi-node distributed training.
Defaults to 1.
fabric_plugins: Optional Fabric plugins for custom behaviors (e.g.,
DeepSpeedPlugin, BitsandbytesPrecision). Defaults to None.
Example:
>>> from lightning_fabric.plugins import BitsandbytesPrecision
>>> plugin = BitsandbytesPrecision(mode="int8")
>>> trainer = FabricTrainer(
... task=task,
... optimizer=optimizer,
... criterion=criterion,
... fabric_plugins=plugin
... )
"""
super().__init__(
task=task,
optimizer=optimizer,
criterion=criterion,
lr_scheduler=None,
lr_scheduler_fn=lr_scheduler_fn,
lr_scheduler_step_policy=lr_scheduler_step_policy,
)
self.fabric = Fabric(
accelerator=accelerator,
strategy=strategy,
devices=devices,
precision=precision,
num_nodes=num_nodes,
plugins=fabric_plugins,
)
self._task._device = self.fabric.device
self.epochs_completed = 0
self.best_val_loss = float("inf")
self.history = defaultdict(list)
self.logger = None
os.makedirs(self._model_dir, exist_ok=True)
[docs]
def fit(
self,
train_loader: torch.utils.data.DataLoader,
val_loader: torch.utils.data.DataLoader = None,
epochs: int = 10,
save_model_after_every_epoch: int = 5,
metrics: Dict[str, torch.nn.Module] = None,
gradient_accumulation_steps: int = 1,
gradient_clip_value: Optional[float] = None,
gradient_clip_max_norm: Optional[float] = None,
resume_from_checkpoint: str = None,
load_optimizer_state: bool = False,
load_scheduler_state: bool = False,
logger: MLExperimentLogger = None,
non_blocking: bool = True,
image_inverse_transform: Callable = None,
logger_img_size: Union[int, Tuple[int, int]] = None,
):
"""Trains the model for the specified number of epochs.
This method launches distributed training using Lightning Fabric and handles
checkpointing, logging, and training history management.
Args:
train_loader: DataLoader for training data.
val_loader: DataLoader for validation data. Defaults to None.
epochs: Total number of epochs to train. Defaults to 10.
save_model_after_every_epoch: Frequency (in epochs) to save model checkpoints.
Defaults to 5.
metrics: Dictionary mapping metric names to metric instances. Each metric
must be a torch.nn.Module with a forward() method. Defaults to None.
gradient_accumulation_steps: Number of steps to accumulate gradients before
performing an optimizer step. Simulates larger batch sizes. Defaults to 1.
gradient_clip_value: Maximum absolute value for gradient clipping. Gradients
will be clipped to [-gradient_clip_value, gradient_clip_value].
Defaults to None (no clipping).
gradient_clip_max_norm: Maximum L2 norm for gradient clipping. Defaults to
None (no clipping).
resume_from_checkpoint: Path to checkpoint file to resume training from.
Defaults to None.
load_optimizer_state: Whether to load optimizer state from checkpoint.
Defaults to False.
load_scheduler_state: Whether to load learning rate scheduler state from
checkpoint. Defaults to False.
logger: Experiment logger for tracking metrics and artifacts. If None, uses
TensorboardLogger. Defaults to None.
non_blocking: Whether to use asynchronous CUDA tensor transfers.
Defaults to True.
image_inverse_transform: Transformation to reverse image normalization for
visualization in TensorBoard. Defaults to None.
logger_img_size: Image size (int or tuple) for TensorBoard logging.
Defaults to None.
Note:
After training completes, the latest model checkpoint is automatically loaded
into the trainer's model and optimizer.
"""
history = self.fabric.launch(
self._fit_impl,
train_loader,
val_loader=val_loader,
epochs=epochs,
save_model_after_every_epoch=save_model_after_every_epoch,
metrics=metrics,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_clip_value=gradient_clip_value,
gradient_clip_max_norm=gradient_clip_max_norm,
resume_from_checkpoint=resume_from_checkpoint,
load_optimizer_state=load_optimizer_state,
load_scheduler_state=load_scheduler_state,
logger=logger,
non_blocking=non_blocking,
image_inverse_transform=image_inverse_transform,
logger_img_size=logger_img_size,
)
# after training is complete, load model weights back
if self.fabric.is_global_zero:
latest_checkpoint_filepath = (
f"{os.path.join(self._model_dir, 'latest_model')}.pt"
)
state_dict = torch.load(
latest_checkpoint_filepath, map_location=self.fabric.device
)
self._model.load_state_dict(state_dict["model_state_dict"])
self._optimizer.load_state_dict(state_dict["optimizer_state_dict"])
self.epochs_completed = state_dict.get("epoch", 0)
self.best_val_loss = state_dict.get("val_loss", float("inf"))
# update history list
for key, value in history.items():
self.history[key].extend(value)
def _fit_impl(
self,
fabric: Fabric,
train_loader: torch.utils.data.DataLoader,
val_loader: torch.utils.data.DataLoader = None,
epochs: int = 10,
save_model_after_every_epoch: int = 5,
metrics: Dict[str, torch.nn.Module] = None,
gradient_accumulation_steps: int = 1,
gradient_clip_value: Optional[float] = None,
gradient_clip_max_norm: Optional[float] = None,
resume_from_checkpoint: str = None,
load_optimizer_state: bool = False,
load_scheduler_state: bool = False,
logger: MLExperimentLogger = None,
non_blocking: bool = True,
image_inverse_transform: Callable = None,
logger_img_size: Union[int, Tuple[int, int]] = None,
) -> Dict[str, list]:
"""Internal implementation of training loop using Lightning Fabric.
This method is launched by Fabric and runs the actual training loop across
distributed processes. It handles model setup, checkpointing, validation,
and metric tracking.
Args:
fabric: Lightning Fabric instance for distributed training utilities.
train_loader: DataLoader for training data.
val_loader: DataLoader for validation data. Defaults to None.
epochs: Total number of epochs to train. Defaults to 10.
save_model_after_every_epoch: Frequency (in epochs) to save model checkpoints.
Defaults to 5.
metrics: Dictionary mapping metric names to metric instances. Each metric
must be a torch.nn.Module with a forward() method. Defaults to None.
gradient_accumulation_steps: Number of steps to accumulate gradients before
performing an optimizer step. Must be greater than 0. Defaults to 1.
gradient_clip_value: Maximum absolute value for gradient clipping. Gradients
will be clipped to [-gradient_clip_value, gradient_clip_value].
Defaults to None (no clipping).
gradient_clip_max_norm: Maximum L2 norm for gradient clipping. Defaults to
None (no clipping).
resume_from_checkpoint: Path to checkpoint file to resume training from.
Defaults to None.
load_optimizer_state: Whether to load optimizer state from checkpoint.
Defaults to False.
load_scheduler_state: Whether to load learning rate scheduler state from
checkpoint. Defaults to False.
logger: Experiment logger for tracking metrics and artifacts. Defaults to None.
non_blocking: Whether to use asynchronous CUDA tensor transfers.
Defaults to True.
image_inverse_transform: Transformation to reverse image normalization for
visualization in TensorBoard. Defaults to None.
logger_img_size: Image size (int or tuple) for TensorBoard logging.
Defaults to None.
Returns:
Dictionary containing training history with metric names as keys and
lists of values as entries.
Raises:
AssertionError: If gradient_accumulation_steps is not greater than 0.
ValueError: If both gradient_clip_value and gradient_clip_max_norm are
provided (only one can be used).
TypeError: If any metric is not a torch.nn.Module with a forward() method.
Note:
Only the global zero process saves checkpoints and manages the logger.
All processes synchronize at the end of each epoch using fabric.barrier().
"""
assert (
gradient_accumulation_steps > 0
), "Accumulation steps should be greater than 0"
if gradient_clip_value is not None and gradient_clip_max_norm is not None:
raise ValueError(
"Only one of gradient_clip_value or gradient_clip_max_norm should be passed."
)
# Check valid metrics types
if metrics:
for metric_name, metric_instance in metrics.items():
if not (
isinstance(metric_instance, torch.nn.Module)
and hasattr(metric_instance, "forward")
):
raise TypeError(f"{metric_instance.__class__} is not supported")
state_dict = {}
if resume_from_checkpoint is not None and os.path.exists(
resume_from_checkpoint
):
state_dict = torch.load(resume_from_checkpoint, map_location=fabric.device)
self._model.load_state_dict(state_dict["model_state_dict"])
self.epochs_completed = state_dict.get("epoch", 0)
self.best_val_loss = state_dict.get("val_loss", float("inf"))
if fabric.is_global_zero:
print(
f"Resuming training from epoch {self.epochs_completed} with best validation loss {self.best_val_loss}"
)
model, optimizer = fabric.setup(self._model, self._optimizer)
if load_optimizer_state:
FabricTrainer.load_optimizer_state(optimizer, state_dict)
train_loader, val_loader = fabric.setup_dataloaders(train_loader, val_loader)
lr_scheduler = (
self._lr_scheduler_fn(optimizer)
if self._lr_scheduler_fn is not None
else None
)
if (
lr_scheduler is not None
and load_scheduler_state
and "scheduler_state_dict" in state_dict
):
FabricTrainer.load_lr_schedular_state(lr_scheduler, state_dict)
if fabric.is_global_zero:
self.logger = (
logger if logger is not None else TensorboardLogger(self._model_dir)
)
self.logger.log_params(
task=self._task,
loader=val_loader,
epochs=epochs,
criterion=self._criterion,
lr_scheduler=lr_scheduler,
)
criterion = self._criterion
epochs_completed = self.epochs_completed
best_val_loss = self.best_val_loss
epochs = epochs_completed + epochs
history = defaultdict(list)
val_global_metrics_dict = {"loss": float("inf")}
train_loss = float("inf")
val_loss = float("inf")
for epoch in range(epochs_completed, epochs):
if fabric.is_global_zero:
print("Epoch {}/{}:".format(epoch + 1, epochs))
FabricTrainer.write_lr(optimizer, epoch + 1, self.logger, history)
# training
train_global_metrics_dict = self.__train(
fabric,
model,
optimizer,
criterion,
train_loader,
step_lr_scheduler=(
lr_scheduler if self._lr_scheduler_step_policy == "step" else None
),
metrics=metrics,
non_blocking=non_blocking,
gradient_accumulation_steps=gradient_accumulation_steps,
gradient_clip_value=gradient_clip_value,
gradient_clip_max_norm=gradient_clip_max_norm,
)
# evaluation
if val_loader is not None:
val_global_metrics_dict = self.__validate(
fabric,
model,
val_loader,
criterion,
metrics,
non_blocking=non_blocking,
)
# After each epoch completed, write metrics to logger
if fabric.is_global_zero:
epochs_completed = epochs_completed + 1
self.log_metrics(
val_loader,
train_global_metrics_dict,
val_global_metrics_dict,
history,
epochs_completed,
logger_img_size,
image_inverse_transform,
)
train_loss = train_global_metrics_dict["loss"]
val_loss = val_global_metrics_dict["loss"]
message = f"\nTrain Loss: {train_loss:.4f} Val Loss: {val_loss:.4f}"
# Save the best validation model
if val_loss < best_val_loss:
message = message + " [Saving best validation model]"
best_val_loss = val_loss
self.save(
"best_val_model",
model,
optimizer,
criterion,
lr_scheduler,
epoch=epochs_completed,
train_loss=train_loss,
val_loss=val_loss,
)
# Log info message to console only global zero process
tqdm.write(message)
if epochs_completed % save_model_after_every_epoch == 0:
last_checkpoint = "epoch_{}_model".format(epochs_completed)
self.save(
last_checkpoint,
model,
optimizer,
criterion,
lr_scheduler,
epoch=epochs_completed,
train_loss=train_loss,
val_loss=val_loss,
)
# Ensure all processes are synchronized before proceeding next epoch
fabric.barrier()
# LR Scheduler step after each epoch
if lr_scheduler is not None and self._lr_scheduler_step_policy == "epoch":
if val_loader and isinstance(
lr_scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau
):
lr_scheduler.step(val_global_metrics_dict["loss"])
else:
lr_scheduler.step()
# Save latest model at the end
if fabric.is_global_zero:
self.save(
"latest_model",
model,
optimizer,
criterion,
lr_scheduler,
epoch=epochs_completed,
train_loss=train_loss,
val_loss=val_loss,
)
return history
def __train(
self,
fabric: Fabric,
model: torch.nn.Module,
optimizer: torch.optim.Optimizer,
criterion: torch.nn.Module,
train_loader: torch.utils.data.DataLoader,
step_lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None,
metrics: Dict[str, torch.nn.Module] = None,
non_blocking: bool = True,
gradient_accumulation_steps: int = 1,
gradient_clip_value: Optional[float] = None,
gradient_clip_max_norm: Optional[float] = None,
) -> OrderedDict[str, float]:
"""Runs a single training epoch with gradient accumulation and distributed training support.
Args:
fabric: Lightning Fabric instance used for device, sync and distributed utilities.
model: Model to train; the function will set it to train() mode.
optimizer: Optimizer used to update model parameters.
criterion: Loss function accepting (outputs, targets) and returning a scalar tensor.
train_loader: Iterable yielding batches in the form (inputs, targets).
step_lr_scheduler: Learning rate scheduler that should be stepped after each
optimizer.step() (for "step" policy). Defaults to None.
metrics: Mapping of metric name to metric module. Each metric should accept
(outputs, targets). Defaults to None.
non_blocking: If True, use non_blocking tensor transfers to device when available.
Defaults to True.
gradient_accumulation_steps: Number of micro-batches to accumulate gradients over
before calling optimizer.step(). Defaults to 1.
gradient_clip_value: If set, gradients will be clipped element-wise to the range
[-gradient_clip_value, gradient_clip_value]. Defaults to None.
gradient_clip_max_norm: If set, gradients will be clipped by global norm to this
value. Defaults to None.
Returns:
OrderedDict mapping metric names to aggregated values (simple moving average)
across all processes. Only meaningful on the global zero process.
Note:
- Uses Fabric's ``no_backward_sync`` to avoid gradient sync during accumulation.
- Aggregates per-batch metrics across processes using ``fabric.all_gather`` and
computes a simple moving average.
- Progress bars and returned metrics are managed only on the global zero process.
"""
# Training mode
model.train()
# init all metrics with zeros
local_batch_metrics_dict = FabricTrainer.init_metrics(metrics)
training_progress_bar = None
step = None
global_metrics_dict = {}
if fabric.is_global_zero:
# Global metrics dict for tracking metrics from all processes,
# separate history is used to track metrics across multiple calls to fit method
global_metrics_dict = FabricTrainer.init_metrics(metrics)
training_progress_bar = tqdm(
total=len(train_loader),
desc="{:12s}".format("Training"),
dynamic_ncols=True,
)
# count number of steps
step = 0
# Nullify the parameter gradients
optimizer.zero_grad(set_to_none=True)
for batch_index, (x, y) in enumerate(train_loader):
is_accumulating = batch_index % gradient_accumulation_steps != 0
is_last_batch = (batch_index + 1) == len(train_loader)
# If we are accumulating gradients, we do not need to step the optimizer
with fabric.no_backward_sync(model, enabled=is_accumulating):
outputs, x, y = self._task.train_step(
x, y, model=model, device=fabric.device, non_blocking=non_blocking
)
if (
isinstance(outputs, torch.Tensor)
and outputs.ndim == 2
and outputs.shape[1] == 1
):
y = y.view_as(outputs)
loss = criterion(outputs, y)
fabric.backward(loss / gradient_accumulation_steps) # normalize loss
# we gather log loss and metrics at each batch, so no need to sum up running loss during accumulation
local_batch_metrics_dict["loss"] = loss.detach()
FabricTrainer.update_metrics(outputs, y, metrics, local_batch_metrics_dict)
# collect metric values from all processes using tensor type, avoid dict type
values = torch.tensor(
[
v.detach() if isinstance(v, torch.Tensor) else v
for v in local_batch_metrics_dict.values()
],
device=fabric.device,
dtype=torch.float32,
)
# all_gather is used to aggregate the value across processes
all_batch_metrics = fabric.all_gather(
values
) # returns tensor of shape (world_size, num_metrics)
# update progress bar for each batch
# Aggregate metrics across all processes
if fabric.is_global_zero:
training_progress_bar.update(1)
step = step + 1
all_batch_metrics = all_batch_metrics.view(
fabric.world_size, len(local_batch_metrics_dict)
)
# Convert all_batch_metrics to dict with metric names
all_batch_metrics = {
name: all_batch_metrics[
:, i
] # all_batch_metrics[:, 0] -> loss, all_batch_metrics[:, 1] -> acc, etc.
for i, name in enumerate(local_batch_metrics_dict.keys())
}
FabricTrainer.update_metrics_with_simple_moving_average(
all_batch_metrics, global_metrics_dict, step
)
training_progress_bar.set_postfix(
{
name: f"{round(value, 4)}"
for name, value in global_metrics_dict.items()
}
)
# If we are not accumulating gradients, we step the optimizer
if not is_accumulating or is_last_batch:
# Gradient clipping
if gradient_clip_value is not None:
fabric.clip_gradients(
model, optimizer, clip_val=gradient_clip_value
)
elif gradient_clip_max_norm is not None:
fabric.clip_gradients(
model, optimizer, max_norm=gradient_clip_max_norm
)
optimizer.step()
if step_lr_scheduler is not None:
step_lr_scheduler.step()
# Nullify the parameter gradients
optimizer.zero_grad(set_to_none=True)
if fabric.is_global_zero:
training_progress_bar.close()
return global_metrics_dict
@torch.no_grad()
def __validate(
self,
fabric: Fabric,
model: torch.nn.Module,
loader: torch.utils.data.DataLoader,
criterion: torch.nn.Module,
metrics: Dict[str, torch.nn.Module] = None,
non_blocking: bool = True,
):
"""Runs a single validation epoch across all processes.
Args:
fabric: Lightning Fabric instance used for device, sync and distributed utilities.
model: Model to evaluate; the function will set it to eval() mode.
loader: DataLoader yielding batches in the form (inputs, targets).
criterion: Loss function accepting (outputs, targets) and returning a scalar tensor.
metrics: Mapping of metric name to metric module. Each metric should accept
(outputs, targets). Defaults to None.
non_blocking: If True, use non_blocking tensor transfers to device when available.
Defaults to True.
Returns:
OrderedDict mapping metric names to aggregated values (simple moving average)
across all processes. Only meaningful on the global zero process.
Note:
- Gradients are disabled via ``@torch.no_grad()`` decorator.
- Aggregates per-batch metrics across processes using ``fabric.all_gather``.
- Progress bars and returned metrics are managed only on the global zero process.
"""
model.eval()
local_batch_metrics_dict = FabricTrainer.init_metrics(metrics)
validation_progress_bar = None
global_metrics_dict = FabricTrainer.init_metrics(metrics)
step = 0
if fabric.is_global_zero:
validation_progress_bar = tqdm(
total=len(loader),
desc="{:12s}".format("Validation"),
dynamic_ncols=True,
leave=True,
)
for batch_index, (x, y) in enumerate(loader):
outputs, x, y = self._task.eval_step(
x, y, model=model, device=fabric.device, non_blocking=non_blocking
)
if isinstance(y, torch.Tensor):
y = y.to(fabric.device)
if (
isinstance(outputs, torch.Tensor)
and outputs.ndim == 2
and outputs.shape[1] == 1
):
y = y.view_as(outputs)
loss = criterion(outputs, y)
local_batch_metrics_dict["loss"] = loss.detach()
FabricTrainer.update_metrics(outputs, y, metrics, local_batch_metrics_dict)
# collect metric values from all processes using tensor type, avoid dict type
values = torch.tensor(
[
v.detach() if isinstance(v, torch.Tensor) else v
for v in local_batch_metrics_dict.values()
],
device=fabric.device,
dtype=torch.float32,
)
# all_gather is used to aggregate the value across processes
all_batch_metrics = fabric.all_gather(
values
) # returns tensor of shape (world_size, num_metrics)
# Aggregate metrics across all processes
if fabric.is_global_zero:
validation_progress_bar.update(1)
step = step + 1
all_batch_metrics = all_batch_metrics.view(
fabric.world_size, len(local_batch_metrics_dict)
)
# Convert all_batch_metrics to dict with metric names
# all_batch_metrics[:, 0] -> loss, all_batch_metrics[:, 1] -> acc, etc.
all_batch_metrics = {
name: all_batch_metrics[:, i]
for i, name in enumerate(local_batch_metrics_dict.keys())
}
FabricTrainer.update_metrics_with_simple_moving_average(
all_batch_metrics, global_metrics_dict, step
)
validation_progress_bar.set_postfix(
{
name: f"{round(value, 4)}"
for name, value in global_metrics_dict.items()
}
)
if fabric.is_global_zero:
validation_progress_bar.close()
return global_metrics_dict
[docs]
def predict(self, loader):
"""Generates predictions for the given data loader.
Args:
loader: DataLoader containing data for prediction.
Returns:
Tuple of (predictions, targets) where predictions are the model outputs
and targets are the ground truth labels.
"""
predictions, targets = self._task.predict(loader)
return predictions, targets
[docs]
def predict_class(self, loader):
"""Generates class predictions with probabilities for the given data loader.
Args:
loader: DataLoader containing data for prediction.
Returns:
Tuple of (predicted_class, probability, targets) where:
- predicted_class: Predicted class labels
- probability: Class probabilities or confidence scores
- targets: Ground truth labels
"""
predicted_class, probability, targets = self._task.predict_class(loader)
return predicted_class, probability, targets
[docs]
def show_predictions(
self,
loader,
image_inverse_transform=None,
samples=9,
cols=3,
figsize=(10, 10),
target_known=True,
):
"""Visualizes model predictions on sample images.
Args:
loader: DataLoader containing data for visualization.
image_inverse_transform: Transformation to reverse image normalization for
display. Defaults to None.
samples: Number of samples to display. Defaults to 9.
cols: Number of columns in the visualization grid. Defaults to 3.
figsize: Figure size as (width, height) tuple. Defaults to (10, 10).
target_known: Whether ground truth targets are available for comparison.
Defaults to True.
"""
self._task.show_predictions(
loader,
image_inverse_transform=image_inverse_transform,
samples=samples,
cols=cols,
figsize=figsize,
target_known=target_known,
)