Skip to content

Commit

Permalink
[Tool]: Fix the issue of safetensors conversion LLama error (InternLM…
Browse files Browse the repository at this point in the history
  • Loading branch information
yuanphoenix authored Apr 11, 2024
1 parent c4108d3 commit 2db5604
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions tools/convert2llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def convert(src, tgt):

head_dim = config.hidden_size // config.num_attention_heads
num_key_value_groups = config.num_attention_heads \
// config.num_key_value_heads
// config.num_key_value_heads

# load index json file
index_file = 'pytorch_model.bin.index.json'
Expand Down Expand Up @@ -138,7 +138,11 @@ def convert(src, tgt):
index_dict['weight_map'][k] = filename

print(f'Saving to {os.path.join(tgt, filename)}...', flush=True)
torch.save(llama_states, os.path.join(tgt, filename))
if filename.endswith('.safetensors'):
from safetensors.torch import save_file
save_file(llama_states, os.path.join(tgt, filename), metadata={"format": "pt"})
else:
torch.save(llama_states, os.path.join(tgt, filename))
del states

print('Saving config and tokenizer...', flush=True)
Expand Down

0 comments on commit 2db5604

Please sign in to comment.