-
Notifications
You must be signed in to change notification settings - Fork 533
/
Copy pathgrpo_training.py
329 lines (284 loc) · 12.3 KB
/
grpo_training.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
# -*- coding: utf-8 -*-
"""
@author:XuMing(xuming624@qq.com)
@description: Train R1 model with GRPO rl algo.
"""
import os
from dataclasses import dataclass, field
from datetime import datetime
from typing import Dict, Optional
import re
from datasets import load_dataset
import torch
from loguru import logger
from transformers import AutoTokenizer
from transformers.trainer_utils import get_last_checkpoint
from trl import GRPOConfig, GRPOTrainer, ModelConfig, TrlParser
from peft import LoraConfig, TaskType
from latex2sympy2_extended import NormalizationConfig
from math_verify import LatexExtractionConfig, parse, verify
os.environ["TOKENIZERS_PARALLELISM"] = "FALSE"
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
@dataclass
class ScriptArguments:
"""
The name of the Casual LM model we wish to fine with DPO
"""
tokenizer_name_or_path: Optional[str] = field(
default=None, metadata={"help": "The tokenizer for weights initialization."}
)
# Dataset arguments
dataset_name: Optional[str] = field(
default="xiaodongguaAIGC/X-R1-750",
metadata={"help": "The name of the dataset to use (via the datasets library)."}
)
train_samples: Optional[int] = field(default=-1, metadata={"help": "Number of samples to train on, -1 for all"})
subset_name: Optional[str] = field(default="default",
metadata={"help": "Subset name, e.g., 'default', 'main'. default is 'default'"})
dataset_splits: Optional[str] = field(default="train", metadata={"help": "Split name"})
preprocessing_num_workers: Optional[int] = field(default=10,
metadata={"help": "Number of workers for preprocessing"})
def normalize_text(text):
"""Normalize text by removing extra whitespace, converting to lowercase."""
if text is None:
return ""
# Remove extra whitespace and convert to lowercase
text = re.sub(r'\s+', ' ', text.strip().lower())
return text
def extract_answer(text):
"""Extract content between <answer> tags."""
if text is None:
return ""
match = re.search(r'<answer>(.*?)</answer>', text, re.DOTALL)
if match:
return match.group(1).strip()
return text.strip()
def accuracy_reward(completions, solution, **kwargs):
"""Reward function that checks if the completion is the same as the ground truth."""
contents = [completion[0]["content"] for completion in completions]
rewards = []
for content, sol in zip(contents, solution):
# First try latex parsing
gold_parsed = parse(
sol,
extraction_mode="first_match",
extraction_config=[LatexExtractionConfig()],
)
if len(gold_parsed) != 0:
# We require the answer to be provided in correct latex (no malformed operators)
answer_parsed = parse(
content,
extraction_config=[
LatexExtractionConfig(
normalization_config=NormalizationConfig(
nits=False,
malformed_operators=False,
basic_latex=True,
equations=True,
boxed="all",
units=True,
),
# Ensures that boxed is tried first
boxed_match_priority=0,
try_extract_without_anchor=False,
)
],
extraction_mode="first_match",
)
# Reward 1 if the content is the same as the ground truth, 0 otherwise
reward = float(verify(answer_parsed, gold_parsed))
logger.debug(f"predict_answer: {content}, \nground_truth: {sol}, \n"
f"answer_parsed: {answer_parsed}, gold_parsed: {gold_parsed}, reward: {reward}\n\n")
else:
# If the gold solution is not parseable, we reward 1 to skip this example
reward = 1.0
logger.debug(f"Failed to parse ground_truth: {sol}")
rewards.append(reward)
logger.debug(f'accuracy rewards: {rewards}')
return rewards
def format_reward(completions, **kwargs):
"""Reward function that checks if the completion has a specific format."""
pattern = r"^<think>.*?</think><answer>.*?</answer>$"
completion_contents = [completion[0]["content"] for completion in completions]
matches = [re.match(pattern, content) for content in completion_contents]
rewards = [1.0 if match else 0.0 for match in matches]
logger.debug(f'format rewards: {rewards}')
return rewards
SYSTEM_PROMPT = (
"A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant "
"first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning "
"process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., "
"<think> reasoning process here </think><answer> answer here </answer>"
)
def get_checkpoint(training_args: GRPOConfig):
last_checkpoint = None
if os.path.isdir(training_args.output_dir):
last_checkpoint = get_last_checkpoint(training_args.output_dir)
return last_checkpoint
def find_all_linear_names(peft_model, int4=False, int8=False):
"""Find all linear layer names in the model. reference from qlora paper."""
cls = torch.nn.Linear
if int4 or int8:
import bitsandbytes as bnb
if int4:
cls = bnb.nn.Linear4bit
elif int8:
cls = bnb.nn.Linear8bitLt
lora_module_names = set()
for name, module in peft_model.named_modules():
if isinstance(module, cls):
# last layer is not add to lora_module_names
if 'lm_head' in name:
continue
if 'output_layer' in name:
continue
names = name.split('.')
lora_module_names.add(names[0] if len(names) == 1 else names[-1])
return sorted(lora_module_names)
def grpo_train(
model_args: ModelConfig, script_args: ScriptArguments, training_args: GRPOConfig
):
# Add distributed training initialization
is_main_process = training_args.local_rank in [-1, 0]
# Only log on main process
if is_main_process:
logger.warning(
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}"
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}"
)
logger.info(f"Model parameters {model_args}")
logger.info(f"Script parameters {script_args}")
logger.info(f"Training parameters {training_args}")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(
(
script_args.tokenizer_name_or_path
if script_args.tokenizer_name_or_path
else model_args.model_name_or_path
),
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
# Load datasets
dataset = load_dataset(script_args.dataset_name, script_args.subset_name, split=script_args.dataset_splits)
if script_args.train_samples > 0:
dataset = dataset.shuffle(seed=42).select(range(script_args.train_samples))
# Prepare dataset
with training_args.main_process_first(desc="Dataset preparation"):
dataset = dataset.map(
lambda x: {
'prompt': [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': x['problem']}
],
'answer': x['solution']
},
num_proc=script_args.preprocessing_num_workers,
desc="Processing dataset" if is_main_process else None,
)
# Split dataset
train_test_split = dataset.train_test_split(test_size=0.1)
train_dataset = train_test_split["train"]
test_dataset = train_test_split["test"]
if is_main_process:
logger.info("*** Initializing model kwargs ***")
# Model initialization
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
# Set up distributed training config
world_size = int(os.environ.get("WORLD_SIZE", "1"))
ddp = world_size != 1
if ddp:
training_args.device_map = {"": int(os.environ.get("LOCAL_RANK", "0"))}
training_args.gradient_accumulation_steps = training_args.gradient_accumulation_steps // world_size
model_kwargs = dict(
revision=model_args.model_revision,
trust_remote_code=model_args.trust_remote_code,
attn_implementation=model_args.attn_implementation,
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=training_args.device_map if ddp else "auto",
)
training_args.model_init_kwargs = model_kwargs
# Configure LoRA if enabled
peft_config = None
if model_args.use_peft:
if is_main_process:
logger.info("Fine-tuning method: LoRA(PEFT)")
target_modules = model_args.lora_target_modules if model_args.lora_target_modules else None
if is_main_process:
logger.info(f"Peft target_modules: {target_modules}")
peft_config = LoraConfig(
task_type=TaskType.CAUSAL_LM,
target_modules=target_modules,
inference_mode=False,
r=model_args.lora_r,
lora_alpha=model_args.lora_alpha,
lora_dropout=model_args.lora_dropout,
)
else:
logger.info("Fine-tuning method: Full parameters training")
# Initialize GRPO trainer with distributed training support
trainer = GRPOTrainer(
model=model_args.model_name_or_path,
processing_class=tokenizer,
reward_funcs=[
accuracy_reward,
format_reward
],
args=training_args,
train_dataset=train_dataset,
eval_dataset=test_dataset if training_args.eval_strategy != "no" else None,
peft_config=peft_config,
)
logger.info("*** GRPO Trainer initialized ***")
logger.debug(f"Trainer: {trainer}")
# Training
last_checkpoint = get_checkpoint(training_args)
if last_checkpoint is not None and training_args.resume_from_checkpoint is None:
if is_main_process:
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint}.")
if is_main_process:
logger.info(
f'*** Starting training {datetime.now().strftime("%Y-%m-%d %H:%M:%S")} for '
f'{training_args.num_train_epochs} epochs ***'
)
train_result = trainer.train(resume_from_checkpoint=last_checkpoint)
# Log and save metrics on main process
if is_main_process:
metrics = train_result.metrics
metrics["train_samples"] = len(train_dataset)
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
logger.info("*** Training complete ***")
logger.info("*** Save model ***")
# Save model
trainer.model.config.use_cache = True
if is_main_process:
trainer.save_model(training_args.output_dir)
logger.info(f"Model saved to {training_args.output_dir}")
training_args.distributed_state.wait_for_everyone()
if is_main_process:
tokenizer.save_pretrained(training_args.output_dir)
logger.info(f"Tokenizer saved to {training_args.output_dir}")
# Create model card and save config
kwargs = {
"dataset_name": script_args.dataset_name,
"tags": ["r1", "grpo"],
}
trainer.create_model_card(**kwargs)
trainer.model.config.use_cache = True
trainer.model.config.save_pretrained(training_args.output_dir)
if is_main_process:
logger.info("*** Training complete! ***")
def main():
parser = TrlParser((ModelConfig, ScriptArguments, GRPOConfig))
model_args, script_args, training_args = parser.parse_args_and_config()
# Run the main training loop
grpo_train(model_args, script_args, training_args)
if __name__ == "__main__":
main()