Skip to content
/ ddpm Public

A PyTorch implementation of the original DDPM, with a focus on improving training efficiency.

Notifications You must be signed in to change notification settings

kimbochen/ddpm

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

3 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Denoising Diffusion Probabalistic Models

A PyTorch implementation of the original DDPM, with a focus on improving training efficiency.

DDPM trained on CIFAR10

Setup

Install PyTorch with version 2+ and FFCV.
To install FFCV without using Conda:

sudo apt install libopencv-dev libturbojpeg-dev
pip install -r requirements.txt

To train the model:

python ddpm.py

Configurations are in ddpm.py, specified by 2 classes: ModelConfig and TrainerConfig.

Performance Optimizations

How Performance is Measured

We measure time and floating point operations.
Time is measured with torch.cuda.Event. We measure the time elapsed for 1 iteration (forward, backward, weight update).
We estimate the number of FLOPs for the forward pass using torchinfo.summary and multiply by 3 to get the forward and backward pass.
Our target GPU A100 has 19.5 TFLOP/s of FP32 compute, 312 TFLOP/s of FP16 compute, and 1.555 TB/s of memory bandwidth.

Optimization Techniques

Version TFLOPs per Second Speedup MFU (%)
Baseline 2.096 1.00x 10.75
Scale Up Batch Size 2.404 1.15x 12.33
torch.compile 3.24 1.54x 16.62
Use PyTorch SDPA 6.06 2.89x 31.08
Mixed-Precision Training 16.47 7.86x 5.28
Use FFCV for Data Loading 16.70 7.97x 5.35

Scale Up Batch Size

We scale up the batch size to 320, the maximum batch size the GPU can hold.

Use PyTorch SDPA

Using Flash Attention, we significantly lower our memory usage, which allow us to scale the batch size to 2048.

Mixed-Precision Training

Lowering to 16-bit precision lowers the memory requirements, so we further increase the batch size to 3072.
Although the performance increased by 4x, the GPU has a lot more FP16 compute, so the MFU decreases.

Use FFCV for Data Loading

Future Works

Future works include optimizing diffusion models for inference and implmenting custom kernels.

References

About

A PyTorch implementation of the original DDPM, with a focus on improving training efficiency.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published