Source code for deepml.losses

import torch
import torch.nn.functional as F


[docs] class JaccardLoss(torch.nn.Module): """Jaccard Loss (Intersection over Union) for segmentation tasks. Computes 1 - IoU as a differentiable loss function for both binary and multiclass segmentation. Attributes: activation: Activation function applied to output logits. Softmax2d for multiclass, Sigmoid for binary. """
[docs] def __init__(self, is_multiclass): """Initializes JaccardLoss with the appropriate activation. Args: is_multiclass: If True, uses Softmax2d activation for multiclass segmentation. Otherwise, uses Sigmoid for binary segmentation. """ super(JaccardLoss, self).__init__() if is_multiclass: self.activation = torch.nn.Softmax2d() else: self.activation = torch.nn.Sigmoid()
[docs] def forward(self, output, target): """Computes the Jaccard loss between predictions and targets. Args: output: Raw model output logits of shape (N, C, H, W). target: Ground truth tensor of the same shape as output. Returns: Scalar tensor representing 1 - mean(IoU). """ output = self.activation(output) intersection = torch.sum(output * target) union = torch.sum(output) + torch.sum(target) jac = (intersection / (union - intersection + 1e-7)).mean() return 1 - jac
[docs] class RMSELoss(torch.nn.Module): """Root Mean Squared Error loss. Computes sqrt(MSE + eps) to provide a differentiable RMSE loss that avoids numerical instability near zero. Attributes: mse: Underlying MSELoss module. eps: Small epsilon value added before the square root for numerical stability. """
[docs] def __init__(self, eps=1e-6): """Initializes RMSELoss. Args: eps: Small constant for numerical stability. Defaults to 1e-6. """ super().__init__() self.mse = torch.nn.MSELoss() self.eps = eps
[docs] def forward(self, output, target): """Computes the RMSE loss. Args: output: Predicted tensor of arbitrary shape. target: Ground truth tensor of the same shape as output. Returns: Scalar tensor representing sqrt(MSE(output, target) + eps). """ return torch.sqrt(self.mse(output, target) + self.eps)
[docs] class WeightedBCEWithLogitsLoss(torch.nn.Module): """Weighted Binary Cross-Entropy loss with logits. Applies separate weights to positive and negative samples in the binary cross-entropy computation. Attributes: w_p: Weight for positive samples. w_n: Weight for negative samples. """
[docs] def __init__(self, w_p=None, w_n=None): """Initializes WeightedBCEWithLogitsLoss. Args: w_p: Weight applied to the positive class loss term. Defaults to None. w_n: Weight applied to the negative class loss term. Defaults to None. """ super(WeightedBCEWithLogitsLoss, self).__init__() self.w_p = w_p self.w_n = w_n
[docs] def forward(self, logits, labels, epsilon=1e-7): """Computes the weighted binary cross-entropy loss. Args: logits: Raw model output logits of shape (N,) or (N, 1). labels: Binary ground truth labels of shape (N,). epsilon: Small constant to avoid log(0). Defaults to 1e-7. Returns: Scalar tensor representing the weighted BCE loss. """ ps = torch.sigmoid(logits.squeeze()) loss_pos = -1 * torch.mean(self.w_p * labels * torch.log(ps + epsilon)) loss_neg = -1 * torch.mean( self.w_n * (1 - labels) * torch.log((1 - ps) + epsilon) ) loss = loss_pos + loss_neg return loss
[docs] class ContrastiveLoss(torch.nn.Module): """Contrastive loss for siamese networks. Encourages embeddings of similar pairs to be close together and embeddings of dissimilar pairs to be at least margin apart. Attributes: margin: Minimum distance margin between negative pairs. distance_func: Optional custom distance function. If None, pairwise Euclidean distance is used. label_transform: Optional transformation applied to target labels before loss computation. """
[docs] def __init__(self, margin=2.0, distance_func=None, label_transform=None): """Initializes ContrastiveLoss. Args: margin: The distance margin between positive and negative class. Defaults to 2.0. distance_func: Custom distance function to use. If None, Euclidean pairwise distance is used. Defaults to None. label_transform: Transformation function to apply on the target label, e.g., lambda label: label[:, 0]. Defaults to None. """ super(ContrastiveLoss, self).__init__() self.margin = margin self.distance_func = distance_func self.label_transform = label_transform
[docs] def forward(self, embeddings: torch.Tensor, label: torch.Tensor): """Computes the contrastive loss for a pair of embeddings. Args: embeddings: A tuple of two tensors (embeddings1, embeddings2), each of shape (N, D) where D is the embedding dimension. label: Tensor of shape (N,). A value of 1 indicates a positive (similar) pair; 0 indicates a negative (dissimilar) pair. Returns: Scalar tensor representing the mean contrastive loss. """ embeddings1, embeddings2 = embeddings distance = ( self.distance_func(embeddings) if self.distance_func else F.pairwise_distance(embeddings1, embeddings2) ) label = self.label_transform(label) if self.label_transform else label pos = label * torch.pow(distance, 2) neg = (1 - label) * torch.pow(torch.clamp(self.margin - distance, min=0.0), 2) loss_contrastive = torch.mean(pos + neg) return loss_contrastive
[docs] class AngularPenaltySMLoss(torch.nn.Module): """Angular Penalty Softmax Loss for deep face recognition. Implements three angular margin-based softmax losses: - **ArcFace**: Additive angular margin loss. See `ArcFace <https://arxiv.org/abs/1801.07698>`_. - **SphereFace**: Multiplicative angular margin loss. See `SphereFace <https://arxiv.org/abs/1704.08063>`_. - **CosFace**: Additive cosine margin loss. See `CosFace <https://arxiv.org/abs/1801.05599>`_. Attributes: s: Scaling factor for the logits. m: Angular or cosine margin penalty. loss_type: One of 'arcface', 'sphereface', or 'cosface'. in_features: Size of the input feature vector. out_features: Number of output classes. fc: Fully connected layer mapping input features to class logits (without bias). eps: Small epsilon for numerical stability in acos clamping. """
[docs] def __init__( self, in_features, out_features, loss_type="arcface", eps=1e-7, s=None, m=None ): """Initializes AngularPenaltySMLoss. Args: in_features: Dimensionality of the input feature embeddings. out_features: Number of target classes. loss_type: Type of angular penalty loss. Must be one of 'arcface', 'sphereface', or 'cosface'. Defaults to 'arcface'. eps: Small constant for numerical stability when clamping values for acos. Defaults to 1e-7. s: Scaling factor for logits. If None, uses the default for the chosen loss type (64.0 for arcface/sphereface, 30.0 for cosface). Defaults to None. m: Margin penalty. If None, uses the default for the chosen loss type (0.5 for arcface, 1.35 for sphereface, 0.4 for cosface). Defaults to None. Raises: AssertionError: If loss_type is not one of the supported types. """ super(AngularPenaltySMLoss, self).__init__() loss_type = loss_type.lower() assert loss_type in ["arcface", "sphereface", "cosface"] if loss_type == "arcface": self.s = 64.0 if not s else s self.m = 0.5 if not m else m if loss_type == "sphereface": self.s = 64.0 if not s else s self.m = 1.35 if not m else m if loss_type == "cosface": self.s = 30.0 if not s else s self.m = 0.4 if not m else m self.loss_type = loss_type self.in_features = in_features self.out_features = out_features self.fc = torch.nn.Linear(in_features, out_features, bias=False) self.eps = eps
[docs] def forward(self, x, labels): """Computes the angular penalty softmax loss. Args: x: Input feature embeddings of shape (N, in_features). labels: Ground truth class labels of shape (N,), with values in the range [0, out_features). Returns: Scalar tensor representing the negative mean log probability. Raises: AssertionError: If input and labels have mismatched batch sizes, or if labels contain values outside the valid range. """ assert len(x) == len(labels) assert torch.min(labels) >= 0 assert torch.max(labels) < self.out_features for W in self.fc.parameters(): W = F.normalize(W, p=2, dim=1) x = F.normalize(x, p=2, dim=1) wf = self.fc(x) if self.loss_type == "cosface": numerator = self.s * (torch.diagonal(wf.transpose(0, 1)[labels]) - self.m) if self.loss_type == "arcface": numerator = self.s * torch.cos( torch.acos( torch.clamp( torch.diagonal(wf.transpose(0, 1)[labels]), -1.0 + self.eps, 1 - self.eps, ) ) + self.m ) if self.loss_type == "sphereface": numerator = self.s * torch.cos( self.m * torch.acos( torch.clamp( torch.diagonal(wf.transpose(0, 1)[labels]), -1.0 + self.eps, 1 - self.eps, ) ) ) excl = torch.cat( [ torch.cat((wf[i, :y], wf[i, y + 1 :])).unsqueeze(0) for i, y in enumerate(labels) ], dim=0, ) denominator = torch.exp(numerator) + torch.sum(torch.exp(self.s * excl), dim=1) L = numerator - torch.log(denominator) return -torch.mean(L)