Neural network training environment (including various MLOps tools). PyTorch Lightning and Hydra serve as the foundation upon this template. It is designed to streamline experimentation, foster modularity, and simplify tracking and reproducibility:
✅ Minimal boilerplate code (easily add new models, datasets, tasks, experiments, and different accelerator configurations).
✅ Logging experiments to one place for easier comparison of performance metrics.
✅ Hyperparameter search integration.
# clone project
git clone https://github.com/lus105/DeepTrainer.git
# change directory
cd DeepTrainer
# update conda
conda update -n base conda
# create conda environment and install dependencies
conda env create -f environment.yaml -n DeepTrainer
# activate conda environment
conda activate DeepTrainer
Train mnist model with default configuration (check if environment is properly set up):
# train on CPU (mnist dataset)
python src/train.py trainer=cpu
# train on GPU (mnist dataset)
python src/train.py trainer=gpu
The structure of a machine learning project can vary depending on the specific requirements and goals of the project, as well as the tools and frameworks being used. However, here is a general outline of a common directory structure for a machine learning project:
src/
data/
logs/
tests/
- some additional directories, like:
notebooks/
,docs/
, etc.
In this project, the directory structure looks like:
├── configs <- Hydra configuration files
│ ├── callbacks <- Callbacks configs
│ ├── data <- Datamodule configs
│ ├── debug <- Debugging configs
│ ├── experiment <- Experiment configs
│ ├── extras <- Extra utilities configs
│ ├── hparams_search <- Hyperparameter search configs
│ ├── hydra <- Hydra settings configs
│ ├── local <- Local configs
│ ├── logger <- Logger configs
│ ├── model <- Model configs
│ ├── paths <- Project paths configs
│ ├── trainer <- Trainer configs
│ │
│ ├── eval.yaml <- Main config for evaluation
│ └── train.yaml <- Main config for training
│
├── data <- Project data
├── docs <- Resources, additional readme files
├── logs <- Logs generated by hydra, lightning loggers, etc.
├── notebooks <- Jupyter notebooks.
├── scripts <- Shell/bat scripts
│
├── src <- Source code
│ ├── data <- Lightning datamodules
│ ├── models <- Lightning modules
│ ├── utils <- Utility scripts
│ │
│ ├── eval.py <- Run evaluation
│ └── train.py <- Run training
│
├── tests <- Tests of any kind
│
├── .dockerignore <- List of files ignored by docker
├── .env.example <- File for environment variables
├── .gitignore <- List of files ignored by git
├── .pre-commit-config.yaml <- Pre-commit hooks for code formatting
├── Dockerfile <- Dockerfile
├── environment.yaml <- Conda environment creation file
├── pyproject.toml <- Configuration options for testing and linting
├── README.md <- Readme file
└── requirements.txt <- File for installing python dependencies
Configuration
This part of the diagram illustrates how configuration files (train.yaml, eval.yaml, model.yaml, etc.) are used to manage different aspects of the project, such as data preprocessing, model parameters, and training settings.
Hydra Loader
The diagram shows how Hydra loads all configuration files and combines them into a single configuration object (DictConfig). This unified configuration object simplifies the management of settings across different modules and aspects of the project, such as data handling, model specifics, callbacks, logging, and the training process.
Train/eval script
This section represents the operational part of the project. Scripts train.py and eval.py are required for training and evaluatging the model. DictConfig: The combined configuration object passed to these scripts, guiding the instantiation of the subsequent components.
-
LightningDataModule: manages data loading and processing specific to training, validation, testing and predicting phases.
-
LightningModule (model): defines the model, including the computation that transforms inputs into outputs, loss computation, and metrics.
-
Callbacks: provide a way to insert custom logic into the training loop, such as model checkpointing, early stopping, etc.
-
Logger: handles the logging of training, testing, and validation metrics for monitoring progress.
-
Trainer: the central object in PyTorch Lightning that orchestrates the training process, leveraging all the other components.
-
The trainer uses the model, data module, logger, and callbacks to execute the training/evaluating process through the trainer.fit/test/predict methods, integrating all the configuration settings specified through Hydra.
- Write your PyTorch Lightning DataModule (see example in data/mnist_datamodule.py)
- Write your PyTorch Lightning Module (see examples in models/mnist_module.py)
- Fill up your configs, in particularly create experiment/runs configs
- Run experiments:
Run training with chosen experiment config:
python src/train.py experiment=experiment_name
Use hyperparameter search, for example by Optuna Sweeper via Hydra:
# using Hydra multirun mode
python src/train.py -m hparams_search=mnist_optuna
Execute the runs with some config parameter manually:
python src/train.py -m logger=csv model.optimizer.weight_decay=0.,0.00001,0.0001
- Run evaluation with different checkpoints or prediction on custom dataset for additional analysis
At the start, you need to create PyTorch Dataset for you task. It has to include __getitem__
and __len__
methods. See more details in PyTorch documentation.
Then, you need to create DataModule using PyTorch Lightning DataModule API.
By default, API has the following methods:
prepare_data
(optional): perform data operations on CPU via a single process, like load and preprocess data, etc.setup
(optional): perform data operations on every GPU, like train/val/test splits, create datasets, etc.train_dataloader
: used to generate the training dataloaderval_dataloader
: used to generate the validation dataloadertest_dataloader
: used to generate the test dataloaderpredict_dataloader
(optional): used to generate the prediction dataloader
See examples of LightningDataModule
in src/data and configs/data.
Next, create LightningModule using PyTorch Lightning LightningModule API. Minimum API has the following methods:
forward
: use for inference only (separate from training_step)training_step
: the complete training loopvalidation_step
: the complete validation looptest_step
: the complete test looppredict_step
: the complete prediction loopconfigure_optimizers
: define optimizers and LR schedulers
Also, you can override optional methods for each step to perform additional logic:
training_step_end
: training step end operationstraining_epoch_end
: training epoch end operationsvalidation_step_end
: validation step end operationsvalidation_epoch_end
: validation epoch end operationstest_step_end
: test step end operationstest_epoch_end
: test epoch end operations
See examples of LightningModule
in src/models and configs/model.
Training loop in the template contains the following stages:
- LightningDataModule instantiating
- LightningModule instantiating
- Callbacks instantiating
- Loggers instantiating
- Trainer instantiating
- Hyperparameteres logging
- Training the model
- Testing the best model
See more details in training loop and configs/train.yaml.
Evaluation loop in the template contains the following stages:
- LightningDataModule instantiating
- LightningModule instantiating
- Callbacks instantiating
- Loggers instantiating
- Trainer instantiating
- Hyperparameteres logging
- Evaluating model or predicting
See more details in evaluation loop and configs/eval.yaml.
PyTorch Lightning has a lot of built-in callbacks, which can be used just by adding them to callbacks config. See examples in callbacks config folder.
By default, the template contains the following callbacks:
- Model Checkpoint
- Early Stopping
- Model Summary
- Rich Progress Bar
Hydra + OmegaConf provide a flexible and efficient configuration management system that allows to dynamically create a hierarchical configurations by composition and override it through config files and the command line.
This powerful tools allow to create a simple and efficient way for managing and organizing the various configurations in one place, constructing complex configurations structure without any limits which can be essential in machine learning projects.
All of that enable to easily switch between any parameters and try different configurations without having to manually update the code.
A decorator hydra.main
is supposed to be used to load Hydra config during launching of the pipeline. Here a config
is being parser by Hydra grammar parser, merged, composed and passed to the pipeline main function.
import hydra
from omegaconf import DictConfig, OmegaConf
from src.train import train
@hydra.main(version_base="1.3", config_path=".", config_name="train.yaml")
def main(cfg: DictConfig) -> None:
print(OmegaConf.to_yaml(cfg))
train(cfg)
if __name__ == "__main__":
main()
For object instantiating from a config, it should contain _target_
key with class
or function
name which
would be instantiated with other parameters passed to config. Hydra provides hydra.utils.instantiate()
(and its alias
hydra.utils.call()
) for instantiating objects and calling class
or function
. Prefer instantiate
for creating
objects and call
for invoking functions.
loss:
_target_: "torch.nn.CrossEntropyLoss"
metric:
_target_: "torchmetrics.Accuracy"
task: "multiclass"
num_classes: 10
top_k: 1
Based on such config you could instantiate loss via loss = hydra.utils.instantiate(config.loss)
and metric via
metric = hydra.utils.instantiate(config.metric)
.
It supports few operations from the command line as well:
- Override existing config value by passing it as well
- Add a new config value, which doesn't exist in the config, by using
+
- Override a config value if it's already in the config, or add it otherwise, by using
++
# train the model with the default config
python src/train.py
# train the model with the overridden parameter
python src/train.py model.compile=true
# train the model with the overridden parameter and add a new parameter
python src/train.py model.compile=true ++model.model_repo="repository"
Hydra provides out-of-the-box hyperparameters sweepers: Optuna, Nevergrad or Ax.
You can define hyperparameters search by adding new config file to configs/hparams_search. See example of hyperparameters search config. With this method, there is no need to add extra code, everything is specified in a single configuration file. The only requirement is to return the optimized metric value from the launch file.
Execute it with:
python src/train.py -m hparams_search=mnist_optuna
The optimization_results.yaml
will be available under logs/task_name/multirun
folder.
Build docker container:
docker build -t deeptrainer \
--build-arg CUDA_VERSION=12.5.1 \
--build-arg OS_VERSION=22.04 \
--build-arg PYTHON_VERSION=3.11 \
--build-arg USER_ID=$(id -u) \
--build-arg GROUP_ID=$(id -g) \
--build-arg NAME=$(whoami) \
--build-arg WORKDIR_PATH=/test .
Run container:
docker run \
-it \
--rm \
--gpus all \
--name my_deeptrainer_container \
-v host/data:/data \
deeptrainer
Linting all files in the project: To run Ruff as a linter, try any of the following:
ruff check # Lint all files in the current directory (and any subdirectories).
ruff check path/to/code/ # Lint all files in `/path/to/code` (and any subdirectories).
ruff check path/to/code/*.py # Lint all `.py` files in `/path/to/code`.
ruff check path/to/code/to/file.py # Lint `file.py`.
Or, to run Ruff as a formatter:
ruff format # Format all files in the current directory (and any subdirectories).
ruff format path/to/code/ # Format all files in `/path/to/code` (and any subdirectories).
ruff format path/to/code/*.py # Format all `.py` files in `/path/to/code`.
ruff format path/to/code/to/file.py # Format `file.py`.
Tests:
# run all tests
pytest
# run tests from specific file
pytest tests/test_train.py
# run all tests except the ones marked as slow
pytest -k "not slow"