Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐦‍🔥 6x faster GRPO with multi-step optimization #2899

Merged
merged 16 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 134 additions & 0 deletions tests/test_grpo_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,115 @@

from trl import GRPOConfig, GRPOTrainer
from trl.import_utils import is_vllm_available
from trl.trainer.grpo_trainer import RepeatRandomSampler


if is_peft_available():
from peft import LoraConfig, PeftModel


class RepeatRandomSamplerTester(unittest.TestCase):
def test_sampler(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2)
# Should output something like [4, 4, 3, 3, 0, 0, 1, 1, 2, 2, 6, 6, 5, 5]
sampled = list(sampler)
# Check that the length is doubled
assert len(sampled) == 2 * len(dataset)
# Check that all indexes are present
assert set(sampled) == set(range(len(dataset)))
# Check that each element is repeated twice
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))

def test_sampler_no_repeat(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1)
# Should output something like [4, 3, 0, 1, 2, 6, 5]
sampled = list(sampler)
# Check that the length is the same
assert len(sampled) == len(dataset)
# Check that all indexes are present
assert set(sampled) == set(range(len(dataset)))

def test_sampler_with_batch_size(self):
dataset = ["a", "b", "c", "d", "e", "f", "g", "h"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6, 5, 7, 5, 7]
sampled = list(sampler)
# Check that the length is doubled
assert len(sampled) == 2 * len(dataset)
# Check that all indexes are present
assert set(sampled) == set(range(len(dataset)))
# Check that each element is repeated as expected
assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4))

def test_sampler_with_batch_size_and_drop(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=1, batch_size=2, repeat_count=2)
# Should output something like [4, 3, 4, 3, 0, 1, 0, 1, 2, 6, 2, 6]
sampled = list(sampler)
# Check that the length is doubled
assert len(sampled) == 2 * (
len(dataset) - 1
) # one element is dropped, because it's not enough to form a batch
# Check that the sampled indexes are a subset of the dataset indexes
assert set(sampled).issubset(set(range(len(dataset))))
# Check that each element is repeated as expected
assert all(sampled[i : i + 1] == sampled[i + 2 : i + 3] for i in range(0, len(sampled), 4))

def test_sampler_with_mini_repeat_count_and_batch_size_1(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=3, repeat_count=2)
# Should output something like [4, 4, 3, 3, 0, 0, 4, 4, 3, 3, 0, 0,
# 1, 1, 2, 2, 6, 6, 1, 1, 2, 2, 6, 6]
sampled = list(sampler)
# Check that the length is quadrupled
assert len(sampled) == 4 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
# Check that the sampled indexes are a subset of the dataset indexes
assert set(sampled).issubset(set(range(len(dataset))))
# Check that each element is repeated as expected
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
# Check that the batch is repeated as expected
assert sampled[0:6] == sampled[6:12]
assert sampled[12:18] == sampled[18:24]

def test_sampler_with_mini_repeat_count_and_batch_size_2(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=3, batch_size=2, repeat_count=2)
# Should output something like [4, 4, 4, 3, 3, 3, 4, 4, 4, 3, 3, 3,
# 0, 0, 0, 1, 1, 1, 0, 0, 0, 1, 1, 1,
# 2, 2, 2, 6, 6, 6, 2, 2, 2, 6, 6, 6]
sampled = list(sampler)
# Check that the length is sextupled
assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
# Check that the sampled indexes are a subset of the dataset indexes
assert set(sampled).issubset(set(range(len(dataset))))
# Check that each element is repeated as expected
assert all(sampled[i] == sampled[i + 1] == sampled[i + 2] for i in range(0, len(sampled), 3))
# Check that the batch is repeated as expected
assert sampled[0:6] == sampled[6:12]
assert sampled[12:18] == sampled[18:24]
assert sampled[24:30] == sampled[30:36]

