Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore closed PRs to avoid spam. #385

Merged
merged 1 commit into from
Nov 17, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 36 additions & 16 deletions bindings/python/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

from huggingface_hub import CommitInfo, CommitOperationAdd, Discussion, HfApi, hf_hub_download
from huggingface_hub.file_download import repo_folder_name
from safetensors.torch import save_file, load_file, _find_shared_tensors, _is_complete
from safetensors.torch import _find_shared_tensors, _is_complete, load_file, save_file


COMMIT_DESCRIPTION = """
Expand All @@ -32,6 +32,7 @@

ConversionResult = Tuple[List["CommitOperationAdd"], List[Tuple[str, "Exception"]]]


def _remove_duplicate_names(
state_dict: Dict[str, torch.Tensor],
*,
Expand All @@ -48,9 +49,7 @@ def _remove_duplicate_names(
shareds = _find_shared_tensors(state_dict)
to_remove = defaultdict(list)
for shared in shareds:
complete_names = set(
[name for name in shared if _is_complete(state_dict[name])]
)
complete_names = set([name for name in shared if _is_complete(state_dict[name])])
if not complete_names:
if len(shared) == 1:
# Force contiguous
Expand Down Expand Up @@ -81,11 +80,13 @@ def _remove_duplicate_names(
to_remove[keep_name].append(name)
return to_remove


def get_discard_names(model_id: str, revision: Optional[str], folder: str, token: Optional[str]) -> List[str]:
try:
import transformers
import json

import transformers

config_filename = hf_hub_download(
model_id, revision=revision, filename="config.json", token=token, cache_dir=folder
)
Expand All @@ -98,10 +99,11 @@ def get_discard_names(model_id: str, revision: Optional[str], folder: str, token
# Name for this varible depends on transformers version.
discard_names = getattr(class_, "_tied_weights_keys", [])

except Exception as e:
except Exception:
discard_names = []
return discard_names


class AlreadyExists(Exception):
pass

Expand All @@ -126,8 +128,12 @@ def rename(pt_filename: str) -> str:
return local


def convert_multi(model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult:
filename = hf_hub_download(repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder)
def convert_multi(
model_id: str, *, revision=Optional[str], folder: str, token: Optional[str], discard_names: List[str]
) -> ConversionResult:
filename = hf_hub_download(
repo_id=model_id, revision=revision, filename="pytorch_model.bin.index.json", token=token, cache_dir=folder
)
with open(filename, "r") as f:
data = json.load(f)

Expand Down Expand Up @@ -157,7 +163,9 @@ def convert_multi(model_id: str, *, revision=Optional[str], folder: str, token:
return operations, errors


def convert_single(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult:
def convert_single(
model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]
) -> ConversionResult:
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token, cache_dir=folder)

sf_name = "model.safetensors"
Expand Down Expand Up @@ -222,23 +230,27 @@ def previous_pr(api: "HfApi", model_id: str, pr_title: str, revision=Optional[st
except Exception:
return None
for discussion in discussions:
if discussion.status == "open" and discussion.is_pull_request and discussion.title == pr_title:
if discussion.status in {"open", "closed"} and discussion.is_pull_request and discussion.title == pr_title:
commits = api.list_repo_commits(model_id, revision=discussion.git_reference)

if main_commit == commits[1].commit_id:
return discussion
return None


def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]) -> ConversionResult:
def convert_generic(
model_id: str, *, revision=Optional[str], folder: str, filenames: Set[str], token: Optional[str]
) -> ConversionResult:
operations = []
errors = []

extensions = set([".bin", ".ckpt"])
for filename in filenames:
prefix, ext = os.path.splitext(filename)
if ext in extensions:
pt_filename = hf_hub_download(model_id, revision=revision, filename=filename, token=token, cache_dir=folder)
pt_filename = hf_hub_download(
model_id, revision=revision, filename=filename, token=token, cache_dir=folder
)
dirname, raw_filename = os.path.split(filename)
if raw_filename == "pytorch_model.bin":
# XXX: This is a special case to handle `transformers` and the
Expand All @@ -255,7 +267,9 @@ def convert_generic(model_id: str, *, revision=Optional[str], folder: str, filen
return operations, errors


def convert(api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
def convert(
api: "HfApi", model_id: str, revision: Optional[str] = None, force: bool = False
) -> Tuple["CommitInfo", List[Tuple[str, "Exception"]]]:
pr_title = "Adding `safetensors` variant of this model"
info = api.model_info(model_id, revision=revision)
filenames = set(s.rfilename for s in info.siblings)
Expand All @@ -279,13 +293,19 @@ def convert(api: "HfApi", model_id: str, revision: Optional[str] = None, force:

discard_names = get_discard_names(model_id, revision=revision, folder=folder, token=api.token)
if "pytorch_model.bin" in filenames:
operations, errors = convert_single(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names)
operations, errors = convert_single(
model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
)
elif "pytorch_model.bin.index.json" in filenames:
operations, errors = convert_multi(model_id, revision=revision, folder=folder, token=api.token, discard_names = discard_names)
operations, errors = convert_multi(
model_id, revision=revision, folder=folder, token=api.token, discard_names=discard_names
)
else:
raise RuntimeError(f"Model {model_id} doesn't seem to be a valid pytorch model. Cannot convert")
else:
operations, errors = convert_generic(model_id, revision=revision, folder=folder, filenames=filenames, token=api.token)
operations, errors = convert_generic(
model_id, revision=revision, folder=folder, filenames=filenames, token=api.token
)

if operations:
new_pr = api.create_commit(
Expand Down
Loading