Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add special token to PRM vocabulary if not present #2646

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions examples/scripts/prm.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@
get_peft_config,
get_quantization_config,
)
from trl.trainer.utils import is_token_in_vocab


if __name__ == "__main__":
Expand Down Expand Up @@ -89,6 +90,16 @@
# Align padding tokens between tokenizer and model
model.config.pad_token_id = tokenizer.pad_token_id

# Check if the step separator is in the vocabulary, if it's not, add it
if not is_token_in_vocab(tokenizer, training_args.step_token):
tokenizer.add_special_tokens({"additional_special_tokens": [training_args.step_separator]})
model.resize_token_embeddings(
len(tokenizer),
pad_to_multiple_of=(
training_args.resize_to_multiple_of if training_args.resize_to_multiple_of is not None else None
),
)

Comment on lines +93 to +102
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

WDYT of having it inside the PRM trainer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense, will update the code

if model_config.use_peft and model_config.lora_task_type != "TOKEN_CLS":
warnings.warn(
"You are using a `task_type` that is different than `TOKEN_CLS` for PEFT. This will lead to silent bugs"
Expand Down
13 changes: 13 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
flush_left,
generate_model_card,
get_peft_config,
is_token_in_vocab,
pad,
)

Expand Down Expand Up @@ -451,3 +452,15 @@ def test_no_tensors(self):
expected_mask = torch.tensor([[1, 1, 1], [1, 1, 0]])

self.assertTrue(torch.equal(new_mask, expected_mask))


class TestIsTokenInVocab(unittest.TestCase):
def setUp(self):
self.tokenizer = AutoTokenizer.from_pretrained("trl-internal-testing/tiny-Qwen2ForCausalLM-2.5")

def test_token_in_vocab(self):
for token in ["<|im_start|>", "\n", "\n\n", "a"]:
self.assertTrue(is_token_in_vocab(self.tokenizer, token))

def test_token_not_in_vocab(self):
self.assertFalse(is_token_in_vocab(self.tokenizer, "<step_token>"))
24 changes: 19 additions & 5 deletions trl/trainer/prm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,11 @@ class PRMConfig(TrainingArguments):
Maximum length of the completion used for truncation. The completion is the concatenation of the steps.
disable_dropout (`bool`, *optional*, defaults to `True`):
Whether to disable dropout in the model.
step_separator (`str`, *optional*, defaults to `"\n"`):
Separator used to separate each step of the reasoning process.
step_separator (`str`, *optional*, defaults to `"<|step_token|>"`):
Separator used to separate each step of the reasoning process. It will be used as the specific token for
the steps to predict the rewards afterwards. If the token is not in the vocabulary, it will be added.
resize_to_multiple_of (`int`, *optional*, defaults to `64`):
Resize the input to multiple of this value. Only takes effect if `step_token` is not `None`.
train_on_last_step_only (`bool`, *optional*, defaults to `False`):
Whether to train only on the last step.
dataset_num_proc (`int`, *optional*, defaults to `None`):
Expand Down Expand Up @@ -73,9 +76,20 @@ class PRMConfig(TrainingArguments):
default=True,
metadata={"help": "Whether to disable dropout in the model and reference model."},
)
step_separator: str = field(
default="\n",
metadata={"help": "Separator used to separate each step of the reasoning process."},
step_separator: Optional[str] = field(
default="<|step_token|>",
metadata={
"help": (
"Separator used to separate each step of the reasoning process. It will be used as the specific token for "
"the steps to predict the rewards afterwards. If the token is not in the vocabulary, it will be added."
)
},
)
resize_to_multiple_of: Optional[int] = field(
default=64,
metadata={
"help": "Resize the input to multiple of this value. Only takes effect if `step_token` is not `None`."
},
)
train_on_last_step_only: bool = field(
default=False,
Expand Down
19 changes: 19 additions & 0 deletions trl/trainer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1647,3 +1647,22 @@ def flush_left(mask: torch.Tensor, *tensors: torch.Tensor) -> tuple[torch.Tensor
return mask
else:
return mask, *tensors


def is_token_in_vocab(tokenizer: PreTrainedTokenizerBase, token: str) -> bool:
# Use repr() to handle special characters, and remove the extra quotes
# Try direct vocab check
if token in tokenizer.get_vocab():
plaguss marked this conversation as resolved.
Show resolved Hide resolved
return True

# Try escaped representation, use repr() to handle special characters, and remove the extra quotes
escaped = repr(token)[1:-1]
if escaped in tokenizer.get_vocab():
return True

# Check if token gets encoded to a single ID
try:
token_ids = tokenizer.encode(token, add_special_tokens=False)
return len(token_ids) == 1
except Exception:
return False