Skip to content

Commit

Permalink
Ignore closed PRs to avoid spam. (#385)
Browse files Browse the repository at this point in the history
  • Loading branch information
Narsil authored Nov 17, 2023
1 parent 1799438 commit 829bfa8
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions bindings/python/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file


COMMIT_DESCRIPTION = """
Expand All @@ -32,6 +32,7 @@

ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]


def _remove_duplicate_names(
state_dict: Dict[str, torch.Tensor],
*,
Expand All @@ -48,9 +49,7 @@ def _remove_duplicate_names(
shareds = _find_shared_tensors(state_dict)
to_remove = defaultdict(list)
for shared in shareds:
complete_names = set(
[name for name in shared if _is_complete(state_dict[name])]
)
complete_names = set([name for name in shared if _is_complete(state_dict[name])])
if not complete_names:
if len(shared) == 1:
# Force contiguous
Expand Down Expand Up @@ -81,11 +80,13 @@ def _remove_duplicate_names(
to_remove[keep_name].append(name)
return to_remove


def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]:
try:
import transformers
import json

import transformers

config_filename = hf_hub_download(
model_id, revision=revision, filename="config.json", token=token, cache_dir=folder
)
Expand All @@ -98,10 +99,11 @@ def get_discard_names(model_id: str, revision: Optional[str], folder: str, token
# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])

except Exception as e:
except Exception:
discard_names = []
return discard_names


class AlreadyExists(Exception):
pass

Expand All @@ -126,8 +128,12 @@ def rename(pt_filename: str) -> str:
return local


def convert_multi(model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult:
filename = hf_hub_download(repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder)
def convert_multi(
model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]
) -> ConversionResult:
filename = hf_hub_download(
repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder
)
with open(filename, "r") as f:
data = json.load(f)

Expand Down Expand Up @@ -157,7 +163,9 @@ def convert_multi(model_id: str, *, revision=Optional[str], folder: str, token:
return operations, errors


def convert_single(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult:
def convert_single(
model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]
) -> ConversionResult:
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token, cache_dir=folder)

sf_name = "model.safetensors"
Expand Down Expand Up @@ -222,23 +230,27 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[st
except Exception:
return None
for discussion in discussions:
if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
if discussion.status in {"open", "closed"} and discussion.is_pull_request and discussion.title == pr_title:
commits = api.list_repo_commits(model_id, revision=discussion.git_reference)

if main_commit == commits[1].commit_id:
return discussion
return None


def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult:
def convert_generic(
model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]
) -> ConversionResult:
operations = []
errors = []

extensions = set([".bin", ".ckpt"])
for filename in filenames:
prefix, ext = os.path.splitext(filename)
if ext in extensions:
pt_filename = hf_hub_download(model_id, revision=revision, filename=filename, token=token, cache_dir=folder)
pt_filename = hf_hub_download(
model_id, revision=revision, filename=filename, token=token, cache_dir=folder
)
dirname, raw_filename = os.path.split(filename)
if raw_filename == "pytorch_model.bin":
# XXX: This is a special case to handle `transformers` and the
Expand All @@ -255,7 +267,9 @@ def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filen
return operations, errors


def convert(api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
def convert(
api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False
) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
pr_title = "Adding `safetensors` variant of this model"
info = api.model_info(model_id, revision=revision)
filenames = set(s.rfilename for s in info.siblings)
Expand All @@ -279,13 +293,19 @@ def convert(api: "HfApi", model_id: str, revision: Optional[str] = None, force:

discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
if "pytorch_model.bin" in filenames:
operations, errors = convert_single(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names)
operations, errors = convert_single(
model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
)
elif "pytorch_model.bin.index.json" in filenames:
operations, errors = convert_multi(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names)
operations, errors = convert_multi(
model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
)
else:
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
else:
operations, errors = convert_generic(model_id, revision=revision, folder=folder, filenames=filenames, token=api.token)
operations, errors = convert_generic(
model_id, revision=revision, folder=folder, filenames=filenames, token=api.token
)

if operations:
new_pr = api.create_commit(
Expand Down

0 comments on commit 829bfa8

Please sign in to comment.