import os
from abc import ABC, abstractmethod
from typing import Any, Callable, Dict, List, Sequence, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
import torchvision
from PIL import Image
from tqdm import tqdm
from deepml.tracking import MLExperimentLogger
from deepml.utils import blend, create_text_image, get_random_samples_batch_from_loader
from deepml.visualize import plot_images, plot_images_with_title
[docs]
class Task(ABC):
"""Abstract base class for all deep learning tasks.
This class provides the foundation for task-specific implementations including
model management, device handling, and prediction workflows.
Subclasses must implement methods for transforming targets and outputs, batch prediction, training and
evaluation steps, and visualization.
"""
[docs]
def __init__(
self,
model: torch.nn.Module,
model_dir: str,
load_saved_model: bool = False,
model_file_name: str = "latest_model.pt",
device: str = "auto",
):
"""Initializes the Task.
Args:
model: PyTorch model instance to be trained or used for inference.
model_dir: Directory path for saving and loading model checkpoints.
load_saved_model: Whether to load a previously saved model from
model_dir. Defaults to False.
Set to True if you want to load model weights from a checkpoint file in model_dir.
model_file_name: Name of the model checkpoint file.
Defaults to "latest_model.pt".
device: Device to use for computation. Options: "auto", "cpu",
"cuda", or "mps". When "auto", automatically selects the best
available device. Defaults to "auto".
Raises:
AssertionError: If model is not a torch.nn.Module instance, or if
model_dir is None, or if model_file_name is not a string, or
if device is not one of the valid options.
"""
super(Task, self).__init__()
assert isinstance(
model, torch.nn.Module
), "model should be an instance of torch.nn.Module"
assert model_dir is not None, "model_dir should not be None"
assert isinstance(model_file_name, str), "model_file_name should be a string"
assert device in [
"auto",
"cpu",
"cuda",
"mps",
], "device should be one of 'auto', 'cpu', 'cuda', 'mps'"
self._model = model
self._model_dir = model_dir
self._model_file_name = model_file_name
if device == "auto":
if torch.cuda.is_available():
self._device = torch.device("cuda")
elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
self._device = torch.device("mps")
else:
self._device = torch.device("cpu")
else:
self._device = torch.device(device)
os.makedirs(self.model_dir, exist_ok=True)
weights_file_path = os.path.join(self._model_dir, self._model_file_name)
if load_saved_model:
self.__load_model_weights(weights_file_path)
def __load_model_weights(self, weights_file_path: str):
if weights_file_path and os.path.exists(weights_file_path):
print(f"Loading Saved Model Weights: {weights_file_path}")
state_dict = torch.load(weights_file_path, map_location=self._device)
self._model.load_state_dict(state_dict["model_state_dict"])
print("Model Weights Successfully Loaded!")
else:
print("Failed to load model weights..!")
@property
def model(self):
return self._model
@property
def model_dir(self):
return self._model_dir
@property
def device(self):
return self._device
@property
def model_file_name(self):
return self._model_file_name
[docs]
@abstractmethod
def predict_batch(self, x, *args, **kwargs):
"""Performs prediction on a single batch.
Args:
x: Input batch.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
Model predictions for the batch.
"""
[docs]
@abstractmethod
def train_step(self, x, y, *args, **kwargs) -> Tuple[Any, Any, Any]:
"""Executes a single training step.
Apply any batch based transformation to the target as well, if needed.
Args:
x: Input batch.
y: Target batch.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
Tuple of (predictions, processed_inputs, processed_targets).
"""
[docs]
@abstractmethod
def eval_step(self, x, y, *args, **kwargs) -> Tuple[Any, Any, Any]:
"""Executes a single evaluation step.
Args:
x: Input batch.
y: Target batch.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
Tuple of (predictions, processed_inputs, processed_targets).
"""
[docs]
@abstractmethod
def predict(self, loader):
"""Generates predictions for all data in the loader.
Args:
loader: DataLoader containing data for prediction.
Returns:
Predictions and targets.
"""
[docs]
@abstractmethod
def predict_class(self, loader):
"""Generates class predictions for all data in the loader.
Args:
loader: DataLoader containing data for prediction.
Returns:
Predicted classes, probabilities, and targets.
"""
[docs]
@abstractmethod
def show_predictions(
self,
loader,
image_inverse_transform=None,
samples=9,
cols=3,
figsize=(10, 10),
target_known=True,
):
"""Visualizes model predictions.
Args:
loader: DataLoader containing data for visualization.
image_inverse_transform: Transformation to reverse normalization.
samples: Number of samples to display.
cols: Number of columns in visualization grid.
figsize: Figure size tuple.
target_known: Whether ground truth is available.
"""
[docs]
@abstractmethod
def write_prediction_to_logger(
self, tag, loader, logger, image_inverse_transform, global_step, img_size=224
):
"""Writes predictions to experiment logger.
Args:
tag: Tag identifier for logged data.
loader: DataLoader containing data.
logger: Experiment logger instance.
image_inverse_transform: Transformation to reverse normalization.
global_step: Current training step/epoch.
img_size: Image size for logging.
"""
[docs]
@abstractmethod
def evaluate(
self,
loader: torch.utils.data.DataLoader,
criterion: torch.nn.Module,
metrics: Dict[str, torch.nn.Module] = None,
non_blocking=False,
):
"""Evaluates model performance on the given data.
Args:
loader: DataLoader containing evaluation data.
criterion: Loss function module.
metrics: Dictionary of metric modules.
non_blocking: Whether to use async CUDA transfers.
Returns:
Dictionary of evaluation metrics.
"""
[docs]
class NeuralNetTask(Task):
"""Base task implementation for general deep learning tasks.
This class provides a simple implementation suitable for any deep learning task.
It performs predictions without applying task-specific transformations and does
not write to TensorBoard by default.
Use this class when you need a minimal task implementation without specialized
handling for classification, segmentation, or regression.
"""
[docs]
def __init__(
self,
model: torch.nn.Module,
model_dir: str,
load_saved_model: bool = False,
model_file_name: str = "latest_model.pt",
device: str = "auto",
):
"""Initializes the NeuralNetTask.
Args:
model: PyTorch model instance to be trained or used for inference.
model_dir: Directory path for saving and loading model checkpoints.
load_saved_model: Whether to load a previously saved model from
model_dir. Defaults to False.
model_file_name: Name of the model checkpoint file.
Defaults to "latest_model.pt".
device: Device to use for computation. Options: "auto", "cpu",
"cuda", or "mps". Defaults to "auto".
"""
super(NeuralNetTask, self).__init__(
model, model_dir, load_saved_model, model_file_name, device
)
[docs]
def predict_batch(self, x: torch.Tensor, *args, **kwargs):
"""Performs prediction on a single batch.
Args:
x: Input batch tensor.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments. If 'model' key is present,
uses that model instead of the task's default model.
Returns:
Model predictions for the batch.
"""
x = self.move_input_to_device(x, **kwargs)
if "model" in kwargs:
return kwargs["model"](x)
else:
return self._model(x)
[docs]
def train_step(self, x, y, *args, **kwargs):
"""Executes a single training step.
Args:
x: Input batch.
y: Target batch.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
Tuple of (predictions, inputs, targets).
"""
return self.predict_batch(x, *args, **kwargs), x, y
[docs]
def eval_step(self, x, y, *args, **kwargs):
"""Executes a single evaluation step.
Args:
x: Input batch.
y: Target batch.
*args: Additional positional arguments.
**kwargs: Additional keyword arguments.
Returns:
Tuple of (predictions, inputs, targets).
"""
return self.predict_batch(x, *args, **kwargs), x, y
[docs]
def predict(self, loader: torch.utils.data.DataLoader):
"""Generates predictions for all batches in the data loader.
Args:
loader: DataLoader containing data for prediction.
Returns:
Tuple of (predictions, targets) where:
- predictions: Concatenated tensor of all model predictions
- targets: Concatenated tensor or list of all ground truth labels
Raises:
AssertionError: If loader is None or empty.
"""
assert loader is not None and len(loader) > 0
self._model.eval()
self._model = self._model.to(self._device)
predictions = []
targets = []
with torch.no_grad():
for x, y in tqdm(
loader, total=len(loader), desc="{:12s}".format("Prediction")
):
y_pred, x, y = self.eval_step(x, y)
predictions.append(y_pred)
targets.append(y)
predictions = torch.cat(predictions)
targets = (
torch.cat(targets)
if isinstance(targets[0], torch.Tensor)
else np.hstack(targets).tolist()
)
return predictions, targets
[docs]
def predict_class(self, loader: torch.utils.data.DataLoader):
"""Generates class predictions for all data in the loader.
Args:
loader: DataLoader containing data for prediction.
Raises:
NotImplementedError: This method must be implemented by subclasses.
"""
raise NotImplementedError()
[docs]
def show_predictions(
self,
loader: torch.utils.data.DataLoader,
image_inverse_transform: Callable = None,
samples: int = 9,
cols: int = 3,
figsize: Tuple[int, int] = (10, 10),
target_known: bool = True,
):
"""Visualizes model predictions.
Args:
loader: DataLoader containing data for visualization.
image_inverse_transform: Transformation to reverse normalization.
samples: Number of samples to display.
cols: Number of columns in visualization grid.
figsize: Figure size tuple.
target_known: Whether ground truth is available.
Raises:
NotImplementedError: This method must be implemented by subclasses.
"""
raise NotImplementedError()
[docs]
def write_prediction_to_logger(
self,
tag: str,
loader,
logger,
image_inverse_transform,
global_step,
img_size=224,
**kwargs: dict,
):
"""Writes predictions to experiment logger.
Args:
tag: Tag identifier for logged data.
loader: DataLoader containing data.
logger: Experiment logger instance.
image_inverse_transform: Transformation to reverse normalization.
global_step: Current training step/epoch.
img_size: Image size for logging.
**kwargs: Additional keyword arguments.
Note:
Default implementation does nothing. Override in subclasses for
custom logging behavior.
"""
@torch.no_grad()
def evaluate(
self,
loader: torch.utils.data.DataLoader,
metrics: Dict[str, torch.nn.Module] = None,
non_blocking=False,
):
"""Evaluates the model on the given data loader using specified metrics.
Args:
loader: DataLoader containing evaluation data.
metrics: Dictionary mapping metric names to metric modules. Each
metric should be a torch.nn.Module with a forward() method.
Defaults to None.
non_blocking: Whether to use asynchronous CUDA transfers.
Defaults to False.
Returns:
Dictionary mapping metric names to their average values across
all batches.
Raises:
Exception: If loader is None.
"""
if loader is None:
raise Exception("Loader cannot be None.")
self._model.eval()
metrics_dict = {metric_name: 0.0 for metric_name in metrics.keys()}
bar = tqdm(
total=len(loader), desc="{:12s}".format("Evaluation"), dynamic_ncols=True
)
total_samples = 0
for batch_index, (x, y) in enumerate(loader):
outputs, x, y = self.eval_step(x, y, non_blocking)
if isinstance(y, torch.Tensor):
y = y.to(self._device)
if (
isinstance(outputs, torch.Tensor)
and outputs.ndim == 2
and outputs.shape[1] == 1
):
y = y.view_as(outputs)
batch_size = x.size(0)
total_samples += batch_size
for metric_name, metric_instance in metrics.items():
metric_value = metric_instance(outputs, y).item()
metrics_dict[metric_name] += metric_value * batch_size
bar.update(1)
bar.close()
for metric_name in metrics_dict.keys():
metrics_dict[metric_name] = metrics_dict[metric_name] / total_samples
return metrics_dict
[docs]
class Segmentation(NeuralNetTask):
"""Task implementation for binary and multiclass semantic segmentation.
This class handles pixel-level classification tasks including binary and
multiclass segmentation with customizable color mapping for visualization.
Attributes:
mode: Segmentation mode ("binary" or "multiclass").
num_classes: Number of segmentation classes.
threshold: Threshold for binary segmentation predictions.
class_index_to_color: Dictionary mapping class indices to colors.
palette: Color palette for visualization (PIL format).
"""
[docs]
def __init__(
self,
model: torch.nn.Module,
model_dir: str,
mode: str = "binary",
load_saved_model: bool = False,
model_file_name: str = "latest_model.pt",
device: str = "auto",
num_classes: int = 1,
threshold: float = 0.5,
color_map: dict = None,
):
"""Initializes the Segmentation task.
Args:
model: PyTorch model architecture for segmentation.
model_dir: Directory path for saving/loading model checkpoints.
mode: Segmentation mode. Options: "binary" or "multiclass".
Defaults to "binary".
load_saved_model: Whether to load a previously saved model.
Defaults to False.
model_file_name: Name of the model checkpoint file.
Defaults to "latest_model.pt".
device: Device to use for computation. Options: "auto", "cpu",
"cuda", or "mps". Defaults to "auto".
num_classes: Number of segmentation classes. For binary
segmentation, use 1 (class 0: background, class 1: foreground).
Defaults to 1.
threshold: Probability threshold for binary segmentation predictions.
Defaults to 0.5.
color_map: Dictionary mapping class indices to colors. If None,
uses default color maps:
- Binary: {0: 0, 1: 255} (grayscale)
- Multiclass: {0: [0,0,0], 1: [R,G,B], ...} (RGB triplets)
For multiclass, random RGB colors are generated if not specified.
Class 0 is always background (black). Defaults to None.
Raises:
AssertionError: If num_classes is not an integer or is less than 1.
Example:
>>> model = UNet(in_channels=3, out_channels=3)
>>> color_map = {0: [0,0,0], 1: [255,0,0], 2: [0,255,0]}
>>> task = Segmentation(
... model=model,
... model_dir="./models",
... mode="multiclass",
... num_classes=3,
... color_map=color_map
... )
"""
super(Segmentation, self).__init__(
model, model_dir, load_saved_model, model_file_name, device
)
assert isinstance(num_classes, int), "should be the number of classes"
assert (
num_classes >= 1
), "for segmentation task, it should be greater than 1 class"
self.mode = mode
self.num_classes = num_classes
self.threshold = threshold
if color_map:
assert isinstance(color_map, dict)
self.class_index_to_color = color_map
else:
if self.mode == "binary":
self.class_index_to_color = {0: 0, 1: 255}
else:
self.class_index_to_color = {0: [0, 0, 0]}
additional_colors = np.random.randint(
0, 256, size=(self.num_classes - 1, 3)
)
# Create random RGB color triplets
for index, color in enumerate(additional_colors.tolist()):
self.class_index_to_color[index + 1] = color
self.__create_color_palette()
def __create_color_palette(self):
"""Creates a PIL-compatible color palette from the class color mapping.
Generates a flat list of RGB values suitable for use with PIL Image
palettes. For binary segmentation, creates a grayscale palette.
For multiclass, converts the RGB color map to a flat array.
Note:
The palette is padded to 768 values (256 colors * 3 channels) as
required by PIL.
"""
if self.mode == "binary":
self.palette = [0, 0, 0, 255, 255, 255]
else:
self.palette = (
np.array(list(self.class_index_to_color.values()))
.flatten()
.astype(np.uint8)
.tolist()
)
self.palette = self.palette + list(
np.zeros(768 - (len(self.palette)), dtype=np.uint8).tolist()
)
[docs]
def predict_batch(
self, x: Union[torch.Tensor, np.ndarray], *args, **kwargs: Dict[str, Any]
) -> torch.Tensor:
x = self.move_input_to_device(x, **kwargs) # Move input to device
if "model" in kwargs:
pred = kwargs["model"](x)
else:
pred = self._model(x)
if isinstance(pred, dict) and "out" in pred:
return pred["out"] # torchvision model's returns prediction in OrderedDict
else:
return pred
[docs]
def save_prediction(self, loader: torch.utils.data.DataLoader, save_dir: str):
"""Generates and saves segmentation predictions as PNG images.
Performs inference on the data loader and saves predicted segmentation
masks as PNG files with the appropriate color palette.
Args:
loader: DataLoader yielding batches of (images, filenames).
The second element must be a list of filename strings.
save_dir: Output directory path where prediction PNG files will be saved.
Directory will be created if it doesn't exist.
Raises:
AssertionError: If loader is None, empty, or save_dir is None.
Note:
Filenames that don't end with '.png' will be automatically converted
to PNG format with the .png extension.
"""
assert loader is not None and len(loader) > 0
assert save_dir is not None, "Output directory should not be none."
if save_dir:
os.makedirs(save_dir, exist_ok=True)
self._model = self._model.to(self._device)
self._model.eval()
with torch.no_grad():
for x, y in tqdm(
loader, total=len(loader), desc="{:12s}".format("Prediction")
):
y_pred, x, y = self.eval_step(x, y)
output_mask = self.transform_output(y_pred).cpu()
self._save_image_batch(output_mask, y, save_dir)
def _save_image_batch(
self, class_indices: torch.Tensor, filenames: List[str], save_dir: str
):
"""Saves a batch of segmentation masks as PNG images.
Args:
class_indices: Batch of segmentation masks with shape (B, H, W)
containing class indices.
filenames: List of filenames for each image in the batch.
save_dir: Directory path where images will be saved.
Raises:
AssertionError: If class_indices is not 3-dimensional (BHW format).
Note:
Images are saved as 8-bit palette PNG files with the task's color palette.
"""
assert class_indices.ndim == 3, "should be in the form of BHW"
for i in range(class_indices.shape[0]):
image = Image.fromarray(
class_indices[i].cpu().numpy().astype(np.uint8), "P"
)
image.putpalette(self.palette)
filename = filenames[i]
if not filename.endswith(".png"):
filename = f"{filename.split('.')[0]}.png"
image.save(os.path.join(save_dir, filename))
[docs]
def predict_class(self, loader: torch.utils.data.DataLoader):
raise NotImplementedError()
[docs]
def show_predictions(
self,
loader: torch.utils.data.DataLoader,
image_inverse_transform: Callable = None,
samples: int = 4,
cols: int = 3,
figsize: Tuple[int, int] = (16, 16),
target_known: bool = True,
):
"""Visualizes segmentation predictions on sample images.
Displays input images, ground truth masks, and predicted masks in a
matplotlib figure with overlays.
Args:
loader: DataLoader containing data for visualization.
image_inverse_transform: Transformation to reverse image normalization
for display. Defaults to None.
samples: Number of samples to display. Defaults to 4.
cols: Number of columns in the visualization grid. Defaults to 3.
figsize: Figure size as (width, height) tuple. Defaults to (16, 16).
target_known: Whether ground truth targets are available for comparison.
Defaults to True.
"""
self._model = self._model.to(self._device)
self._model.eval()
with torch.no_grad():
x, targets = get_random_samples_batch_from_loader(loader, samples)
predictions, x, targets = self.eval_step(x, targets)
x = self.transform_input(x, image_inverse_transform)
target_mask = self.decode_segmentation_mask(targets)
class_indices = self.transform_output(predictions)
output_mask = self.decode_segmentation_mask(class_indices)
x_uint8 = (x * 255.0).to(torch.uint8)
target_overlay = blend(x_uint8, target_mask) # BCHW
output_overlay = blend(x_uint8, output_mask) # BCHW
# BCHW --> #BHWC
x = x.permute([0, 2, 3, 1])
target_mask = target_mask.permute([0, 2, 3, 1])
output_mask = output_mask.permute([0, 2, 3, 1])
target_overlay = target_overlay.permute([0, 2, 3, 1])
output_overlay = output_overlay.permute([0, 2, 3, 1])
if self.num_classes == 1:
target_mask = torch.cat([target_mask, target_mask, target_mask], dim=3)
output_mask = torch.cat([output_mask, output_mask, output_mask], dim=3)
images = []
for i in range(x.shape[0]):
images.extend(
[
x[i],
target_mask[i],
target_overlay[i],
output_mask[i],
output_overlay[i],
]
)
image_titles = [
"Input",
"Target Mask",
"Target Overlay",
"Pred Mask",
"Pred Overlay",
] * x.shape[0]
plot_images(images, image_titles, cols=5, figsize=figsize, fontsize=12)
[docs]
def decode_segmentation_mask(self, class_indices: torch.Tensor) -> torch.Tensor:
"""Converts class indices to RGB color images for visualization.
Args:
class_indices: Batch of segmentation masks with shape (B, H, W)
containing class indices.
Returns:
Batch of RGB images with shape (B, C, H, W) where:
- For binary: C=1 (grayscale)
- For multiclass: C=3 (RGB)
Colors are mapped according to the class_index_to_color palette.
Note:
Uses PIL Image palette for efficient color mapping in multiclass mode.
"""
# Convert to numpy array
class_indices = class_indices.cpu().numpy().astype(np.uint8)
decoded_images = []
# For each image in the batch
for i in range(class_indices.shape[0]):
if self.mode == "binary":
image_arr = np.zeros_like(class_indices[i]) # HW
image_arr[class_indices[i] > 0] = 255
image_arr = image_arr[np.newaxis, ...] # CHW for grayscale
else:
image = Image.fromarray(class_indices[i], mode="P")
image.putpalette(self.palette)
image = image.convert("RGB")
image_arr = np.array(image) # HWC for RGB
image_arr = image_arr.transpose(2, 0, 1) # Convert to CHW format
decoded_images.append(torch.from_numpy(image_arr))
# return tensor of size (B, C, H, W) for both RGB and grayscale images
return torch.stack(decoded_images)
[docs]
def log_prediction(
self,
tag: str,
predictions: torch.Tensor,
x: torch.Tensor,
targets: torch.Tensor,
logger: MLExperimentLogger,
image_inverse_transform: Callable,
global_step: int,
img_size: Union[int, Tuple[int, int], None] = 224,
**kwargs: dict,
):
"""Logs input images, target masks, and output masks to the experiment logger.
Creates a visualization grid showing input images, ground truth masks,
ground truth overlays, predicted masks, and predicted overlays side by side.
Args:
tag: Tag identifier for the logged images in the experiment tracker.
predictions: Model predictions with shape (B, C, H, W) or (B, H, W).
x: Input images with shape (B, C, H, W).
targets: Ground truth masks with shape (B, H, W) or (B, C, H, W).
logger: Experiment logger instance for tracking visualizations.
image_inverse_transform: Callable to reverse image normalization
for proper visualization.
global_step: Current training step/epoch for the logger.
img_size: Target size for resizing images. Can be int or (H, W) tuple.
If None, no resizing is performed. Defaults to 224.
**kwargs: Additional keyword arguments passed through.
Note:
Override this method to customize the logging behavior. The default
implementation creates a grid with 5 images per sample: input, target
mask, target overlay, predicted mask, and predicted overlay.
"""
x = self.transform_input(x, image_inverse_transform).cpu() # BCHW
target_mask = self.decode_segmentation_mask(
targets.cpu()
) # (B, C, H, W) for RGB images or (B, H, W) for grayscale images
class_indices = self.transform_output(predictions).cpu() # BHW
output_mask = self.decode_segmentation_mask(
class_indices
) # (B, C, H, W) for RGB images or (B, H, W) for grayscale images
x = (x * 255.0).to(torch.uint8) # Convert to uint8 for visualization
target_segmentation = blend(x, target_mask) # B, C, H, W
output_segmentation = blend(x, output_mask) # B, C, H, W
# Resize images to img_size
x = F.interpolate(x, size=img_size, mode="bilinear", align_corners=False)
target_segmentation = F.interpolate(
target_segmentation,
size=img_size,
mode="bilinear",
align_corners=False,
)
target_mask = F.interpolate(
target_mask,
size=img_size,
mode="bilinear",
align_corners=False,
)
output_segmentation = F.interpolate(
output_segmentation,
size=img_size,
mode="bilinear",
align_corners=False,
)
output_mask = F.interpolate(
output_mask,
size=img_size,
mode="bilinear",
align_corners=False,
)
images = []
for t in zip(
x,
target_mask,
target_segmentation,
output_mask,
output_segmentation,
):
images.extend(t)
images = torch.stack(images) # B * 5, C, H, W
# nrow is number of images in a row, first is input, second is target mask, third is target mask overlay,
# 4th is output mask, 5th is output mask overlay
image_grid = torchvision.utils.make_grid(
images, nrow=5, padding=5, pad_value=255
)
image_grid = image_grid.permute(1, 2, 0).cpu().numpy() # CHW --> HWC
logger.log_image(tag, image_grid, global_step)
[docs]
def write_prediction_to_logger(
self,
tag: str,
loader: torch.utils.data.DataLoader,
logger: MLExperimentLogger,
image_inverse_transform: Callable,
global_step: int,
img_size: Union[int, Tuple[int, int], None] = 224,
**kwargs: dict,
):
"""Writes input images, targets, and predictions to the experiment logger.
Samples random batches from the data loader, generates predictions, and
logs visualizations to the experiment tracker.
Args:
tag: Tag identifier for the logged images in the experiment tracker.
loader: DataLoader containing data for visualization.
logger: Experiment logger instance for tracking visualizations.
image_inverse_transform: Callable to reverse image normalization
for proper visualization.
global_step: Current training step/epoch for the logger.
img_size: Target size for resizing images. Can be int or (H, W) tuple.
If None, no resizing is performed. Defaults to 224.
**kwargs: Additional keyword arguments passed to eval_step.
"""
self._model.eval()
with torch.no_grad():
x, targets = get_random_samples_batch_from_loader(loader, samples=4)
predictions, x, targets = self.eval_step(x, targets, **kwargs)
self.log_prediction(
tag,
predictions,
x,
targets,
logger,
image_inverse_transform,
global_step,
img_size,
**kwargs,
)
[docs]
class ImageRegression(NeuralNetTask):
"""Task implementation for image regression problems.
This class handles tasks where the model predicts continuous values from
images, such as age estimation, pose estimation, or depth prediction.
The task supports visualization of predictions alongside ground truth values
and logging to experiment trackers.
"""
[docs]
def __init__(
self,
model: torch.nn.Module,
model_dir: str,
load_saved_model: bool = False,
model_file_name: str = "latest_model.pt",
device: str = "auto",
):
"""Initializes the ImageRegression task.
Args:
model: PyTorch model instance for regression.
model_dir: Directory path for saving and loading model checkpoints.
load_saved_model: Whether to load a previously saved model from
model_dir. Defaults to False.
model_file_name: Name of the model checkpoint file.
Defaults to "latest_model.pt".
device: Device to use for computation. Options: "auto", "cpu",
"cuda", or "mps". Defaults to "auto".
"""
super(ImageRegression, self).__init__(
model, model_dir, load_saved_model, model_file_name, device
)
[docs]
def show_predictions(
self,
loader: torch.utils.data.DataLoader,
image_inverse_transform: Callable = None,
samples: int = 9,
cols: int = 3,
figsize: Tuple[int, int] = (10, 10),
target_known: bool = True,
):
"""Visualizes model predictions on sample images.
Displays random samples from the loader with their ground truth values
and predicted values in a matplotlib figure.
Args:
loader: DataLoader containing data for visualization.
image_inverse_transform: Transformation to reverse image normalization
for display. Defaults to None.
samples: Number of samples to display. Defaults to 9.
cols: Number of columns in the visualization grid. Defaults to 3.
figsize: Figure size as (width, height) tuple. Defaults to (10, 10).
target_known: Whether ground truth targets are available for comparison.
Defaults to True.
"""
self._model = self._model.to(self._device)
self._model.eval()
with torch.no_grad():
x, y = get_random_samples_batch_from_loader(loader, samples)
predictions, x, y = self.eval_step(x, y)
x = self.transform_input(x, image_inverse_transform)
# #BCHW --> #BHWC
x = x.permute([0, 2, 3, 1])
def create_title(y, prediction):
prediction = self.transform_output(prediction)
if target_known:
return f"Ground Truth={self.transform_target(y)}\nPrediction={prediction}"
else:
return f"{y}\nPrediction={prediction}"
image_title_generator = (
(x[index], create_title(y[index], predictions[index]), None)
for index in range(x.shape[0])
)
plot_images_with_title(
image_title_generator, samples=samples, cols=cols, figsize=figsize
)
[docs]
def write_prediction_to_logger(
self,
tag: str,
loader: torch.utils.data.DataLoader,
logger: MLExperimentLogger,
image_inverse_transform: Callable,
global_step: int,
img_size: Union[int, Tuple[int, int], None] = 224,
):
"""Writes predictions with ground truth values to the experiment logger.
Creates a visualization grid showing input images alongside their
ground truth and predicted values as text overlays.
Args:
tag: Unique tag identifier for the logged images.
loader: DataLoader containing data for visualization.
logger: Experiment logger instance for tracking visualizations.
image_inverse_transform: Transformation to reverse image normalization.
global_step: Current training epoch/step for the logger.
img_size: Image size for TensorBoard logging. Can be int or (H, W) tuple.
If None, no visualization is written. Defaults to 224.
"""
if img_size:
assert isinstance(img_size, int) or (
isinstance(img_size, tuple) and len(img_size) == 2
)
else:
return
if isinstance(img_size, int):
img_size = (img_size, img_size)
self._model = self._model.to(self._device)
self._model.eval()
with torch.no_grad():
x, targets = get_random_samples_batch_from_loader(loader)
predictions, x, targets = self.eval_step(x, targets)
x, y = x.cpu(), targets.cpu()
x = self.transform_input(x, image_inverse_transform)
input_img_size = tuple(x.shape[-2:])
to_pillow_image = torchvision.transforms.Compose(
[
torchvision.transforms.ToPILImage(),
torchvision.transforms.Resize(img_size),
]
)
to_tensor = torchvision.transforms.ToTensor()
text = "GT={ground_truth}\nPred={prediction}"
output_images = []
for index in range(x.shape[0]):
ground_truth = self.transform_target(y[index])
prediction = self.transform_output(predictions[index])
content = text.format(ground_truth=ground_truth, prediction=prediction)
content_image = create_text_image(content, img_size=img_size)
if input_img_size != img_size:
output_images.append(
to_tensor(to_pillow_image(x[index].squeeze(dim=0)))
)
else:
output_images.append(x[index].squeeze(dim=0))
output_images.append(to_tensor(content_image))
logger.log_artifact(f"{tag}", torch.stack(output_images), global_step)
[docs]
def predict_class(self, loader):
raise NotImplementedError()
[docs]
class ImageClassification(NeuralNetTask):
"""Task implementation for image classification.
This class handles both binary and multiclass classification tasks where
each image belongs to exactly one class. Supports custom class labels and
visualization of predictions.
Attributes:
_classes: Optional sequence of class names for human-readable labels.
"""
[docs]
def __init__(
self,
model: torch.nn.Module,
model_dir: str,
load_saved_model: bool = False,
model_file_name: str = "latest_model.pt",
device: str = "auto",
classes: Sequence = None,
):
"""Initializes the ImageClassification task.
Args:
model: PyTorch model instance for classification.
model_dir: Directory path for saving and loading model checkpoints.
load_saved_model: Whether to load a previously saved model from
model_dir. Defaults to False.
model_file_name: Name of the model checkpoint file.
Defaults to "latest_model.pt".
device: Device to use for computation. Options: "auto", "cpu",
"cuda", or "mps". Defaults to "auto".
classes: Optional sequence of class names (e.g., ['cat', 'dog']).
If provided, predictions will use these labels instead of
class indices. Defaults to None.
"""
super(ImageClassification, self).__init__(
model, model_dir, load_saved_model, model_file_name, device
)
self._classes = classes
[docs]
def predict_class(self, loader: torch.utils.data.DataLoader):
"""Generates class predictions with probabilities for all data.
Args:
loader: DataLoader containing data for prediction.
Returns:
Tuple of (predicted_class, probability, targets) where:
- predicted_class: Tensor of predicted class indices
- probability: Tensor of prediction confidence scores
- targets: Ground truth class labels
"""
predictions, targets = self.predict(loader)
predicted_class, probability = self.transform_output(predictions)
return predicted_class, probability, targets
def _create_title_for_display(
self,
target_class_index,
predicted_class_index,
predicted_probability,
target_known=True,
):
"""Creates a colored title string for prediction visualization.
Args:
target_class_index: Ground truth class index.
predicted_class_index: Predicted class index.
predicted_probability: Prediction confidence score.
target_known: Whether ground truth is available. Defaults to True.
Returns:
Tuple of (title_text, title_color) where:
- title_text: Formatted string with prediction info
- title_color: "green" (correct), "red" (incorrect), or "yellow" (unknown)
"""
predicted_class = self.transform_target(predicted_class_index)
probability = round(predicted_probability.item(), 2)
if target_known:
target_class = self.transform_target(target_class_index)
title_color = "green" if predicted_class == target_class else "red"
return (
f"Ground Truth={target_class}"
f"\nPrediction={predicted_class}, "
f"{probability}",
title_color,
)
else:
return (
f"{target_class_index}\nPrediction={predicted_class}, {probability}",
"yellow",
)
def _create_output_image_for_tensorboard(
self,
target_class_index,
predicted_class_index,
predicted_probability,
img_size=(224, 224),
):
"""Creates a text image showing prediction results for TensorBoard.
Args:
target_class_index: Ground truth class index.
predicted_class_index: Predicted class index.
predicted_probability: Prediction confidence score.
img_size: Output image size as (height, width) tuple.
Defaults to (224, 224).
Returns:
PIL Image containing formatted text with ground truth, prediction,
and probability colored by correctness (green or red).
"""
ground_truth = self.transform_target(target_class_index)
predicted_class = self.transform_target(predicted_class_index)
probability = round(predicted_probability.item(), 2)
text_color = "green" if ground_truth == predicted_class else "red"
display_content = f"{ground_truth}\n{predicted_class}, {probability}"
return create_text_image(
display_content, img_size=img_size, text_color=text_color
)
[docs]
def show_predictions(
self,
loader: torch.utils.data.DataLoader,
image_inverse_transform: Callable = None,
samples: int = 9,
cols: int = 3,
figsize: Tuple[int, int] = (10, 10),
target_known: bool = True,
):
"""Visualizes model predictions on sample images.
Displays random samples from the loader with their ground truth labels,
predicted labels, and confidence scores in a matplotlib figure.
Args:
loader: DataLoader containing data for visualization.
image_inverse_transform: Transformation to reverse image normalization
for display. Defaults to None.
samples: Number of samples to display. Defaults to 9.
cols: Number of columns in the visualization grid. Defaults to 3.
figsize: Figure size as (width, height) tuple. Defaults to (10, 10).
target_known: Whether ground truth targets are available for comparison.
If True, titles will be colored green (correct) or red (incorrect).
Defaults to True.
"""
self._model = self._model.to(self._device)
self._model.eval()
with torch.no_grad():
x, targets = get_random_samples_batch_from_loader(loader, samples)
predictions, x, targets = self.eval_step(x, targets)
x = self.transform_input(x, image_inverse_transform)
# #BCHW --> #BHWC
x = x.permute([0, 2, 3, 1])
class_indices, probabilities = self.transform_output(predictions)
image_title_generator = (
(
x[index],
*self._create_title_for_display(
targets[index],
class_indices[index],
probabilities[index],
target_known,
),
)
for index in range(x.shape[0])
)
plot_images_with_title(
image_title_generator, samples=samples, cols=cols, figsize=figsize
)
[docs]
def write_prediction_to_logger(
self,
tag: str,
loader,
logger: MLExperimentLogger,
image_inverse_transform,
global_step: int,
img_size=224,
):
"""Writes predictions with labels to the experiment logger.
Creates a visualization grid showing input images alongside their
ground truth and predicted class labels with confidence scores.
Args:
tag: Unique tag identifier for the logged images.
loader: DataLoader containing data for visualization.
logger: Experiment logger instance for tracking visualizations.
image_inverse_transform: Transformation to reverse image normalization.
global_step: Current training epoch/step for the logger.
img_size: Image size for logging. Can be int or (H, W) tuple.
If None, no visualization is written. Defaults to 224.
Note:
Predictions are colored green for correct classifications and red
for incorrect ones.
"""
if img_size:
assert isinstance(img_size, int) or (
isinstance(img_size, tuple) and len(img_size) == 2
)
else:
return
if isinstance(img_size, int):
img_size = (img_size, img_size)
self._model = self._model.to(self._device)
self._model.eval()
with torch.no_grad():
x, targets = get_random_samples_batch_from_loader(loader)
predictions, x, targets = self.eval_step(x, targets)
x = self.transform_input(x).cpu()
class_indices, probabilities = self.transform_output(predictions)
input_img_size = tuple(x.shape[-2:])
to_pillow_image = torchvision.transforms.Compose(
[
torchvision.transforms.ToPILImage(),
torchvision.transforms.Resize(img_size),
]
)
to_tensor = torchvision.transforms.ToTensor()
output_images = []
for index in range(x.shape[0]):
content_image = self._create_output_image_for_tensorboard(
targets[index], class_indices[index], probabilities[index], img_size
)
if input_img_size != img_size:
output_images.append(
to_tensor(to_pillow_image(x[index].squeeze(dim=0)))
)
else:
output_images.append(x[index].squeeze(dim=0))
output_images.append(to_tensor(content_image))
logger.log_artifact(f"{tag}", torch.stack(output_images), global_step)
[docs]
class MultiLabelImageClassification(ImageClassification):
"""Task implementation for multi-label image classification.
This class handles classification tasks where each image can belong to
multiple classes simultaneously (e.g., an image containing both a cat
and a dog).
Attributes:
_classes: Optional sequence of class names for human-readable labels.
"""
[docs]
def __init__(
self,
model: torch.nn.Module,
model_dir,
load_saved_model: bool = False,
model_file_name: str = "latest_model.pt",
device: str = "auto",
classes=None,
):
"""Initializes the MultiLabelImageClassification task.
Args:
model: PyTorch model instance for multi-label classification.
model_dir: Directory path for saving and loading model checkpoints.
load_saved_model: Whether to load a previously saved model from
model_dir. Defaults to False.
model_file_name: Name of the model checkpoint file.
Defaults to "latest_model.pt".
device: Device to use for computation. Options: "auto", "cpu",
"cuda", or "mps". Defaults to "auto".
classes: Optional sequence of class names for labeling.
Defaults to None.
"""
super(MultiLabelImageClassification, self).__init__(
model, model_dir, load_saved_model, model_file_name, device
)
self._classes = classes
[docs]
def predict_class(self, loader):
"""Generates multi-label class predictions with probabilities for all data.
Args:
loader: DataLoader containing data for prediction.
Returns:
Tuple of (predicted_class, probability, targets) where:
- predicted_class: Binary tensor indicating predicted classes
- probability: Tensor of class probabilities for all classes
- targets: Ground truth multi-label targets
"""
predictions, targets = self.predict(loader)
predicted_class, probability = self.transform_output(predictions)
return predicted_class, probability, targets
def _create_title_for_display(
self,
target_class_indices,
predicted_class_indexes,
predicted_probs,
target_known=True,
):
"""Creates a colored title string for multi-label prediction visualization.
Args:
target_class_indices: Binary tensor of ground truth class labels.
predicted_class_indexes: Binary tensor of predicted class labels.
predicted_probs: Tensor of prediction probabilities for each class.
target_known: Whether ground truth is available. Defaults to True.
Returns:
Tuple of (title_text, title_color) where:
- title_text: Formatted string with comma-separated class names
and their probabilities
- title_color: "green" (correct), "red" (incorrect), or "yellow" (unknown)
"""
predicted_classes = self.transform_target(predicted_class_indexes)
predicted_probs = f'{", ".join([round(predicted_probs[prob_index], 2) for prob_index in predicted_class_indexes if prob_index])}'
if target_known:
target_classes = self.transform_target(target_class_indices)
title_color = "green" if target_classes == predicted_classes else "red"
return (
f"GT={target_classes}\nPred={predicted_classes},\n{predicted_probs}",
title_color,
)
else:
return (
f"{target_class_indices}\nPred={predicted_classes},\n{predicted_probs}",
"yellow",
)
def _create_output_image_for_tensorboard(
self,
target_class_indices,
predicted_class_indexes,
predicted_probs,
img_size=(224, 224),
):
target_classes = self.transform_target(target_class_indices)
predicted_classes = self.transform_target(predicted_class_indexes)
probabilities = ", ".join(
[
round(predicted_probs[prob_index], 2)
for prob_index in predicted_class_indexes
if prob_index
]
)
display_content = f"{target_classes}\n{predicted_classes}\n{probabilities}"
text_color = "green" if target_classes == predicted_classes else "red"
return create_text_image(
display_content, img_size=img_size, text_color=text_color
)