Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🪂 Don't gather logits in SFT to avoid hanging #2890

Merged
merged 3 commits into from
Feb 18, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 0 additions & 55 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
from trl.trainer.utils import (
DataCollatorForChatML,
batch_generation,
compute_token_accuracy,
decode_and_strip_padding,
flush_left,
generate_model_card,
Expand Down Expand Up @@ -456,60 +455,6 @@ def test_no_tensors(self):
self.assertTrue(torch.equal(new_mask, expected_mask))


class TestComputeTokenAccuracy(unittest.TestCase):
def test_basic_accuracy(self):
# Test basic accuracy computation
logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) # Shape: [2, 2, 2]
labels = torch.tensor([[1, 0], [1, 0]]) # Shape: [2, 2]
accuracy = compute_token_accuracy(logits, labels)
self.assertAlmostEqual(accuracy, 0.75) # 3 correct out of 4 tokens

def test_with_ignore_index(self):
# Test accuracy computation with ignored tokens
logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]])
labels = torch.tensor([[1, -100], [1, 0]]) # -100 is ignored
accuracy = compute_token_accuracy(logits, labels, ignore_index=-100)
self.assertAlmostEqual(accuracy, 2 / 3) # 2 correct out of 3 non-ignored tokens

def test_all_ignored(self):
# Test case where all tokens are ignored
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
labels = torch.tensor([[-100, -100]])
accuracy = compute_token_accuracy(logits, labels)
self.assertEqual(accuracy, 0.0) # No valid tokens to compute accuracy

def test_perfect_accuracy(self):
# Test case with 100% accuracy
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
labels = torch.tensor([[1, 0]])
accuracy = compute_token_accuracy(logits, labels)
self.assertEqual(accuracy, 1.0) # All predictions correct

def test_zero_accuracy(self):
# Test case with 0% accuracy
logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]])
labels = torch.tensor([[0, 1]])
accuracy = compute_token_accuracy(logits, labels)
self.assertEqual(accuracy, 0.0) # All predictions wrong

def test_batch_accuracy(self):
# Test accuracy computation across multiple batches
logits = torch.tensor(
[
[[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]], # Batch 1
[[0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], # Batch 2
]
)
labels = torch.tensor(
[
[1, 0, 1], # Batch 1
[1, 0, -100], # Batch 2 (last token ignored)
]
)
accuracy = compute_token_accuracy(logits, labels)
self.assertAlmostEqual(accuracy, 0.8)


class TestSelectiveLogSoftmax(unittest.TestCase):
@parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)])
def test_selective_log_softmax(self, dtype):
Expand Down
4 changes: 2 additions & 2 deletions trl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@
"XPOTrainer",
],
"trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "compute_token_accuracy"],
"trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"],
}

try:
Expand Down Expand Up @@ -204,7 +204,7 @@
XPOTrainer,
)
from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback
from .trainer.utils import compute_token_accuracy, get_kbit_device_map, get_peft_config, get_quantization_config
from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config

try:
if not is_diffusers_available():
Expand Down
2 changes: 0 additions & 2 deletions trl/trainer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@
"disable_dropout_in_model",
"empty_cache",
"peft_module_casting_to_bf16",
"compute_token_accuracy",
],
"xpo_config": ["XPOConfig"],
"xpo_trainer": ["XPOTrainer"],
Expand Down Expand Up @@ -145,7 +144,6 @@
DataCollatorForCompletionOnlyLM,
RunningMoments,
compute_accuracy,
compute_token_accuracy,
disable_dropout_in_model,
empty_cache,
peft_module_casting_to_bf16,
Expand Down
48 changes: 25 additions & 23 deletions trl/trainer/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,7 @@

from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples
from .sft_config import SFTConfig
from .utils import (
ConstantLengthDataset,
compute_token_accuracy,
generate_model_card,
get_comet_experiment_url,
peft_module_casting_to_bf16,
)
from .utils import ConstantLengthDataset, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16


if is_peft_available():
Expand All @@ -60,8 +54,6 @@

if is_liger_kernel_available():
from liger_kernel.transformers import AutoLigerKernelForCausalLM
else:
AutoLigerKernelForCausalLM = None

if is_wandb_available():
import wandb
Expand Down Expand Up @@ -185,12 +177,13 @@ def __init__(
)
if isinstance(model, str):
model = self._create_model_from_path(model, args)
self.use_liger = is_liger_kernel_available() and isinstance(model, AutoLigerKernelForCausalLM)

# PEFT configuration and model wrapping
if peft_config is not None:
model = self._prepare_peft_model(model, peft_config, args)

# 3. Handle the tokenizer
# Handle the tokenizer
if processing_class is None:
processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path)
if processing_class.pad_token is None:
Expand Down Expand Up @@ -279,8 +272,10 @@ def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTraine
if args.use_liger:
if not is_liger_kernel_available():
raise ImportError("Please install Liger-kernel for use_liger=True")
return AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
else:
model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs)
return model

def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel:
"""Prepares a model for PEFT training."""
Expand Down Expand Up @@ -471,21 +466,28 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N
)

# Compute token accuracy if we have labels and if the model is not using Liger (no logits)
use_liger = self.args.use_liger or (
AutoLigerKernelForCausalLM is not None and isinstance(model, AutoLigerKernelForCausalLM)
)
if "labels" in inputs and not use_liger:
if "labels" in inputs and not self.use_liger:
shift_logits = outputs.logits[..., :-1, :].contiguous()
shift_labels = inputs["labels"][..., 1:].contiguous()

# Gather logits and labels from all GPUs first
shift_logits = self.accelerator.gather_for_metrics(shift_logits)
shift_labels = self.accelerator.gather_for_metrics(shift_labels)
# Get predictions
predictions = shift_logits.argmax(dim=-1)

# Create mask for non-padding tokens (assuming ignore_index is -100)
mask = shift_labels != -100

# Calculate accuracy only on non-padding tokens
correct_predictions = (predictions == shift_labels) & mask
total_tokens = mask.sum()
correct_tokens = correct_predictions.sum()

# Gather the correct_tokens and total_tokens across all processes
correct_tokens = self.accelerator.gather_for_metrics(correct_tokens)
total_tokens = self.accelerator.gather_for_metrics(total_tokens)

# Then compute accuracy on the gathered tensors
if self.accelerator.is_main_process:
accuracy = compute_token_accuracy(shift_logits, shift_labels)
self._metrics["mean_token_accuracy"].append(accuracy)
# Compute the mean token accuracy and log it
accuracy = (correct_tokens.sum() / total_tokens.sum()).item() if total_tokens.sum() > 0 else 0.0
self._metrics["mean_token_accuracy"].append(accuracy)

return (loss, outputs) if return_outputs else loss

Expand Down
21 changes: 0 additions & 21 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,27 +1650,6 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor
return mask, *tensors


def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float:
"""
Compute the mean token accuracy.
"""
# Get predictions
predictions = logits.argmax(dim=-1)

# Create mask for non-padding tokens (assuming pad_token_id is ignore_index)
mask = labels != ignore_index

# Calculate accuracy only on non-padding tokens
correct_predictions = (predictions == labels) & mask
total_tokens = mask.sum()
correct_tokens = correct_predictions.sum()

# Calculate accuracy
accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0

return accuracy


def selective_log_softmax(logits, index):
"""
A memory-efficient implementation of the common `log_softmax -> gather` operation.
Expand Down