Source code for deepml.base

import abc
import os
from collections import OrderedDict
from typing import Callable, Dict, Optional, Tuple, Union

import torch

from deepml.tasks import Task
from deepml.tracking import MLExperimentLogger


[docs] class BaseLearner(abc.ABC):
[docs] def __init__( self, task: Task, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, lr_scheduler_fn: Optional[ Callable[[torch.optim.Optimizer], torch.optim.lr_scheduler._LRScheduler] ] = None, lr_scheduler_step_policy: str = "epoch", ): assert isinstance(task, Task) assert (lr_scheduler is None) or ( lr_scheduler_fn is None ), "Either lr_scheduler or lr_scheduler_fn can be provided, not both." self._task = task self._model = self._task.model self._model_dir = self._task.model_dir self._model_file_name = self._task.model_file_name self._optimizer = None self._criterion = None self._lr_scheduler = lr_scheduler self._lr_scheduler_fn = lr_scheduler_fn self._lr_scheduler_step_policy = None self.logger = None self.set_optimizer(optimizer) self.set_criterion(criterion) self.set_lr_scheduler_policy(lr_scheduler_step_policy)
[docs] def set_optimizer(self, optimizer: torch.optim.Optimizer): assert isinstance(optimizer, torch.optim.Optimizer) self._optimizer = optimizer
[docs] def set_criterion(self, criterion: torch.nn.Module): assert isinstance(criterion, torch.nn.Module) self._criterion = criterion
[docs] def set_lr_scheduler_policy(self, lr_scheduler_step_policy: str = "epoch"): assert isinstance( lr_scheduler_step_policy, str ) and lr_scheduler_step_policy in ["epoch", "step"] self._lr_scheduler_step_policy = lr_scheduler_step_policy
[docs] @staticmethod def load_optimizer_state(optimizer: torch.optim.Optimizer, state_dict: dict): if "optimizer_state_dict" in state_dict: optimizer.load_state_dict(state_dict["optimizer_state_dict"])
[docs] @staticmethod def load_lr_schedular_state( lr_scheduler: torch.optim.lr_scheduler._LRScheduler, state_dict: dict ): if "scheduler_state_dict" in state_dict: lr_scheduler.load_state_dict(state_dict["scheduler_state_dict"])
[docs] def create_state_dict( self, model: torch.nn.Module, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, epoch: int = -1, train_loss: float = float("inf"), val_loss: float = float("inf"), ): state_dict = { "model_state_dict": model.state_dict(), "optimizer": optimizer.__class__.__name__, "optimizer_state_dict": optimizer.state_dict(), "criterion": self._criterion.__class__.__name__, "epoch": epoch, "train_loss": train_loss, "val_loss": val_loss, } if lr_scheduler is not None: state_dict["scheduler"] = lr_scheduler.__class__.__name__ state_dict["scheduler_state_dict"] = lr_scheduler.state_dict() return state_dict
[docs] def save( self, tag: str, model: torch.nn.Module, optimizer: torch.optim.Optimizer, criterion: torch.nn.Module, lr_scheduler: Optional[torch.optim.lr_scheduler._LRScheduler] = None, epoch: int = -1, train_loss: float = float("inf"), val_loss: float = float("inf"), **kwargs, ): state_dict = self.create_state_dict( model=model, optimizer=optimizer, criterion=criterion, lr_scheduler=lr_scheduler, epoch=epoch, train_loss=train_loss, val_loss=val_loss, ) filepath = f"{os.path.join(self._model_dir, tag)}.pt" torch.save(state_dict, filepath) self.logger.log_model(tag, model, epoch, artifact_path=filepath) return filepath
[docs] @staticmethod def init_metrics(metrics: Dict[str, torch.nn.Module]) -> OrderedDict[str, float]: metrics_dict = OrderedDict({"loss": 0.0}) if metrics is None: return metrics_dict for metric_name, _ in metrics.items(): if metric_name == "loss": raise ValueError("Metric name 'loss' is reserved of criterion") metrics_dict[metric_name] = 0.0 return metrics_dict
[docs] @staticmethod def update_metrics( outputs: torch.Tensor, targets: torch.Tensor, metrics_instance_dict: Dict[str, torch.nn.Module], target_metrics_dict: OrderedDict[str, float], ): if metrics_instance_dict is None: return for metric_name, metric_instance in metrics_instance_dict.items(): target_metrics_dict[metric_name] = metric_instance(outputs, targets)
[docs] @staticmethod def update_metrics_with_simple_moving_average( source_metrics_dict: Dict[str, torch.nn.Module], target_metrics_dict: OrderedDict[str, float], step: int, ): for metric_name, metric_value in source_metrics_dict.items(): target_metrics_dict[metric_name] = target_metrics_dict[metric_name] + ( metric_value.mean().item() - target_metrics_dict[metric_name] ) / float(step)
[docs] @staticmethod def write_metrics_to_logger( metrics_dict: dict, tag: str, global_step: int, logger: MLExperimentLogger, history: dict, ): for name, value in metrics_dict.items(): logger.log_metric(f"{name}/{tag}", value, global_step) history[f"{tag}_{name}"].append(value)
[docs] @staticmethod def write_lr( optimizer, global_step: int, logger: MLExperimentLogger, history: dict ): # Write lr to tensor-board and history dict if len(optimizer.param_groups) == 1: param_group = optimizer.param_groups[0] logger.log_metric("learning_rate", param_group["lr"], global_step) history["learning_rate"].append(param_group["lr"]) else: for index, param_group in enumerate(optimizer.param_groups): logger.log_metric( f"learning_rate/param_group_{index}", param_group["lr"], global_step ) history[f"learning_rate/param_group_{index}"].append(param_group["lr"])
[docs] def log_metrics( self, val_loader: torch.utils.data.DataLoader, train_metrics: dict, val_metrics: dict, metrics_history: dict, epochs_completed: int, logger_img_size: Union[int, Tuple[int, int]], image_inverse_transform: Callable, ): BaseLearner.write_metrics_to_logger( train_metrics, "train", epochs_completed, self.logger, metrics_history, ) BaseLearner.write_metrics_to_logger( val_metrics, "val", epochs_completed, self.logger, metrics_history, ) # write random val images to tensorboard if logger_img_size is not None: self._task.write_prediction_to_logger( "val", val_loader, self.logger, image_inverse_transform, epochs_completed, img_size=logger_img_size, )
[docs] def fit(self, *args, **kwargs): raise NotImplementedError("Subclass should implement this method.")
[docs] def predict(self, *args, **kwargs): raise NotImplementedError("Subclass should implement this method.")