Skip to content

Commit

Permalink
Document huggingface_hub.get_safetensors_metadata (#417)
Browse files Browse the repository at this point in the history
* Document huggingface_hub.get_safetensors_metadata

* Update docs/source/metadata_parsing.mdx

Co-authored-by: Mishig <[email protected]>

* Update docs/source/metadata_parsing.mdx

Co-authored-by: Mishig <[email protected]>

* add import line

* Update docs/source/metadata_parsing.mdx

Co-authored-by: Mishig <[email protected]>

---------

Co-authored-by: Mishig <[email protected]>
  • Loading branch information
Wauplin and mishig25 authored Jan 4, 2024
1 parent a2afb0a commit 3c68e59
Showing 1 changed file with 31 additions and 30 deletions.
61 changes: 31 additions & 30 deletions docs/source/metadata_parsing.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -92,40 +92,41 @@ export type SafetensorsShardedHeaders = Record<FileName, SafetensorsFileHeader>;

### Python

In this example python script, we are parsing metadata of [gpt2](https://huggingface.co/gpt2/blob/main/model.safetensors).
[`huggingface_hub`](https://huggingface.co/docs/huggingface_hub) provides a Python API to parse safetensors metadata.
Use [`get_safetensors_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.get_safetensors_metadata) to get all safetensors metadata of a model.
Depending on if the model is sharded or not, one or multiple safetensors files will be parsed.

```python
import requests # pip install requests
import struct

def parse_single_file(url):
# Fetch the first 8 bytes of the file
headers = {'Range': 'bytes=0-7'}
response = requests.get(url, headers=headers)
# Interpret the bytes as a little-endian unsigned 64-bit integer
length_of_header = struct.unpack('<Q', response.content)[0]
# Fetch length_of_header bytes starting from the 9th byte
headers = {'Range': f'bytes=8-{7 + length_of_header}'}
response = requests.get(url, headers=headers)
# Interpret the response as a JSON object
header = response.json()
return header

url = "https://huggingface.co/gpt2/resolve/main/model.safetensors"
header = parse_single_file(url)

print(header)
# {
# "__metadata__": { "format": "pt" },
# "h.10.ln_1.weight": {
# "dtype": "F32",
# "shape": [768],
# "data_offsets": [223154176, 223157248]
# },
# ...
# }
>>> from huggingface_hub import get_safetensors_metadata

# Parse repo with single weights file
>>> metadata = get_safetensors_metadata("bigscience/bloomz-560m")
>>> metadata
SafetensorsRepoMetadata(
metadata=None,
sharded=False,
weight_map={'h.0.input_layernorm.bias': 'model.safetensors', ...},
files_metadata={'model.safetensors': SafetensorsFileMetadata(...)}
)
>>> metadata.files_metadata["model.safetensors"].metadata
{'format': 'pt'}

# Parse repo with sharded model (i.e. multiple weights files)
>>> metadata = get_safetensors_metadata("bigscience/bloom")
Parse safetensors files: 100%|██████████████████████████████████████████| 72/72 [00:12<00:00, 5.78it/s]
>>> metadata
SafetensorsRepoMetadata(metadata={'total_size': 352494542848}, sharded=True, weight_map={...}, files_metadata={...})
>>> len(metadata.files_metadata)
72 # All safetensors files have been fetched

# Parse repo that is not a safetensors repo
>>> get_safetensors_metadata("runwayml/stable-diffusion-v1-5")
NotASafetensorsRepoError: 'runwayml/stable-diffusion-v1-5' is not a safetensors repo. Couldn't find 'model.safetensors.index.json' or 'model.safetensors' files.
```

To parse the metadata of a single safetensors file, use [`parse_safetensors_file_metadata`](https://huggingface.co/docs/huggingface_hub/package_reference/hf_api#huggingface_hub.HfApi.parse_safetensors_file_metadata).


## Example output

For instance, here are the number of params per dtype for a few models on the HuggingFace Hub. Also see [this issue](https://github.com/huggingface/safetensors/issues/44) for more examples of usage.
Expand Down

0 comments on commit 3c68e59

Please sign in to comment.