diff --git a/tools/convert2llama.py b/tools/convert2llama.py index 14629244..de65b588 100644 --- a/tools/convert2llama.py +++ b/tools/convert2llama.py @@ -60,7 +60,7 @@ def convert(src, tgt): head_dim = config.hidden_size // config.num_attention_heads num_key_value_groups = config.num_attention_heads \ - // config.num_key_value_heads + // config.num_key_value_heads # load index json file index_file = 'pytorch_model.bin.index.json' @@ -138,7 +138,11 @@ def convert(src, tgt): index_dict['weight_map'][k] = filename print(f'Saving to {os.path.join(tgt, filename)}...', flush=True) - torch.save(llama_states, os.path.join(tgt, filename)) + if filename.endswith('.safetensors'): + from safetensors.torch import save_file + save_file(llama_states, os.path.join(tgt, filename), metadata={"format": "pt"}) + else: + torch.save(llama_states, os.path.join(tgt, filename)) del states print('Saving config and tokenizer...', flush=True)