diff --git a/examples/scripts/prm.py b/examples/scripts/prm.py index df3ab1702f..e363f594cb 100644 --- a/examples/scripts/prm.py +++ b/examples/scripts/prm.py @@ -58,6 +58,7 @@ get_peft_config, get_quantization_config, ) +from trl.trainer.utils import is_token_in_vocab if __name__ == "__main__": @@ -89,6 +90,16 @@ # Align padding tokens between tokenizer and model model.config.pad_token_id = tokenizer.pad_token_id + # Check if the step separator is in the vocabulary, if it's not, add it + if not is_token_in_vocab(tokenizer, training_args.step_token): + tokenizer.add_special_tokens({"additional_special_tokens": [training_args.step_separator]}) + model.resize_token_embeddings( + len(tokenizer), + pad_to_multiple_of=( + training_args.resize_to_multiple_of if training_args.resize_to_multiple_of is not None else None + ), + ) + if model_config.use_peft and model_config.lora_task_type != "TOKEN_CLS": warnings.warn( "You are using a `task_type` that is different than `TOKEN_CLS` for PEFT. This will lead to silent bugs" diff --git a/tests/test_utils.py b/tests/test_utils.py index 0061dd5e5e..90c42c6a34 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -30,6 +30,7 @@ flush_left, generate_model_card, get_peft_config, + is_token_in_vocab, pad, ) @@ -451,3 +452,15 @@ def test_no_tensors(self): expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]]) self.assertTrue(torch.equal(new_mask, expected_mask)) + + +class TestIsTokenInVocab(unittest.TestCase): + def setUp(self): + self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5") + + def test_token_in_vocab(self): + for token in ["<|im_start|>", "\n", "\n\n", "a"]: + self.assertTrue(is_token_in_vocab(self.tokenizer, token)) + + def test_token_not_in_vocab(self): + self.assertFalse(is_token_in_vocab(self.tokenizer, "")) diff --git a/trl/trainer/prm_config.py b/trl/trainer/prm_config.py index ddff854859..1152d61e55 100644 --- a/trl/trainer/prm_config.py +++ b/trl/trainer/prm_config.py @@ -39,8 +39,11 @@ class PRMConfig(TrainingArguments): Maximum length of the completion used for truncation. The completion is the concatenation of the steps. disable_dropout (`bool`, *optional*, defaults to `True`): Whether to disable dropout in the model. - step_separator (`str`, *optional*, defaults to `"\n"`): - Separator used to separate each step of the reasoning process. + step_separator (`str`, *optional*, defaults to `"<|step_token|>"`): + Separator used to separate each step of the reasoning process. It will be used as the specific token for + the steps to predict the rewards afterwards. If the token is not in the vocabulary, it will be added. + resize_to_multiple_of (`int`, *optional*, defaults to `64`): + Resize the input to multiple of this value. Only takes effect if `step_token` is not `None`. train_on_last_step_only (`bool`, *optional*, defaults to `False`): Whether to train only on the last step. dataset_num_proc (`int`, *optional*, defaults to `None`): @@ -73,9 +76,20 @@ class PRMConfig(TrainingArguments): default=True, metadata={"help": "Whether to disable dropout in the model and reference model."}, ) - step_separator: str = field( - default="\n", - metadata={"help": "Separator used to separate each step of the reasoning process."}, + step_separator: Optional[str] = field( + default="<|step_token|>", + metadata={ + "help": ( + "Separator used to separate each step of the reasoning process. It will be used as the specific token for " + "the steps to predict the rewards afterwards. If the token is not in the vocabulary, it will be added." + ) + }, + ) + resize_to_multiple_of: Optional[int] = field( + default=64, + metadata={ + "help": "Resize the input to multiple of this value. Only takes effect if `step_token` is not `None`." + }, ) train_on_last_step_only: bool = field( default=False, diff --git a/trl/trainer/utils.py b/trl/trainer/utils.py index 719d952f1f..bcc1c3b6b5 100644 --- a/trl/trainer/utils.py +++ b/trl/trainer/utils.py @@ -1647,3 +1647,22 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor return mask else: return mask, *tensors + + +def is_token_in_vocab(tokenizer: PreTrainedTokenizerBase, token: str) -> bool: + # Use repr() to handle special characters, and remove the extra quotes + # Try direct vocab check + if token in tokenizer.get_vocab(): + return True + + # Try escaped representation, use repr() to handle special characters, and remove the extra quotes + escaped = repr(token)[1:-1] + if escaped in tokenizer.get_vocab(): + return True + + # Check if token gets encoded to a single ID + try: + token_ids = tokenizer.encode(token, add_special_tokens=False) + return len(token_ids) == 1 + except Exception: + return False