import importlib.resources as resources
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
font_resource = resources.files("deepml.resources.fonts").joinpath("OpenSans-Light.ttf")
FONT = ImageFont.truetype(str(font_resource), 16)
[docs]
def create_text_image(text, img_size=(224, 224), text_color="black"):
image = Image.new("RGB", img_size, color=(255, 255, 255))
img_width, img_height = img_size
draw = ImageDraw.Draw(image)
text_width, text_height = draw.textsize(text, font=FONT)
draw.text(
((img_width - text_width) / 2, (img_height - text_height) / 2),
text,
fill=text_color,
align="center",
font=FONT,
)
return image
[docs]
def get_random_samples_batch_from_loader(loader, samples=None):
if len(loader) == 0:
raise ValueError("Loader is empty")
sample_batch: list = get_random_samples_batch_from_dataset(
loader.dataset, loader.batch_size if samples is None else samples
)
return loader.collate_fn(sample_batch)
[docs]
def get_random_samples_batch_from_dataset(dataset, samples=8) -> list:
"""
Returns a random batch of samples from the dataset.
:param dataset: torch.utils.data.Dataset or any iterable dataset
:param samples: no. of samples to return, defaults to 8
:return: list of samples from the dataset
"""
if len(dataset) == 0:
raise ValueError("Dataset is empty")
indexes = np.random.randint(0, len(dataset), samples)
return [dataset[index] for index in indexes]
[docs]
def blend(
image: torch.Tensor, mask: torch.Tensor, alpha: float = 0.6, beta: float = 0.4
) -> np.array:
"""
Blends an input image with a mask using specified alpha and beta values.
:param image: torch.Tensor of size BCHW, Grayscale or RGB image to blend with the mask of size #HWC or #HW
:param mask: torch.Tensor, torch.Tensor of size BCHW , mask to blend with the input image of size #HWC or #HW
:param alpha: alpha blending factor for the RGB image
:param beta: beta blending factor for the mask
:return: torch.Tensor of original size, blended image
"""
assert image.ndim == 4, "Image must be a 4D tensor of size BCHW"
if image.shape[1] == 1: # Grayscale image
if mask.ndim == 3: # mask dim is B, H, W
# match dimensions
mask = mask.unsqueeze(1)
elif image.shape[1] == 3: # RGB image
if mask.ndim == 4 and mask.shape[1] == 1:
# Expand mask to match RGB channels
# -1 keep the dimension unchanged
# expand does not allocate new memory, it just creates a view, however repeat() creates a copy of new tensor
mask = mask.expand(-1, 3, -1, -1)
elif mask.ndim == 3: # mask dim is B, H, W
# Expand mask to match RGB channels
mask = mask.unsqueeze(1).expand(-1, 3, -1, -1)
assert (
image.shape == mask.shape
), f"Input image shape {image.shape} and mask shape {mask.shape} must match"
return (image * alpha + mask * beta).clip(0, 255).to(torch.uint8)