Skip to content

Commit

Permalink
dev(narugo): add wd14 syncer, ci skip
Browse files Browse the repository at this point in the history
  • Loading branch information
narugo1992 committed May 12, 2024
1 parent 090c281 commit 859e829
Show file tree
Hide file tree
Showing 3 changed files with 196 additions and 0 deletions.
44 changes: 44 additions & 0 deletions .github/workflows/wd14.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
name: Sync WD14 Models

on:
# push:
workflow_dispatch:
schedule:
- cron: '30 18 * * *'

jobs:
sync:
name: Sync Waifu2x ONNX
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os:
- 'ubuntu-latest'
python-version:
- '3.8'

steps:
- name: Checkout code
uses: actions/checkout@v3
with:
fetch-depth: 20
- name: Set up python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Set up python dependences
run: |
pip install --upgrade pip
pip install --upgrade flake8 setuptools wheel twine
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
if [ -f requirements-build.txt ]; then pip install -r requirements-build.txt; fi
if [ -f requirements-test.txt ]; then pip install -r requirements-test.txt; fi
if [ -f requirements-test.txt ]; then pip install -r requirements-zoo.txt; fi
pip install --upgrade build
- name: Sync Models
env:
HF_TOKEN: ${{ secrets.HF_TOKEN }}
GH_ACCESS_TOKEN: ${{ secrets.GH_ACCESS_TOKEN }}
run: |
python -m zoo.wd14.sync
Empty file added zoo/wd14/__init__.py
Empty file.
152 changes: 152 additions & 0 deletions zoo/wd14/sync.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
import os.path
import re
from functools import lru_cache

import numpy as np
import onnx
import onnxruntime
import pandas as pd
from ditk import logging
from hbutils.string import plural_word
from hbutils.system import TemporaryDirectory
from hfutils.operate import upload_directory_as_directory
from huggingface_hub import hf_hub_download
from onnx.helper import make_tensor_value_info
from tqdm import tqdm

from imgutils.tagging.wd14 import MODEL_NAMES

logging.try_init_root(logging.INFO)


@lru_cache()
def _get_model_file(name) -> str:
return hf_hub_download(
repo_id=MODEL_NAMES[name],
filename='model.onnx'
)


@lru_cache()
def _get_model_tags_length(name) -> int:
return len(pd.read_csv(hf_hub_download(
repo_id=MODEL_NAMES[name],
filename='selected_tags.csv',
)))


def _seg_split(text):
return tuple(filter(bool, re.split(r'[./]+', text)))


_FC_KEYWORDS_FOR_V2 = {'predictions_dense'}
_FC_NODE_PREFIXES_FOR_V3 = {
"SwinV2_v3": ('core_model', 'head', 'fc'),
"ConvNext_v3": ('core_model', 'head', 'fc'),
"ViT_v3": ('core_model', 'head'),
}

if __name__ == '__main__':
with TemporaryDirectory() as td:
records = []
for model_name in tqdm(MODEL_NAMES):
model_file = _get_model_file(model_name)
logging.info(f'Model name: {model_name!r}, model file: {model_file!r}')
logging.info(f'Loading model {model_name!r} ...')
model = onnx.load(model_file)
embs_outputs = []
if model_name in _FC_NODE_PREFIXES_FOR_V3:
prefix = _FC_NODE_PREFIXES_FOR_V3[model_name]


def _is_fc(name):
return _seg_split(name)[:len(prefix)] == prefix
else:
def _is_fc(name):
return any(seg in _FC_KEYWORDS_FOR_V2 for seg in _seg_split(name))

for node in model.graph.node:
if _is_fc(node.name):
for input_name in node.input:
if not _is_fc(input_name):
logging.info(f'Input {input_name!r} for fc layer {node.name!r}.')
embs_outputs.append(input_name)

logging.info(f'Embedding outputs: {embs_outputs!r}.')
assert len(embs_outputs) == 1, f'Outputs: {embs_outputs!r}'
# make_tensor_value_info(name=embs_outputs[0], elem_type=onnx.TensorProto.FLOAT, )
model.graph.output.extend([onnx.ValueInfoProto(name=embs_outputs[0])])

logging.info('Analysing via onnxruntime ...')
session = onnxruntime.InferenceSession(model.SerializeToString())
input_data = np.random.randn(1, 448, 448, 3).astype(np.float32)
assert len(session.get_inputs()) == 1
assert len(session.get_outputs()) == 2
assert session.get_outputs()[1].name == embs_outputs[0]

tags_data, embeddings = session.run([], {session.get_inputs()[0].name: input_data})
logging.info(f'Tag output, shape: {tags_data.shape!r}, dtype: {tags_data.dtype!r}')
logging.info(f'Embeddings output, shape: {embeddings.shape!r}, dtype: {embeddings.dtype!r}')
assert tags_data.shape == (1, _get_model_tags_length(model_name))
assert len(embeddings.shape) == 2 and embeddings.shape[0] == 1
emb_width = embeddings.shape[-1]

logging.info('Remaking model ...')
model = onnx.load(model_file)
model.graph.output.extend([make_tensor_value_info(
name=embs_outputs[0],
elem_type=onnx.TensorProto.FLOAT,
shape=embeddings.shape,
)])

onnx_file = os.path.join(td, MODEL_NAMES[model_name], 'model.onnx')
os.makedirs(os.path.dirname(onnx_file), exist_ok=True)
onnx.save_model(model, onnx_file)

logging.info(f'Loading and testing for the exported model {onnx_file!r}.')
session = onnxruntime.InferenceSession(onnx_file)
assert len(session.get_inputs()) == 1
assert len(session.get_outputs()) == 2
assert session.get_outputs()[1].name == embs_outputs[0]
assert session.get_outputs()[1].shape == [1, emb_width]

tags_data, embeddings = session.run([], {session.get_inputs()[0].name: input_data})
logging.info(f'Tag output, shape: {tags_data.shape!r}, dtype: {tags_data.dtype!r}')
logging.info(f'Embeddings output, shape: {embeddings.shape!r}, dtype: {embeddings.dtype!r}')
assert tags_data.shape == (1, _get_model_tags_length(model_name))
assert embeddings.shape == (1, emb_width)

records.append({
'Name': model_name,
'Source Repository': f'[{MODEL_NAMES[model_name]}](https://huggingface.co/{MODEL_NAMES[model_name]})',
'Tags Count': _get_model_tags_length(model_name),
'Embedding Width': emb_width,
})
_get_model_file.cache_clear()
_get_model_tags_length.cache_clear()

df_records = pd.DataFrame(records)
with open(os.path.join(td, 'README.md'), 'w') as f:
print('---', file=f)
print('license: apache-2.0', file=f)
print('language:', file=f)
print('- en', file=f)
print('---', file=f)
print('', file=f)

print(
f'This is onnx models based on [SmilingWolf](https://huggingface.co/SmilingWolf)\'s wd14 anime taggers, '
f'which added the embeddings output as the second output.', file=f)
print(f'', file=f)
print(f'{plural_word(len(df_records), "model")} in total: ', file=f)
print(f'', file=f)
print(df_records.to_markdown(index=False), file=f)

upload_directory_as_directory(
repo_id='deepghs/wd14_tagger_with_embeddings',
repo_type='model',
local_directory=td,
path_in_repo='.',
message=f'Upload {plural_word(len(df_records), "models")}',
clear=True,
)

0 comments on commit 859e829

Please sign in to comment.