Skip to content

Commit

Permalink
Update: CLI ツールの実装とバリデーションを改善
Browse files Browse the repository at this point in the history
  • Loading branch information
tsukumijima committed Oct 21, 2024
1 parent 006b98f commit b78b181
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 49 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,6 @@ cython_debug/
.python-version
build/
dist/

*.aivm
*.aivmx
6 changes: 3 additions & 3 deletions aivmlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def generate_aivm_metadata(
style_vectors_file.seek(0)

# Style-Bert-VITS2 系の音声合成モデルの場合
if model_architecture.startswith('Style-Bert-VITS2'):
if model_architecture in [ModelArchitecture.StyleBertVITS2, ModelArchitecture.StyleBertVITS2JPExtra]:

# ハイパーパラメータファイル (JSON) を読み込んだ後、Pydantic でバリデーションする
hyper_parameters_content = hyper_parameters_file.read().decode('utf-8')
Expand Down Expand Up @@ -147,7 +147,7 @@ def validate_aivm_metadata(raw_metadata: dict[str, str]) -> AivmMetadata:
# ハイパーパラメータのバリデーション
if 'aivm_hyper_parameters' in raw_metadata:
try:
if aivm_manifest.model_architecture.startswith('Style-Bert-VITS2'):
if aivm_manifest.model_architecture in [ModelArchitecture.StyleBertVITS2, ModelArchitecture.StyleBertVITS2JPExtra]:
aivm_hyper_parameters = StyleBertVITS2HyperParameters.model_validate_json(raw_metadata['aivm_hyper_parameters'])
else:
raise AivmValidationError(f"Unsupported hyper-parameters for model architecture: {aivm_manifest.model_architecture}.")
Expand Down Expand Up @@ -395,7 +395,7 @@ def apply_aivm_manifest_to_hyper_parameters(aivm_metadata: AivmMetadata) -> None
"""

# Style-Bert-VITS2 系の音声合成モデルの場合
if aivm_metadata.manifest.model_architecture.startswith('Style-Bert-VITS2'):
if aivm_metadata.manifest.model_architecture in [ModelArchitecture.StyleBertVITS2, ModelArchitecture.StyleBertVITS2JPExtra]:

# スタイルベクトルが設定されていなければエラー
if aivm_metadata.style_vectors is None:
Expand Down
162 changes: 116 additions & 46 deletions aivmlib/__main__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,13 @@

import rich
import traceback
import typer
from pathlib import Path
from rich.rule import Rule
from rich.style import Style
from typing import Annotated, Union

from aivmlib import (
generate_aivm_metadata,
read_aivm_metadata,
read_aivmx_metadata,
write_aivm_metadata,
write_aivmx_metadata,
)
import aivmlib
from aivmlib.schemas.aivm_manifest import ModelArchitecture


Expand All @@ -24,38 +19,39 @@ def show_metadata(
file_path: Annotated[Path, typer.Argument(help='Path to the AIVM / AIVMX file')]
):
"""
指定されたパスの AIVM / AIVMX ファイル内に記録されている AIVM メタデータを見やすく出力する
指定されたパスの AIVM / AIVMX ファイル内に格納されている AIVM メタデータを見やすく出力する
"""

try:
with file_path.open('rb') as file:
if file_path.suffix == '.aivmx':
metadata = read_aivmx_metadata(file)
metadata = aivmlib.read_aivmx_metadata(file)
else:
metadata = read_aivm_metadata(file)
metadata = aivmlib.read_aivm_metadata(file)

for speaker in metadata.manifest.speakers:
speaker.icon = '(Image Base64 DataURL)'
for style in speaker.styles:
style.icon = '(Image Base64 DataURL)'
if style.icon:
style.icon = '(Image Base64 DataURL)'
for sample in style.voice_samples:
sample.audio = '(Audio Base64 DataURL)'
rich.print(Rule(title='AIVM Manifest:', characters='=', style=Style(color='#E33157')))
rich.print(Rule(title='AIVM Manifest:', characters='=', style=Style(color='#41A2EC')))
rich.print(metadata.manifest)
rich.print(Rule(title='Hyper Parameters:', characters='=', style=Style(color='#E33157')))
rich.print(Rule(title='Hyper Parameters:', characters='=', style=Style(color='#41A2EC')))
rich.print(metadata.hyper_parameters)
rich.print(Rule(characters='=', style=Style(color='#E33157')))
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
except Exception as e:
rich.print(Rule(characters='=', style=Style(color='#E33157')))
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print(f'[red]Error reading AIVM or AIVMX file: {e}[/red]')
rich.print(Rule(characters='=', style=Style(color='#E33157')))
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))


@app.command()
def create_aivm(
output_path: Annotated[Path, typer.Option('-o', '--output', help='Path to the output AIVM file')],
safetensors_model_path: Annotated[Path, typer.Option('-m', '--model', help='Path to the Safetensors model file')],
hyper_parameters_path: Annotated[Path, typer.Option('-h', '--hyper-parameters', help='Path to the hyper parameters file')],
hyper_parameters_path: Annotated[Union[Path, None], typer.Option('-h', '--hyper-parameters', help='Path to the hyper parameters file (optional)')] = None,
style_vectors_path: Annotated[Union[Path, None], typer.Option('-s', '--style-vectors', help='Path to the style vectors file (optional)')] = None,
model_architecture: Annotated[ModelArchitecture, typer.Option('-a', '--model-architecture', help='Model architecture')] = ModelArchitecture.StyleBertVITS2JPExtra,
):
Expand All @@ -64,31 +60,68 @@ def create_aivm(
それを書き込んだ仮の AIVM ファイルを生成する
"""

# 拡張子チェック
if safetensors_model_path.suffix != '.safetensors':
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print('[red]Safetensors model file must have a .safetensors extension.[/red]')
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
return
if output_path.suffix != '.aivm':
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print('[red]Output file must have a .aivm extension.[/red]')
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
return

try:
# アーキテクチャに合わせて未指定のファイルパスを自動設定
if model_architecture in [ModelArchitecture.StyleBertVITS2, ModelArchitecture.StyleBertVITS2JPExtra]:
model_dir = safetensors_model_path.parent
if not hyper_parameters_path:
hyper_parameters_path = model_dir / 'config.json'
if not style_vectors_path:
style_vectors_path = model_dir / 'style_vectors.npy'

# 必要なファイルが存在しない場合はエラーを発生させる
if not hyper_parameters_path.exists():
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print(f'[red]Hyper parameters file not found: {hyper_parameters_path}[/red]')
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
return
if not style_vectors_path.exists():
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print(f'[red]Style vectors file not found: {style_vectors_path}[/red]')
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
return
else:
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print(f'[red]Model architecture {model_architecture} is not supported.[/red]')
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
return

with hyper_parameters_path.open('rb') as hyper_parameters_file:
style_vectors_file = style_vectors_path.open('rb') if style_vectors_path else None
metadata = generate_aivm_metadata(model_architecture, hyper_parameters_file, style_vectors_file)
if style_vectors_file:
style_vectors_file.close()

with safetensors_model_path.open('rb') as safetensors_file:
new_aivm_file_content = write_aivm_metadata(safetensors_file, metadata)
with output_path.open('wb') as f:
f.write(new_aivm_file_content)
rich.print(Rule(characters='=', style=Style(color='#E33157')))
rich.print(f'Generated AIVM file: {output_path}')
rich.print(Rule(characters='=', style=Style(color='#E33157')))
with style_vectors_path.open('rb') as style_vectors_file:
metadata = aivmlib.generate_aivm_metadata(model_architecture, hyper_parameters_file, style_vectors_file)

with safetensors_model_path.open('rb') as safetensors_file:
new_aivm_file_content = aivmlib.write_aivm_metadata(safetensors_file, metadata)
with output_path.open('wb') as f:
f.write(new_aivm_file_content)

rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print(f'Generated AIVM file: {output_path}')
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
except Exception as e:
rich.print(Rule(characters='=', style=Style(color='#E33157')))
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print(f'[red]Error creating AIVM file: {e}[/red]')
rich.print(Rule(characters='=', style=Style(color='#E33157')))
rich.print(traceback.format_exc())
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))


