Skip to content

Commit

Permalink
ENH: pydanticをv2へ (VOICEVOX#1184)
Browse files Browse the repository at this point in the history
* ENH: pydanticをv2へ

* FIX: json indent=4

* FIX: 変数名

* FIX: コメント追加

* FIX: スキーマー修正

* Update voicevox_engine/app/application.py

---------

Co-authored-by: Hiroshiba <[email protected]>
  • Loading branch information
sabonerune and Hiroshiba authored Jun 10, 2024
1 parent c88b3ad commit 2513168
Show file tree
Hide file tree
Showing 30 changed files with 285 additions and 165 deletions.
155 changes: 112 additions & 43 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ pyopenjtalk = { git = "https://github.com/VOICEVOX/pyopenjtalk", rev = "b35fc89f
semver = "^3.0.0"
platformdirs = "^4.2.0"
soxr = "^0.3.6"
pydantic = "^1.10.15"
pydantic = "^2.7.3"
starlette = "^0.37.0"

[tool.poetry.group.dev.dependencies]
Expand Down
4 changes: 3 additions & 1 deletion requirements-build.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
altgraph==0.17.4 ; python_version >= "3.11" and python_version < "3.12"
annotated-types==0.6.0 ; python_version >= "3.11" and python_version < "3.12"
anyio==4.3.0 ; python_version >= "3.11" and python_version < "3.12"
cffi==1.16.0 ; python_version >= "3.11" and python_version < "3.12"
click==8.1.7 ; python_version >= "3.11" and python_version < "3.12"
Expand All @@ -15,7 +16,8 @@ packaging==24.0 ; python_version >= "3.11" and python_version < "3.12"
pefile==2023.2.7 ; python_version >= "3.11" and python_version < "3.12" and sys_platform == "win32"
platformdirs==4.2.2 ; python_version >= "3.11" and python_version < "3.12"
pycparser==2.22 ; python_version >= "3.11" and python_version < "3.12"
pydantic==1.10.15 ; python_version >= "3.11" and python_version < "3.12"
pydantic-core==2.18.4 ; python_version >= "3.11" and python_version < "3.12"
pydantic==2.7.3 ; python_version >= "3.11" and python_version < "3.12"
pyinstaller-hooks-contrib==2024.6 ; python_version >= "3.11" and python_version < "3.12"
pyinstaller==5.13.2 ; python_version >= "3.11" and python_version < "3.12"
pyopenjtalk @ git+https://github.com/VOICEVOX/pyopenjtalk@b35fc89fe42948a28e33aed886ea145a51113f88 ; python_version >= "3.11" and python_version < "3.12"
Expand Down
4 changes: 3 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
annotated-types==0.6.0 ; python_version >= "3.11" and python_version < "3.12"
anyio==4.3.0 ; python_version >= "3.11" and python_version < "3.12"
attrs==23.2.0 ; python_version >= "3.11" and python_version < "3.12"
authlib==1.3.0 ; python_version >= "3.11" and python_version < "3.12"
Expand Down Expand Up @@ -68,7 +69,8 @@ prettytable==3.10.0 ; python_version >= "3.11" and python_version < "3.12"
ptyprocess==0.7.0 ; python_version >= "3.11" and python_version < "3.12"
pycodestyle==2.11.1 ; python_version >= "3.11" and python_version < "3.12"
pycparser==2.22 ; python_version >= "3.11" and python_version < "3.12"
pydantic==1.10.15 ; python_version >= "3.11" and python_version < "3.12"
pydantic-core==2.18.4 ; python_version >= "3.11" and python_version < "3.12"
pydantic==2.7.3 ; python_version >= "3.11" and python_version < "3.12"
pyflakes==3.2.0 ; python_version >= "3.11" and python_version < "3.12"
pygments==2.18.0 ; python_version >= "3.11" and python_version < "3.12"
pyopenjtalk @ git+https://github.com/VOICEVOX/pyopenjtalk@b35fc89fe42948a28e33aed886ea145a51113f88 ; python_version >= "3.11" and python_version < "3.12"
Expand Down
4 changes: 3 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
annotated-types==0.6.0 ; python_version >= "3.11" and python_version < "3.12"
anyio==4.3.0 ; python_version >= "3.11" and python_version < "3.12"
cffi==1.16.0 ; python_version >= "3.11" and python_version < "3.12"
click==8.1.7 ; python_version >= "3.11" and python_version < "3.12"
Expand All @@ -11,7 +12,8 @@ markupsafe==2.1.5 ; python_version >= "3.11" and python_version < "3.12"
numpy==1.26.4 ; python_version >= "3.11" and python_version < "3.12"
platformdirs==4.2.2 ; python_version >= "3.11" and python_version < "3.12"
pycparser==2.22 ; python_version >= "3.11" and python_version < "3.12"
pydantic==1.10.15 ; python_version >= "3.11" and python_version < "3.12"
pydantic-core==2.18.4 ; python_version >= "3.11" and python_version < "3.12"
pydantic==2.7.3 ; python_version >= "3.11" and python_version < "3.12"
pyopenjtalk @ git+https://github.com/VOICEVOX/pyopenjtalk@b35fc89fe42948a28e33aed886ea145a51113f88 ; python_version >= "3.11" and python_version < "3.12"
python-multipart==0.0.9 ; python_version >= "3.11" and python_version < "3.12"
pyworld==0.3.4 ; python_version >= "3.11" and python_version < "3.12"
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 5 additions & 3 deletions test/e2e/test_speakers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,13 @@
from test.utility import hash_long_string

from fastapi.testclient import TestClient
from pydantic import parse_obj_as
from pydantic import TypeAdapter
from syrupy.assertion import SnapshotAssertion

from voicevox_engine.metas.Metas import Speaker

_speaker_list_adapter = TypeAdapter(list[Speaker])


def test_話者一覧が取得できる(
client: TestClient, snapshot_json: SnapshotAssertion
Expand All @@ -23,7 +25,7 @@ def test_話者一覧が取得できる(
def test_話者の情報を取得できる(
client: TestClient, snapshot_json: SnapshotAssertion
) -> None:
speakers = parse_obj_as(list[Speaker], client.get("/speakers").json())
speakers = _speaker_list_adapter.validate_python(client.get("/speakers").json())
for speaker in speakers:
response = client.get(
"/speaker_info", params={"speaker_uuid": speaker.speaker_uuid}
Expand All @@ -44,7 +46,7 @@ def test_歌手一覧が取得できる(
def test_歌手の情報を取得できる(
client: TestClient, snapshot_json: SnapshotAssertion
) -> None:
singers = parse_obj_as(list[Speaker], client.get("/singers").json())
singers = _speaker_list_adapter.validate_python(client.get("/singers").json())
for singer in singers:
response = client.get(
"/singer_info", params={"speaker_uuid": singer.speaker_uuid}
Expand Down
10 changes: 5 additions & 5 deletions test/unit/setting/test_setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def test_setting_handler_load_not_exist_file() -> None:
# Expects
true_setting = {"allow_origin": None, "cors_policy_mode": CorsPolicyMode.localapps}
# Outputs
setting = settings.dict()
setting = settings.model_dump()
# Test
assert true_setting == setting

Expand All @@ -29,7 +29,7 @@ def test_setting_handler_load_exist_file_1() -> None:
# Expects
true_setting = {"allow_origin": None, "cors_policy_mode": CorsPolicyMode.localapps}
# Outputs
setting = settings.dict()
setting = settings.model_dump()
# Test
assert true_setting == setting

Expand All @@ -43,7 +43,7 @@ def test_setting_handler_load_exist_file_2() -> None:
# Expects
true_setting = {"allow_origin": None, "cors_policy_mode": "all"}
# Outputs
setting = settings.dict()
setting = settings.model_dump()
# Test
assert true_setting == setting

Expand All @@ -60,7 +60,7 @@ def test_setting_handler_load_exist_file_3() -> None:
"cors_policy_mode": CorsPolicyMode.localapps,
}
# Outputs
setting = settings.dict()
setting = settings.model_dump()
# Test
assert true_setting == setting

Expand All @@ -76,7 +76,7 @@ def test_setting_handler_save(tmp_path: Path) -> None:
# Outputs
setting_loader.save(new_setting)
# NOTE: `.load()` の正常動作を前提とする
setting = setting_loader.load().dict()
setting = setting_loader.load().model_dump()
# Test
assert true_setting == setting

Expand Down
5 changes: 2 additions & 3 deletions test/utility.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import hashlib
import io
import json
from typing import Any

import numpy as np
import soundfile as sf
from pydantic.json import pydantic_encoder
from fastapi.encoders import jsonable_encoder


def round_floats(value: Any, round_value: int) -> Any:
Expand All @@ -22,7 +21,7 @@ def round_floats(value: Any, round_value: int) -> Any:

def pydantic_to_native_type(value: Any) -> Any:
"""pydanticの型をnativeな型に変換する"""
return json.loads(json.dumps(value, default=pydantic_encoder))
return jsonable_encoder(value)


def hash_long_string(value: Any) -> Any:
Expand Down
1 change: 1 addition & 0 deletions voicevox_engine/app/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def generate_app(
title=engine_manifest.name,
description=f"{engine_manifest.brand_name} の音声合成エンジンです。",
version=__version__,
separate_input_output_schemas=False, # Pydantic V1 のときのスキーマに合わせるため
)
app = configure_middlewares(app, cors_policy_mode, allow_origin)
app = configure_global_exception_handlers(app)
Expand Down
18 changes: 12 additions & 6 deletions voicevox_engine/app/openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,17 @@ def custom_openapi() -> Any:
openapi_schema = get_openapi(
title=app.title,
version=app.version,
openapi_version=app.openapi_version,
summary=app.summary,
description=app.description,
routes=app.routes,
tags=app.openapi_tags,
servers=app.servers,
terms_of_service=app.terms_of_service,
contact=app.contact,
license_info=app.license_info,
routes=app.routes,
webhooks=app.webhooks.routes,
tags=app.openapi_tags,
servers=app.servers,
separate_input_output_schemas=app.separate_input_output_schemas,
)
if manage_library:
additional_models: list[type[BaseModel]] = [
Expand All @@ -35,10 +39,12 @@ def custom_openapi() -> Any:
]
for model in additional_models:
# ref_templateを指定しない場合、definitionsを参照してしまうので、手動で指定する
schema = model.schema(ref_template="#/components/schemas/{model}")
schema = model.model_json_schema(
ref_template="#/components/schemas/{model}"
)
# definitionsは既存のモデルを重複して定義するため、不要なので削除
if "definitions" in schema:
del schema["definitions"]
if "$defs" in schema:
del schema["$defs"]
openapi_schema["components"]["schemas"][schema["title"]] = schema
app.openapi_schema = openapi_schema
return openapi_schema
Expand Down
5 changes: 4 additions & 1 deletion voicevox_engine/app/routers/engine_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fastapi import APIRouter, HTTPException
from pydantic import BaseModel, Field
from pydantic.json_schema import SkipJsonSchema

from voicevox_engine import __version__
from voicevox_engine.core.core_adapter import DeviceSupport
Expand Down Expand Up @@ -47,7 +48,9 @@ async def core_versions() -> list[str]:
return core_manager.versions()

@router.get("/supported_devices")
def supported_devices(core_version: str | None = None) -> SupportedDevicesInfo:
def supported_devices(
core_version: str | SkipJsonSchema[None] = None,
) -> SupportedDevicesInfo:
"""対応デバイスの一覧を取得します。"""
supported_devices = core_manager.get_core(core_version).supported_devices
if supported_devices is None:
Expand Down
5 changes: 3 additions & 2 deletions voicevox_engine/app/routers/morphing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

import soundfile
from fastapi import APIRouter, HTTPException, Query
from pydantic.json_schema import SkipJsonSchema
from starlette.background import BackgroundTask
from starlette.responses import FileResponse

Expand Down Expand Up @@ -45,7 +46,7 @@ def generate_morphing_router(
summary="指定したスタイルに対してエンジン内の話者がモーフィングが可能か判定する",
)
def morphable_targets(
base_style_ids: list[StyleId], core_version: str | None = None
base_style_ids: list[StyleId], core_version: str | SkipJsonSchema[None] = None
) -> list[dict[str, MorphableTargetInfo]]:
"""
指定されたベーススタイルに対してエンジン内の各話者がモーフィング機能を利用可能か返します。
Expand Down Expand Up @@ -83,7 +84,7 @@ def _synthesis_morphing(
base_style_id: Annotated[StyleId, Query(alias="base_speaker")],
target_style_id: Annotated[StyleId, Query(alias="target_speaker")],
morph_rate: Annotated[float, Query(ge=0.0, le=1.0)],
core_version: str | None = None,
core_version: str | SkipJsonSchema[None] = None,
) -> FileResponse:
"""
指定された2種類のスタイルで音声を合成、指定した割合でモーフィングした音声を得ます。
Expand Down
3 changes: 2 additions & 1 deletion voicevox_engine/app/routers/setting.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from fastapi import APIRouter, Depends, Form, Request, Response
from fastapi.templating import Jinja2Templates
from pydantic.json_schema import SkipJsonSchema

from voicevox_engine.engine_manifest import BrandName
from voicevox_engine.setting.model import CorsPolicyMode
Expand Down Expand Up @@ -53,7 +54,7 @@ def setting_get(request: Request) -> Response:
)
def setting_post(
cors_policy_mode: Annotated[CorsPolicyMode, Form()],
allow_origin: Annotated[str | None, Form()] = None,
allow_origin: Annotated[str | SkipJsonSchema[None], Form()] = None,
) -> None:
"""
設定を更新します。
Expand Down
21 changes: 14 additions & 7 deletions voicevox_engine/app/routers/speaker.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
from typing import Literal

from fastapi import APIRouter, HTTPException
from pydantic import parse_obj_as
from pydantic import TypeAdapter
from pydantic.json_schema import SkipJsonSchema

from voicevox_engine.core.core_initializer import CoreManager
from voicevox_engine.metas.Metas import Speaker, SpeakerInfo
Expand All @@ -25,14 +26,16 @@ def generate_speaker_router(
router = APIRouter(tags=["その他"])

@router.get("/speakers")
def speakers(core_version: str | None = None) -> list[Speaker]:
def speakers(core_version: str | SkipJsonSchema[None] = None) -> list[Speaker]:
"""話者情報の一覧を取得します。"""
core = core_manager.get_core(core_version)
speakers = metas_store.load_combined_metas(core.speakers)
return filter_speakers_and_styles(speakers, "speaker")

@router.get("/speaker_info")
def speaker_info(speaker_uuid: str, core_version: str | None = None) -> SpeakerInfo:
def speaker_info(
speaker_uuid: str, core_version: str | SkipJsonSchema[None] = None
) -> SpeakerInfo:
"""
指定されたspeaker_uuidの話者に関する情報をjson形式で返します。
画像や音声はbase64エンコードされたものが返されます。
Expand All @@ -43,6 +46,8 @@ def speaker_info(speaker_uuid: str, core_version: str | None = None) -> SpeakerI
core_version=core_version,
)

_speaker_list_adapter = TypeAdapter(list[Speaker])

# FIXME: この関数をどこかに切り出す
def _speaker_info(
speaker_uuid: str,
Expand Down Expand Up @@ -73,8 +78,8 @@ def _speaker_info(
# ...

# 該当話者を検索する
speakers = parse_obj_as(
list[Speaker], core_manager.get_core(core_version).speakers
speakers = _speaker_list_adapter.validate_python(
core_manager.get_core(core_version).speakers, from_attributes=True
)
speakers = filter_speakers_and_styles(speakers, speaker_or_singer)
speaker = next(
Expand Down Expand Up @@ -135,14 +140,16 @@ def _speaker_info(
return spk_info

@router.get("/singers")
def singers(core_version: str | None = None) -> list[Speaker]:
def singers(core_version: str | SkipJsonSchema[None] = None) -> list[Speaker]:
"""歌手情報の一覧を取得します"""
core = core_manager.get_core(core_version)
singers = metas_store.load_combined_metas(core.speakers)
return filter_speakers_and_styles(singers, "singer")

@router.get("/singer_info")
def singer_info(speaker_uuid: str, core_version: str | None = None) -> SpeakerInfo:
def singer_info(
speaker_uuid: str, core_version: str | SkipJsonSchema[None] = None
) -> SpeakerInfo:
"""
指定されたspeaker_uuidの歌手に関する情報をjson形式で返します。
画像や音声はbase64エンコードされたものが返されます。
Expand Down
Loading

0 comments on commit 2513168

Please sign in to comment.