Skip to content

Commit

Permalink
dev(narugo): add model metadata reader and writer
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed Mar 19, 2024
1 parent cc802fc commit 7540259
Show file tree
Hide file tree
Showing 7 changed files with 739 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
sudo apt-get install -y make wget curl cloc graphviz pandoc
dot -V
python -m pip install -r requirements.txt
python -m pip install -r requirements-model.txt
python -m pip install -r requirements-doc.txt
- name: Prepare dataset
uses: nick-fields/retry@v2
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ jobs:
python -m pip install --upgrade pip
pip install --upgrade flake8 setuptools wheel twine
pip install -r requirements.txt
pip install -r requirements-model.txt
pip install -r requirements-test.txt
- name: Test the basic environment
shell: bash
Expand Down
1 change: 1 addition & 0 deletions imgutils/sd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
Utilities for dealing with data from `AUTOMATIC1111/stable-diffusion-webui <https://github.com/AUTOMATIC1111/stable-diffusion-webui>`_.
"""
from .metadata import parse_sdmeta_from_text, get_sdmeta_from_image, SDMetaData
from .model import read_metadata, save_with_metadata
40 changes: 40 additions & 0 deletions imgutils/sd/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
from typing import Dict

try:
import torch
except (ImportError, ModuleNotFoundError): # pragma: no cover
torch = None

try:
import safetensors.torch
except (ImportError, ModuleNotFoundError): # pragma: no cover
safetensors = None


def _check_env():
if not safetensors:
raise EnvironmentError(
'Safetensors not installed. Please use "pip install dghs-imgutils[model]".') # pragma: no cover
if not torch:
raise EnvironmentError(
'Torch not installed. Please use "pip install dghs-imgutils[model]".') # pragma: no cover


def read_metadata(model_file: str) -> Dict[str, str]:
_check_env()
with safetensors.safe_open(model_file, 'pt') as f:
return f.metadata()


def save_with_metadata(src_model_file: str, dst_model_file: str, metadata: Dict[str, str], clear: bool = False):
_check_env()
with safetensors.safe_open(src_model_file, framework='pt') as f:
if clear:
new_metadata = {**(metadata or {})}
else:
new_metadata = {**f.metadata(), **(metadata or {})}
safetensors.torch.save_file(
tensors={key: f.get_tensor(key) for key in f.keys()},
filename=dst_model_file,
metadata=new_metadata,
)
2 changes: 2 additions & 0 deletions requirements-model.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch
safetensors
694 changes: 694 additions & 0 deletions test/sd/test_model.py

Large diffs are not rendered by default.

Binary file added test/testfile/surtr_arknights.safetensors
Binary file not shown.

0 comments on commit 7540259

Please sign in to comment.