@app.command()
def create_aivmx(
output_path: Annotated[Path, typer.Option('-o', '--output', help='Path to the output AIVMX file')],
onnx_model_path: Annotated[Path, typer.Option('-m', '--model', help='Path to the ONNX model file')],
hyper_parameters_path: Annotated[Path, typer.Option('-h', '--hyper-parameters', help='Path to the hyper parameters file')],
hyper_parameters_path: Annotated[Union[Path, None], typer.Option('-h', '--hyper-parameters', help='Path to the hyper parameters file (optional)')] = None,
style_vectors_path: Annotated[Union[Path, None], typer.Option('-s', '--style-vectors', help='Path to the style vectors file (optional)')] = None,
model_architecture: Annotated[ModelArchitecture, typer.Option('-a', '--model-architecture', help='Model architecture')] = ModelArchitecture.StyleBertVITS2JPExtra,
):
Expand All @@ -97,24 +130,61 @@ def create_aivmx(
それを書き込んだ仮の AIVMX ファイルを生成する
"""

# 拡張子チェック
if onnx_model_path.suffix != '.onnx':
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print('[red]ONNX model file must have a .onnx extension.[/red]')
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
return
if output_path.suffix != '.aivmx':
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print('[red]Output file must have a .aivmx extension.[/red]')
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
return

try:
# アーキテクチャに合わせて未指定のファイルパスを自動設定
if model_architecture in [ModelArchitecture.StyleBertVITS2, ModelArchitecture.StyleBertVITS2JPExtra]:
model_dir = onnx_model_path.parent
if not hyper_parameters_path:
hyper_parameters_path = model_dir / 'config.json'
if not style_vectors_path:
style_vectors_path = model_dir / 'style_vectors.npy'

# 必要なファイルが存在しない場合はエラーを発生させる
if not hyper_parameters_path.exists():
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print(f'[red]Hyper parameters file not found: {hyper_parameters_path}[/red]')
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
return
if not style_vectors_path.exists():
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print(f'[red]Style vectors file not found: {style_vectors_path}[/red]')
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
return
else:
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print(f'[red]Model architecture {model_architecture} is not supported.[/red]')
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
return

with hyper_parameters_path.open('rb') as hyper_parameters_file:
style_vectors_file = style_vectors_path.open('rb') if style_vectors_path else None
metadata = generate_aivm_metadata(model_architecture, hyper_parameters_file, style_vectors_file)
if style_vectors_file:
style_vectors_file.close()

with onnx_model_path.open('rb') as onnx_file:
new_aivmx_file_content = write_aivmx_metadata(onnx_file, metadata)
with output_path.open('wb') as f:
f.write(new_aivmx_file_content)
rich.print(Rule(characters='=', style=Style(color='#E33157')))
rich.print(f'Generated AIVMX file: {output_path}')
rich.print(Rule(characters='=', style=Style(color='#E33157')))
with style_vectors_path.open('rb') as style_vectors_file:
metadata = aivmlib.generate_aivm_metadata(model_architecture, hyper_parameters_file, style_vectors_file)

with onnx_model_path.open('rb') as onnx_file:
new_aivmx_file_content = aivmlib.write_aivmx_metadata(onnx_file, metadata)
with output_path.open('wb') as f:
f.write(new_aivmx_file_content)

rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print(f'Generated AIVMX file: {output_path}')
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
except Exception as e:
rich.print(Rule(characters='=', style=Style(color='#E33157')))
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))
rich.print(f'[red]Error creating AIVMX file: {e}[/red]')
rich.print(Rule(characters='=', style=Style(color='#E33157')))
rich.print(traceback.format_exc())
rich.print(Rule(characters='=', style=Style(color='#41A2EC')))


if __name__ == '__main__':
Expand Down

0 comments on commit b78b181

Please sign in to comment.