Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better convert. #384

Merged
merged 1 commit into from
Nov 17, 2023
Merged

Better convert. #384

merged 1 commit into from
Nov 17, 2023

Conversation

Narsil
Copy link
Collaborator

@Narsil Narsil commented Nov 17, 2023

What does this PR do?

Fixes # (issue) or description of the problem this PR solves.

@Narsil Narsil merged commit 1799438 into main Nov 17, 2023
9 of 10 checks passed
@Narsil Narsil deleted the better_convert branch November 17, 2023 17:28
Copy link
Contributor

@Wauplin Wauplin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hey @Narsil, just reviewed the PR. I mostly focused on the huggingface_hub integration and less on the shared_tensors/discard_names logic. The current version won't work on a different revision but will be very quick to fix (see comments).

main_commit = api.list_repo_commits(model_id)[0].commit_id
discussions = api.get_repo_discussions(repo_id=model_id)
main_commit = api.list_repo_commits(model_id, revision=revision)[0].commit_id
discussions = api.get_repo_discussions(repo_id=model_id, revision=revision)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get_repo_discussions lists all discussions/PRs in the Community Tab. It doesn't have a revision parameter => previous_pr will always fail (return None)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops :(

@@ -101,12 +157,12 @@ def convert_multi(model_id: str, folder: str, token: Optional[str]) -> Conversio
return operations, errors


def convert_single(model_id: str, folder: str, token: Optional[str]) -> ConversionResult:
def convert_single(model_id: str, *, revision: Optional[str], folder: str, token: Optional[str], discard_names: List[str]) -> ConversionResult:
pt_filename = hf_hub_download(repo_id=model_id, filename="pytorch_model.bin", token=token, cache_dir=folder)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use revision parameter to download pt_filename?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Forgot ! Thanks

filenames = set(s.rfilename for s in info.siblings)

with TemporaryDirectory() as d:
with TemporaryDirectory(prefix=os.getenv("HF_HOME", "") + "/") as d:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) you can use huggingface_hub.constants.HF_HOME to retrieve the user hf home (will check for HF_HOME or XDG_CACHE_HOME env variable + default to ~/.cache/huggingface)

(nit 2) I would set a default directory in HF_HOME + "/safetensors_converter" just in case the converter crash without cleaning the folder afterwards (at least all tmp directories will be in the same place)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed. I removed that part actually to keep everything in a real temporary folder (/tmp). It prevents cache reuse, but should make the OS clean up correctly.

I had set this up for testing to use a real SSD on the machine I was using (/tmp was mounted to a slower disk)

try:
main_commit = api.list_repo_commits(model_id)[0].commit_id
discussions = api.get_repo_discussions(repo_id=model_id)
main_commit = api.list_repo_commits(model_id, revision=revision)[0].commit_id
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to list all commit history to get the last one. This information is available as model_info(...).sha

(nit) I would also rename the variable to something like revision_commit instead of main_commit (since pulling from revision and not main)

Suggested change
main_commit = api.list_repo_commits(model_id, revision=revision)[0].commit_id
revision_commit = api.model_info(model_id, revision=revision).sha

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agreed ! (Made send before only adding the revision)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants