From ab00b106b631c9a44143c08b621b0c5f9f5ccc27 Mon Sep 17 00:00:00 2001 From: Nicolas Patry Date: Fri, 17 Nov 2023 19:03:11 +0100 Subject: [PATCH] Ignore closed PRs to avoid spam. --- bindings/python/convert.py | 52 ++++++++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/bindings/python/convert.py b/bindings/python/convert.py index a700382d..a61476dd 100644 --- a/bindings/python/convert.py +++ b/bindings/python/convert.py @@ -10,7 +10,7 @@ from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download from huggingface_hub.file_download import repo_folder_name -from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete +from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file COMMIT_DESCRIPTION = """ @@ -32,6 +32,7 @@ ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]] + def _remove_duplicate_names( state_dict: Dict[str, torch.Tensor], *, @@ -48,9 +49,7 @@ def _remove_duplicate_names( shareds = _find_shared_tensors(state_dict) to_remove = defaultdict(list) for shared in shareds: - complete_names = set( - [name for name in shared if _is_complete(state_dict[name])] - ) + complete_names = set([name for name in shared if _is_complete(state_dict[name])]) if not complete_names: if len(shared) == 1: # Force contiguous @@ -81,11 +80,13 @@ def _remove_duplicate_names( to_remove[keep_name].append(name) return to_remove + def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]: try: - import transformers import json + import transformers + config_filename = hf_hub_download( model_id, revision=revision, filename="config.json", token=token, cache_dir=folder ) @@ -98,10 +99,11 @@ def get_discard_names(model_id: str, revision: Optional[str], folder: str, token # Name for this varible depends on transformers version. discard_names = getattr(class_, "_tied_weights_keys", []) - except Exception as e: + except Exception: discard_names = [] return discard_names + class AlreadyExists(Exception): pass @@ -126,8 +128,12 @@ def rename(pt_filename: str) -> str: return local -def convert_multi(model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult: - filename = hf_hub_download(repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder) +def convert_multi( + model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str] +) -> ConversionResult: + filename = hf_hub_download( + repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder + ) with open(filename, "r") as f: data = json.load(f) @@ -157,7 +163,9 @@ def convert_multi(model_id: str, *, revision=Optional[str], folder: str, token: return operations, errors -def convert_single(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult: +def convert_single( + model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str] +) -> ConversionResult: pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token, cache_dir=folder) sf_name = "model.safetensors" @@ -222,7 +230,7 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[st except Exception: return None for discussion in discussions: - if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title: + if discussion.status in {"open", "closed"} and discussion.is_pull_request and discussion.title == pr_title: commits = api.list_repo_commits(model_id, revision=discussion.git_reference) if main_commit == commits[1].commit_id: @@ -230,7 +238,9 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[st return None -def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult: +def convert_generic( + model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str] +) -> ConversionResult: operations = [] errors = [] @@ -238,7 +248,9 @@ def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filen for filename in filenames: prefix, ext = os.path.splitext(filename) if ext in extensions: - pt_filename = hf_hub_download(model_id, revision=revision, filename=filename, token=token, cache_dir=folder) + pt_filename = hf_hub_download( + model_id, revision=revision, filename=filename, token=token, cache_dir=folder + ) dirname, raw_filename = os.path.split(filename) if raw_filename == "pytorch_model.bin": # XXX: This is a special case to handle `transformers` and the @@ -255,7 +267,9 @@ def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filen return operations, errors -def convert(api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]: +def convert( + api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False +) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]: pr_title = "Adding `safetensors` variant of this model" info = api.model_info(model_id, revision=revision) filenames = set(s.rfilename for s in info.siblings) @@ -279,13 +293,19 @@ def convert(api: "HfApi", model_id: str, revision: Optional[str] = None, force: discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token) if "pytorch_model.bin" in filenames: - operations, errors = convert_single(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names) + operations, errors = convert_single( + model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names + ) elif "pytorch_model.bin.index.json" in filenames: - operations, errors = convert_multi(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names) + operations, errors = convert_multi( + model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names + ) else: raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert") else: - operations, errors = convert_generic(model_id, revision=revision, folder=folder, filenames=filenames, token=api.token) + operations, errors = convert_generic( + model_id, revision=revision, folder=folder, filenames=filenames, token=api.token + ) if operations: new_pr = api.create_commit(