Source code for modlee.trainer

""" 
Modlee Trainer.
"""
import os
import mlflow
from modlee.utils import last_run_path
from lightning import pytorch as pl

[docs] class Trainer(pl.Trainer): """ Trainer that directs checkpoint files to the current run directory. """ def __init__(self, *args, **kwargs): artifacts_dir = os.path.join(last_run_path(), 'artifacts') kwargs['default_root_dir'] = kwargs.get('default_root_dir', artifacts_dir) super().__init__(*args, **kwargs)