def test_sampler_with_mini_repeat_count_and_batch_size_3(self):
dataset = ["a", "b", "c", "d", "e", "f", "g"]
sampler = RepeatRandomSampler(dataset, mini_repeat_count=2, batch_size=2, repeat_count=3)
# Should output something like [4, 4, 3, 3, 4, 4, 3, 3, 4, 4, 3, 3,
# 0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1,
# 2, 2, 6, 6, 2, 2, 6, 6, 2, 2, 6, 6]
sampled = list(sampler)
# Check that the length is sextupled
assert len(sampled) == 6 * (len(dataset) - 1) # 1 element is dropped, because it's not enough to form a batch
# Check that the sampled indexes are a subset of the dataset indexes
assert set(sampled).issubset(set(range(len(dataset))))
# Check that each element is repeated as expected
assert all(sampled[i] == sampled[i + 1] for i in range(0, len(sampled), 2))
# Check that the batch is repeated as expected
assert sampled[0:4] == sampled[4:8] == sampled[8:12]
assert sampled[12:16] == sampled[16:20] == sampled[20:24]
assert sampled[24:28] == sampled[28:32] == sampled[32:36]


class GRPOTrainerTester(unittest.TestCase):
def test_init_minimal(self):
# Test that GRPOTrainer can be instantiated with only model, reward_model and train_dataset
Expand Down Expand Up @@ -96,6 +199,37 @@ def test_training_with_eval(self):

trainer.train()

def test_training_multiple_iterations(self):
dataset = load_dataset("trl-internal-testing/zen", "standard_prompt_only", split="train")

with tempfile.TemporaryDirectory() as tmp_dir:
training_args = GRPOConfig(
output_dir=tmp_dir,
learning_rate=0.1, # increase the learning rate to speed up the test
per_device_train_batch_size=3, # reduce the batch size to reduce memory usage
num_generations=3, # reduce the number of generations to reduce memory usage
max_completion_length=32, # reduce the completion length to reduce memory usage
num_iterations=2,
report_to="none",
)
trainer = GRPOTrainer(
model="trl-internal-testing/tiny-Qwen2ForCausalLM-2.5",
reward_funcs="trl-internal-testing/tiny-Qwen2ForSequenceClassification-2.5",
args=training_args,
train_dataset=dataset,
)

previous_trainable_params = {n: param.clone() for n, param in trainer.model.named_parameters()}

trainer.train()

self.assertIsNotNone(trainer.state.log_history[-1]["train_loss"])

# Check that the params have changed
for n, param in previous_trainable_params.items():
new_param = trainer.model.get_parameter(n)
self.assertFalse(torch.equal(param, new_param), f"Parameter {n} has not changed.")

@require_peft
def test_training_peft(self):
model = AutoModelForCausalLM.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")
Expand Down
12 changes: 12 additions & 0 deletions trl/trainer/grpo_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,10 @@ class GRPOConfig(TrainingArguments):
beta (`float`, *optional*, defaults to `0.04`):
KL coefficient. If `0.0`, the reference model is not loaded, reducing memory usage and improving training
speed.
num_iterations (`int`, *optional*, defaults to `1`):
Number of iterations per batch (denoted as μ in the algorithm).
epsilon (`float`, *optional*, defaults to `0.2`):
Epsilon value for clipping.
reward_weights (`list[float]` or `None`, *optional*, defaults to `None`):
Weights for each reward function. Must match the number of reward functions. If `None`, all rewards are
weighted equally with weight `1.0`.
Expand Down Expand Up @@ -224,6 +228,14 @@ class GRPOConfig(TrainingArguments):
"training speed."
},
)
num_iterations: int = field(
default=1,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we set another default? eg 4?

metadata={"help": "Number of iterations per batch (denoted as μ in the algorithm)."},
)
epsilon: float = field(
default=0.2,
metadata={"help": "Epsilon value for clipping."},
)
reward_weights: Optional[list[float]] = field(
default=None,
metadata={
Expand Down
Loading