Source code for deepml.transforms
import numpy as np
import torch
import torchvision
from deepml import constants
[docs]
class AlbumentationTorchTranforms:
"""
This class is a composition of albumentations augmentation and
torchvision.transforms.ToTensor()
This first applies albumentations transformations followed by
torch transforms if any.
albumentations transforms gets applied on both image and mask, however the
torch transforms gets applied on only on input image and not on
the target mask.
"""
[docs]
def __init__(self, albu_transforms=None, torch_transforms=None):
super(AlbumentationTorchTranforms, self).__init__()
self.albu_transforms = albu_transforms
self.to_tensor = torchvision.transforms.ToTensor()
self.torch_transforms = torch_transforms
"""
Accepts image and mask in python dict as PIL.Image or np.ndarray
return torch tensor
"""
def __call__(self, image, mask):
if type(image) != np.ndarray:
image = np.array(image)
if type(mask) != np.ndarray:
mask = np.array(mask)
if self.albu_transforms is not None:
augmented = self.albu_transforms(image=image, mask=mask)
image, mask = augmented["image"], augmented["mask"]
if self.torch_transforms is not None:
image = self.torch_transforms(image)
if not isinstance(image, torch.Tensor):
image = self.to_tensor(image)
mask = torch.from_numpy(mask).astype(torch.FloatTensor)
return image, mask
[docs]
class ImageInverseTransform:
"""Implementation of the inverse transform for image using mean and std_dev
Accepts image_batch in #B, #C, #H #W order
"""
[docs]
def __init__(self, mean, std):
super(ImageInverseTransform, self).__init__()
self.mean = torch.tensor(mean)
self.std = torch.tensor(std)
def __call__(self, image_batch):
self.mean = self.mean.to(image_batch.device)
self.std = self.std.to(image_batch.device)
return image_batch * self.std[:, None, None] + self.mean[:, None, None]
[docs]
class ImageNetInverseTransform(ImageInverseTransform):
"""
Imagenet inverse transform
accepts image_batch in #B, #C, #H #W order
"""
[docs]
def __init__(self):
super(ImageNetInverseTransform, self).__init__(
constants.IMAGENET_MEAN, constants.IMAGENET_STD
)
[docs]
class DivideBy255:
"""
Divide by 255
"""
def __call__(self, image_batch):
return image_batch / 255
[docs]
class MulticlassSegmentationTargetTransform:
"""
Converts categorical class index tensor into one-hot vector required for multiclass segmentation.
"""
def __call__(self, target):
assert target.ndim == 2 # H,W
return torch.stack(
[(target == class_index) for class_index in range(self.num_classes)]
).to(torch.float32)