Source code for modlee.model.callbacks

from functools import partial
import pickle
from typing import Any, Optional
import numpy as np
import logging

import os
import pandas as pd
import torch
from torch.utils.data import Dataset

import lightning.pytorch as pl
from lightning.pytorch import LightningModule, Trainer
from lightning.pytorch.callbacks import Callback
from lightning.pytorch.utilities.types import STEP_OUTPUT

import modlee
from modlee import data_metafeatures, save_run, get_code_text_for_model, save_run_as_json
from modlee import logging, utils as modlee_utils, exp_loss_logger
from modlee.converter import Converter

modlee_converter = Converter()

from modlee.config import TMP_DIR, MLRUNS_DIR

import mlflow
import json

base_lightning_module = LightningModule()
base_lm_keys = list(LightningModule.__dict__.keys())


[docs] class ModleeCallback(Callback): """ Base class for Modlee-specific callbacks. """ def __init__(self) -> None: super().__init__()
[docs] def get_input(self, trainer, pl_module): """ Get an input (one element from a batch) from a trainer's dataloader. :param trainer: The trainer with the dataloader. :param pl_module: The model module, used for loading the data input to the correct device. :return: An input from the batch. """ _dataloader = trainer.train_dataloader _batch = next(iter(_dataloader)) # NOTE - how can we generalize to different input schemes? # e.g. siamese network with multiple inputs # Right now, this makes the assumption that that only the network # uses only the first element # NOTE - maybe using inspect.signature(pl_module.forward) # could help generalize to different forward() calls if type(_batch) in [list, tuple]: _input = _batch[0] else: _input = _batch # print(_batch[0].shape) # _batch = torch.Tensor(_batch[0]) try: _input = _input.to(pl_module.device) except: pass return _input
[docs] class PushServerCallback(Callback): """ Callback to push run assets to the server at the end of training. """
[docs] def on_fit_end(self, trainer: Trainer, pl_module: LightningModule) -> None: #save_run(pl_module.run_path) save_run_as_json(pl_module.run_path) return super().on_fit_end(trainer, pl_module)
[docs] class LogParamsCallback(Callback): """ Callback to log parameters at the start of training. """
[docs] def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: mlflow.log_param("batch_size", trainer.train_dataloader.batch_size) return super().on_train_start(trainer, pl_module)
[docs] class LogCodeTextCallback(ModleeCallback): """ Callback to log the model as code and text. """ def __init__(self, kwargs_to_cache={}, *args, **kwargs): """ Constructor for LogCodeTextCallback. :param kwargs_to_cache: A dictionary of kwargs to cache in the run for rebuilding the model. """ Callback.__init__(self, *args, **kwargs) self.kwargs_cache = kwargs_to_cache
[docs] def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: # log the code text as a python file # self._log_code_text(trainer=trainer, pl_module=pl_module) return super().setup(trainer, pl_module, stage)
[docs] def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: logging.info("Logging model as code (model_graph.py) and text (model_graph.txt)...") self._log_code_text(trainer=trainer, pl_module=pl_module) return super().on_train_start(trainer, pl_module)
def _log_code_text(self, trainer: Trainer, pl_module: LightningModule): """ Log the model as code and text. Converts the model through modlee.converter pipelines. :param trainer: The trainer that contains the dataloader. :param pl_module: The model as a module. """ # _get_code_text_for_model = getattr(modlee, "get_code_text_for_model", None) _get_code_text_for_model = get_code_text_for_model code_text = "" # return if _get_code_text_for_model is not None: # ==== METHOD 1 ==== # Save model as code using parsing code_text = get_code_text_for_model(pl_module, include_header=True) mlflow.log_text(code_text, "model.py") # Save variables required to rebuild the model pl_module._update_kwargs_cached() mlflow.log_dict(self.kwargs_cache, "cached_vars") # ==== METHOD 2 ==== # Save model as code by converting to a graph through ONNX input_dummy = self.get_input(trainer, pl_module) onnx_model = modlee_converter.torch2onnx(pl_module, input_dummy=input_dummy) onnx_text = modlee_converter.onnx2onnx_text(onnx_model) mlflow.log_text(onnx_text, "model_graph.txt") torch_graph_code = modlee_converter.onnx_text2code(onnx_text) mlflow.log_text(torch_graph_code, "model_graph.py") # Save model size model_size = modlee_utils.get_model_size(pl_module, as_MB=False) mlflow.log_text(str(model_size), "model_size") else: logging.warning( "Could not access model-text converter, \ not logging but continuing experiment" ) if exp_loss_logger.module_available: _extract_loss_functions = getattr( exp_loss_logger, "extract_loss_functions", None ) if _extract_loss_functions is not None: loss_calls = exp_loss_logger.extract_loss_functions(code_text) # logging.warning(loss_calls) # mlflow.log_text(code_text, 'model.py') if len(loss_calls) > 0: loss_calls_str = str.join("\n", loss_calls) mlflow.log_text(loss_calls_str, "loss_calls.txt") else: pass # logging.warning("Could not record loss functions explicitly, \ # check for usage of custom loss definitions") else: logging.warning( "exp_loss_logger has no attribute extract_loss_functions" ) else: logging.warning( "Could not access exp_loss_logger, \ not logging but continuing experiment" )
[docs] class LogONNXCallback(ModleeCallback): """ Callback for logging the model in its ONNX representations. Deprecated, will be combined with LogCodeTextCallback. """
[docs] def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: # self._log_onnx(trainer, pl_module) return super().setup(trainer, pl_module, stage)
[docs] def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: self._log_onnx(trainer, pl_module) return super().on_fit_start(trainer, pl_module)
def _log_onnx(self, trainer, pl_module): # train_input,_ = next(iter(trainer.train_dataloader)) # print(train_input) # NOTE assumes that model input is the first output of a batch modlee_utils.safe_mkdir(TMP_DIR) data_filename = f"{TMP_DIR}/model.onnx" _input = self.get_input(trainer, pl_module) model_output = pl_module.forward(_input) torch.onnx.export( pl_module, # train_input, _input, data_filename, export_params=False, ) mlflow.log_artifact(data_filename) pass
[docs] class LogOutputCallback(Callback): """ Callback to log the output metrics for each batch. """ def __init__(self, *args, **kwargs): """ Constructor for LogOutputCallback. """ Callback.__init__(self, *args, **kwargs) self.on_train_batch_end = partial(self._on_batch_end, phase="train") self.on_validation_batch_end = partial(self._on_batch_end, phase="val") self.outputs = {"train": [], "val": []} def _on_batch_end( self, trainer: Trainer, pl_module: LightningModule, outputs: STEP_OUTPUT, batch: Any, batch_idx: int, phase="train", ) -> None: """ Helper function to log output metrics on batch end. Currently catches metrics formatted as '{phase}_loss'. :param trainer: The trainer. :param pl_module: The model as a module. :param outputs: The outputs on batch end, automatically passed by the base callback. :param batch: The batch, automatically passed by the base callback. :param batch_idx: The index of the batch, automatcally passed by te base callback. :param phase: The phase of training for logging, ["train", "val"]. Defaults to "train". """ if trainer.is_last_batch: if isinstance(outputs, dict): for output_key, output_value in outputs.items(): pl_module.log(output_key, output_value) elif isinstance(outputs, list): for output_idx, output_value in outputs: pl_module.log(f"{phase}_step_output_{output_idx}", output_value) elif outputs is not None: pl_module.log(f"{phase}_loss", outputs) self.outputs[phase].append(outputs) return super().on_train_batch_end(trainer, pl_module, outputs, batch, batch_idx)
[docs] class DataMetafeaturesCallback(ModleeCallback): """ Callback to calculate and log data meta-features. """ def __init__(self, data_snapshot_size=1e7, DataMetafeatures=None, *args, **kwargs): """ Constructor for the data metafeature callback. :param data_snapshot_size: The maximum size of the cached data snapshot. :param DataMetafeatures: The DataMetafeatures module. If not provided, will not calculate metafeatures. """ Callback.__init__(self, *args, **kwargs) super().__init__() self.data_snapshot_size = data_snapshot_size if not DataMetafeatures: DataMetafeatures = getattr(data_metafeatures, "DataMetafeatures", None) self.DataMetafeatures = DataMetafeatures
[docs] def on_train_start(self, trainer: Trainer, pl_module: LightningModule) -> None: # data, targets = self._get_data_targets(trainer) # data_snapshots = self._get_snapshots_batched(trainer.train_dataloader) # self._save_snapshots_batched(data_snapshots) # log the data statistics # self._log_data_metafeatures(data, targets) logging.info("Logging data metafeatures...") self._log_data_metafeatures_dataloader(trainer.train_dataloader) self._log_output_size(trainer, pl_module) return super().on_train_start(trainer, pl_module)
def _log_data_metafeatures(self, data, targets=[]) -> None: """ Log the data metafeatures from input data and targets. Deprecated in favor of _log_data_metafeatures_dataloader. :param data: The input data. :param targets: The targets. """ if self.DataMetafeatures: if isinstance(data, torch.Tensor): data, targets = data.numpy(), targets.numpy() data_metafeatures = self.DataMetafeatures(x=data, y=targets) mlflow.log_dict(data_metafeatures.data_metafeatures, "data_metafeatures") else: logging.warning( "Could not access data statistics calculation from server, \ not logging but continuing experiment" ) def _log_data_metafeatures_dataloader(self, dataloader) -> None: """ Log data metafeatures with a dataloader. :param dataloader: The dataloader. """ if self.DataMetafeatures: # TODO - use data batch and model to get output size data_metafeatures = self.DataMetafeatures(dataloader) mlflow.log_dict(data_metafeatures._serializable_stats_rep, "stats_rep") else: logging.warning("Cannot log data statistics, could not access from server") def _log_output_size(self, trainer: Trainer, pl_module: LightningModule) -> None: """ Log the output size of the model. :param trainer: The trainer. :param pl_module: The model as a module. """ # _dataloader = trainer.train_dataloader # _batch = next(iter(_dataloader)) # # NOTE - how can we generalize to different input schemes? # # e.g. siamese network with multiple inputs # # Right now, this makes the assumption that that only the network # # uses only the first element # # NOTE - maybe using inspect.signature(pl_module.forward) # # could help generalize to different forward() calls # if type(_batch) in [list,tuple]: # _batch = _batch[0] # # print(_batch[0].shape) # # _batch = torch.Tensor(_batch[0]) # try: # _batch = _batch.to(pl_module.device) # except: # pass _input = self.get_input(trainer, pl_module) try: _output = pl_module.forward(_input) output_shape = list(_output.shape[1:]) mlflow.log_param("output_shape", output_shape) except: logging.warning( "Cannot log output shape, could not pass batch through network" ) def _get_data_targets(self, trainer: Trainer): """ Get the data and targets from a trainer's dataloader. :param trainer: The trainer. :return: The data and targets. """ _dataset = trainer.train_dataloader.dataset if isinstance(_dataset, list): data = np.array(_dataset) elif isinstance(_dataset, torch.utils.data.dataset.IterableDataset): data = list(_dataset) # data = np.array(list(_dataset)) else: if isinstance(_dataset, torch.utils.data.Subset): _dataset = _dataset.dataset data = _dataset.data if isinstance(data, torch.Tensor): data = data.numpy() # data = _dataset.data.numpy() self._save_snapshot(data, "data") targets = [] if hasattr(_dataset, "targets"): targets = _dataset.targets self._save_snapshot(targets, "targets", max_len=len(data)) return data, targets def _save_snapshot(self, data, snapshot_name="data", max_len=None): """ Save a snapshot of data. :param data: The data to save. :param snapshot_name: The name to save the data. :param max_len: The maximum length of the data. """ data = self._get_snapshot(data=data, max_len=max_len) modlee_utils.safe_mkdir(TMP_DIR) data_filename = f"{TMP_DIR}/{snapshot_name}_snapshot.npy" np.save(data_filename, data) mlflow.log_artifact(data_filename) def _get_snapshot(self, data, max_len=None): """ Get a snapshot of data. :param data: The data. :param max_len: The maximum length of the snapshot. :return: A snapshot of the data. """ if isinstance(data, torch.Tensor): data = data.numpy() elif not isinstance(data, np.ndarray): data = np.array(data) if max_len is None: data_size = data.nbytes # take a slice that should be no larger than 10MB max_len = int( np.min([(len(data) * self.data_snapshot_size) // data_size, len(data)]) ) return data[:max_len] def _save_snapshots_batched(self, data_snapshots): """ Save batches of data snapshots. :param data_snapshots: A batch of data snapshots. """ modlee_utils.safe_mkdir(TMP_DIR) for snapshot_idx, data_snapshot in enumerate(data_snapshots): data_filename = f"{TMP_DIR}/snapshot_{snapshot_idx}.npy" np.save(data_filename, data_snapshot) mlflow.log_artifact(data_filename) def _get_snapshots_batched(self, dataloader, max_len=None): """ Get a batch of data snapshots. :param dataloader: The dataloader of the data to snapshot. :param max_len: The maximum length of the snapshot. :return: A batch of data snapshots. """ # Use batch to determine how many "sub"batches to create _batch = next(iter(dataloader)) data_snapshot_size = self.data_snapshot_size if type(_batch) in [list, tuple]: n_snapshots = len(_batch) else: n_snapshots = 1 data_snapshots = [np.array([])] * n_snapshots # Keep appending to batches until the combined size reaches the limit batch_ctr = 0 while np.sum([ds.nbytes for ds in data_snapshots]) < data_snapshot_size: _batch = next(iter(dataloader)) # If there are multiple elements in the batch, # append to respective subbranches if type(_batch) in [list, tuple]: for batch_idx, _subbatch in enumerate(_batch): if str(_subbatch.device) != "cpu": _subbatch = _subbatch.cpu() if data_snapshots[batch_idx].size == 0: data_snapshots[batch_idx] = _subbatch.numpy() else: data_snapshots[batch_idx] = np.vstack( [data_snapshots[batch_idx], (_subbatch.numpy())] ) else: if data_snapshots[0].size == 0: data_snapshots[0] = _batch.numpy() else: data_snapshots[0] = np.vstack([data_snapshots[0], (_batch.numpy())]) batch_ctr += 1 return data_snapshots
[docs] class LogTransformsCallback(ModleeCallback): """ Logs transforms applied to the dataset, if applied with torchvision.transforms """
[docs] def on_train_start(self, trainer, pl_module): dataset = trainer.train_dataloader.dataset if hasattr(dataset, "transforms"): mlflow.log_text(str(dataset.transform), "transforms.txt")