Skip to content

Commit

Permalink
update_dictのリファクタリング (#630)
Browse files Browse the repository at this point in the history
* ファンクションが間違っていますエラーをなくす

* .

* NamedTempFileの消滅

* a

* random_string
  • Loading branch information
Hiroshiba authored Mar 12, 2023
1 parent a062880 commit eee7f4a
Showing 1 changed file with 53 additions and 45 deletions.
98 changes: 53 additions & 45 deletions voicevox_engine/user_dict.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import json
import shutil
import sys
import threading
import traceback
from pathlib import Path
from tempfile import NamedTemporaryFile
from typing import Dict, List, Optional
from uuid import UUID, uuid4

Expand All @@ -15,7 +13,7 @@

from .model import UserDictWord, WordTypes
from .part_of_speech_data import MAX_PRIORITY, MIN_PRIORITY, part_of_speech_data
from .utility import delete_file, engine_root, get_save_dir, mutex_wrapper
from .utility import engine_root, get_save_dir, mutex_wrapper

root_dir = engine_root()
save_dir = get_save_dir()
Expand Down Expand Up @@ -53,61 +51,71 @@ def update_dict(
user_dict_path: Path = user_dict_path,
compiled_dict_path: Path = compiled_dict_path,
):
with NamedTemporaryFile(encoding="utf-8", mode="w", delete=False) as f:
random_string = uuid4()
tmp_csv_path = save_dir / f".tmp.dict_csv-{random_string}"
tmp_compiled_path = save_dir / f".tmp.dict_compiled-{random_string}"

try:
# 辞書.csvを作成
csv_text = ""
if not default_dict_path.is_file():
print("Warning: Cannot find default dictionary.", file=sys.stderr)
return
default_dict = default_dict_path.read_text(encoding="utf-8")
if default_dict == default_dict.rstrip():
default_dict += "\n"
f.write(default_dict)
csv_text += default_dict
user_dict = read_dict(user_dict_path=user_dict_path)
for word_uuid in user_dict:
word = user_dict[word_uuid]
f.write(
(
"{surface},{context_id},{context_id},{cost},{part_of_speech},"
+ "{part_of_speech_detail_1},{part_of_speech_detail_2},"
+ "{part_of_speech_detail_3},{inflectional_type},"
+ "{inflectional_form},{stem},{yomi},{pronunciation},"
+ "{accent_type}/{mora_count},{accent_associative_rule}\n"
).format(
surface=word.surface,
context_id=word.context_id,
cost=priority2cost(word.context_id, word.priority),
part_of_speech=word.part_of_speech,
part_of_speech_detail_1=word.part_of_speech_detail_1,
part_of_speech_detail_2=word.part_of_speech_detail_2,
part_of_speech_detail_3=word.part_of_speech_detail_3,
inflectional_type=word.inflectional_type,
inflectional_form=word.inflectional_form,
stem=word.stem,
yomi=word.yomi,
pronunciation=word.pronunciation,
accent_type=word.accent_type,
mora_count=word.mora_count,
accent_associative_rule=word.accent_associative_rule,
)
csv_text += (
"{surface},{context_id},{context_id},{cost},{part_of_speech},"
+ "{part_of_speech_detail_1},{part_of_speech_detail_2},"
+ "{part_of_speech_detail_3},{inflectional_type},"
+ "{inflectional_form},{stem},{yomi},{pronunciation},"
+ "{accent_type}/{mora_count},{accent_associative_rule}\n"
).format(
surface=word.surface,
context_id=word.context_id,
cost=priority2cost(word.context_id, word.priority),
part_of_speech=word.part_of_speech,
part_of_speech_detail_1=word.part_of_speech_detail_1,
part_of_speech_detail_2=word.part_of_speech_detail_2,
part_of_speech_detail_3=word.part_of_speech_detail_3,
inflectional_type=word.inflectional_type,
inflectional_form=word.inflectional_form,
stem=word.stem,
yomi=word.yomi,
pronunciation=word.pronunciation,
accent_type=word.accent_type,
mora_count=word.mora_count,
accent_associative_rule=word.accent_associative_rule,
)
tmp_dict_path = Path(NamedTemporaryFile(delete=False).name).resolve()
pyopenjtalk.create_user_dict(
str(Path(f.name).resolve(strict=True)),
str(tmp_dict_path),
)
delete_file(f.name)
if not tmp_dict_path.is_file():
raise RuntimeError("辞書のコンパイル時にエラーが発生しました。")
pyopenjtalk.unset_user_dict()
try:
shutil.move(tmp_dict_path, compiled_dict_path) # ドライブを跨ぐためPath.replaceが使えない
except OSError:
traceback.print_exc()
if tmp_dict_path.exists():
delete_file(tmp_dict_path.name)
finally:
tmp_csv_path.write_text(csv_text, encoding="utf-8")

# 辞書.csvをOpenJTalk用にコンパイル
pyopenjtalk.create_user_dict(str(tmp_csv_path), str(tmp_compiled_path))
if not tmp_compiled_path.is_file():
raise RuntimeError("辞書のコンパイル時にエラーが発生しました。")

# コンパイル済み辞書の置き換え・読み込み
pyopenjtalk.unset_user_dict()
tmp_compiled_path.replace(compiled_dict_path)
if compiled_dict_path.is_file():
pyopenjtalk.set_user_dict(str(compiled_dict_path.resolve(strict=True)))

except Exception as e:
print("Error: Failed to update dictionary.", file=sys.stderr)
traceback.print_exc(file=sys.stderr)
raise e

finally:
# 後処理
if tmp_csv_path.exists():
tmp_csv_path.unlink()
if tmp_compiled_path.exists():
tmp_compiled_path.unlink()


@mutex_wrapper(mutex_user_dict)
def read_dict(user_dict_path: Path = user_dict_path) -> Dict[str, UserDictWord]:
Expand Down

0 comments on commit eee7f4a

Please sign in to comment.