From 98069aeac65553f7feb6b7302377c84834a387ee Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 17 Nov 2023 17:27:48 +0000 Subject: [PATCH] Better convert. --- bindings/python/convert.py | 241 ++++++++++++++++++------------------- 1 file changed, 115 insertions(+), 126 deletions(-) diff --git a/bindings/python/convert.py b/bindings/python/convert.py index 9ad459e0..a700382d 100644 --- a/bindings/python/convert.py +++ b/bindings/python/convert.py @@ -3,7 +3,6 @@ import os import shutil from collections import defaultdict -from inspect import signature from tempfile import TemporaryDirectory from typing import Dict, List, Optional, Set, Tuple @@ -11,8 +10,7 @@ from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download from huggingface_hub.file_download import repo_folder_name -from safetensors.torch import load_file, save_file -from transformers import AutoConfig +from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete COMMIT_DESCRIPTION = """ @@ -34,20 +32,78 @@ ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]] +def _remove_duplicate_names( + state_dict: Dict[str, torch.Tensor], + *, + preferred_names: List[str] = None, + discard_names: List[str] = None, +) -> Dict[str, List[str]]: + if preferred_names is None: + preferred_names = [] + preferred_names = set(preferred_names) + if discard_names is None: + discard_names = [] + discard_names = set(discard_names) + + shareds = _find_shared_tensors(state_dict) + to_remove = defaultdict(list) + for shared in shareds: + complete_names = set( + [name for name in shared if _is_complete(state_dict[name])] + ) + if not complete_names: + if len(shared) == 1: + # Force contiguous + name = list(shared)[0] + state_dict[name] = state_dict[name].clone() + complete_names = {name} + else: + raise RuntimeError( + f"Error while trying to find names to remove to save state dict, but found no suitable name to keep for saving amongst: {shared}. None is covering the entire storage.Refusing to save/load the model since you could be storing much more memory than needed. Please refer to https://huggingface.co/docs/safetensors/torch_shared_tensors for more information. Or open an issue." + ) -class AlreadyExists(Exception): - pass + keep_name = sorted(list(complete_names))[0] + + # Mecanism to preferentially select keys to keep + # coming from the on-disk file to allow + # loading models saved with a different choice + # of keep_name + preferred = complete_names.difference(discard_names) + if preferred: + keep_name = sorted(list(preferred))[0] + + if preferred_names: + preferred = preferred_names.intersection(complete_names) + if preferred: + keep_name = sorted(list(preferred))[0] + for name in sorted(shared): + if name != keep_name: + to_remove[keep_name].append(name) + return to_remove + +def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]: + try: + import transformers + import json + + config_filename = hf_hub_download( + model_id, revision=revision, filename="config.json", token=token, cache_dir=folder + ) + with open(config_filename, "r") as f: + config = json.load(f) + architecture = config["architectures"][0] + class_ = getattr(transformers, architecture) -def shared_pointers(tensors): - ptrs = defaultdict(list) - for k, v in tensors.items(): - ptrs[v.data_ptr()].append(k) - failing = [] - for ptr, names in ptrs.items(): - if len(names) > 1: - failing.append(names) - return failing + # Name for this varible depends on transformers version. + discard_names = getattr(class_, "_tied_weights_keys", []) + + except Exception as e: + discard_names = [] + return discard_names + +class AlreadyExists(Exception): + pass def check_file_size(sf_filename: str, pt_filename: str): @@ -70,8 +126,8 @@ def rename(pt_filename: str) -> str: return local -def convert_multi(model_id: str, folder: str, token: Optional[str]) -> ConversionResult: - filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder) +def convert_multi(model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult: + filename = hf_hub_download(repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder) with open(filename, "r") as f: data = json.load(f) @@ -82,7 +138,7 @@ def convert_multi(model_id: str, folder: str, token: Optional[str]) -> Conversio sf_filename = rename(pt_filename) sf_filename = os.path.join(folder, sf_filename) - convert_file(pt_filename, sf_filename) + convert_file(pt_filename, sf_filename, discard_names=discard_names) local_filenames.append(sf_filename) index = os.path.join(folder, "model.safetensors.index.json") @@ -101,12 +157,12 @@ def convert_multi(model_id: str, folder: str, token: Optional[str]) -> Conversio return operations, errors -def convert_single(model_id: str, folder: str, token: Optional[str]) -> ConversionResult: +def convert_single(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult: pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token, cache_dir=folder) sf_name = "model.safetensors" sf_filename = os.path.join(folder, sf_name) - convert_file(pt_filename, sf_filename) + convert_file(pt_filename, sf_filename, discard_names) operations = [CommitOperationAdd(path_in_repo=sf_name, path_or_fileobj=sf_filename)] errors: List[Tuple[str, "Exception"]] = [] return operations, errors @@ -115,21 +171,25 @@ def convert_single(model_id: str, folder: str, token: Optional[str]) -> Conversi def convert_file( pt_filename: str, sf_filename: str, + discard_names: List[str], ): loaded = torch.load(pt_filename, map_location="cpu") if "state_dict" in loaded: loaded = loaded["state_dict"] - shared = shared_pointers(loaded) - for shared_weights in shared: - for name in shared_weights[1:]: - loaded.pop(name) - - # For tensors to be contiguous + to_removes = _remove_duplicate_names(loaded, discard_names=discard_names) + + metadata = {"format": "pt"} + for kept_name, to_remove_group in to_removes.items(): + for to_remove in to_remove_group: + if to_remove not in metadata: + metadata[to_remove] = kept_name + del loaded[to_remove] + # Force tensors to be contiguous loaded = {k: v.contiguous() for k, v in loaded.items()} dirname = os.path.dirname(sf_filename) os.makedirs(dirname, exist_ok=True) - save_file(loaded, sf_filename, metadata={"format": "pt"}) + save_file(loaded, sf_filename, metadata=metadata) check_file_size(sf_filename, pt_filename) reloaded = load_file(sf_filename) for k in loaded: @@ -155,79 +215,10 @@ def create_diff(pt_infos: Dict[str, List[str]], sf_infos: Dict[str, List[str]]) return "\n".join(errors) -def check_final_model(model_id: str, folder: str, token: Optional[str]): - config = hf_hub_download(repo_id=model_id, filename="config.json", token=token, cache_dir=folder) - shutil.copy(config, os.path.join(folder, "config.json")) - config = AutoConfig.from_pretrained(folder) - - import transformers - - class_ = getattr(transformers, config.architectures[0]) - (pt_model, pt_infos) = class_.from_pretrained(folder, output_loading_info=True) - (sf_model, sf_infos) = class_.from_pretrained(folder, output_loading_info=True) - - if pt_infos != sf_infos: - error_string = create_diff(pt_infos, sf_infos) - raise ValueError(f"Different infos when reloading the model: {error_string}") - - pt_params = pt_model.state_dict() - sf_params = sf_model.state_dict() - - pt_shared = shared_pointers(pt_params) - sf_shared = shared_pointers(sf_params) - if pt_shared != sf_shared: - raise RuntimeError("The reconstructed model is wrong, shared tensors are different {shared_pt} != {shared_tf}") - - sig = signature(pt_model.forward) - input_ids = torch.arange(10).unsqueeze(0) - pixel_values = torch.randn(1, 3, 224, 224) - input_values = torch.arange(1000).float().unsqueeze(0) - # Hardcoded for whisper basically - input_features = torch.zeros((1, 80, 3000)) - kwargs = {} - if "input_ids" in sig.parameters: - kwargs["input_ids"] = input_ids - if "input_features" in sig.parameters: - kwargs["input_features"] = input_features - if "decoder_input_ids" in sig.parameters: - kwargs["decoder_input_ids"] = input_ids - if "pixel_values" in sig.parameters: - kwargs["pixel_values"] = pixel_values - if "input_values" in sig.parameters: - kwargs["input_values"] = input_values - if "bbox" in sig.parameters: - kwargs["bbox"] = torch.zeros((1, 10, 4)).long() - if "image" in sig.parameters: - kwargs["image"] = pixel_values - - if torch.cuda.is_available(): - pt_model = pt_model.cuda() - sf_model = sf_model.cuda() - kwargs = {k: v.cuda() for k, v in kwargs.items()} - - try: - pt_logits = pt_model(**kwargs)[0] - except Exception as e: - try: - # Musicgen special exception. - decoder_input_ids = torch.ones((input_ids.shape[0] * pt_model.decoder.num_codebooks, 1), dtype=torch.long) - if torch.cuda.is_available(): - decoder_input_ids = decoder_input_ids.cuda() - - kwargs["decoder_input_ids"] = decoder_input_ids - pt_logits = pt_model(**kwargs)[0] - except Exception: - raise e - sf_logits = sf_model(**kwargs)[0] - - torch.testing.assert_close(sf_logits, pt_logits) - print(f"Model {model_id} is ok !") - - -def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discussion"]: +def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[str]) -> Optional["Discussion"]: try: - main_commit = api.list_repo_commits(model_id)[0].commit_id - discussions = api.get_repo_discussions(repo_id=model_id) + main_commit = api.list_repo_commits(model_id, revision=revision)[0].commit_id + discussions = api.get_repo_discussions(repo_id=model_id, revision=revision) except Exception: return None for discussion in discussions: @@ -239,7 +230,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str) -> Optional["Discuss return None -def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult: +def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult: operations = [] errors = [] @@ -247,7 +238,7 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Opti for filename in filenames: prefix, ext = os.path.splitext(filename) if ext in extensions: - pt_filename = hf_hub_download(model_id, filename=filename, token=token, cache_dir=folder) + pt_filename = hf_hub_download(model_id, revision=revision, filename=filename, token=token, cache_dir=folder) dirname, raw_filename = os.path.split(filename) if raw_filename == "pytorch_model.bin": # XXX: This is a special case to handle `transformers` and the @@ -257,25 +248,25 @@ def convert_generic(model_id: str, folder: str, filenames: Set[str], token: Opti sf_in_repo = f"{prefix}.safetensors" sf_filename = os.path.join(folder, sf_in_repo) try: - convert_file(pt_filename, sf_filename) + convert_file(pt_filename, sf_filename, discard_names=[]) operations.append(CommitOperationAdd(path_in_repo=sf_in_repo, path_or_fileobj=sf_filename)) except Exception as e: errors.append((pt_filename, e)) return operations, errors -def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]: +def convert(api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]: pr_title = "Adding `safetensors` variant of this model" - info = api.model_info(model_id) + info = api.model_info(model_id, revision=revision) filenames = set(s.rfilename for s in info.siblings) - with TemporaryDirectory() as d: + with TemporaryDirectory(prefix=os.getenv("HF_HOME", "") + "/") as d: folder = os.path.join(d, repo_folder_name(repo_id=model_id, repo_type="models")) os.makedirs(folder) new_pr = None try: operations = None - pr = previous_pr(api, model_id, pr_title) + pr = previous_pr(api, model_id, pr_title, revision=revision) library_name = getattr(info, "library_name", None) if any(filename.endswith(".safetensors") for filename in filenames) and not force: @@ -285,19 +276,21 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn new_pr = pr raise AlreadyExists(f"Model {model_id} already has an open PR check out {url}") elif library_name == "transformers": + + discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token) if "pytorch_model.bin" in filenames: - operations, errors = convert_single(model_id, folder, token=api.token) + operations, errors = convert_single(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names) elif "pytorch_model.bin.index.json" in filenames: - operations, errors = convert_multi(model_id, folder, token=api.token) + operations, errors = convert_multi(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names) else: raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert") - check_final_model(model_id, folder, token=api.token) else: - operations, errors = convert_generic(model_id, folder, filenames, token=api.token) + operations, errors = convert_generic(model_id, revision=revision, folder=folder, filenames=filenames, token=api.token) if operations: new_pr = api.create_commit( repo_id=model_id, + revision=revision, operations=operations, commit_message=pr_title, commit_description=COMMIT_DESCRIPTION, @@ -324,6 +317,11 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn type=str, help="The name of the model on the hub to convert. E.g. `gpt2` or `facebook/wav2vec2-base-960h`", ) + parser.add_argument( + "--revision", + type=str, + help="The revision to convert", + ) parser.add_argument( "--force", action="store_true", @@ -346,26 +344,17 @@ def convert(api: "HfApi", model_id: str, force: bool = False) -> Tuple["CommitIn " Continue [Y/n] ?" ) if txt.lower() in {"", "y"}: - try: - commit_info, errors = convert(api, model_id, force=args.force) - string = f""" + commit_info, errors = convert(api, model_id, revision=args.revision, force=args.force) + string = f""" ### Success 🔥 Yay! This model was successfully converted and a PR was open using your token, here: [{commit_info.pr_url}]({commit_info.pr_url}) - """ - if errors: - string += "\nErrors during conversion:\n" - string += "\n".join( - f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors - ) - print(string) - except Exception as e: - print( - f""" -### Error 😢😢😢 - -{e} - """ + """ + if errors: + string += "\nErrors during conversion:\n" + string += "\n".join( + f"Error while converting {filename}: {e}, skipped conversion" for filename, e in errors ) + print(string) else: print(f"Answer was `{txt}` aborting.")