Scalable Protein Language Model Finetuning with Distributed Learning and Advanced Training Techniques such as LoRA.
Report Bug
·
Request Feature
Table of Contents
This project explores scalable and efficient finetuning of protein language models like ESM-2 using advanced training techniques like FSDP (Fully Sharded Data Parallel) and LoRA (Low-Rank Adaptation).
Highlights
- Distributed training: Leverage distributed computing for finetuning large protein language models on multiple GPUs.
- Advanced techniques: Explore LoRA and other methods to improve finetuning efficiency and performance.
- Reproducibility: Track and manage finetuning experiments using tools like MLflow.
- Clone the repo:
git clone https://github.com/naity/finetune-esm.git
- Run the
train.py
script to see a list of available parameters:
python finetune-esm/train.py --help
The requirements.txt
file lists the Python packages that need to be installed in order to run the scripts. Please use the command below for installation.
pip install -r requirements.txt
In this example, we will finetune ESM-2 for the CAFA 5 Protein Function Prediction Challenge to predict the biological function of a protein based on its primary sequence. I have already preprocessed the data and formatted the problem as a multi-class, multi-label problem. This means that for a given protein sequence, we will predict whether it is positive for each of the 100 preselected Gene Ontology (GO) terms. Thus, the target for each protein sequence is a binary vector with a length of 100.
The processed datasets can be downloaded from here. Details about the preprocessing steps can be found in the notebooks/cafa5_data_processing.ipynb
notebook.
Run the following example command to finetune ESM-2 models with the processed datasets. Here, we are using the smallest model esm2_t6_8M_UR50D
with 1 GPU and the LoRA approach. If you want to finetune a larger model and have multiple GPUs, please adjust num_workers
and/or num-devices
accordingly.
python finetune_esm/train.py \
--experiment-name esm2_t6_8M_UR50D_lora \
--dataset-loc data/cafa5/top100_train_split.parquet \
--targets-loc data/cafa5/train_bp_top100_targets.npy \
--esm-model esm2_t6_8M_UR50D \
--num-workers 1 \
--num-devices 1 \
--training-mode lora \
--learning-rate 0.0001 \
--num-epochs 5
Once training is done, we can use MLflow to view the experiment using:
mlflow server --host 127.0.0.1 --port 8080 --backend-store-uri ./finetune_results/mlflow
Below are screenshots of the example experiment. We can view parameters, artifacts, and visualize the metrics results.
- Data Processing
- Training
- Serving
See the open issues for a full list of proposed features (and known issues).
Contributions are what make the open source community such an amazing place to learn, inspire, and create. Any contributions you make are greatly appreciated.
If you have a suggestion that would make this better, please fork the repo and create a pull request. You can also simply open an issue with the tag "enhancement". Don't forget to give the project a star! Thanks again!
- Fork the Project
- Create your Feature Branch (
git checkout -b feature/AmazingFeature
) - Commit your Changes (
git commit -m 'Add some AmazingFeature'
) - Push to the Branch (
git push origin feature/AmazingFeature
) - Open a Pull Request
Distributed under the MIT License. See LICENSE.txt
for more information.