import torch
import torch.nn.functional as F
from .commons import (
false_negatives,
false_positives,
multiclass_tp_fp_tn_fn,
true_negatives,
true_positives,
)
[docs]
class Binarizer(torch.nn.Module):
[docs]
def __init__(self, threshold=0.5):
super(Binarizer, self).__init__()
self.threshold = threshold
[docs]
def forward(self, output):
if output.ndim == 2 and output.shape[-1] > 1:
# multiclass
probabilities, indices = torch.max(F.softmax(output, dim=1), dim=1)
else:
# binary
probabilities = torch.sigmoid(output)
indices = torch.zeros_like(probabilities)
indices[probabilities > self.threshold] = 1
return indices, probabilities
[docs]
class Accuracy(torch.nn.Module):
[docs]
def __init__(self, threshold=0.5):
super(Accuracy, self).__init__()
self.binarize = Binarizer(threshold)
[docs]
def forward(self, output, target):
indices, _ = self.binarize(output)
return (indices == target).float().mean()
[docs]
class Precision(torch.nn.Module):
[docs]
def __init__(self, threshold=0.5, epsilon=1e-6):
super(Precision, self).__init__()
self.binarize = Binarizer(threshold)
self.epsilon = epsilon
[docs]
def forward(self, output, target):
indices, probabilities = self.binarize(output)
if output.shape[-1] > 1:
# multiclass
tp, fp, _, _ = multiclass_tp_fp_tn_fn(indices, target)
else:
tp = true_positives(indices, target)
fp = false_positives(indices, target)
return tp / (tp + fp + self.epsilon)
[docs]
class Recall(torch.nn.Module):
[docs]
def __init__(self, threshold=0.5, epsilon=1e-6):
super(Recall, self).__init__()
self.binarize = Binarizer(threshold)
self.epsilon = epsilon
[docs]
def forward(self, output, target):
indices, probabilities = self.binarize(output)
if output.shape[-1] > 1:
# multiclass
tp, _, _, fn = multiclass_tp_fp_tn_fn(indices, target)
else:
tp = true_positives(indices, target)
fn = false_negatives(indices, target)
return tp / (tp + fn + self.epsilon)
[docs]
class FScore(torch.nn.Module):
[docs]
def __init__(self, beta=1.0, threshold=0.5, epsilon=1e-6):
super(FScore, self).__init__()
self.beta = beta
self.binarize = Binarizer(threshold)
self.epsilon = epsilon
[docs]
def forward(self, output, target):
indices, probabilities = self.binarize(output)
if output.shape[-1] > 1:
# multiclass
tp, fp, _, fn = multiclass_tp_fp_tn_fn(indices, target)
else:
tp = true_positives(indices, target)
fp = false_positives(indices, target)
fn = false_negatives(indices, target)
precision = tp / (tp + fp + self.epsilon)
recall = tp / (tp + fn + self.epsilon)
return ((1 + self.beta**2) * precision * recall) / (
self.beta**2 * (precision + recall)
)
[docs]
class MCC(torch.nn.Module):
"""
Matthews correlation coefficient
The metric useful for imbalanced dataset.
Check more info at https://en.wikipedia.org/wiki/Matthews_correlation_coefficient
"""
[docs]
def __init__(self, threshold=0.5, epsilon=1e-6):
super(MCC, self).__init__()
self.binarize = Binarizer(threshold)
self.epsilon = epsilon
[docs]
def forward(self, output, target):
indices, probabilities = self.binarize(output)
if output.shape[-1] > 1:
# multiclass
tp, fp, tn, fn = multiclass_tp_fp_tn_fn(indices, target)
else:
tp = true_positives(indices, target)
tn = true_negatives(indices, target)
fp = false_positives(indices, target)
fn = false_negatives(indices, target)
numerator = (tp * tn) - (fp * fn)
denominator = torch.sqrt(
torch.tensor(
(tp + fp) * (tp + fn) * (tn + fp) * (tn + fn), dtype=torch.float
)
)
return numerator / (denominator + self.epsilon)