Skip to content

Commit

Permalink
[Tool]: Update tools/convert2llama.py to support safetensors format (
Browse files Browse the repository at this point in the history
  • Loading branch information
gaoyang07 authored Apr 10, 2024
1 parent 861327b commit c4108d3
Showing 1 changed file with 44 additions and 11 deletions.
55 changes: 44 additions & 11 deletions tools/convert2llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,28 @@
from transformers import AutoConfig, LlamaConfig, LlamaTokenizer


def weight_load(fp, **kwargs):
"""Load weights from a file."""
is_safetensors = kwargs.pop('is_safetensors', False)

if is_safetensors:
try:
from safetensors import safe_open
except ImportError:
raise ImportError(
'Before loading ckpts in the `safetensors` format, '
'please install the `safetensors` package first.')

model = safe_open(fp, framework='pt')
state_dict = {}
for k in model.keys():
state_dict[k] = model.get_tensor(k)
return state_dict

else:
return torch.load(fp, **kwargs)


def save_conifg(config, tgt):
config_dict = config.to_dict()
unnecessary_keys = [
Expand Down Expand Up @@ -41,19 +63,29 @@ def convert(src, tgt):
// config.num_key_value_heads

# load index json file
index_file = os.path.join(src, 'pytorch_model.bin.index.json')
if os.path.exists(index_file):
with open(index_file) as fp:
index_file = 'pytorch_model.bin.index.json'
if os.path.exists(os.path.join(src, index_file)):
with open(os.path.join(src, index_file)) as fp:
index_dict = json.load(fp)
index_dict['weight_map'] = {}
else:
index_dict = None
index_file = 'model.safetensors.index.json'
if os.path.exists(os.path.join(src, index_file)):
with open(os.path.join(src, index_file)) as fp:
index_dict = json.load(fp)
index_dict['weight_map'] = {}
else:
index_dict = None

os.makedirs(tgt, exist_ok=True)
for filename in tqdm(os.listdir(src)):
if not filename.endswith('.bin'):
if not any(filename.endswith(ext) for ext in ('.bin', '.safetensors')):
continue
states = torch.load(os.path.join(src, filename))

print(f'Loading {os.path.join(src, filename)}...', flush=True)
states = weight_load(os.path.join(src, filename),
is_safetensors=filename.endswith('.safetensors'))

llama_states = {}
for k, v in states.copy().items():
if 'wqkv' in k:
Expand Down Expand Up @@ -104,23 +136,24 @@ def convert(src, tgt):
if index_dict is not None:
for k in llama_states:
index_dict['weight_map'][k] = filename
print(f"Saving to {os.path.join(tgt, filename)}...", flush=True)

print(f'Saving to {os.path.join(tgt, filename)}...', flush=True)
torch.save(llama_states, os.path.join(tgt, filename))
del states

print('Saving config and tokenizer...')
print('Saving config and tokenizer...', flush=True)
# index.json
if index_dict is not None:
with open(os.path.join(tgt, 'pytorch_model.bin.index.json'),
'w') as fp:
with open(os.path.join(tgt, index_file), 'w') as fp:
json.dump(index_dict, fp, indent=2)
# tokenizer
tokenizer = LlamaTokenizer.from_pretrained(src)
tokenizer.init_kwargs.pop('auto_map', None)
tokenizer.save_pretrained(tgt)
# config
save_conifg(config, tgt)
print('Done!')

print('Done!', flush=True)


def parse_args():
Expand Down

0 comments on commit c4108d3

Please sign in to comment.