from abc import ABC, abstractmethod
from typing import Union
import torch
from segmentation_models_pytorch.metrics.functional import (
accuracy,
f1_score,
get_stats,
iou_score,
precision,
recall,
)
[docs]
class ToClassIndex(torch.nn.Module):
[docs]
def __init__(self, mode: str = "binary", threshold: float = None, activation=None):
super(ToClassIndex, self).__init__()
self.mode = mode
self.activation = activation
self.threshold = threshold
if self.mode not in ["binary", "multiclass", "multilabel"]:
raise ValueError(
"mode should be either 'binary', 'multiclass' or 'multilabel' "
)
if self.threshold and self.mode == "multiclass":
raise ValueError(f"threshold and mode={self.mode} cannot be used together")
if self.activation is None:
self.activation = (
torch.nn.Softmax2d()
if self.mode == "multiclass"
else torch.nn.Sigmoid()
)
[docs]
def forward(self, output: torch.FloatTensor) -> torch.Tensor:
assert output.ndim == 4 # B,C,H,W
if self.mode in ["binary", "multilabel"]:
probability = self.activation(output)
threshold = self.threshold if self.threshold is not None else 0.5
class_indices = torch.zeros_like(probability)
class_indices[probability >= threshold] = 1
else:
# Multiclass
probability = self.activation(output)
class_indices = torch.argmax(probability, dim=1)
return class_indices
[docs]
class SegmentationMetric(torch.nn.Module, ABC):
[docs]
def __init__(
self,
mode: str = "binary",
reduction: str = "macro-imagewise",
activation=None,
ignore_index=None,
threshold=None,
num_classes=None,
class_weights=None,
target_class_index=None,
zero_division=1.0,
callable=None,
):
super(SegmentationMetric, self).__init__()
self.mode = mode
self.ignore_index = ignore_index
self.threshold = threshold
self.num_classes = num_classes
self.reduction = reduction
self.class_weights = class_weights
self.zero_division = zero_division
self.activation = activation
self.target_class_index = target_class_index
self.callable = callable
if self.mode not in ["binary", "multiclass", "multilabel"]:
raise ValueError(
"mode should be either 'binary', 'multiclass' or 'multilabel'"
)
if self.ignore_index is not None and self.mode == "binary":
raise ValueError("ignore_index is not supported for binary")
if self.target_class_index is not None and self.mode == "binary":
raise ValueError("target_class_index is not supported for binary")
if self.num_classes is None and self.mode == "multiclass":
raise ValueError("num_classes is required for multiclass mode")
if (
self.target_class_index is not None
and self.num_classes is not None
and self.target_class_index >= self.num_classes
):
raise ValueError("target_class_index should be less than num_classes")
self.to_class_index = ToClassIndex(self.mode, self.threshold, self.activation)
[docs]
@abstractmethod
def forward(
self,
output: Union[torch.LongTensor, torch.FloatTensor],
target: torch.LongTensor,
):
pass
def _get_stats(
self,
output: Union[torch.LongTensor, torch.FloatTensor],
target: torch.LongTensor,
) -> tuple:
if self.callable is not None:
output, target = self.callable(output, target)
output = self.to_class_index(output)
# Ensure target is on the same device as output (e.g. when output is on GPU)
target = target.to(output.device)
if self.mode == "multiclass" and self.ignore_index == 0:
# to handle class 0 (background) in multiclass segmentation for ignore index
return get_stats(
output - 1,
target - 1,
ignore_index=-1,
mode=self.mode,
num_classes=self.num_classes,
threshold=self.threshold,
)
else:
return get_stats(
output,
target,
ignore_index=self.ignore_index,
mode=self.mode,
num_classes=self.num_classes,
threshold=self.threshold,
)
[docs]
class Precision(SegmentationMetric):
"""
Computes the precision metric for segmentation.
Args:
mode (str): The mode of the metric, either 'binary' or 'multiclass' or 'multilabel'. Default is 'Binary'.
reduction (str, optional): Define how to aggregate metric between classes and images: 'micro', 'macro', 'weighted', 'micro-imagewise', 'macro-imagewise', 'weighted-imagewise'.
Default is "macro-imagewise". Reference link: https://smp.readthedocs.io/en/latest/metrics.html#segmentation_models_pytorch.metrics.functional.precision
activation (torch.nn.Module, optional): An activation function to apply to the output of the model. Default is None.
ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric calculation. Default is None.
threshold (float, optional): Threshold value for binarizing the output. Default is None.
num_classes (int, optional): Number of classes for the metric calculation. Default is None.
class_weights (torch.Tensor, optional): A manual rescaling weight given to each class. Default is None.
zero_division (float): Value to return when there is a zero division. Default is 1.0.
target_class_index (int, optional): The class index for which to compute the precision. Default is None.
callabe (callable, optional): A callable function to apply to the output and target before metric calculation. Default is None.
"""
[docs]
def __init__(
self,
mode: str = "binary",
reduction: str = "macro-imagewise",
activation=None,
ignore_index=None,
threshold=None,
num_classes=None,
class_weights=None,
target_class_index=None,
zero_division=1.0,
callable=None,
):
super(Precision, self).__init__(
mode=mode,
reduction=reduction,
activation=activation,
ignore_index=ignore_index,
threshold=threshold,
num_classes=num_classes,
class_weights=class_weights,
target_class_index=target_class_index,
zero_division=zero_division,
callable=callable,
)
[docs]
def forward(
self,
output: Union[torch.LongTensor, torch.FloatTensor],
target: torch.LongTensor,
):
tp, fp, fn, tn = self._get_stats(output, target)
# tp shape is [N, C] where N is the batch size and C is the number of classes
# for each image in the batch, we have tp, fp, fn, tn for each class
if self.target_class_index is not None:
tp = tp[:, self.target_class_index]
fp = fp[:, self.target_class_index]
fn = fn[:, self.target_class_index]
tn = tn[:, self.target_class_index]
return precision(
tp=tp,
fp=fp,
fn=fn,
tn=tn,
reduction=self.reduction,
class_weights=self.class_weights,
zero_division=self.zero_division,
)
[docs]
class Recall(SegmentationMetric):
"""
Computes the recall metric for segmentation tasks.
Args:
mode (str): The mode of the metric, either 'binary' or 'multiclass' or 'multilabel'. Default is 'Binary'.
reduction (str, optional): Define how to aggregate metric between classes and images: 'micro', 'macro', 'weighted', 'micro-imagewise', 'macro-imagewise', 'weighted-imagewise'.
Default is 'macro-imagewise'. Reference link: https://smp.readthedocs.io/en/latest/metrics.html#segmentation_models_pytorch.metrics.functional.precision
activation (torch.nn.Module, optional): An activation function to apply to the output of the model. Default is None.
ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric calculation. Default is None.
threshold (float, optional): Threshold value for binarizing the output. Default is None.
num_classes (int, optional): Number of classes for the metric calculation. Default is None.
class_weights (torch.Tensor, optional): A manual rescaling weight given to each class. Default is None.
target_class_index (int, optional): The class index for which to compute the recall. Default is None.
zero_division (float): Value to return when there is a zero division. Default is 1.0.
callable (callable, optional): A callable function to apply to the output and target before metric calculation. Default is None.
"""
[docs]
def __init__(
self,
mode: str = "binary",
reduction: str = "macro-imagewise",
activation=None,
ignore_index=None,
threshold=None,
num_classes=None,
class_weights=None,
target_class_index=None,
zero_division=1.0,
callable=None,
):
super(Recall, self).__init__(
mode=mode,
reduction=reduction,
activation=activation,
ignore_index=ignore_index,
threshold=threshold,
num_classes=num_classes,
class_weights=class_weights,
target_class_index=target_class_index,
zero_division=zero_division,
callable=callable,
)
[docs]
def forward(
self,
output: Union[torch.LongTensor, torch.FloatTensor],
target: torch.LongTensor,
):
tp, fp, fn, tn = self._get_stats(output, target)
# tp shape is [N, C] where N is the batch size and C is the number of classes
# for each image in the batch, we have tp, fp, fn, tn for each class
if self.target_class_index is not None:
tp = tp[:, self.target_class_index]
fp = fp[:, self.target_class_index]
fn = fn[:, self.target_class_index]
tn = tn[:, self.target_class_index]
return recall(
tp=tp,
fp=fp,
fn=fn,
tn=tn,
reduction=self.reduction,
class_weights=self.class_weights,
zero_division=self.zero_division,
)
[docs]
class F1Score(SegmentationMetric):
"""
Computes the f1 metric for segmentation tasks.
Args:
mode (str): The mode of the metric, either 'binary' or 'multiclass' or 'multilabel'. Default is 'Binary'.
reduction (str, optional): Define how to aggregate metric between classes and images: 'micro', 'macro', 'weighted'. Default is None.
activation (torch.nn.Module, optional): An activation function to apply to the output of the model. Default is None.
ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric calculation. Default is None.
threshold (float, optional): Threshold value for binarizing the output. Default is None.
num_classes (int, optional): Number of classes for the metric calculation. Default is None.
class_weights (torch.Tensor, optional): A manual rescaling weight given to each class. Default is None.
target_class_index (int, optional): The class index for which to compute the f1 score. Default is None.
zero_division (float): Value to return when there is a zero division. Default is 1.0.
callable (callable, optional): A callable function to apply to the output and target before metric calculation. Default is None.
"""
[docs]
def __init__(
self,
mode: str = "binary",
reduction: str = "macro-imagewise",
activation=None,
ignore_index=None,
threshold=None,
num_classes=None,
class_weights=None,
target_class_index=None,
zero_division=1.0,
callable=None,
):
super(F1Score, self).__init__(
mode=mode,
reduction=reduction,
activation=activation,
ignore_index=ignore_index,
threshold=threshold,
num_classes=num_classes,
class_weights=class_weights,
target_class_index=target_class_index,
zero_division=zero_division,
callable=callable,
)
[docs]
def forward(
self,
output: Union[torch.LongTensor, torch.FloatTensor],
target: torch.LongTensor,
):
tp, fp, fn, tn = self._get_stats(output, target)
# tp shape is [N, C] where N is the batch size and C is the number of classes
# for each image in the batch, we have tp, fp, fn, tn for each class
if self.target_class_index is not None:
tp = tp[:, self.target_class_index]
fp = fp[:, self.target_class_index]
fn = fn[:, self.target_class_index]
tn = tn[:, self.target_class_index]
return f1_score(
tp=tp,
fp=fp,
fn=fn,
tn=tn,
reduction=self.reduction,
class_weights=self.class_weights,
zero_division=self.zero_division,
)
[docs]
class Accuracy(SegmentationMetric):
"""
Computes the accuracy metric for segmentation tasks.
Args:
mode (str): The mode of the metric, either 'binary' or 'multiclass' or 'multilabel'. Default is 'binary'.
reduction (str, optional): Define how to aggregate metric between classes and images: 'micro', 'macro', 'weighted', 'micro-imagewise', 'macro-imagewise', 'weighted-imagewise'.
Default is "macro-imagewise". Reference link: https://smp.readthedocs.io/en/latest/metrics.html
activation (torch.nn.Module, optional): An activation function to apply to the output of the model. Default is None.
ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric calculation. Default is None.
threshold (float, optional): Threshold value for binarizing the output. Default is None.
num_classes (int, optional): Number of classes for the metric calculation. Default is None.
class_weights (torch.Tensor, optional): A manual rescaling weight given to each class. Default is None.
target_class_index (int, optional): The class index for which to compute the accuracy. Default is None.
zero_division (float): Value to return when there is a zero division. Default is 1.0.
callable (callable, optional): A callable function to apply to the output and target before metric calculation. Default is None.
"""
[docs]
def __init__(
self,
mode: str = "binary",
reduction: str = "macro-imagewise",
activation=None,
ignore_index=None,
threshold=None,
num_classes=None,
class_weights=None,
target_class_index=None,
zero_division=1.0,
callable=None,
):
super(Accuracy, self).__init__(
mode=mode,
reduction=reduction,
activation=activation,
ignore_index=ignore_index,
threshold=threshold,
num_classes=num_classes,
class_weights=class_weights,
target_class_index=target_class_index,
zero_division=zero_division,
callable=callable,
)
[docs]
def forward(
self,
output: Union[torch.LongTensor, torch.FloatTensor],
target: torch.LongTensor,
):
tp, fp, fn, tn = self._get_stats(output, target)
# tp shape is [N, C] where N is the batch size and C is the number of classes
# for each image in the batch, we have tp, fp, fn, tn for each class
if self.target_class_index is not None:
tp = tp[:, self.target_class_index]
fp = fp[:, self.target_class_index]
fn = fn[:, self.target_class_index]
tn = tn[:, self.target_class_index]
return accuracy(
tp=tp,
fp=fp,
fn=fn,
tn=tn,
reduction=self.reduction,
class_weights=self.class_weights,
zero_division=self.zero_division,
)
[docs]
class IoUScore(SegmentationMetric):
"""
Computes the jaccard index metric for segmentation.
Args:
mode (str): The mode of the metric, either 'binary' or 'multiclass' or 'multilabel'. Default is 'Binary'.
reduction (str, optional): Define how to aggregate metric between classes and images: 'micro', 'macro', 'weighted'. Default is None.
activation (torch.nn.Module, optional): An activation function to apply to the output of the model. Default is None.
ignore_index (int, optional): Specifies a target value that is ignored and does not contribute to the metric calculation. Default is None.
threshold (float, optional): Threshold value for binarizing the output. Default is None.
num_classes (int, optional): Number of classes for the metric calculation. Default is None.
class_weights (torch.Tensor, optional): A manual rescaling weight given to each class. Default is None.
zero_division (float): Value to return when there is a zero division. Default is 1.0.
target_class_index (int, optional): The class index for which to compute the precision. Default is None.
callable (callable, optional): A callable function to apply to the output and target before metric calculation. Default is None.
"""
[docs]
def __init__(
self,
mode: str = "binary",
reduction: str = "macro-imagewise",
activation=None,
ignore_index=None,
threshold=None,
num_classes=None,
class_weights=None,
target_class_index=None,
zero_division=1.0,
callable=None,
):
super(IoUScore, self).__init__(
mode=mode,
reduction=reduction,
activation=activation,
ignore_index=ignore_index,
threshold=threshold,
num_classes=num_classes,
class_weights=class_weights,
target_class_index=target_class_index,
zero_division=zero_division,
callable=callable,
)
[docs]
def forward(
self,
output: Union[torch.LongTensor, torch.FloatTensor],
target: torch.LongTensor,
):
tp, fp, fn, tn = self._get_stats(output, target)
# tp shape is [N, C] where N is the batch size and C is the number of classes
# for each image in the batch, we have tp, fp, fn, tn for each class
if self.target_class_index is not None:
tp = tp[:, self.target_class_index]
fp = fp[:, self.target_class_index]
fn = fn[:, self.target_class_index]
tn = tn[:, self.target_class_index]
return iou_score(
tp=tp,
fp=fp,
fn=fn,
tn=tn,
reduction=self.reduction,
class_weights=self.class_weights,
zero_division=self.zero_division,
)