Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Downloading and caching checkpoints #42

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

sidharthrajaram
Copy link

@sidharthrajaram sidharthrajaram commented Dec 7, 2023

Downloading and caching ability when loading checkpoints (similar to Transformers API).

  • Can now provide a checkpoint filename or a checkpoint URL (or just a file path as before). Regardless, value passed to checkpoint_path in Phonemizer.from_checkpoint() should end with .pt. See test results below.
  • If just a checkpoint name is provided, it will try to retrieve the checkpoint under the DEFAULT_MODEL_BUCKET (see dp/model/model.py).
  • Caching is facilitated with cached_path

Why?

  • Convenient checkpoint loading (familiar behavior to what we're used to with Transformers/HuggingFace).
  • No longer have to manage location of checkpoints.
  • Enables easier model fetching (if your own trained checkpoints are on S3, MinIO, etc.)

Test cases:

Invalid file

>>> phonemizer = Phonemizer.from_checkpoint('test')
ValueError: test is not a valid model file (.pt).

Invalid checkpoint file

>>> phonemizer = Phonemizer.from_checkpoint('test.pt')
...
requests.exceptions.HTTPError: 403 Client Error: Forbidden for url: https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/test.pt

Checkpoint name provided

>>> phonemizer = Phonemizer.from_checkpoint('en_us_cmudict_forward.pt')
Loading model from /PATH/TO/.cache/cached_path/6c84425c...
>>>

Cached model loaded

>>> phonemizer = Phonemizer.from_checkpoint('en_us_cmudict_forward.pt')
Loading model from /PATH/TO/.cache/cached_path/6c84425c...
>>>

URL to checkpoint provided

>>> phonemizer = Phonemizer.from_checkpoint('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')
Loading model from /PATH/TO/.cache/cached_path/3f662135...
>>>

Cached model loaded (with just checkpoint name)

>>> phonemizer = Phonemizer.from_checkpoint('en_us_cmudict_ipa_forward.pt')
en_us_cmudict_ipa_forward.pt already exists in cache.
Loading model from /PATH/TO/.cache/cached_path/3f662135...
>>>

Cached model loaded (with checkpoint URL)

>>> phonemizer = Phonemizer.from_checkpoint('https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/DeepPhonemizer/en_us_cmudict_ipa_forward.pt')
en_us_cmudict_ipa_forward.pt already exists in cache.
Loading model from /PATH/TO/.cache/cached_path/3f662135...
>>>

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant