A PyTorch implementation of the original DDPM, with a focus on improving training efficiency.
DDPM trained on CIFAR10
Install PyTorch with version 2+ and FFCV.
To install FFCV without using Conda:
sudo apt install libopencv-dev libturbojpeg-dev
pip install -r requirements.txt
To train the model:
python ddpm.py
Configurations are in ddpm.py
, specified by 2 classes: ModelConfig
and TrainerConfig
.
We measure time and floating point operations.
Time is measured with torch.cuda.Event
. We measure the time elapsed for 1 iteration (forward, backward, weight update).
We estimate the number of FLOPs for the forward pass using torchinfo.summary
and multiply by 3 to get the forward and backward pass.
Our target GPU A100 has 19.5 TFLOP/s of FP32 compute, 312 TFLOP/s of FP16 compute, and 1.555 TB/s of memory bandwidth.
Version | TFLOPs per Second | Speedup | MFU (%) |
---|---|---|---|
Baseline | 2.096 | 1.00x | 10.75 |
Scale Up Batch Size | 2.404 | 1.15x | 12.33 |
torch.compile |
3.24 | 1.54x | 16.62 |
Use PyTorch SDPA | 6.06 | 2.89x | 31.08 |
Mixed-Precision Training | 16.47 | 7.86x | 5.28 |
Use FFCV for Data Loading | 16.70 | 7.97x | 5.35 |
We scale up the batch size to 320, the maximum batch size the GPU can hold.
Using Flash Attention, we significantly lower our memory usage, which allow us to scale the batch size to 2048.
Lowering to 16-bit precision lowers the memory requirements, so we further increase the batch size to 3072.
Although the performance increased by 4x, the GPU has a lot more FP16 compute, so the MFU decreases.
Future works include optimizing diffusion models for inference and implmenting custom kernels.