Skip to content

Commit

Permalink
Merge pull request #76 from deepghs/dev/lora
Browse files Browse the repository at this point in the history
dev(narugo): add model metadata reader and writer
  • Loading branch information
narugo1992 authored Mar 19, 2024
2 parents cc802fc + 153de31 commit 4cf5184
Show file tree
Hide file tree
Showing 9 changed files with 799 additions and 0 deletions.
1 change: 1 addition & 0 deletions .github/workflows/doc.yml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ jobs:
sudo apt-get install -y make wget curl cloc graphviz pandoc
dot -V
python -m pip install -r requirements.txt
python -m pip install -r requirements-model.txt
python -m pip install -r requirements-doc.txt
- name: Prepare dataset
uses: nick-fields/retry@v2
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ jobs:
python -m pip install --upgrade pip
pip install --upgrade flake8 setuptools wheel twine
pip install -r requirements.txt
pip install -r requirements-model.txt
pip install -r requirements-test.txt
- name: Test the basic environment
shell: bash
Expand Down
1 change: 1 addition & 0 deletions docs/source/api_doc/sd/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,5 @@ imgutils.sd
:maxdepth: 3

metadata
model

22 changes: 22 additions & 0 deletions docs/source/api_doc/sd/model.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
imgutils.sd.model
====================================

.. currentmodule:: imgutils.sd.model

.. automodule:: imgutils.sd.model


read_metadata
------------------------------------------

.. autofunction:: read_metadata



save_with_metadata
------------------------------------------

.. autofunction:: save_with_metadata



1 change: 1 addition & 0 deletions imgutils/sd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@
Utilities for dealing with data from `AUTOMATIC1111/stable-diffusion-webui <https://github.com/AUTOMATIC1111/stable-diffusion-webui>`_.
"""
from .metadata import parse_sdmeta_from_text, get_sdmeta_from_image, SDMetaData
from .model import read_metadata, save_with_metadata
77 changes: 77 additions & 0 deletions imgutils/sd/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""
Overview:
A utility for reading and writing metadata from/to model files in the A41 WebUI format.

.. note::
``torch`` and ``safetensors`` are required by this model.
Please install them with the following command before start using this module.

.. code:: shell

pip install dghs-imgutils[model]
"""

from typing import Dict

try:
import torch
except (ImportError, ModuleNotFoundError): # pragma: no cover
torch = None

try:
import safetensors.torch
except (ImportError, ModuleNotFoundError): # pragma: no cover
safetensors = None


def _check_env():
"""
Checks if the required dependencies (Safetensors and Torch) are installed.
Raises EnvironmentError if they are not installed.
"""
if not safetensors:
raise EnvironmentError(
'Safetensors not installed. Please use "pip install dghs-imgutils[model]".') # pragma: no cover
if not torch:
raise EnvironmentError(
'Torch not installed. Please use "pip install dghs-imgutils[model]".') # pragma: no cover


def read_metadata(model_file: str) -> Dict[str, str]:
"""
Reads metadata from a model file and returns it as a dictionary.

:param model_file: The path to the model file.
:type model_file: str
:return: The metadata extracted from the model file.
:rtype: Dict[str, str]
"""
_check_env()
with safetensors.safe_open(model_file, 'pt') as f:
return f.metadata()


def save_with_metadata(src_model_file: str, dst_model_file: str, metadata: Dict[str, str], clear: bool = False):
"""
Saves a model file with metadata. Optionally, existing metadata can be cleared before adding new metadata.

:param src_model_file: The path to the source model file.
:type src_model_file: str
:param dst_model_file: The path to save the new model file.
:type dst_model_file: str
:param metadata: The metadata to add to the model file.
:type metadata: Dict[str, str]
:param clear: Whether to clear existing metadata before adding new metadata. Default is False.
:type clear: bool
"""
_check_env()
with safetensors.safe_open(src_model_file, framework='pt') as f:
if clear:
new_metadata = {**(metadata or {})}
else:
new_metadata = {**f.metadata(), **(metadata or {})}
safetensors.torch.save_file(
tensors={key: f.get_tensor(key) for key in f.keys()},
filename=dst_model_file,
metadata=new_metadata,
)
2 changes: 2 additions & 0 deletions requirements-model.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch
safetensors
694 changes: 694 additions & 0 deletions test/sd/test_model.py

Large diffs are not rendered by default.

Binary file added test/testfile/surtr_arknights.safetensors
Binary file not shown.

0 comments on commit 4cf5184

Please sign in to comment.