Skip to content

Commit

Permalink
fix: get num_added_tokens from resize function (#344)
Browse files Browse the repository at this point in the history
* get num_added_tokens

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

* remove extra code

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>

---------

Signed-off-by: Sukriti-Sharma4 <sukriti.sharma4@ibm.com>
  • Loading branch information
Ssukriti committed Sep 19, 2024
1 parent 57cadc3 commit 36a554c
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 33 deletions.
2 changes: 1 addition & 1 deletion tests/utils/test_merge_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_post_process_vLLM_adapters_new_tokens():
# do the post processing
with tempfile.TemporaryDirectory() as tempdir:
post_process_vLLM_adapters_new_tokens(
DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS, tempdir
DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS, tempdir, num_added_tokens =1
)

# check that new_embeddings.safetensors exist
Expand Down
3 changes: 2 additions & 1 deletion tuning/data/tokenizer_data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def tokenizer_and_embedding_resize(
tokenizer: transformers.PreTrainedTokenizer,
model: transformers.PreTrainedModel,
multiple_of: int = 1,
):
):
"""Resize tokenizer and embedding."""
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
embedding_size = int(multiple_of * math.ceil(len(tokenizer) / multiple_of))
Expand All @@ -44,3 +44,4 @@ def tokenizer_and_embedding_resize(

input_embeddings[-num_new_tokens:] = input_embeddings_avg
output_embeddings[-num_new_tokens:] = output_embeddings_avg
return num_new_tokens
8 changes: 4 additions & 4 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ def train(

# TODO: lower priority but understand if resizing impacts inference quality and why its needed.
# It makes sense if we manipulate tokenizer that we also save it and provide it to inference.
tokenizer_data_utils.tokenizer_and_embedding_resize(
num_added_tokens = tokenizer_data_utils.tokenizer_and_embedding_resize(
special_tokens_dict=special_tokens_dict,
tokenizer=tokenizer,
model=model,
Expand Down Expand Up @@ -387,7 +387,7 @@ def train(

trainer.train(resume_from_checkpoint)

return trainer
return trainer, num_added_tokens


def save(path: str, trainer: SFTTrainer, log_level="WARNING"):
Expand Down Expand Up @@ -611,7 +611,7 @@ def main():
combined_tracker_configs.aim_config = aim_config

try:
trainer = train(
trainer, num_added_tokens = train(
model_args=model_args,
data_args=data_args,
train_args=training_args,
Expand Down Expand Up @@ -669,7 +669,7 @@ def main():
checkpoint_dir = training_args.save_model_dir
if checkpoint_dir:
print(f"Post processing LoRA adapters in {checkpoint_dir}")
post_process_vLLM_adapters_new_tokens(path_to_checkpoint=checkpoint_dir)
post_process_vLLM_adapters_new_tokens(path_to_checkpoint=checkpoint_dir, num_added_tokens=num_added_tokens)
except Exception as e: # pylint: disable=broad-except
logging.error(traceback.format_exc())
write_termination_log(
Expand Down
35 changes: 8 additions & 27 deletions tuning/utils/merge_model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,21 +120,12 @@ def _copy_files_to_directory(src: str, dest: str, exclude_files: list[str] = Non


def post_process_vLLM_adapters_new_tokens(
path_to_checkpoint: str, modified_checkpoint_path: str = None
path_to_checkpoint: str, modified_checkpoint_path: str = None, num_added_tokens: int=0
):
# if not set, original checkpoint will be modified
if not modified_checkpoint_path:
modified_checkpoint_path = path_to_checkpoint

# Get all values of new token indexes
sorted_token_indexes = []
if os.path.isfile(os.path.join(path_to_checkpoint, "added_tokens.json")):
with open(
os.path.join(path_to_checkpoint, "added_tokens.json"), "r", encoding="utf-8"
) as fp:
added_tokens = json.load(fp)
sorted_token_indexes = sorted(added_tokens.values())

with safe_open(
os.path.join(path_to_checkpoint, "adapter_model.safetensors"), framework="pt"
) as f:
Expand All @@ -145,7 +136,7 @@ def post_process_vLLM_adapters_new_tokens(
for k in f.keys():
if "lm_head.weight" in k or "embed_tokens.weight" in k:
embeddings_weights_in_adapters = True
if len(sorted_token_indexes) < 1:
if num_added_tokens == 0 :
raise NotImplementedError(
"Seems like embeddings are resized without adding new tokens. \
Cannot be post-processed to load on vLLM. Try setting \
Expand All @@ -158,28 +149,18 @@ def post_process_vLLM_adapters_new_tokens(
if "lm_head.weight" in k:
lm_head = f.get_tensor(k)
# pull out tensor values of new tokens
if len(sorted_token_indexes) == 1:
new_output_embeddings = lm_head[
sorted_token_indexes[0] : sorted_token_indexes[0] + 1
]
elif len(sorted_token_indexes) > 1:
new_output_embeddings = lm_head[
sorted_token_indexes[0] : sorted_token_indexes[-1]
]

new_output_embeddings = lm_head[
-num_added_tokens : ]
# vLLM requires renaming to output_embeddings
new_embeddings["output_embeddings"] = new_output_embeddings

elif "embed_tokens.weight" in k:
embed_tokens = f.get_tensor(k)
# pull out tensor values of new tokens
if len(sorted_token_indexes) == 1:
new_input_embeddings = embed_tokens[
sorted_token_indexes[0] : sorted_token_indexes[0] + 1
]
elif len(sorted_token_indexes) > 1:
new_input_embeddings = embed_tokens[
sorted_token_indexes[0] : sorted_token_indexes[-1]
]
new_input_embeddings = embed_tokens[
-num_added_tokens :
]
# vLLM requires renaming to input_embeddings
new_embeddings["input_embeddings"] = new_input_embeddings
else:
Expand Down

0 comments on commit 36a554c

Please sign in to comment.