Skip to content

Commit 091c8a5

Browse files
nngokhaleZhengHongming888yafshar
authored
Fix deepseeed crash with Sentence Transformer Trainer (huggingface#1328)
Co-authored-by: ZhengHongming888 <hongming.zheng@intel.com> Co-authored-by: Yaser Afshar <yaser.afshar@intel.com>
1 parent 1a52079 commit 091c8a5

File tree

12 files changed

+236
-33
lines changed

12 files changed

+236
-33
lines changed

examples/sentence-transformers-training/nli/README.md

+26-2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ Given two sentences (premise and hypothesis), the task of Natural Language Infer
44

55
The paper in [Conneau et al.](https://arxiv.org/abs/1705.02364) shows that NLI data can be quite useful when training Sentence Embedding methods. In [Sentence-BERT-Paper](https://arxiv.org/abs/1908.10084) NLI as a first fine-tuning step for sentence embedding methods has been used.
66

7+
# General Models
8+
79
## Single-card Training
810

911
To pre-train on the NLI task:
@@ -46,7 +48,29 @@ For multi-card training you can use the script of [gaudi_spawn.py](https://githu
4648
HABANA_VISIBLE_MODULES="2,3" python ../../gaudi_spawn.py --use_deepspeed --world_size 2 training_nli.py bert-base-uncased
4749
```
4850

49-
## Dataset
51+
52+
# Large Models (intfloat/e5-mistral-7b-instruct)
53+
54+
## Single-card Training with LoRA+gradient_checkpointing
55+
56+
Pretraining the `intfloat/e5-mistral-7b-instruct` model requires approximately 130GB of memory, which exceeds the capacity of a single HPU (Gaudi 2 with 98GB memory). To address this, we can utilize LoRA and gradient checkpointing techniques to reduce the memory requirements, making it feasible to train the model on a single HPU.
57+
58+
```bash
59+
python training_nli.py intfloat/e5-mistral-7b-instruct --peft --lora_target_module "q_proj" "k_proj" "v_proj" --learning_rate 1e-5
60+
```
61+
62+
## Multi-card Training with Deepspeed Zero2/3
63+
64+
Pretraining the `intfloat/e5-mistral-7b-instruct` model requires approximately 130GB of memory, which exceeds the capacity of a single HPU (Gaudi 2 with 98GB memory). To address this, we can use the Zero2/Zero3 stages of DeepSpeed (model parallelism) to reduce the memory requirements.
65+
66+
Our tests have shown that training this model requires at least four HPUs when using DeepSpeed Zero2.
67+
68+
```bash
69+
python ../../gaudi_spawn.py --world_size 4 --use_deepspeed training_nli.py intfloat/e5-mistral-7b-instruct --deepspeed ds_config.json --bf16 --no-use_hpu_graphs_for_training --learning_rate 1e-7
70+
```
71+
In the above command, we need to enable lazy mode with a learning rate of `1e-7` and configure DeepSpeed using the `ds_config.json` file. To further reduce memory usage, change the stage to 3 (DeepSpeed Zero3) in the `ds_config.json` file.
72+
73+
# Dataset
5074

5175
We combine [SNLI](https://huggingface.co/datasets/stanfordnlp/snli) and [MultiNLI](https://huggingface.co/datasets/nyu-mll/multi_nli) into a dataset we call [AllNLI](https://huggingface.co/datasets/sentence-transformers/all-nli). These two datasets contain sentence pairs and one of three labels: entailment, neutral, contradiction:
5276

@@ -58,7 +82,7 @@ We combine [SNLI](https://huggingface.co/datasets/stanfordnlp/snli) and [MultiNL
5882

5983
We format AllNLI in a few different subsets, compatible with different loss functions. See [triplet subset of AllNLI](https://huggingface.co/datasets/sentence-transformers/all-nli/viewer/triplet) as example.
6084

61-
## SoftmaxLoss
85+
# SoftmaxLoss
6286

6387
<img src="https://raw.githubusercontent.com/UKPLab/sentence-transformers/master/docs/img/SBERT_SoftmaxLoss.png" alt="SBERT SoftmaxLoss" width="250"/>
6488

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"steps_per_print": 1,
3+
"train_batch_size": "auto",
4+
"train_micro_batch_size_per_gpu": "auto",
5+
"gradient_accumulation_steps": "auto",
6+
"bf16": {
7+
"enabled": true
8+
},
9+
"gradient_clipping": 1.0,
10+
"zero_optimization": {
11+
"stage": 2,
12+
"overlap_comm": false,
13+
"reduce_scatter": false,
14+
"contiguous_gradients": false
15+
}
16+
}

examples/sentence-transformers-training/nli/training_nli.py

+46-12
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
STS benchmark dataset
55
"""
66

7+
import argparse
78
import logging
8-
import sys
99
from datetime import datetime
1010

1111
from datasets import load_dataset
@@ -28,16 +28,43 @@ def main():
2828
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
2929

3030
# You can specify any Hugging Face pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
31-
model_name = sys.argv[1] if len(sys.argv) > 1 else "bert-base-uncased"
32-
train_batch_size = 16
31+
parser = argparse.ArgumentParser()
32+
parser.add_argument("model_name", help="model name or path", default="bert-base-uncased", nargs="?")
33+
parser.add_argument("--peft", help="use LoRA", action="store_true", default=False)
34+
parser.add_argument("--lora_target_modules", nargs="+", default=["query", "key", "value"])
35+
parser.add_argument("--bf16", help="use bf16", action="store_true", default=False)
36+
parser.add_argument(
37+
"--use_hpu_graphs_for_training",
38+
help="use hpu graphs for training",
39+
action=argparse.BooleanOptionalAction,
40+
default=True,
41+
)
42+
parser.add_argument("--learning_rate", help="learning rate", type=float, default=5e-5)
43+
parser.add_argument("--deepspeed", help="deepspeed config file", default=None)
44+
parser.add_argument("--train_batch_size", help="train batch size", default=16, type=int)
45+
args = parser.parse_args()
3346

3447
output_dir = (
35-
"output/training_nli_" + model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
48+
"output/training_nli_" + args.model_name.replace("/", "-") + "-" + datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
3649
)
3750

3851
# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically
3952
# create one with "mean" pooling.
40-
model = SentenceTransformer(model_name)
53+
model = SentenceTransformer(args.model_name)
54+
if args.peft:
55+
from peft import LoraConfig, get_peft_model
56+
57+
peft_config = LoraConfig(
58+
r=16,
59+
lora_alpha=64,
60+
lora_dropout=0.05,
61+
bias="none",
62+
inference_mode=False,
63+
target_modules=args.lora_target_modules,
64+
)
65+
model = get_peft_model(model, peft_config)
66+
model.print_trainable_parameters()
67+
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
4168

4269
# 2. Load the AllNLI dataset: https://huggingface.co/datasets/sentence-transformers/all-nli
4370
# We'll start with 10k training samples, but you can increase this to get a stronger model
@@ -66,16 +93,16 @@ def main():
6693
dev_evaluator(model)
6794

6895
# 5. Define the training arguments
69-
args = SentenceTransformerGaudiTrainingArguments(
96+
stargs = SentenceTransformerGaudiTrainingArguments(
7097
# Required parameter:
7198
output_dir=output_dir,
7299
# Optional training parameters:
73100
num_train_epochs=1,
74-
per_device_train_batch_size=train_batch_size,
75-
per_device_eval_batch_size=train_batch_size,
101+
per_device_train_batch_size=args.train_batch_size,
102+
per_device_eval_batch_size=args.train_batch_size,
76103
warmup_ratio=0.1,
77104
# fp16=True, # Set to False if you get an error that your GPU can't run on FP16
78-
# bf16=False, # Set to True if you have a GPU that supports BF16
105+
bf16=args.bf16, # Set to True if you have a GPU that supports BF16
79106
# Optional tracking/debugging parameters:
80107
evaluation_strategy="steps",
81108
eval_steps=100,
@@ -87,16 +114,18 @@ def main():
87114
use_habana=True,
88115
gaudi_config_name="Habana/bert-base-uncased",
89116
use_lazy_mode=True,
90-
use_hpu_graphs=True,
117+
use_hpu_graphs=args.use_hpu_graphs_for_training,
91118
use_hpu_graphs_for_inference=False,
92-
use_hpu_graphs_for_training=True,
119+
use_hpu_graphs_for_training=args.use_hpu_graphs_for_training,
93120
dataloader_drop_last=True,
121+
learning_rate=args.learning_rate,
122+
deepspeed=args.deepspeed,
94123
)
95124

96125
# 6. Create the trainer & start training
97126
trainer = SentenceTransformerGaudiTrainer(
98127
model=model,
99-
args=args,
128+
args=stargs,
100129
train_dataset=train_dataset,
101130
eval_dataset=eval_dataset,
102131
loss=train_loss,
@@ -119,6 +148,11 @@ def main():
119148
final_output_dir = f"{output_dir}/final"
120149
model.save(final_output_dir)
121150

151+
if args.peft:
152+
model.eval()
153+
model = model.merge_and_unload()
154+
model.save_pretrained(f"{output_dir}/merged")
155+
122156

123157
if __name__ == "__main__":
124158
main()

examples/sentence-transformers-training/sts/README.md

+27-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ Semantic Textual Similarity (STS) assigns a score on the similarity of two texts
55
- **[training_stsbenchmark.py](training_stsbenchmark.py)** - This example shows how to create a SentenceTransformer model from scratch by using a pre-trained transformer model (e.g. [`distilbert-base-uncased`](https://huggingface.co/distilbert/distilbert-base-uncased)) together with a pooling layer.
66
- **[training_stsbenchmark_continue_training.py](training_stsbenchmark_continue_training.py)** - This example shows how to continue training on STS data for a previously created & trained SentenceTransformer model (e.g. [`all-mpnet-base-v2`](https://huggingface.co/sentence-transformers/all-mpnet-base-v2)).
77

8+
# General Models
9+
810
## Single-card Training
911

1012
To fine tune on the STS task:
@@ -33,7 +35,30 @@ For multi-card training you can use the script of [gaudi_spawn.py](https://githu
3335
HABANA_VISIBLE_MODULES="2,3" python ../../gaudi_spawn.py --use_deepspeed --world_size 2 training_stsbenchmark.py bert-base-uncased
3436
```
3537

36-
## Training data
38+
39+
# Large Models (intfloat/e5-mistral-7b-instruct Model)
40+
41+
## Single-card Training with LoRA+gradient_checkpointing
42+
43+
Pretraining the `intfloat/e5-mistral-7b-instruct` model requires approximately 130GB of memory, which exceeds the capacity of a single HPU (Gaudi 2 with 98GB memory). To address this, we can utilize LoRA and gradient checkpointing techniques to reduce the memory requirements, making it feasible to train the model on a single HPU.
44+
45+
```bash
46+
python training_stsbenchmark.py intfloat/e5-mistral-7b-instruct --peft --lora_target_modules "q_proj" "k_proj" "v_proj"
47+
```
48+
49+
## Multi-card Training with Deepspeed Zero2/3
50+
51+
Pretraining the `intfloat/e5-mistral-7b-instruct` model requires approximately 130GB of memory, which exceeds the capacity of a single HPU (Gaudi 2 with 98GB memory). To address this, we can use the Zero2/Zero3 stages of DeepSpeed (model parallelism) to reduce the memory requirements.
52+
53+
Our tests have shown that training this model requires at least four HPUs when using DeepSpeed Zero2.
54+
55+
```bash
56+
python ../../gaudi_spawn.py --world_size 4 --use_deepspeed training_stsbenchmark.py intfloat/e5-mistral-7b-instruct --deepspeed ds_config.json --bf16 --no-use_hpu_graphs_for_training --learning_rate 1e-7
57+
```
58+
59+
In the above command, we need to enable lazy mode with a learning rate of `1e-7` and configure DeepSpeed using the `ds_config.json` file. To further reduce memory usage, change the stage to 3 (DeepSpeed Zero3) in the `ds_config.json` file.
60+
61+
# Training data
3762

3863
Here is a simplified version of our training data:
3964

@@ -70,7 +95,7 @@ train_dataset = load_dataset("sentence-transformers/stsb", split="train")
7095
# })
7196
```
7297

73-
## Loss Function
98+
# Loss Function
7499

75100
<img src="https://raw.githubusercontent.com/UKPLab/sentence-transformers/master/docs/img/SBERT_Siamese_Network.png" alt="SBERT Siamese Network Architecture" width="250"/>
76101

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
{
2+
"steps_per_print": 1,
3+
"train_batch_size": "auto",
4+
"train_micro_batch_size_per_gpu": "auto",
5+
"gradient_accumulation_steps": "auto",
6+
"bf16": {
7+
"enabled": true
8+
},
9+
"gradient_clipping": 1.0,
10+
"zero_optimization": {
11+
"stage": 2,
12+
"overlap_comm": false,
13+
"reduce_scatter": false,
14+
"contiguous_gradients": false
15+
}
16+
}

examples/sentence-transformers-training/sts/training_stsbenchmark.py

+45-9
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
55
"""
66

7+
import argparse
78
import logging
8-
import sys
99
from datetime import datetime
1010

1111
from datasets import load_dataset
@@ -25,19 +25,48 @@ def main():
2525
logging.basicConfig(format="%(asctime)s - %(message)s", datefmt="%Y-%m-%d %H:%M:%S", level=logging.INFO)
2626

2727
# You can specify any Hugging Face pre-trained model here, for example, bert-base-uncased, roberta-base, xlm-roberta-base
28-
model_name = sys.argv[1] if len(sys.argv) > 1 else "distilbert-base-uncased"
28+
parser = argparse.ArgumentParser()
29+
parser.add_argument("model_name", help="model name or path", default="distilbert-base-uncased", nargs="?")
30+
parser.add_argument("--peft", help="use LoRA", action="store_true", default=False)
31+
parser.add_argument("--lora_target_modules", nargs="+", default=["q_lin", "k_lin", "v_lin"])
32+
parser.add_argument("--bf16", help="use bf16", action="store_true", default=False)
33+
parser.add_argument(
34+
"--use_hpu_graphs_for_training",
35+
help="use hpu graphs for training",
36+
action=argparse.BooleanOptionalAction,
37+
default=True,
38+
)
39+
parser.add_argument("--learning_rate", help="learning rate", type=float, default=5e-5)
40+
parser.add_argument("--deepspeed", help="deepspeed config file", default=None)
41+
args = parser.parse_args()
42+
2943
train_batch_size = 16
3044
num_epochs = 1
3145
output_dir = (
3246
"output/training_stsbenchmark_"
33-
+ model_name.replace("/", "-")
47+
+ args.model_name.replace("/", "-")
3448
+ "-"
3549
+ datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
3650
)
3751

3852
# 1. Here we define our SentenceTransformer model. If not already a Sentence Transformer model, it will automatically
3953
# create one with "mean" pooling.
40-
model = SentenceTransformer(model_name)
54+
model = SentenceTransformer(args.model_name)
55+
56+
if args.peft:
57+
from peft import LoraConfig, get_peft_model
58+
59+
peft_config = LoraConfig(
60+
r=16,
61+
lora_alpha=64,
62+
lora_dropout=0.05,
63+
bias="none",
64+
inference_mode=False,
65+
target_modules=args.lora_target_modules,
66+
)
67+
model = get_peft_model(model, peft_config)
68+
model.print_trainable_parameters()
69+
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
4170

4271
# 2. Load the STSB dataset: https://huggingface.co/datasets/sentence-transformers/stsb
4372
train_dataset = load_dataset("sentence-transformers/stsb", split="train")
@@ -61,7 +90,7 @@ def main():
6190
)
6291

6392
# 5. Define the training arguments
64-
args = SentenceTransformerGaudiTrainingArguments(
93+
stargs = SentenceTransformerGaudiTrainingArguments(
6594
# Required parameter:
6695
output_dir=output_dir,
6796
# Optional training parameters:
@@ -70,7 +99,7 @@ def main():
7099
per_device_eval_batch_size=train_batch_size,
71100
warmup_ratio=0.1,
72101
# fp16=True, # Set to False if you get an error that your GPU can't run on FP16
73-
# bf16=True, # Set to True if you have a GPU that supports BF16
102+
bf16=args.bf16, # Set to True if you have a GPU that supports BF16
74103
# Optional tracking/debugging parameters:
75104
evaluation_strategy="steps",
76105
eval_steps=100,
@@ -82,16 +111,18 @@ def main():
82111
use_habana=True,
83112
gaudi_config_name="Habana/distilbert-base-uncased",
84113
use_lazy_mode=True,
85-
use_hpu_graphs=True,
114+
use_hpu_graphs=args.use_hpu_graphs_for_training,
86115
use_hpu_graphs_for_inference=False,
87-
use_hpu_graphs_for_training=True,
116+
use_hpu_graphs_for_training=args.use_hpu_graphs_for_training,
117+
learning_rate=args.learning_rate,
118+
deepspeed=args.deepspeed,
88119
)
89120

90121
# 6. Create the trainer & start training
91122
# trainer = SentenceTransformerTrainer(
92123
trainer = SentenceTransformerGaudiTrainer(
93124
model=model,
94-
args=args,
125+
args=stargs,
95126
train_dataset=train_dataset,
96127
eval_dataset=eval_dataset,
97128
loss=train_loss,
@@ -113,6 +144,11 @@ def main():
113144
final_output_dir = f"{output_dir}/final"
114145
model.save(final_output_dir)
115146

147+
if args.peft:
148+
model.eval()
149+
model = model.merge_and_unload()
150+
model.save_pretrained(f"{output_dir}/merged")
151+
116152

117153
if __name__ == "__main__":
118154
main()

optimum/habana/sentence_transformers/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,5 @@
1919
from .st_gaudi_trainer import SentenceTransformerGaudiTrainer
2020
from .st_gaudi_training_args import SentenceTransformerGaudiTrainingArguments
2121
from .st_gaudi_encoder import st_gaudi_encode
22-
from .st_gaudi_transformer_tokenize import st_gaudi_transformer_tokenize
22+
from .st_gaudi_transformer import st_gaudi_transformer_tokenize, st_gaudi_transformer_save
2323
from .st_gaudi_data_collator import st_gaudi_data_collator_call

optimum/habana/sentence_transformers/modeling_utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ def adapt_sentence_transformers_to_gaudi():
2525
from optimum.habana.sentence_transformers import (
2626
st_gaudi_data_collator_call,
2727
st_gaudi_encode,
28+
st_gaudi_transformer_save,
2829
st_gaudi_transformer_tokenize,
2930
)
3031

@@ -33,6 +34,7 @@ def adapt_sentence_transformers_to_gaudi():
3334
from sentence_transformers.models import Transformer
3435

3536
Transformer.tokenize = st_gaudi_transformer_tokenize
37+
Transformer.save = st_gaudi_transformer_save
3638

3739
from sentence_transformers.data_collator import SentenceTransformerDataCollator
3840

0 commit comments

Comments
 (0)