diff --git a/tests/utils/test_merge_model_utils.py b/tests/utils/test_merge_model_utils.py index e6b5c2687..b82665f0f 100644 --- a/tests/utils/test_merge_model_utils.py +++ b/tests/utils/test_merge_model_utils.py @@ -56,7 +56,7 @@ def test_post_process_vLLM_adapters_new_tokens(): # do the post processing with tempfile.TemporaryDirectory() as tempdir: post_process_vLLM_adapters_new_tokens( - DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS, tempdir + DUMMY_TUNED_LLAMA_WITH_ADDED_TOKENS, tempdir, num_added_tokens =1 ) # check that new_embeddings.safetensors exist diff --git a/tuning/data/tokenizer_data_utils.py b/tuning/data/tokenizer_data_utils.py index 1edca167a..c16e89fc1 100644 --- a/tuning/data/tokenizer_data_utils.py +++ b/tuning/data/tokenizer_data_utils.py @@ -25,7 +25,7 @@ def tokenizer_and_embedding_resize( tokenizer: transformers.PreTrainedTokenizer, model: transformers.PreTrainedModel, multiple_of: int = 1, -): +): """Resize tokenizer and embedding.""" num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict) embedding_size = int(multiple_of * math.ceil(len(tokenizer) / multiple_of)) @@ -44,3 +44,4 @@ def tokenizer_and_embedding_resize( input_embeddings[-num_new_tokens:] = input_embeddings_avg output_embeddings[-num_new_tokens:] = output_embeddings_avg + return num_new_tokens diff --git a/tuning/sft_trainer.py b/tuning/sft_trainer.py index 2463ccadc..530aa10d4 100644 --- a/tuning/sft_trainer.py +++ b/tuning/sft_trainer.py @@ -266,7 +266,7 @@ def train( # TODO: lower priority but understand if resizing impacts inference quality and why its needed. # It makes sense if we manipulate tokenizer that we also save it and provide it to inference. - tokenizer_data_utils.tokenizer_and_embedding_resize( + num_added_tokens = tokenizer_data_utils.tokenizer_and_embedding_resize( special_tokens_dict=special_tokens_dict, tokenizer=tokenizer, model=model, @@ -387,7 +387,7 @@ def train( trainer.train(resume_from_checkpoint) - return trainer + return trainer, num_added_tokens def save(path: str, trainer: SFTTrainer, log_level="WARNING"): @@ -611,7 +611,7 @@ def main(): combined_tracker_configs.aim_config = aim_config try: - trainer = train( + trainer, num_added_tokens = train( model_args=model_args, data_args=data_args, train_args=training_args, @@ -669,7 +669,7 @@ def main(): checkpoint_dir = training_args.save_model_dir if checkpoint_dir: print(f"Post processing LoRA adapters in {checkpoint_dir}") - post_process_vLLM_adapters_new_tokens(path_to_checkpoint=checkpoint_dir) + post_process_vLLM_adapters_new_tokens(path_to_checkpoint=checkpoint_dir, num_added_tokens=num_added_tokens) except Exception as e: # pylint: disable=broad-except logging.error(traceback.format_exc()) write_termination_log( diff --git a/tuning/utils/merge_model_utils.py b/tuning/utils/merge_model_utils.py index 3ab39095d..984ed8fbb 100644 --- a/tuning/utils/merge_model_utils.py +++ b/tuning/utils/merge_model_utils.py @@ -120,21 +120,12 @@ def _copy_files_to_directory(src: str, dest: str, exclude_files: list[str] = Non def post_process_vLLM_adapters_new_tokens( - path_to_checkpoint: str, modified_checkpoint_path: str = None + path_to_checkpoint: str, modified_checkpoint_path: str = None, num_added_tokens: int=0 ): # if not set, original checkpoint will be modified if not modified_checkpoint_path: modified_checkpoint_path = path_to_checkpoint - # Get all values of new token indexes - sorted_token_indexes = [] - if os.path.isfile(os.path.join(path_to_checkpoint, "added_tokens.json")): - with open( - os.path.join(path_to_checkpoint, "added_tokens.json"), "r", encoding="utf-8" - ) as fp: - added_tokens = json.load(fp) - sorted_token_indexes = sorted(added_tokens.values()) - with safe_open( os.path.join(path_to_checkpoint, "adapter_model.safetensors"), framework="pt" ) as f: @@ -145,7 +136,7 @@ def post_process_vLLM_adapters_new_tokens( for k in f.keys(): if "lm_head.weight" in k or "embed_tokens.weight" in k: embeddings_weights_in_adapters = True - if len(sorted_token_indexes) < 1: + if num_added_tokens == 0 : raise NotImplementedError( "Seems like embeddings are resized without adding new tokens. \ Cannot be post-processed to load on vLLM. Try setting \ @@ -158,28 +149,18 @@ def post_process_vLLM_adapters_new_tokens( if "lm_head.weight" in k: lm_head = f.get_tensor(k) # pull out tensor values of new tokens - if len(sorted_token_indexes) == 1: - new_output_embeddings = lm_head[ - sorted_token_indexes[0] : sorted_token_indexes[0] + 1 - ] - elif len(sorted_token_indexes) > 1: - new_output_embeddings = lm_head[ - sorted_token_indexes[0] : sorted_token_indexes[-1] - ] + + new_output_embeddings = lm_head[ + -num_added_tokens : ] # vLLM requires renaming to output_embeddings new_embeddings["output_embeddings"] = new_output_embeddings elif "embed_tokens.weight" in k: embed_tokens = f.get_tensor(k) # pull out tensor values of new tokens - if len(sorted_token_indexes) == 1: - new_input_embeddings = embed_tokens[ - sorted_token_indexes[0] : sorted_token_indexes[0] + 1 - ] - elif len(sorted_token_indexes) > 1: - new_input_embeddings = embed_tokens[ - sorted_token_indexes[0] : sorted_token_indexes[-1] - ] + new_input_embeddings = embed_tokens[ + -num_added_tokens : + ] # vLLM requires renaming to input_embeddings new_embeddings["input_embeddings"] = new_input_embeddings else: