diff --git a/tests/test_utils.py b/tests/test_utils.py index 64ab68bf74..871ec6c737 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -27,7 +27,6 @@ from trl.trainer.utils import ( DataCollatorForChatML, batch_generation, - compute_token_accuracy, decode_and_strip_padding, flush_left, generate_model_card, @@ -456,60 +455,6 @@ def test_no_tensors(self): self.assertTrue(torch.equal(new_mask, expected_mask)) -class TestComputeTokenAccuracy(unittest.TestCase): - def test_basic_accuracy(self): - # Test basic accuracy computation - logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) # Shape: [2, 2, 2] - labels = torch.tensor([[1, 0], [1, 0]]) # Shape: [2, 2] - accuracy = compute_token_accuracy(logits, labels) - self.assertAlmostEqual(accuracy, 0.75) # 3 correct out of 4 tokens - - def test_with_ignore_index(self): - # Test accuracy computation with ignored tokens - logits = torch.tensor([[[0.9, 0.1], [0.8, 0.2]], [[0.3, 0.7], [0.6, 0.4]]]) - labels = torch.tensor([[1, -100], [1, 0]]) # -100 is ignored - accuracy = compute_token_accuracy(logits, labels, ignore_index=-100) - self.assertAlmostEqual(accuracy, 2 / 3) # 2 correct out of 3 non-ignored tokens - - def test_all_ignored(self): - # Test case where all tokens are ignored - logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]]) - labels = torch.tensor([[-100, -100]]) - accuracy = compute_token_accuracy(logits, labels) - self.assertEqual(accuracy, 0.0) # No valid tokens to compute accuracy - - def test_perfect_accuracy(self): - # Test case with 100% accuracy - logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]]) - labels = torch.tensor([[1, 0]]) - accuracy = compute_token_accuracy(logits, labels) - self.assertEqual(accuracy, 1.0) # All predictions correct - - def test_zero_accuracy(self): - # Test case with 0% accuracy - logits = torch.tensor([[[0.1, 0.9], [0.8, 0.2]]]) - labels = torch.tensor([[0, 1]]) - accuracy = compute_token_accuracy(logits, labels) - self.assertEqual(accuracy, 0.0) # All predictions wrong - - def test_batch_accuracy(self): - # Test accuracy computation across multiple batches - logits = torch.tensor( - [ - [[0.9, 0.1], [0.8, 0.2], [0.3, 0.7]], # Batch 1 - [[0.2, 0.8], [0.7, 0.3], [0.6, 0.4]], # Batch 2 - ] - ) - labels = torch.tensor( - [ - [1, 0, 1], # Batch 1 - [1, 0, -100], # Batch 2 (last token ignored) - ] - ) - accuracy = compute_token_accuracy(logits, labels) - self.assertAlmostEqual(accuracy, 0.8) - - class TestSelectiveLogSoftmax(unittest.TestCase): @parameterized.expand([(torch.float64,), (torch.float32,), (torch.float16,), (torch.bfloat16,)]) def test_selective_log_softmax(self, dtype): diff --git a/trl/__init__.py b/trl/__init__.py index 188b05056d..9a17e8e873 100644 --- a/trl/__init__.py +++ b/trl/__init__.py @@ -102,7 +102,7 @@ "XPOTrainer", ], "trainer.callbacks": ["MergeModelCallback", "RichProgressCallback", "SyncRefModelCallback"], - "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config", "compute_token_accuracy"], + "trainer.utils": ["get_kbit_device_map", "get_peft_config", "get_quantization_config"], } try: @@ -204,7 +204,7 @@ XPOTrainer, ) from .trainer.callbacks import RichProgressCallback, SyncRefModelCallback - from .trainer.utils import compute_token_accuracy, get_kbit_device_map, get_peft_config, get_quantization_config + from .trainer.utils import get_kbit_device_map, get_peft_config, get_quantization_config try: if not is_diffusers_available(): diff --git a/trl/trainer/__init__.py b/trl/trainer/__init__.py index 9ef887864a..85968218cc 100644 --- a/trl/trainer/__init__.py +++ b/trl/trainer/__init__.py @@ -76,7 +76,6 @@ "disable_dropout_in_model", "empty_cache", "peft_module_casting_to_bf16", - "compute_token_accuracy", ], "xpo_config": ["XPOConfig"], "xpo_trainer": ["XPOTrainer"], @@ -145,7 +144,6 @@ DataCollatorForCompletionOnlyLM, RunningMoments, compute_accuracy, - compute_token_accuracy, disable_dropout_in_model, empty_cache, peft_module_casting_to_bf16, diff --git a/trl/trainer/sft_trainer.py b/trl/trainer/sft_trainer.py index 510ceb5348..a3d6829fd4 100644 --- a/trl/trainer/sft_trainer.py +++ b/trl/trainer/sft_trainer.py @@ -45,13 +45,7 @@ from ..data_utils import is_conversational, maybe_apply_chat_template, maybe_convert_to_chatml, pack_examples from .sft_config import SFTConfig -from .utils import ( - ConstantLengthDataset, - compute_token_accuracy, - generate_model_card, - get_comet_experiment_url, - peft_module_casting_to_bf16, -) +from .utils import ConstantLengthDataset, generate_model_card, get_comet_experiment_url, peft_module_casting_to_bf16 if is_peft_available(): @@ -60,8 +54,6 @@ if is_liger_kernel_available(): from liger_kernel.transformers import AutoLigerKernelForCausalLM -else: - AutoLigerKernelForCausalLM = None if is_wandb_available(): import wandb @@ -185,12 +177,13 @@ def __init__( ) if isinstance(model, str): model = self._create_model_from_path(model, args) + self.use_liger = is_liger_kernel_available() and isinstance(model, AutoLigerKernelForCausalLM) # PEFT configuration and model wrapping if peft_config is not None: model = self._prepare_peft_model(model, peft_config, args) - # 3. Handle the tokenizer + # Handle the tokenizer if processing_class is None: processing_class = AutoTokenizer.from_pretrained(model.config._name_or_path) if processing_class.pad_token is None: @@ -279,8 +272,10 @@ def _create_model_from_path(self, model_path: str, args: SFTConfig) -> PreTraine if args.use_liger: if not is_liger_kernel_available(): raise ImportError("Please install Liger-kernel for use_liger=True") - return AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs) - return AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) + model = AutoLigerKernelForCausalLM.from_pretrained(model_path, **model_init_kwargs) + else: + model = AutoModelForCausalLM.from_pretrained(model_path, **model_init_kwargs) + return model def _prepare_peft_model(self, model: PreTrainedModel, peft_config: Any, args: SFTConfig) -> PreTrainedModel: """Prepares a model for PEFT training.""" @@ -471,21 +466,28 @@ def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=N ) # Compute token accuracy if we have labels and if the model is not using Liger (no logits) - use_liger = self.args.use_liger or ( - AutoLigerKernelForCausalLM is not None and isinstance(model, AutoLigerKernelForCausalLM) - ) - if "labels" in inputs and not use_liger: + if "labels" in inputs and not self.use_liger: shift_logits = outputs.logits[..., :-1, :].contiguous() shift_labels = inputs["labels"][..., 1:].contiguous() - # Gather logits and labels from all GPUs first - shift_logits = self.accelerator.gather_for_metrics(shift_logits) - shift_labels = self.accelerator.gather_for_metrics(shift_labels) + # Get predictions + predictions = shift_logits.argmax(dim=-1) + + # Create mask for non-padding tokens (assuming ignore_index is -100) + mask = shift_labels != -100 + + # Calculate accuracy only on non-padding tokens + correct_predictions = (predictions == shift_labels) & mask + total_tokens = mask.sum() + correct_tokens = correct_predictions.sum() + + # Gather the correct_tokens and total_tokens across all processes + correct_tokens = self.accelerator.gather_for_metrics(correct_tokens) + total_tokens = self.accelerator.gather_for_metrics(total_tokens) - # Then compute accuracy on the gathered tensors - if self.accelerator.is_main_process: - accuracy = compute_token_accuracy(shift_logits, shift_labels) - self._metrics["mean_token_accuracy"].append(accuracy) + # Compute the mean token accuracy and log it + accuracy = (correct_tokens.sum() / total_tokens.sum()).item() if total_tokens.sum() > 0 else 0.0 + self._metrics["mean_token_accuracy"].append(accuracy) return (loss, outputs) if return_outputs else loss diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index ea603b9637..7a20645535 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1650,27 +1650,6 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor return mask, *tensors -def compute_token_accuracy(logits: torch.Tensor, labels: torch.Tensor, ignore_index: int = -100) -> float: - """ - Compute the mean token accuracy. - """ - # Get predictions - predictions = logits.argmax(dim=-1) - - # Create mask for non-padding tokens (assuming pad_token_id is ignore_index) - mask = labels != ignore_index - - # Calculate accuracy only on non-padding tokens - correct_predictions = (predictions == labels) & mask - total_tokens = mask.sum() - correct_tokens = correct_predictions.sum() - - # Calculate accuracy - accuracy = correct_tokens.item() / total_tokens.item() if total_tokens > 0 else 0.0 - - return accuracy - - def selective_log_softmax(logits, index): """ A memory-efficient implementation of the common `log_softmax -> gather` operation.