-
Notifications
You must be signed in to change notification settings - Fork 201
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Better convert. #384
Better convert. #384
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hey @Narsil, just reviewed the PR. I mostly focused on the huggingface_hub integration and less on the shared_tensors/discard_names logic. The current version won't work on a different revision but will be very quick to fix (see comments).
main_commit = api.list_repo_commits(model_id)[0].commit_id | ||
discussions = api.get_repo_discussions(repo_id=model_id) | ||
main_commit = api.list_repo_commits(model_id, revision=revision)[0].commit_id | ||
discussions = api.get_repo_discussions(repo_id=model_id, revision=revision) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_repo_discussions
lists all discussions/PRs in the Community Tab. It doesn't have a revision
parameter => previous_pr
will always fail (return None)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops :(
@@ -101,12 +157,12 @@ def convert_multi(model_id: str, folder: str, token: Optional[str]) -> Conversio | |||
return operations, errors | |||
|
|||
|
|||
def convert_single(model_id: str, folder: str, token: Optional[str]) -> ConversionResult: | |||
def convert_single(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult: | |||
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token, cache_dir=folder) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use revision
parameter to download pt_filename
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Forgot ! Thanks
filenames = set(s.rfilename for s in info.siblings) | ||
|
||
with TemporaryDirectory() as d: | ||
with TemporaryDirectory(prefix=os.getenv("HF_HOME", "") + "/") as d: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) you can use huggingface_hub.constants.HF_HOME
to retrieve the user hf home (will check for HF_HOME
or XDG_CACHE_HOME
env variable + default to ~/.cache/huggingface
)
(nit 2) I would set a default directory in HF_HOME + "/safetensors_converter"
just in case the converter crash without cleaning the folder afterwards (at least all tmp directories will be in the same place)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Indeed. I removed that part actually to keep everything in a real temporary folder (/tmp
). It prevents cache reuse, but should make the OS clean up correctly.
I had set this up for testing to use a real SSD on the machine I was using (/tmp was mounted to a slower disk)
try: | ||
main_commit = api.list_repo_commits(model_id)[0].commit_id | ||
discussions = api.get_repo_discussions(repo_id=model_id) | ||
main_commit = api.list_repo_commits(model_id, revision=revision)[0].commit_id |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no need to list all commit history to get the last one. This information is available as model_info(...).sha
(nit) I would also rename the variable to something like revision_commit
instead of main_commit
(since pulling from revision and not main
)
main_commit = api.list_repo_commits(model_id, revision=revision)[0].commit_id | |
revision_commit = api.model_info(model_id, revision=revision).sha |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Agreed ! (Made send before only adding the revision)
What does this PR do?
Fixes # (issue) or description of the problem this PR solves.