diff --git a/.github/workflows/wd14.yml b/.github/workflows/wd14.yml new file mode 100644 index 00000000000..eaa58ddbfef --- /dev/null +++ b/.github/workflows/wd14.yml @@ -0,0 +1,44 @@ +name: Sync WD14 Models + +on: + # push: + workflow_dispatch: + schedule: + - cron: '30 18 * * *' + +jobs: + sync: + name: Sync Waifu2x ONNX + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: + - 'ubuntu-latest' + python-version: + - '3.8' + + steps: + - name: Checkout code + uses: actions/checkout@v3 + with: + fetch-depth: 20 + - name: Set up python ${{ matrix.python-version }} + uses: actions/setup-python@v4 + with: + python-version: ${{ matrix.python-version }} + - name: Set up python dependences + run: | + pip install --upgrade pip + pip install --upgrade flake8 setuptools wheel twine + if [ -f requirements.txt ]; then pip install -r requirements.txt; fi + if [ -f requirements-build.txt ]; then pip install -r requirements-build.txt; fi + if [ -f requirements-test.txt ]; then pip install -r requirements-test.txt; fi + if [ -f requirements-test.txt ]; then pip install -r requirements-zoo.txt; fi + pip install --upgrade build + - name: Sync Models + env: + HF_TOKEN: ${{ secrets.HF_TOKEN }} + GH_ACCESS_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }} + run: | + python -m zoo.wd14.sync diff --git a/zoo/wd14/__init__.py b/zoo/wd14/__init__.py new file mode 100644 index 00000000000..e69de29bb2d diff --git a/zoo/wd14/sync.py b/zoo/wd14/sync.py new file mode 100644 index 00000000000..8c1785c87e5 --- /dev/null +++ b/zoo/wd14/sync.py @@ -0,0 +1,152 @@ +import os.path +import re +from functools import lru_cache + +import numpy as np +import onnx +import onnxruntime +import pandas as pd +from ditk import logging +from hbutils.string import plural_word +from hbutils.system import TemporaryDirectory +from hfutils.operate import upload_directory_as_directory +from huggingface_hub import hf_hub_download +from onnx.helper import make_tensor_value_info +from tqdm import tqdm + +from imgutils.tagging.wd14 import MODEL_NAMES + +logging.try_init_root(logging.INFO) + + +@lru_cache() +def _get_model_file(name) -> str: + return hf_hub_download( + repo_id=MODEL_NAMES[name], + filename='model.onnx' + ) + + +@lru_cache() +def _get_model_tags_length(name) -> int: + return len(pd.read_csv(hf_hub_download( + repo_id=MODEL_NAMES[name], + filename='selected_tags.csv', + ))) + + +def _seg_split(text): + return tuple(filter(bool, re.split(r'[./]+', text))) + + +_FC_KEYWORDS_FOR_V2 = {'predictions_dense'} +_FC_NODE_PREFIXES_FOR_V3 = { + "SwinV2_v3": ('core_model', 'head', 'fc'), + "ConvNext_v3": ('core_model', 'head', 'fc'), + "ViT_v3": ('core_model', 'head'), +} + +if __name__ == '__main__': + with TemporaryDirectory() as td: + records = [] + for model_name in tqdm(MODEL_NAMES): + model_file = _get_model_file(model_name) + logging.info(f'Model name: {model_name!r}, model file: {model_file!r}') + logging.info(f'Loading model {model_name!r} ...') + model = onnx.load(model_file) + embs_outputs = [] + if model_name in _FC_NODE_PREFIXES_FOR_V3: + prefix = _FC_NODE_PREFIXES_FOR_V3[model_name] + + + def _is_fc(name): + return _seg_split(name)[:len(prefix)] == prefix + else: + def _is_fc(name): + return any(seg in _FC_KEYWORDS_FOR_V2 for seg in _seg_split(name)) + + for node in model.graph.node: + if _is_fc(node.name): + for input_name in node.input: + if not _is_fc(input_name): + logging.info(f'Input {input_name!r} for fc layer {node.name!r}.') + embs_outputs.append(input_name) + + logging.info(f'Embedding outputs: {embs_outputs!r}.') + assert len(embs_outputs) == 1, f'Outputs: {embs_outputs!r}' + # make_tensor_value_info(name=embs_outputs[0], elem_type=onnx.TensorProto.FLOAT, ) + model.graph.output.extend([onnx.ValueInfoProto(name=embs_outputs[0])]) + + logging.info('Analysing via onnxruntime ...') + session = onnxruntime.InferenceSession(model.SerializeToString()) + input_data = np.random.randn(1, 448, 448, 3).astype(np.float32) + assert len(session.get_inputs()) == 1 + assert len(session.get_outputs()) == 2 + assert session.get_outputs()[1].name == embs_outputs[0] + + tags_data, embeddings = session.run([], {session.get_inputs()[0].name: input_data}) + logging.info(f'Tag output, shape: {tags_data.shape!r}, dtype: {tags_data.dtype!r}') + logging.info(f'Embeddings output, shape: {embeddings.shape!r}, dtype: {embeddings.dtype!r}') + assert tags_data.shape == (1, _get_model_tags_length(model_name)) + assert len(embeddings.shape) == 2 and embeddings.shape[0] == 1 + emb_width = embeddings.shape[-1] + + logging.info('Remaking model ...') + model = onnx.load(model_file) + model.graph.output.extend([make_tensor_value_info( + name=embs_outputs[0], + elem_type=onnx.TensorProto.FLOAT, + shape=embeddings.shape, + )]) + + onnx_file = os.path.join(td, MODEL_NAMES[model_name], 'model.onnx') + os.makedirs(os.path.dirname(onnx_file), exist_ok=True) + onnx.save_model(model, onnx_file) + + logging.info(f'Loading and testing for the exported model {onnx_file!r}.') + session = onnxruntime.InferenceSession(onnx_file) + assert len(session.get_inputs()) == 1 + assert len(session.get_outputs()) == 2 + assert session.get_outputs()[1].name == embs_outputs[0] + assert session.get_outputs()[1].shape == [1, emb_width] + + tags_data, embeddings = session.run([], {session.get_inputs()[0].name: input_data}) + logging.info(f'Tag output, shape: {tags_data.shape!r}, dtype: {tags_data.dtype!r}') + logging.info(f'Embeddings output, shape: {embeddings.shape!r}, dtype: {embeddings.dtype!r}') + assert tags_data.shape == (1, _get_model_tags_length(model_name)) + assert embeddings.shape == (1, emb_width) + + records.append({ + 'Name': model_name, + 'Source Repository': f'[{MODEL_NAMES[model_name]}](https://huggingface.co/{MODEL_NAMES[model_name]})', + 'Tags Count': _get_model_tags_length(model_name), + 'Embedding Width': emb_width, + }) + _get_model_file.cache_clear() + _get_model_tags_length.cache_clear() + + df_records = pd.DataFrame(records) + with open(os.path.join(td, 'README.md'), 'w') as f: + print('---', file=f) + print('license: apache-2.0', file=f) + print('language:', file=f) + print('- en', file=f) + print('---', file=f) + print('', file=f) + + print( + f'This is onnx models based on [SmilingWolf](https://huggingface.co/SmilingWolf)\'s wd14 anime taggers, ' + f'which added the embeddings output as the second output.', file=f) + print(f'', file=f) + print(f'{plural_word(len(df_records), "model")} in total: ', file=f) + print(f'', file=f) + print(df_records.to_markdown(index=False), file=f) + + upload_directory_as_directory( + repo_id='deepghs/wd14_tagger_with_embeddings', + repo_type='model', + local_directory=td, + path_in_repo='.', + message=f'Upload {plural_word(len(df_records), "models")}', + clear=True, + )