From 9d7c437de454afab4d4959c8f320617f8a514046 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Thu, 25 Jul 2024 01:35:15 +0800 Subject: [PATCH] doc: add vLLM instruction --- ChatTTS/model/gpt.py | 5 +++-- README.md | 9 +++++++-- requirements.txt | 2 -- setup.py | 1 - 4 files changed, 10 insertions(+), 7 deletions(-) diff --git a/ChatTTS/model/gpt.py b/ChatTTS/model/gpt.py index f7fdbe616..5c9f63283 100644 --- a/ChatTTS/model/gpt.py +++ b/ChatTTS/model/gpt.py @@ -5,7 +5,6 @@ import gc from pathlib import Path -from safetensors.torch import save_file import torch import torch.nn as nn import torch.nn.functional as F @@ -92,6 +91,8 @@ def __init__( def from_pretrained(self, file_path: str): if self.is_vllm and platform.system().lower() == "linux": + from safetensors.torch import save_file + from .velocity.llm import LLM from .velocity.post_model import PostModel @@ -104,7 +105,7 @@ def from_pretrained(self, file_path: str): gpt.gpt.save_pretrained(vllm_folder / "gpt") post_model = ( PostModel( - int(self.gpt.config.hidden_size), + int(gpt.gpt.config.hidden_size), self.num_audio_tokens, self.num_text_tokens, ) diff --git a/README.md b/README.md index 41bfc830b..e31408cc0 100644 --- a/README.md +++ b/README.md @@ -101,7 +101,12 @@ conda activate chattts pip install -r requirements.txt ``` -#### Optional: Install TransformerEngine if using NVIDIA GPU (Linux only) +#### Optional: Install vLLM (Linux only) +```bash +pip install safetensors vllm==0.2.7 torchaudio +``` + +#### Unrecommended Optional: Install TransformerEngine if using NVIDIA GPU (Linux only) > [!Note] > The installation process is very slow. @@ -113,7 +118,7 @@ pip install -r requirements.txt pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable ``` -#### Optional: Install FlashAttention-2 (mainly NVIDIA GPU) +#### Unrecommended Optional: Install FlashAttention-2 (mainly NVIDIA GPU) > [!Note] > See supported devices at the [Hugging Face Doc](https://huggingface.co/docs/transformers/perf_infer_gpu_one#flashattention-2). diff --git a/requirements.txt b/requirements.txt index 8d8066224..75066bb96 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,5 +14,3 @@ WeTextProcessing; sys_platform == 'linux' nemo_text_processing; sys_platform == 'linux' av pydub -safetensors -vllm>=0.2.7; sys_platform == 'linux' diff --git a/setup.py b/setup.py index c5fe0a69e..da7b609e6 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,6 @@ "transformers>=4.41.1", "vector_quantize_pytorch", "vocos", - "safetensors", ], platforms="any", classifiers=[