Skip to content

Commit

Permalink
過去バージョンを読み込むユーザーディレクトリを作る (#458)
Browse files Browse the repository at this point in the history
* 過去バージョンを読み込むユーザーディレクトリを作る

* fix flake8

* importまとめ
  • Loading branch information
Hiroshiba authored Aug 29, 2022
1 parent 8057ee4 commit 9ae873c
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 33 deletions.
77 changes: 49 additions & 28 deletions voicevox_engine/synthesis_engine/make_synthesis_engines.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import json
import sys
import traceback
from pathlib import Path
from typing import Dict, List, Optional

from ..utility import engine_root
from ..utility import engine_root, get_save_dir
from .core_wrapper import CoreWrapper, load_runtime_lib
from .synthesis_engine import SynthesisEngine, SynthesisEngineBase

Expand Down Expand Up @@ -68,34 +67,56 @@ def make_synthesis_engines(
runtime_dirs = [p.expanduser() for p in runtime_dirs]

load_runtime_lib(runtime_dirs)

synthesis_engines = {}
for core_dir in voicelib_dirs:
try:
core = CoreWrapper(use_gpu, core_dir, cpu_num_threads, load_all_models)
metas = json.loads(core.metas())
core_version = metas[0]["version"]
if core_version in synthesis_engines:
print(
"Warning: Core loading is skipped because of version duplication.",
file=sys.stderr,
)

if not enable_mock:

def load_core_library(core_dir: Path, suppress_error: bool = False):
"""
指定されたディレクトリにあるコアを読み込む。
ユーザーディレクトリの場合は存在しないこともあるので、エラーを抑制すると良い。
"""
try:
core = CoreWrapper(use_gpu, core_dir, cpu_num_threads, load_all_models)
metas = json.loads(core.metas())
core_version = metas[0]["version"]
if core_version in synthesis_engines:
print(
"Warning: Core loading is skipped because of version duplication.",
file=sys.stderr,
)
else:
synthesis_engines[core_version] = SynthesisEngine(core=core)
except Exception:
if not suppress_error:
raise

for core_dir in voicelib_dirs:
load_core_library(core_dir)

# ユーザーディレクトリにあるコアを読み込む
user_voicelib_dirs = []
core_libraries_dir = get_save_dir() / "core_libraries"
core_libraries_dir.mkdir(exist_ok=True)
user_voicelib_dirs.append(core_libraries_dir)
for path in core_libraries_dir.glob("*"):
if not path.is_dir():
continue
synthesis_engines[core_version] = SynthesisEngine(core=core)
except Exception:
if not enable_mock:
raise
traceback.print_exc()
print(
"Notice: mock-library will be used. Try re-run with valid --voicevox_dir",
file=sys.stderr,
user_voicelib_dirs.append(path)

for core_dir in user_voicelib_dirs:
load_core_library(core_dir, suppress_error=True)

else:
# モック追加
from ..dev.core import metas as mock_metas
from ..dev.core import supported_devices as mock_supported_devices
from ..dev.synthesis_engine import MockSynthesisEngine

if "0.0.0" not in synthesis_engines:
synthesis_engines["0.0.0"] = MockSynthesisEngine(
speakers=mock_metas(), supported_devices=mock_supported_devices()
)
from ..dev.core import metas as mock_metas
from ..dev.core import supported_devices as mock_supported_devices
from ..dev.synthesis_engine import MockSynthesisEngine

if "0.0.0" not in synthesis_engines:
synthesis_engines["0.0.0"] = MockSynthesisEngine(
speakers=mock_metas(), supported_devices=mock_supported_devices()
)

return synthesis_engines
6 changes: 2 additions & 4 deletions voicevox_engine/user_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,15 @@

import numpy as np
import pyopenjtalk
from appdirs import user_data_dir
from fastapi import HTTPException
from pydantic import conint

from .model import UserDictWord, WordTypes
from .part_of_speech_data import MAX_PRIORITY, MIN_PRIORITY, part_of_speech_data
from .utility import engine_root
from .utility import engine_root, get_save_dir

root_dir = engine_root()
# FIXME: ファイル保存場所をエンジン固有のIDが入ったものにする
save_dir = Path(user_data_dir("voicevox-engine"))
save_dir = get_save_dir()

if not save_dir.is_dir():
save_dir.mkdir(parents=True)
Expand Down
3 changes: 2 additions & 1 deletion voicevox_engine/utility/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
connect_base64_waves,
decode_base64_waves,
)
from .engine_root import engine_root
from .path_utility import engine_root, get_save_dir

__all__ = [
"ConnectBase64WavesException",
"connect_base64_waves",
"decode_base64_waves",
"engine_root",
"get_save_dir",
]
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sys
from pathlib import Path

from appdirs import user_data_dir


def engine_root() -> Path:
# nuitkaビルドをした際はグローバルに__compiled__が含まれる
Expand All @@ -15,3 +17,10 @@ def engine_root() -> Path:
root_dir = Path(__file__).parents[2]

return root_dir.resolve(strict=True)


def get_save_dir():
# FIXME: ファイル保存場所をエンジン固有のIDが入ったものにする
# FIXME: Windowsは`voicevox-engine/voicevox-engine`ディレクトリに保存されているので
# `VOICEVOX/voicevox-engine`に変更する
return Path(user_data_dir("voicevox-engine"))

0 comments on commit 9ae873c

Please sign in to comment.