diff --git a/ChatTTS/core.py b/ChatTTS/core.py index 94659b52a..bfc728437 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -291,33 +291,44 @@ def _load( self.dvae = dvae self.logger.log(logging.INFO, "dvae loaded.") - if gpt_config_path: cfg = OmegaConf.load(gpt_config_path) self.num_vq = 4 if not os.path.exists("asset/vllm_model"): gpt = GPT( - **cfg, use_flash_attn=use_flash_attn, device=device, logger=self.logger + **cfg, + use_flash_attn=use_flash_attn, + device=device, + logger=self.logger, ).eval() assert gpt_ckpt_path, "gpt_ckpt_path should not be None" - gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True)) + gpt.load_state_dict( + torch.load(gpt_ckpt_path, weights_only=True, mmap=True) + ) gpt.prepare(compile=compile and "cuda" in str(device)) self.gpt = gpt pathlib.Path("asset/vllm_model").mkdir(parents=True, exist_ok=True) self.gpt.gpt.save_pretrained("asset/vllm_model/gpt") - self.post_model = Post_model( - cfg.gpt_config.hidden_size, - cfg.num_audio_tokens, - cfg.num_text_tokens, - device = device - ).to(device).eval() - + self.post_model = ( + Post_model( + cfg.gpt_config.hidden_size, + cfg.num_audio_tokens, + cfg.num_text_tokens, + device=device, + ) + .to(device) + .eval() + ) + self.post_model.emb_code = self.gpt.emb_code self.post_model.emb_text = self.gpt.emb_text self.post_model.head_text = self.gpt.head_text self.post_model.head_code = self.gpt.head_code - save_file(self.post_model.state_dict(), "asset/vllm_model/post_model.safetensors") - + save_file( + self.post_model.state_dict(), + "asset/vllm_model/post_model.safetensors", + ) + self.num_audio_tokens = cfg.num_audio_tokens spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") assert os.path.exists( @@ -331,15 +342,15 @@ def _load( ) self.std, self.mean = spk_stat.requires_grad_(False).chunk(2) self.logger.log(logging.INFO, "gpt loaded.") - + self.hidden_size = cfg.gpt_config.hidden_size self.gpt = LLM( model="asset/vllm_model/gpt", - num_audio_tokens = cfg.num_audio_tokens, - num_text_tokens = cfg.num_text_tokens, + num_audio_tokens=cfg.num_audio_tokens, + num_text_tokens=cfg.num_text_tokens, post_model_path="asset/vllm_model/post_model.safetensors", ) - + if dvae_config_path: cfg = OmegaConf.load(dvae_config_path) dvae = DVAE(**cfg, coef=coef).to(device).eval() @@ -369,7 +380,7 @@ def _load( self.coef = coef return self.has_loaded() - + def _infer( self, text, @@ -506,7 +517,7 @@ def destroy(self): del_all(self.ids) # del_all(self.attentions) # del_all(self.hiddens) - + @torch.no_grad() def _infer_code( self, @@ -548,7 +559,7 @@ def _infer_code( text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text] else: text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text] - + input_ids, attention_mask, text_mask = self.tokenizer.encode( text, self.num_vq, @@ -556,7 +567,7 @@ def _infer_code( device=self.device, ) start_idx = input_ids.shape[-2] - + num_code = self.num_audio_tokens - 1 logits_warpers, logits_processors = gen_logits( @@ -565,34 +576,35 @@ def _infer_code( top_K=params.top_K, repetition_penalty=params.repetition_penalty, ) - + sample_params = SamplingParams( temperature=temperature, max_new_token=params.max_new_token, - max_tokens = 8192, + max_tokens=8192, min_new_token=params.min_new_token, logits_processors=(logits_warpers, logits_processors), - eos_token = num_code, + eos_token=num_code, infer_text=False, - start_idx=start_idx + start_idx=start_idx, ) input_ids = [i.tolist() for i in input_ids] - + result = gpt.generate( None, sample_params, input_ids, ) - + token_ids = [] hidden_states = [] for i in result: token_ids.append(torch.tensor(i.outputs[0].token_ids)) - hidden_states.append(i.outputs[0].hidden_states.to(torch.float32).to(self.device)) - return [self.GenerationOutputs( - ids=token_ids, - hiddens=hidden_states - ),] + hidden_states.append( + i.outputs[0].hidden_states.to(torch.float32).to(self.device) + ) + return [ + self.GenerationOutputs(ids=token_ids, hiddens=hidden_states), + ] @torch.no_grad() def _refine_text( @@ -602,7 +614,7 @@ def _refine_text( params: RefineTextParams, ): - gpt:LLM = self.gpt + gpt: LLM = self.gpt if not isinstance(text, list): text = [text] @@ -614,7 +626,7 @@ def _refine_text( self.num_vq, device=self.device, ) - + start_idx = input_ids.shape[-2] # print(start_idx) logits_warpers, logits_processors = gen_logits( @@ -627,26 +639,19 @@ def _refine_text( sample_params = SamplingParams( temperature=params.temperature, max_new_token=params.max_new_token, - max_tokens = 8192, + max_tokens=8192, min_new_token=params.min_new_token, logits_processors=(logits_warpers, logits_processors), - eos_token = self.tokenizer.eos_token, + eos_token=self.tokenizer.eos_token, infer_text=True, - start_idx=start_idx + start_idx=start_idx, ) input_ids = [i.tolist() for i in input_ids] - - result = gpt.generate( - None, - sample_params, - input_ids - ) + + result = gpt.generate(None, sample_params, input_ids) token_ids = [] hidden_states = [] for i in result: token_ids.append(torch.tensor(i.outputs[0].token_ids)) hidden_states.append(i.outputs[0].hidden_states) - return self.GenerationOutputs( - ids=token_ids, - hiddens=hidden_states - ) + return self.GenerationOutputs(ids=token_ids, hiddens=hidden_states) diff --git a/ChatTTS/model/velocity/block_manager.py b/ChatTTS/model/velocity/block_manager.py index 199a3a278..ad69aa1b9 100644 --- a/ChatTTS/model/velocity/block_manager.py +++ b/ChatTTS/model/velocity/block_manager.py @@ -1,4 +1,5 @@ """A block manager that manages token blocks.""" + import enum from typing import Dict, List, Optional, Set, Tuple @@ -31,9 +32,9 @@ def __init__( # Initialize the free blocks. self.free_blocks: BlockTable = [] for i in range(num_blocks): - block = PhysicalTokenBlock(device=device, - block_number=i, - block_size=block_size) + block = PhysicalTokenBlock( + device=device, block_number=i, block_size=block_size + ) self.free_blocks.append(block) def allocate(self) -> PhysicalTokenBlock: @@ -63,6 +64,7 @@ class AllocStatus(enum.Enum): 3. Never: seq_group can never be allocated. The seq_group is too large to allocated in GPU. """ + OK = enum.auto() LATER = enum.auto() NEVER = enum.auto() @@ -85,18 +87,15 @@ def __init__( self.block_sliding_window = None if sliding_window is not None: - assert sliding_window % block_size == 0, (sliding_window, - block_size) + assert sliding_window % block_size == 0, (sliding_window, block_size) self.block_sliding_window = sliding_window // block_size self.watermark = watermark assert watermark >= 0.0 self.watermark_blocks = int(watermark * num_gpu_blocks) - self.gpu_allocator = BlockAllocator(Device.GPU, block_size, - num_gpu_blocks) - self.cpu_allocator = BlockAllocator(Device.CPU, block_size, - num_cpu_blocks) + self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks) + self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks) # Mapping: seq_id -> BlockTable. self.block_tables: Dict[int, BlockTable] = {} @@ -106,13 +105,11 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0] num_required_blocks = len(seq.logical_token_blocks) if self.block_sliding_window is not None: - num_required_blocks = min(num_required_blocks, - self.block_sliding_window) + num_required_blocks = min(num_required_blocks, self.block_sliding_window) num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks() # Use watermark to avoid frequent cache eviction. - if (self.num_total_gpu_blocks - num_required_blocks < - self.watermark_blocks): + if self.num_total_gpu_blocks - num_required_blocks < self.watermark_blocks: return AllocStatus.NEVER if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks: return AllocStatus.OK @@ -127,8 +124,10 @@ def allocate(self, seq_group: SequenceGroup) -> None: # Allocate new physical token blocks that will store the prompt tokens. block_table: BlockTable = [] for logical_idx in range(len(seq.logical_token_blocks)): - if (self.block_sliding_window is not None - and logical_idx >= self.block_sliding_window): + if ( + self.block_sliding_window is not None + and logical_idx >= self.block_sliding_window + ): block = block_table[logical_idx % self.block_sliding_window] else: block = self.gpu_allocator.allocate() @@ -153,11 +152,14 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]: block_table = self.block_tables[seq.seq_id] if len(block_table) < len(logical_blocks): - if (self.block_sliding_window - and len(block_table) >= self.block_sliding_window): + if ( + self.block_sliding_window + and len(block_table) >= self.block_sliding_window + ): # re-use a block - block_table.append(block_table[len(block_table) % - self.block_sliding_window]) + block_table.append( + block_table[len(block_table) % self.block_sliding_window] + ) else: # The sequence has a new logical block. # Allocate a new physical block. @@ -188,7 +190,8 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None: block.ref_count += 1 def _get_physical_blocks( - self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]: + self, seq_group: SequenceGroup + ) -> List[PhysicalTokenBlock]: # NOTE: Here, we assume that the physical blocks are only shared by # the sequences in the same group. blocks: Set[PhysicalTokenBlock] = set() diff --git a/ChatTTS/model/velocity/configs.py b/ChatTTS/model/velocity/configs.py index 30d6c9afa..c578f468a 100644 --- a/ChatTTS/model/velocity/configs.py +++ b/ChatTTS/model/velocity/configs.py @@ -79,7 +79,7 @@ def __init__( enforce_eager: bool = False, max_context_len_to_capture: Optional[int] = None, num_audio_tokens: int = 1024, - num_text_tokens: int = 80 + num_text_tokens: int = 80, ) -> None: self.model = model self.tokenizer = tokenizer @@ -95,22 +95,24 @@ def __init__( self.max_context_len_to_capture = max_context_len_to_capture self.num_audio_tokens = num_audio_tokens self.num_text_tokens = num_text_tokens - + if os.environ.get("VLLM_USE_MODELSCOPE", "False").lower() == "true": # download model from ModelScope hub, # lazy import so that modelscope is not required for normal use. - from modelscope.hub.snapshot_download import snapshot_download # pylint: disable=C - model_path = snapshot_download(model_id=model, - cache_dir=download_dir, - revision=revision) + from modelscope.hub.snapshot_download import ( + snapshot_download, + ) # pylint: disable=C + + model_path = snapshot_download( + model_id=model, cache_dir=download_dir, revision=revision + ) self.model = model_path self.download_dir = model_path self.tokenizer = model_path self.hf_config = get_config(self.model, trust_remote_code, revision) self.dtype = _get_and_verify_dtype(self.hf_config, dtype) - self.max_model_len = _get_and_verify_max_len(self.hf_config, - max_model_len) + self.max_model_len = _get_and_verify_max_len(self.hf_config, max_model_len) self._verify_load_format() self._verify_tokenizer_mode() self._verify_quantization() @@ -118,30 +120,32 @@ def __init__( def _verify_load_format(self) -> None: load_format = self.load_format.lower() - supported_load_format = [ - "auto", "pt", "safetensors", "npcache", "dummy" - ] + supported_load_format = ["auto", "pt", "safetensors", "npcache", "dummy"] rocm_not_supported_load_format = [] if load_format not in supported_load_format: raise ValueError( f"Unknown load format: {self.load_format}. Must be one of " - "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'.") + "'auto', 'pt', 'safetensors', 'npcache', or 'dummy'." + ) if is_hip() and load_format in rocm_not_supported_load_format: rocm_supported_load_format = [ - f for f in supported_load_format + f + for f in supported_load_format if (f not in rocm_not_supported_load_format) ] raise ValueError( - f"load format \'{load_format}\' is not supported in ROCm. " + f"load format '{load_format}' is not supported in ROCm. " f"Supported load format are " - f"{rocm_supported_load_format}") + f"{rocm_supported_load_format}" + ) # TODO: Remove this check once HF updates the pt weights of Mixtral. architectures = getattr(self.hf_config, "architectures", []) if "MixtralForCausalLM" in architectures and load_format == "pt": raise ValueError( "Currently, the 'pt' format is not supported for Mixtral. " - "Please use the 'safetensors' format instead. ") + "Please use the 'safetensors' format instead. " + ) self.load_format = load_format def _verify_tokenizer_mode(self) -> None: @@ -149,7 +153,8 @@ def _verify_tokenizer_mode(self) -> None: if tokenizer_mode not in ["auto", "slow"]: raise ValueError( f"Unknown tokenizer mode: {self.tokenizer_mode}. Must be " - "either 'auto' or 'slow'.") + "either 'auto' or 'slow'." + ) self.tokenizer_mode = tokenizer_mode def _verify_quantization(self) -> None: @@ -169,27 +174,32 @@ def _verify_quantization(self) -> None: "Quantization method specified in the model config " f"({hf_quant_method}) does not match the quantization " f"method specified in the `quantization` argument " - f"({self.quantization}).") + f"({self.quantization})." + ) if self.quantization is not None: if self.quantization not in supported_quantization: raise ValueError( f"Unknown quantization method: {self.quantization}. Must " - f"be one of {supported_quantization}.") - if is_hip( - ) and self.quantization in rocm_not_supported_quantization: + f"be one of {supported_quantization}." + ) + if is_hip() and self.quantization in rocm_not_supported_quantization: raise ValueError( f"{self.quantization} quantization is currently not supported " - f"in ROCm.") - logger.warning(f"{self.quantization} quantization is not fully " - "optimized yet. The speed can be slower than " - "non-quantized models.") + f"in ROCm." + ) + logger.warning( + f"{self.quantization} quantization is not fully " + "optimized yet. The speed can be slower than " + "non-quantized models." + ) def _verify_cuda_graph(self) -> None: if self.max_context_len_to_capture is None: self.max_context_len_to_capture = self.max_model_len - self.max_context_len_to_capture = min(self.max_context_len_to_capture, - self.max_model_len) + self.max_context_len_to_capture = min( + self.max_context_len_to_capture, self.max_model_len + ) def verify_with_parallel_config( self, @@ -201,7 +211,8 @@ def verify_with_parallel_config( raise ValueError( f"Total number of attention heads ({total_num_attention_heads})" " must be divisible by tensor parallel size " - f"({tensor_parallel_size}).") + f"({tensor_parallel_size})." + ) total_num_hidden_layers = self.hf_config.num_hidden_layers pipeline_parallel_size = parallel_config.pipeline_parallel_size @@ -209,7 +220,8 @@ def verify_with_parallel_config( raise ValueError( f"Total number of hidden layers ({total_num_hidden_layers}) " "must be divisible by pipeline parallel size " - f"({pipeline_parallel_size}).") + f"({pipeline_parallel_size})." + ) def get_sliding_window(self) -> Optional[int]: return getattr(self.hf_config, "sliding_window", None) @@ -233,9 +245,11 @@ def get_total_num_kv_heads(self) -> int: falcon_model_types = ["falcon", "RefinedWeb", "RefinedWebModel"] new_decoder_arch_falcon = ( self.hf_config.model_type in falcon_model_types - and getattr(self.hf_config, "new_decoder_architecture", False)) - if not new_decoder_arch_falcon and getattr(self.hf_config, - "multi_query", False): + and getattr(self.hf_config, "new_decoder_architecture", False) + ) + if not new_decoder_arch_falcon and getattr( + self.hf_config, "multi_query", False + ): # Multi-query attention, only one KV head. # Currently, tensor parallelism is not supported in this case. return 1 @@ -265,8 +279,7 @@ def get_num_kv_heads(self, parallel_config: "ParallelConfig") -> int: # the tensor parallel size. We will replicate the KV heads in the # case where the number of KV heads is smaller than the tensor # parallel size so each GPU has at least one KV head. - return max(1, - total_num_kv_heads // parallel_config.tensor_parallel_size) + return max(1, total_num_kv_heads // parallel_config.tensor_parallel_size) def get_num_layers(self, parallel_config: "ParallelConfig") -> int: total_num_hidden_layers = self.hf_config.num_hidden_layers @@ -304,7 +317,8 @@ def _verify_args(self) -> None: if self.gpu_memory_utilization > 1.0: raise ValueError( "GPU memory utilization must be less than 1.0. Got " - f"{self.gpu_memory_utilization}.") + f"{self.gpu_memory_utilization}." + ) def verify_with_parallel_config( self, @@ -316,9 +330,11 @@ def verify_with_parallel_config( num_gpus_per_node = parallel_config.tensor_parallel_size cpu_memory_usage = self.swap_space_bytes * num_gpus_per_node - msg = (f"{cpu_memory_usage / _GB:.2f} GiB out of " - f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is " - "allocated for the swap space.") + msg = ( + f"{cpu_memory_usage / _GB:.2f} GiB out of " + f"the {total_cpu_memory / _GB:.2f} GiB total CPU memory is " + "allocated for the swap space." + ) if cpu_memory_usage > 0.7 * total_cpu_memory: raise ValueError("Too large swap space. " + msg) elif cpu_memory_usage > 0.4 * total_cpu_memory: @@ -355,8 +371,7 @@ def __init__( def _verify_args(self) -> None: if self.pipeline_parallel_size > 1: - raise NotImplementedError( - "Pipeline parallelism is not supported yet.") + raise NotImplementedError("Pipeline parallelism is not supported yet.") class SchedulerConfig: @@ -398,12 +413,14 @@ def _verify_args(self) -> None: "This effectively limits the maximum sequence length to " "max_num_batched_tokens and makes vLLM reject longer " "sequences. Please increase max_num_batched_tokens or " - "decrease max_model_len.") + "decrease max_model_len." + ) if self.max_num_batched_tokens < self.max_num_seqs: raise ValueError( f"max_num_batched_tokens ({self.max_num_batched_tokens}) must " "be greater than or equal to max_num_seqs " - f"({self.max_num_seqs}).") + f"({self.max_num_seqs})." + ) _STR_DTYPE_TO_TORCH_DTYPE = { @@ -447,11 +464,14 @@ def _get_and_verify_dtype( if is_hip() and torch_dtype == torch.float32: rocm_supported_dtypes = [ - k for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() + k + for k, v in _STR_DTYPE_TO_TORCH_DTYPE.items() if (k not in _ROCM_NOT_SUPPORTED_DTYPE) ] - raise ValueError(f"dtype \'{dtype}\' is not supported in ROCm. " - f"Supported dtypes are {rocm_supported_dtypes}") + raise ValueError( + f"dtype '{dtype}' is not supported in ROCm. " + f"Supported dtypes are {rocm_supported_dtypes}" + ) # Verify the dtype. if torch_dtype != config_dtype: @@ -502,7 +522,8 @@ def _get_and_verify_max_len( "The model's config.json does not contain any of the following " "keys to determine the original maximum length of the model: " f"{possible_keys}. Assuming the model's maximum length is " - f"{default_max_len}.") + f"{default_max_len}." + ) derived_max_model_len = default_max_len rope_scaling = getattr(hf_config, "rope_scaling", None) @@ -510,8 +531,7 @@ def _get_and_verify_max_len( assert "factor" in rope_scaling scaling_factor = rope_scaling["factor"] if rope_scaling["type"] == "yarn": - derived_max_model_len = rope_scaling[ - "original_max_position_embeddings"] + derived_max_model_len = rope_scaling["original_max_position_embeddings"] derived_max_model_len *= scaling_factor if max_model_len is None: @@ -522,20 +542,22 @@ def _get_and_verify_max_len( f"the derived max_model_len ({max_len_key}={derived_max_model_len}" " in model's config.json). This may lead to incorrect model " "outputs or CUDA errors. Make sure the value is correct and " - "within the model context size.") + "within the model context size." + ) return int(max_model_len) @dataclass class EngineArgs: """Arguments for vLLM engine.""" + model: str tokenizer: Optional[str] = None - tokenizer_mode: str = 'auto' + tokenizer_mode: str = "auto" trust_remote_code: bool = False download_dir: Optional[str] = None - load_format: str = 'auto' - dtype: str = 'auto' + load_format: str = "auto" + dtype: str = "auto" seed: int = 0 max_model_len: Optional[int] = None worker_use_ray: bool = False @@ -556,14 +578,13 @@ class EngineArgs: max_context_len_to_capture: int = 8192 num_audio_tokens: int = 1024 num_text_tokens: int = 80 - + def __post_init__(self): if self.tokenizer is None: self.tokenizer = self.model @staticmethod - def add_cli_args( - parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: """Shared CLI arguments for vLLM engine.""" # NOTE: If you update any of the arguments below, please also @@ -571,162 +592,198 @@ def add_cli_args( # Model arguments parser.add_argument( - '--model', + "--model", type=str, - default='facebook/opt-125m', - help='name or path of the huggingface model to use') + default="facebook/opt-125m", + help="name or path of the huggingface model to use", + ) parser.add_argument( - '--tokenizer', + "--tokenizer", type=str, default=EngineArgs.tokenizer, - help='name or path of the huggingface tokenizer to use') + help="name or path of the huggingface tokenizer to use", + ) parser.add_argument( - '--revision', + "--revision", type=str, default=None, - help='the specific model version to use. It can be a branch ' - 'name, a tag name, or a commit id. If unspecified, will use ' - 'the default version.') + help="the specific model version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) parser.add_argument( - '--tokenizer-revision', + "--tokenizer-revision", type=str, default=None, - help='the specific tokenizer version to use. It can be a branch ' - 'name, a tag name, or a commit id. If unspecified, will use ' - 'the default version.') - parser.add_argument('--tokenizer-mode', - type=str, - default=EngineArgs.tokenizer_mode, - choices=['auto', 'slow'], - help='tokenizer mode. "auto" will use the fast ' - 'tokenizer if available, and "slow" will ' - 'always use the slow tokenizer.') - parser.add_argument('--trust-remote-code', - action='store_true', - help='trust remote code from huggingface') - parser.add_argument('--download-dir', - type=str, - default=EngineArgs.download_dir, - help='directory to download and load the weights, ' - 'default to the default cache dir of ' - 'huggingface') + help="the specific tokenizer version to use. It can be a branch " + "name, a tag name, or a commit id. If unspecified, will use " + "the default version.", + ) + parser.add_argument( + "--tokenizer-mode", + type=str, + default=EngineArgs.tokenizer_mode, + choices=["auto", "slow"], + help='tokenizer mode. "auto" will use the fast ' + 'tokenizer if available, and "slow" will ' + "always use the slow tokenizer.", + ) + parser.add_argument( + "--trust-remote-code", + action="store_true", + help="trust remote code from huggingface", + ) + parser.add_argument( + "--download-dir", + type=str, + default=EngineArgs.download_dir, + help="directory to download and load the weights, " + "default to the default cache dir of " + "huggingface", + ) parser.add_argument( - '--load-format', + "--load-format", type=str, default=EngineArgs.load_format, - choices=['auto', 'pt', 'safetensors', 'npcache', 'dummy'], - help='The format of the model weights to load. ' + choices=["auto", "pt", "safetensors", "npcache", "dummy"], + help="The format of the model weights to load. " '"auto" will try to load the weights in the safetensors format ' - 'and fall back to the pytorch bin format if safetensors format ' - 'is not available. ' + "and fall back to the pytorch bin format if safetensors format " + "is not available. " '"pt" will load the weights in the pytorch bin format. ' '"safetensors" will load the weights in the safetensors format. ' '"npcache" will load the weights in pytorch format and store ' - 'a numpy cache to speed up the loading. ' + "a numpy cache to speed up the loading. " '"dummy" will initialize the weights with random values, ' - 'which is mainly for profiling.') + "which is mainly for profiling.", + ) parser.add_argument( - '--dtype', + "--dtype", type=str, default=EngineArgs.dtype, - choices=[ - 'auto', 'half', 'float16', 'bfloat16', 'float', 'float32' - ], - help='data type for model weights and activations. ' + choices=["auto", "half", "float16", "bfloat16", "float", "float32"], + help="data type for model weights and activations. " 'The "auto" option will use FP16 precision ' - 'for FP32 and FP16 models, and BF16 precision ' - 'for BF16 models.') - parser.add_argument('--max-model-len', - type=int, - default=None, - help='model context length. If unspecified, ' - 'will be automatically derived from the model.') + "for FP32 and FP16 models, and BF16 precision " + "for BF16 models.", + ) + parser.add_argument( + "--max-model-len", + type=int, + default=None, + help="model context length. If unspecified, " + "will be automatically derived from the model.", + ) # Parallel arguments - parser.add_argument('--worker-use-ray', - action='store_true', - help='use Ray for distributed serving, will be ' - 'automatically set when using more than 1 GPU') - parser.add_argument('--pipeline-parallel-size', - '-pp', - type=int, - default=EngineArgs.pipeline_parallel_size, - help='number of pipeline stages') - parser.add_argument('--tensor-parallel-size', - '-tp', - type=int, - default=EngineArgs.tensor_parallel_size, - help='number of tensor parallel replicas') parser.add_argument( - '--max-parallel-loading-workers', + "--worker-use-ray", + action="store_true", + help="use Ray for distributed serving, will be " + "automatically set when using more than 1 GPU", + ) + parser.add_argument( + "--pipeline-parallel-size", + "-pp", + type=int, + default=EngineArgs.pipeline_parallel_size, + help="number of pipeline stages", + ) + parser.add_argument( + "--tensor-parallel-size", + "-tp", type=int, - help='load model sequentially in multiple batches, ' - 'to avoid RAM OOM when using tensor ' - 'parallel and large models') + default=EngineArgs.tensor_parallel_size, + help="number of tensor parallel replicas", + ) + parser.add_argument( + "--max-parallel-loading-workers", + type=int, + help="load model sequentially in multiple batches, " + "to avoid RAM OOM when using tensor " + "parallel and large models", + ) # KV cache arguments - parser.add_argument('--block-size', - type=int, - default=EngineArgs.block_size, - choices=[8, 16, 32], - help='token block size') + parser.add_argument( + "--block-size", + type=int, + default=EngineArgs.block_size, + choices=[8, 16, 32], + help="token block size", + ) # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). - parser.add_argument('--seed', - type=int, - default=EngineArgs.seed, - help='random seed') - parser.add_argument('--swap-space', - type=int, - default=EngineArgs.swap_space, - help='CPU swap space size (GiB) per GPU') parser.add_argument( - '--gpu-memory-utilization', + "--seed", type=int, default=EngineArgs.seed, help="random seed" + ) + parser.add_argument( + "--swap-space", + type=int, + default=EngineArgs.swap_space, + help="CPU swap space size (GiB) per GPU", + ) + parser.add_argument( + "--gpu-memory-utilization", type=float, default=EngineArgs.gpu_memory_utilization, - help='the fraction of GPU memory to be used for ' - 'the model executor, which can range from 0 to 1.' - 'If unspecified, will use the default value of 0.9.') - parser.add_argument('--max-num-batched-tokens', - type=int, - default=EngineArgs.max_num_batched_tokens, - help='maximum number of batched tokens per ' - 'iteration') - parser.add_argument('--max-num-seqs', - type=int, - default=EngineArgs.max_num_seqs, - help='maximum number of sequences per iteration') - parser.add_argument('--max-paddings', - type=int, - default=EngineArgs.max_paddings, - help='maximum number of paddings in a batch') - parser.add_argument('--disable-log-stats', - action='store_true', - help='disable logging statistics') + help="the fraction of GPU memory to be used for " + "the model executor, which can range from 0 to 1." + "If unspecified, will use the default value of 0.9.", + ) + parser.add_argument( + "--max-num-batched-tokens", + type=int, + default=EngineArgs.max_num_batched_tokens, + help="maximum number of batched tokens per " "iteration", + ) + parser.add_argument( + "--max-num-seqs", + type=int, + default=EngineArgs.max_num_seqs, + help="maximum number of sequences per iteration", + ) + parser.add_argument( + "--max-paddings", + type=int, + default=EngineArgs.max_paddings, + help="maximum number of paddings in a batch", + ) + parser.add_argument( + "--disable-log-stats", + action="store_true", + help="disable logging statistics", + ) # Quantization settings. - parser.add_argument('--quantization', - '-q', - type=str, - choices=['awq', 'gptq', 'squeezellm', None], - default=None, - help='Method used to quantize the weights. If ' - 'None, we first check the `quantization_config` ' - 'attribute in the model config file. If that is ' - 'None, we assume the model weights are not ' - 'quantized and use `dtype` to determine the data ' - 'type of the weights.') - parser.add_argument('--enforce-eager', - action='store_true', - help='Always use eager-mode PyTorch. If False, ' - 'will use eager mode and CUDA graph in hybrid ' - 'for maximal performance and flexibility.') - parser.add_argument('--max-context-len-to-capture', - type=int, - default=EngineArgs.max_context_len_to_capture, - help='maximum context length covered by CUDA ' - 'graphs. When a sequence has context length ' - 'larger than this, we fall back to eager mode.') + parser.add_argument( + "--quantization", + "-q", + type=str, + choices=["awq", "gptq", "squeezellm", None], + default=None, + help="Method used to quantize the weights. If " + "None, we first check the `quantization_config` " + "attribute in the model config file. If that is " + "None, we assume the model weights are not " + "quantized and use `dtype` to determine the data " + "type of the weights.", + ) + parser.add_argument( + "--enforce-eager", + action="store_true", + help="Always use eager-mode PyTorch. If False, " + "will use eager mode and CUDA graph in hybrid " + "for maximal performance and flexibility.", + ) + parser.add_argument( + "--max-context-len-to-capture", + type=int, + default=EngineArgs.max_context_len_to_capture, + help="maximum context length covered by CUDA " + "graphs. When a sequence has context length " + "larger than this, we fall back to eager mode.", + ) return parser @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': + def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs": # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. @@ -736,52 +793,73 @@ def from_cli_args(cls, args: argparse.Namespace) -> 'EngineArgs': def create_engine_configs( self, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: - model_config = ModelConfig(self.model, self.tokenizer, - self.tokenizer_mode, self.trust_remote_code, - self.download_dir, self.load_format, - self.dtype, self.seed, self.revision, - self.tokenizer_revision, self.max_model_len, - self.quantization, self.enforce_eager, - self.max_context_len_to_capture, - self.num_audio_tokens, self.num_text_tokens, - ) - cache_config = CacheConfig(self.block_size, - self.gpu_memory_utilization, - self.swap_space, - model_config.get_sliding_window()) - parallel_config = ParallelConfig(self.pipeline_parallel_size, - self.tensor_parallel_size, - self.worker_use_ray, - self.max_parallel_loading_workers) - scheduler_config = SchedulerConfig(self.max_num_batched_tokens, - self.max_num_seqs, - model_config.max_model_len, - self.max_paddings) + model_config = ModelConfig( + self.model, + self.tokenizer, + self.tokenizer_mode, + self.trust_remote_code, + self.download_dir, + self.load_format, + self.dtype, + self.seed, + self.revision, + self.tokenizer_revision, + self.max_model_len, + self.quantization, + self.enforce_eager, + self.max_context_len_to_capture, + self.num_audio_tokens, + self.num_text_tokens, + ) + cache_config = CacheConfig( + self.block_size, + self.gpu_memory_utilization, + self.swap_space, + model_config.get_sliding_window(), + ) + parallel_config = ParallelConfig( + self.pipeline_parallel_size, + self.tensor_parallel_size, + self.worker_use_ray, + self.max_parallel_loading_workers, + ) + scheduler_config = SchedulerConfig( + self.max_num_batched_tokens, + self.max_num_seqs, + model_config.max_model_len, + self.max_paddings, + ) return model_config, cache_config, parallel_config, scheduler_config @dataclass class AsyncEngineArgs(EngineArgs): """Arguments for asynchronous vLLM engine.""" + engine_use_ray: bool = False disable_log_requests: bool = False max_log_len: Optional[int] = None @staticmethod - def add_cli_args( - parser: argparse.ArgumentParser) -> argparse.ArgumentParser: + def add_cli_args(parser: argparse.ArgumentParser) -> argparse.ArgumentParser: parser = EngineArgs.add_cli_args(parser) - parser.add_argument('--engine-use-ray', - action='store_true', - help='use Ray to start the LLM engine in a ' - 'separate process as the server process.') - parser.add_argument('--disable-log-requests', - action='store_true', - help='disable logging requests') - parser.add_argument('--max-log-len', - type=int, - default=None, - help='max number of prompt characters or prompt ' - 'ID numbers being printed in log. ' - 'Default: unlimited.') + parser.add_argument( + "--engine-use-ray", + action="store_true", + help="use Ray to start the LLM engine in a " + "separate process as the server process.", + ) + parser.add_argument( + "--disable-log-requests", + action="store_true", + help="disable logging requests", + ) + parser.add_argument( + "--max-log-len", + type=int, + default=None, + help="max number of prompt characters or prompt " + "ID numbers being printed in log. " + "Default: unlimited.", + ) return parser diff --git a/ChatTTS/model/velocity/llama.py b/ChatTTS/model/velocity/llama.py index 415b09d86..8e6c8a896 100644 --- a/ChatTTS/model/velocity/llama.py +++ b/ChatTTS/model/velocity/llama.py @@ -31,19 +31,26 @@ from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.attention import PagedAttention from vllm.model_executor.layers.layernorm import RMSNorm -from vllm.model_executor.layers.linear import (LinearMethodBase, - MergedColumnParallelLinear, - QKVParallelLinear, - RowParallelLinear) +from vllm.model_executor.layers.linear import ( + LinearMethodBase, + MergedColumnParallelLinear, + QKVParallelLinear, + RowParallelLinear, +) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.sampler import Sampler from vllm.model_executor.layers.vocab_parallel_embedding import ( - VocabParallelEmbedding, ParallelLMHead) + VocabParallelEmbedding, + ParallelLMHead, +) from vllm.model_executor.parallel_utils.parallel_state import ( - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, +) from vllm.model_executor.sampling_metadata import SamplingMetadata -from vllm.model_executor.weight_utils import (default_weight_loader, - hf_model_weights_iterator) +from vllm.model_executor.weight_utils import ( + default_weight_loader, + hf_model_weights_iterator, +) from vllm.sequence import SamplerOutput KVCache = Tuple[torch.Tensor, torch.Tensor] @@ -60,16 +67,19 @@ def __init__( ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( - hidden_size, [intermediate_size] * 2, + hidden_size, + [intermediate_size] * 2, bias=False, - linear_method=linear_method) - self.down_proj = RowParallelLinear(intermediate_size, - hidden_size, - bias=False, - linear_method=linear_method) + linear_method=linear_method, + ) + self.down_proj = RowParallelLinear( + intermediate_size, hidden_size, bias=False, linear_method=linear_method + ) if hidden_act != "silu": - raise ValueError(f"Unsupported activation: {hidden_act}. " - "Only silu is supported for now.") + raise ValueError( + f"Unsupported activation: {hidden_act}. " + "Only silu is supported for now." + ) self.act_fn = SiluAndMul() def forward(self, x): @@ -136,10 +146,9 @@ def __init__( base=rope_theta, rope_scaling=rope_scaling, ) - self.attn = PagedAttention(self.num_heads, - self.head_dim, - self.scaling, - num_kv_heads=self.num_kv_heads) + self.attn = PagedAttention( + self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads + ) def forward( self, @@ -168,8 +177,7 @@ def __init__( self.hidden_size = config.hidden_size rope_theta = getattr(config, "rope_theta", 10000) rope_scaling = getattr(config, "rope_scaling", None) - max_position_embeddings = getattr(config, "max_position_embeddings", - 8192) + max_position_embeddings = getattr(config, "max_position_embeddings", 8192) self.self_attn = LlamaAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, @@ -185,10 +193,10 @@ def __init__( hidden_act=config.hidden_act, linear_method=linear_method, ) - self.input_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) - self.post_attention_layernorm = RMSNorm(config.hidden_size, - eps=config.rms_norm_eps) + self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm( + config.hidden_size, eps=config.rms_norm_eps + ) def forward( self, @@ -203,8 +211,7 @@ def forward( residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: - hidden_states, residual = self.input_layernorm( - hidden_states, residual) + hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, @@ -213,8 +220,7 @@ def forward( ) # Fully Connected - hidden_states, residual = self.post_attention_layernorm( - hidden_states, residual) + hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @@ -234,10 +240,12 @@ def __init__( config.vocab_size, config.hidden_size, ) - self.layers = nn.ModuleList([ - LlamaDecoderLayer(config, linear_method) - for _ in range(config.num_hidden_layers) - ]) + self.layers = nn.ModuleList( + [ + LlamaDecoderLayer(config, linear_method) + for _ in range(config.num_hidden_layers) + ] + ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) def forward( @@ -261,11 +269,13 @@ def forward( hidden_states, _ = self.norm(hidden_states, residual) return hidden_states - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -276,15 +286,15 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, cache_dir, load_format, revision + ): if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -300,10 +310,10 @@ def load_weights(self, if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) + class LlamaForCausalLM(nn.Module): def __init__( @@ -325,8 +335,7 @@ def forward( kv_caches: List[KVCache], input_metadata: InputMetadata, ) -> torch.Tensor: - hidden_states = self.model(input_ids, positions, kv_caches, - input_metadata) + hidden_states = self.model(input_ids, positions, kv_caches, input_metadata) return hidden_states def sample( @@ -334,15 +343,18 @@ def sample( hidden_states: torch.Tensor, sampling_metadata: SamplingMetadata, ) -> Optional[SamplerOutput]: - next_tokens = self.sampler(self.lm_head.weight, hidden_states, - sampling_metadata) + next_tokens = self.sampler( + self.lm_head.weight, hidden_states, sampling_metadata + ) return next_tokens - def load_weights(self, - model_name_or_path: str, - cache_dir: Optional[str] = None, - load_format: str = "auto", - revision: Optional[str] = None): + def load_weights( + self, + model_name_or_path: str, + cache_dir: Optional[str] = None, + load_format: str = "auto", + revision: Optional[str] = None, + ): stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), @@ -353,15 +365,15 @@ def load_weights(self, ] params_dict = dict(self.named_parameters()) for name, loaded_weight in hf_model_weights_iterator( - model_name_or_path, cache_dir, load_format, revision): + model_name_or_path, cache_dir, load_format, revision + ): if "rotary_emb.inv_freq" in name: continue - if ("rotary_emb.cos_cached" in name - or "rotary_emb.sin_cached" in name): + if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name: # Models trained using ColossalAI may include these tensors in # the checkpoint. Skip them. continue - for (param_name, weight_name, shard_id) in stacked_params_mapping: + for param_name, weight_name, shard_id in stacked_params_mapping: if weight_name not in name: continue name = name.replace(weight_name, param_name) @@ -377,6 +389,5 @@ def load_weights(self, if name.endswith(".bias") and name not in params_dict: continue param = params_dict[name] - weight_loader = getattr(param, "weight_loader", - default_weight_loader) + weight_loader = getattr(param, "weight_loader", default_weight_loader) weight_loader(param, loaded_weight) diff --git a/ChatTTS/model/velocity/llm.py b/ChatTTS/model/velocity/llm.py index 9668c87cf..98a90af26 100644 --- a/ChatTTS/model/velocity/llm.py +++ b/ChatTTS/model/velocity/llm.py @@ -103,15 +103,14 @@ def __init__( swap_space=swap_space, enforce_eager=enforce_eager, max_context_len_to_capture=max_context_len_to_capture, - num_audio_tokens = num_audio_tokens, - num_text_tokens = num_text_tokens, + num_audio_tokens=num_audio_tokens, + num_text_tokens=num_text_tokens, **kwargs, ) self.llm_engine = LLMEngine.from_engine_args(engine_args, post_model_path) self.request_counter = Counter() - def get_tokenizer( - self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: + def get_tokenizer(self) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: return self.llm_engine.tokenizer def set_tokenizer( @@ -146,28 +145,29 @@ def generate( completions in the same order as the input prompts. """ if prompts is None and prompt_token_ids is None: - raise ValueError("Either prompts or prompt_token_ids must be " - "provided.") + raise ValueError("Either prompts or prompt_token_ids must be " "provided.") if isinstance(prompts, str): # Convert a single prompt to a list. prompts = [prompts] - if (prompts is not None and prompt_token_ids is not None - and len(prompts) != len(prompt_token_ids)): - raise ValueError("The lengths of prompts and prompt_token_ids " - "must be the same.") + if ( + prompts is not None + and prompt_token_ids is not None + and len(prompts) != len(prompt_token_ids) + ): + raise ValueError( + "The lengths of prompts and prompt_token_ids " "must be the same." + ) if sampling_params is None: # Use default sampling params. sampling_params = SamplingParams() # Add requests to the engine. - num_requests = len(prompts) if prompts is not None else len( - prompt_token_ids) + num_requests = len(prompts) if prompts is not None else len(prompt_token_ids) for i in range(num_requests): prompt = prompts[i] if prompts is not None else None - token_ids = None if prompt_token_ids is None else prompt_token_ids[ - i] + token_ids = None if prompt_token_ids is None else prompt_token_ids[i] self._add_request(prompt, sampling_params, token_ids) - + rtns = self._run_engine(use_tqdm) for i, rtn in enumerate(rtns): token_ids = rtn.outputs[0].token_ids @@ -176,7 +176,7 @@ def generate( token_ids[j] = token_id[0] else: token_ids[j] = list(token_id) - + return rtns def _add_request( @@ -186,8 +186,9 @@ def _add_request( prompt_token_ids: Optional[List[int]], ) -> None: request_id = str(next(self.request_counter)) - self.llm_engine.add_request(request_id, prompt, sampling_params, - prompt_token_ids) + self.llm_engine.add_request( + request_id, prompt, sampling_params, prompt_token_ids + ) def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. diff --git a/ChatTTS/model/velocity/llm_engine.py b/ChatTTS/model/velocity/llm_engine.py index 4a72c0c3f..66dd205ff 100644 --- a/ChatTTS/model/velocity/llm_engine.py +++ b/ChatTTS/model/velocity/llm_engine.py @@ -2,11 +2,9 @@ from collections import defaultdict import os import time -from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, - Union) +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig from ChatTTS.model.velocity.scheduler import Scheduler, SchedulerOutputs from ChatTTS.model.velocity.configs import EngineArgs from vllm.engine.metrics import record_metrics @@ -14,12 +12,18 @@ from vllm.logger import init_logger from ChatTTS.model.velocity.output import RequestOutput from ChatTTS.model.velocity.sampling_params import SamplingParams -from ChatTTS.model.velocity.sequence import (SamplerOutput, Sequence, SequenceGroup, - SequenceGroupOutput, SequenceOutput, SequenceStatus) -from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - get_tokenizer) +from ChatTTS.model.velocity.sequence import ( + SamplerOutput, + Sequence, + SequenceGroup, + SequenceGroupOutput, + SequenceOutput, + SequenceStatus, +) +from vllm.transformers_utils.tokenizer import detokenize_incrementally, get_tokenizer from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port import numpy as np + if ray: from ray.util.scheduling_strategies import PlacementGroupSchedulingStrategy @@ -85,7 +89,7 @@ def __init__( f"enforce_eager={model_config.enforce_eager}, " f"seed={model_config.seed}), " f"post_model_path={post_model_path!r}" - ) + ) # TODO(woosuk): Print more configs in debug mode. self.model_config = model_config @@ -125,8 +129,9 @@ def _init_workers(self): # before CUDA_VISIBLE_DEVICES is set in the Worker from ChatTTS.model.velocity.worker import Worker - assert self.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") + assert ( + self.parallel_config.world_size == 1 + ), "Ray is required if parallel_config.world_size > 1." self.workers: List[Worker] = [] distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}" @@ -138,13 +143,12 @@ def _init_workers(self): rank=0, distributed_init_method=distributed_init_method, is_driver_worker=True, - post_model_path = self.post_model_path + post_model_path=self.post_model_path, ) self._run_workers("init_model") self._run_workers("load_model") - def _init_workers_ray(self, placement_group: "PlacementGroup", - **ray_remote_kwargs): + def _init_workers_ray(self, placement_group: "PlacementGroup", **ray_remote_kwargs): if self.parallel_config.tensor_parallel_size == 1: num_gpus = self.cache_config.gpu_memory_utilization else: @@ -181,20 +185,22 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", raise ValueError( "Ray does not allocate any GPUs on the driver node. Consider " "adjusting the Ray placement group or running the driver on a " - "GPU node.") + "GPU node." + ) driver_node_id, driver_gpu_ids = ray.get( - self.driver_dummy_worker.get_node_and_gpu_ids.remote()) + self.driver_dummy_worker.get_node_and_gpu_ids.remote() + ) worker_node_and_gpu_ids = ray.get( - [worker.get_node_and_gpu_ids.remote() for worker in self.workers]) + [worker.get_node_and_gpu_ids.remote() for worker in self.workers] + ) node_workers = defaultdict(list) node_gpus = defaultdict(list) node_workers[driver_node_id].append(0) node_gpus[driver_node_id].extend(driver_gpu_ids) - for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, - start=1): + for i, (node_id, gpu_ids) in enumerate(worker_node_and_gpu_ids, start=1): node_workers[node_id].append(i) node_gpus[node_id].extend(gpu_ids) for node_id, gpu_ids in node_gpus.items(): @@ -216,10 +222,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", parallel_config = copy.deepcopy(self.parallel_config) scheduler_config = copy.deepcopy(self.scheduler_config) - for rank, (worker, (node_id, - _)) in enumerate(zip(self.workers, - worker_node_and_gpu_ids), - start=1): + for rank, (worker, (node_id, _)) in enumerate( + zip(self.workers, worker_node_and_gpu_ids), start=1 + ): local_rank = node_workers[node_id].index(rank) worker.init_worker.remote( lambda rank=rank, local_rank=local_rank: Worker( @@ -229,7 +234,8 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", local_rank, rank, distributed_init_method, - )) + ) + ) driver_rank = 0 driver_local_rank = node_workers[driver_node_id].index(driver_rank) @@ -246,8 +252,7 @@ def _init_workers_ray(self, placement_group: "PlacementGroup", self._run_workers("init_model") self._run_workers( "load_model", - max_concurrent_workers=self.parallel_config. - max_parallel_loading_workers, + max_concurrent_workers=self.parallel_config.max_parallel_loading_workers, ) def _verify_args(self) -> None: @@ -270,13 +275,16 @@ def _init_cache(self) -> None: num_gpu_blocks = min(b[0] for b in num_blocks) num_cpu_blocks = min(b[1] for b in num_blocks) # FIXME(woosuk): Change to debug log. - logger.info(f"# GPU blocks: {num_gpu_blocks}, " - f"# CPU blocks: {num_cpu_blocks}") + logger.info( + f"# GPU blocks: {num_gpu_blocks}, " f"# CPU blocks: {num_cpu_blocks}" + ) if num_gpu_blocks <= 0: - raise ValueError("No available memory for the cache blocks. " - "Try increasing `gpu_memory_utilization` when " - "initializing the engine.") + raise ValueError( + "No available memory for the cache blocks. " + "Try increasing `gpu_memory_utilization` when " + "initializing the engine." + ) max_seq_len = self.cache_config.block_size * num_gpu_blocks if self.model_config.max_model_len > max_seq_len: raise ValueError( @@ -284,7 +292,8 @@ def _init_cache(self) -> None: "is larger than the maximum number of tokens that can be " f"stored in KV cache ({max_seq_len}). Try increasing " "`gpu_memory_utilization` or decreasing `max_model_len` when " - "initializing the engine.") + "initializing the engine." + ) self.cache_config.num_gpu_blocks = num_gpu_blocks self.cache_config.num_cpu_blocks = num_cpu_blocks @@ -296,7 +305,9 @@ def _init_cache(self) -> None: self._run_workers("warm_up_model") @classmethod - def from_engine_args(cls, engine_args: EngineArgs, post_model_path=None) -> "LLMEngine": + def from_engine_args( + cls, engine_args: EngineArgs, post_model_path=None + ) -> "LLMEngine": """Creates an LLM engine from the engine arguments.""" # Create the engine configs. engine_configs = engine_args.create_engine_configs() @@ -304,11 +315,12 @@ def from_engine_args(cls, engine_args: EngineArgs, post_model_path=None) -> "LLM # Initialize the cluster. placement_group = initialize_cluster(parallel_config) # Create the LLM engine. - engine = cls(*engine_configs, - placement_group, - log_stats=not engine_args.disable_log_stats, - post_model_path = post_model_path - ) + engine = cls( + *engine_configs, + placement_group, + log_stats=not engine_args.disable_log_stats, + post_model_path=post_model_path, + ) return engine def add_request( @@ -337,7 +349,7 @@ def add_request( """ if arrival_time is None: arrival_time = time.monotonic() - + assert prompt_token_ids is not None, "prompt_token_ids must be provided" # Create the sequences. block_size = self.cache_config.block_size @@ -345,8 +357,7 @@ def add_request( seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) # Create the sequence group. - seq_group = SequenceGroup(request_id, [seq], sampling_params, - arrival_time) + seq_group = SequenceGroup(request_id, [seq], sampling_params, arrival_time) # Add the sequence group to the scheduler. self.scheduler.add_seq_group(seq_group) @@ -383,13 +394,13 @@ def _check_beam_search_early_stopping( if early_stopping is True: return True - current_worst_score = (current_worst_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + current_worst_score = current_worst_seq.get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id + ) if early_stopping is False: - highest_attainable_score = (best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id + ) else: assert early_stopping == "never" if length_penalty > 0.0: @@ -397,26 +408,27 @@ def _check_beam_search_early_stopping( # sequences. The highest attainable score calculation is # based on the longest possible sequence length in this case. max_possible_length = max( - best_running_seq.get_prompt_len() + - sampling_params.max_tokens, - self.scheduler_config.max_model_len) - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id, - seq_len=max_possible_length)) + best_running_seq.get_prompt_len() + sampling_params.max_tokens, + self.scheduler_config.max_model_len, + ) + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id, + seq_len=max_possible_length, + ) else: # Otherwise, beam search will prefer shorter sequences. The # highest attainable score calculation is based on the current # sequence length. - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id, + ) return current_worst_score >= highest_attainable_score - def _process_sequence_group_outputs(self, seq_group: SequenceGroup, - outputs: SequenceGroupOutput) -> None: + def _process_sequence_group_outputs( + self, seq_group: SequenceGroup, outputs: SequenceGroupOutput + ) -> None: # Process prompt logprobs prompt_logprobs = outputs.prompt_logprobs if prompt_logprobs is not None: @@ -426,10 +438,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, samples = outputs.samples parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING) existing_finished_seqs = seq_group.get_finished_seqs() - parent_child_dict = { - parent_seq.seq_id: [] - for parent_seq in parent_seqs - } + parent_child_dict = {parent_seq.seq_id: [] for parent_seq in parent_seqs} for sample in samples: parent_child_dict[sample.parent_seq_id].append(sample) # List of (child, parent) @@ -437,8 +446,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Process the child samples for each parent sequence for parent in parent_seqs: - child_samples: List[SequenceOutput] = parent_child_dict[ - parent.seq_id] + child_samples: List[SequenceOutput] = parent_child_dict[parent.seq_id] if len(child_samples) == 0: # This parent sequence has no children samples. Remove # the parent sequence from the sequence group since it will @@ -451,27 +459,29 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, for child_sample in child_samples[:-1]: new_child_seq_id = next(self.seq_counter) child = parent.fork(new_child_seq_id) - child.append_token_id(child_sample.output_token, - child_sample.logprobs, - child_sample.hidden_states, - child_sample.finished - ) + child.append_token_id( + child_sample.output_token, + child_sample.logprobs, + child_sample.hidden_states, + child_sample.finished, + ) child_seqs.append((child, parent)) # Continue the parent sequence for the last child sample. # We reuse the parent sequence here to reduce redundant memory # copies, especially when using non-beam search sampling methods. last_child_sample = child_samples[-1] - parent.append_token_id(last_child_sample.output_token, - last_child_sample.logprobs, - last_child_sample.hidden_states, - last_child_sample.finished - ) + parent.append_token_id( + last_child_sample.output_token, + last_child_sample.logprobs, + last_child_sample.hidden_states, + last_child_sample.finished, + ) child_seqs.append((parent, parent)) for seq, _ in child_seqs: # self._decode_sequence(seq, seq_group.sampling_params) self._check_stop(seq, seq_group.sampling_params) - + # Non-beam search case if not seq_group.sampling_params.use_beam_search: # For newly created child sequences, add them to the sequence group @@ -501,16 +511,18 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # Select the newly finished sequences with the highest scores # to replace existing finished sequences. # Tuple of (seq, parent, is_new) - existing_finished_seqs = [(seq, None, False) - for seq in existing_finished_seqs] - new_finished_seqs = [(seq, parent, True) for seq, parent in child_seqs - if seq.is_finished()] + existing_finished_seqs = [(seq, None, False) for seq in existing_finished_seqs] + new_finished_seqs = [ + (seq, parent, True) for seq, parent in child_seqs if seq.is_finished() + ] all_finished_seqs = existing_finished_seqs + new_finished_seqs # Sort the finished sequences by their scores. - all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), - reverse=True) + all_finished_seqs.sort( + key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id + ), + reverse=True, + ) for seq, parent, is_new in all_finished_seqs[:beam_width]: if is_new: # A newly generated child sequence finishes and has a high @@ -532,13 +544,16 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # select the top beam_width sequences from the running # sequences for the next iteration to continue the beam # search. - running_child_seqs = [(seq, parent) for seq, parent in child_seqs - if not seq.is_finished()] + running_child_seqs = [ + (seq, parent) for seq, parent in child_seqs if not seq.is_finished() + ] # Sort the running sequences by their scores. - running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), - reverse=True) + running_child_seqs.sort( + key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, eos_token_id=self.tokenizer.eos_token_id + ), + reverse=True, + ) # Check if we can stop the beam search. if len(running_child_seqs) == 0: @@ -553,7 +568,10 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, current_worst_seq = all_finished_seqs[beam_width - 1][0] stop_beam_search = self._check_beam_search_early_stopping( seq_group.sampling_params.early_stopping, - seq_group.sampling_params, best_running_seq, current_worst_seq) + seq_group.sampling_params, + best_running_seq, + current_worst_seq, + ) if stop_beam_search: # Stop the beam search and remove all the running sequences from @@ -593,8 +611,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, self.scheduler.free_seq(seq) def _process_model_outputs( - self, output: SamplerOutput, - scheduler_outputs: SchedulerOutputs) -> List[RequestOutput]: + self, output: SamplerOutput, scheduler_outputs: SchedulerOutputs + ) -> List[RequestOutput]: # Update the scheduled sequence groups with the model outputs. scheduled_seq_groups = scheduler_outputs.scheduled_seq_groups for seq_group, outputs in zip(scheduled_seq_groups, output): @@ -605,15 +623,15 @@ def _process_model_outputs( # Create the outputs. request_outputs: List[RequestOutput] = [] - for seq_group in (scheduled_seq_groups + - scheduler_outputs.ignored_seq_groups): + for seq_group in scheduled_seq_groups + scheduler_outputs.ignored_seq_groups: request_output = RequestOutput.from_seq_group(seq_group) request_outputs.append(request_output) if self.log_stats: # Log the system stats. - self._log_system_stats(scheduler_outputs.prompt_run, - scheduler_outputs.num_batched_tokens) + self._log_system_stats( + scheduler_outputs.prompt_run, scheduler_outputs.num_batched_tokens + ) return request_outputs def step(self) -> List[RequestOutput]: @@ -626,7 +644,7 @@ def step(self) -> List[RequestOutput]: the sequences and returns the newly generated results. """ seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule() - + if not scheduler_outputs.is_empty(): # Execute the model. all_outputs = self._run_workers( @@ -636,7 +654,8 @@ def step(self) -> List[RequestOutput]: "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, "blocks_to_copy": scheduler_outputs.blocks_to_copy, - }) + }, + ) # Only the driver worker returns the sampling results. output = all_outputs[0] @@ -662,11 +681,14 @@ def _log_system_stats( return # Discard the old stats. - self.num_prompt_tokens = [(t, n) for t, n in self.num_prompt_tokens - if now - t < _LOGGING_INTERVAL_SEC] - self.num_generation_tokens = [(t, n) - for t, n in self.num_generation_tokens - if now - t < _LOGGING_INTERVAL_SEC] + self.num_prompt_tokens = [ + (t, n) for t, n in self.num_prompt_tokens if now - t < _LOGGING_INTERVAL_SEC + ] + self.num_generation_tokens = [ + (t, n) + for t, n in self.num_generation_tokens + if now - t < _LOGGING_INTERVAL_SEC + ] if len(self.num_prompt_tokens) > 1: total_num_tokens = sum(n for _, n in self.num_prompt_tokens[:-1]) @@ -675,23 +697,20 @@ def _log_system_stats( else: avg_prompt_throughput = 0.0 if len(self.num_generation_tokens) > 1: - total_num_tokens = sum(n - for _, n in self.num_generation_tokens[:-1]) + total_num_tokens = sum(n for _, n in self.num_generation_tokens[:-1]) window = now - self.num_generation_tokens[0][0] avg_generation_throughput = total_num_tokens / window else: avg_generation_throughput = 0.0 total_num_gpu_blocks = self.cache_config.num_gpu_blocks - num_free_gpu_blocks = ( - self.scheduler.block_manager.get_num_free_gpu_blocks()) + num_free_gpu_blocks = self.scheduler.block_manager.get_num_free_gpu_blocks() num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks total_num_cpu_blocks = self.cache_config.num_cpu_blocks if total_num_cpu_blocks > 0: - num_free_cpu_blocks = ( - self.scheduler.block_manager.get_num_free_cpu_blocks()) + num_free_cpu_blocks = self.scheduler.block_manager.get_num_free_cpu_blocks() num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks else: @@ -707,29 +726,32 @@ def _log_system_stats( cpu_cache_usage=cpu_cache_usage, ) - logger.info("Avg prompt throughput: " - f"{avg_prompt_throughput:.1f} tokens/s, " - "Avg generation throughput: " - f"{avg_generation_throughput:.1f} tokens/s, " - f"Running: {len(self.scheduler.running)} reqs, " - f"Swapped: {len(self.scheduler.swapped)} reqs, " - f"Pending: {len(self.scheduler.waiting)} reqs, " - f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " - f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%") + logger.info( + "Avg prompt throughput: " + f"{avg_prompt_throughput:.1f} tokens/s, " + "Avg generation throughput: " + f"{avg_generation_throughput:.1f} tokens/s, " + f"Running: {len(self.scheduler.running)} reqs, " + f"Swapped: {len(self.scheduler.swapped)} reqs, " + f"Pending: {len(self.scheduler.waiting)} reqs, " + f"GPU KV cache usage: {gpu_cache_usage * 100:.1f}%, " + f"CPU KV cache usage: {cpu_cache_usage * 100:.1f}%" + ) self.last_logging_time = now def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" - (new_tokens, new_output_text, prefix_offset, - read_offset) = detokenize_incrementally( - self.tokenizer, - all_input_ids=seq.get_token_ids(), - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms.spaces_between_special_tokens, - ) + (new_tokens, new_output_text, prefix_offset, read_offset) = ( + detokenize_incrementally( + self.tokenizer, + all_input_ids=seq.get_token_ids(), + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms.spaces_between_special_tokens, + ) + ) if seq.tokens is None: seq.tokens = new_tokens else: @@ -738,21 +760,20 @@ def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: seq.read_offset = read_offset seq.output_text += new_output_text - def _check_stop(self, seq: Sequence, - sampling_params: SamplingParams) -> None: + def _check_stop(self, seq: Sequence, sampling_params: SamplingParams) -> None: """Stop the finished sequences.""" for stop_str in sampling_params.stop: if seq.output_text.endswith(stop_str): if not sampling_params.include_stop_str_in_output: # Truncate the output text so that the stop string is # not included in the output. - seq.output_text = seq.output_text[:-len(stop_str)] + seq.output_text = seq.output_text[: -len(stop_str)] seq.status = SequenceStatus.FINISHED_STOPPED return if seq.data.finished: seq.status = SequenceStatus.FINISHED_STOPPED return - + for token_id in seq.get_last_token_id(): if token_id == sampling_params.eos_token: seq.status = SequenceStatus.FINISHED_STOPPED @@ -769,11 +790,12 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id()[0] == sampling_params.eos_token): + if (not sampling_params.ignore_eos) and seq.get_last_token_id()[ + 0 + ] == sampling_params.eos_token: seq.status = SequenceStatus.FINISHED_STOPPED return - + def _run_workers( self, method: str, @@ -786,8 +808,7 @@ def _run_workers( """Runs the given method on all workers.""" if max_concurrent_workers: - raise NotImplementedError( - "max_concurrent_workers is not supported yet.") + raise NotImplementedError("max_concurrent_workers is not supported yet.") # Start the ray workers first. ray_worker_outputs = [ @@ -801,8 +822,9 @@ def _run_workers( driver_kwargs = kwargs # Start the driver worker after all the ray workers. - driver_worker_output = getattr(self.driver_worker, - method)(*driver_args, **driver_kwargs) + driver_worker_output = getattr(self.driver_worker, method)( + *driver_args, **driver_kwargs + ) # Get the results of the ray workers. if self.workers: diff --git a/ChatTTS/model/velocity/model_loader.py b/ChatTTS/model/velocity/model_loader.py index bb4605875..40de6d960 100644 --- a/ChatTTS/model/velocity/model_loader.py +++ b/ChatTTS/model/velocity/model_loader.py @@ -1,4 +1,5 @@ """Utilities for selecting and loading models.""" + import contextlib from typing import Type @@ -8,10 +9,10 @@ from vllm.config import ModelConfig from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.weight_utils import (get_quant_config, - initialize_dummy_weights) +from vllm.model_executor.weight_utils import get_quant_config, initialize_dummy_weights import importlib + @contextlib.contextmanager def _set_default_torch_dtype(dtype: torch.dtype): """Sets the default torch dtype to the given dtype.""" @@ -22,19 +23,24 @@ def _set_default_torch_dtype(dtype: torch.dtype): def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: - model_cls = getattr(importlib.import_module("ChatTTS.model.velocity.llama"), "LlamaModel", None) + model_cls = getattr( + importlib.import_module("ChatTTS.model.velocity.llama"), "LlamaModel", None + ) return model_cls + def get_model(model_config: ModelConfig) -> nn.Module: model_class = _get_model_architecture(model_config.hf_config) # Get the (maybe quantized) linear method. linear_method = None if model_config.quantization is not None: - quant_config = get_quant_config(model_config.quantization, - model_config.model, - model_config.hf_config, - model_config.download_dir) + quant_config = get_quant_config( + model_config.quantization, + model_config.model, + model_config.hf_config, + model_config.download_dir, + ) capability = torch.cuda.get_device_capability() capability = capability[0] * 10 + capability[1] if capability < quant_config.get_min_capability(): @@ -42,13 +48,15 @@ def get_model(model_config: ModelConfig) -> nn.Module: f"The quantization method {model_config.quantization} is not " "supported for the current GPU. " f"Minimum capability: {quant_config.get_min_capability()}. " - f"Current capability: {capability}.") + f"Current capability: {capability}." + ) supported_dtypes = quant_config.get_supported_act_dtypes() if model_config.dtype not in supported_dtypes: raise ValueError( f"{model_config.dtype} is not supported for quantization " f"method {model_config.quantization}. Supported dtypes: " - f"{supported_dtypes}") + f"{supported_dtypes}" + ) linear_method = quant_config.get_linear_method() with _set_default_torch_dtype(model_config.dtype): @@ -62,6 +70,10 @@ def get_model(model_config: ModelConfig) -> nn.Module: initialize_dummy_weights(model) else: # Load the weights from the cached or downloaded files. - model.load_weights(model_config.model, model_config.download_dir, - model_config.load_format, model_config.revision) + model.load_weights( + model_config.model, + model_config.download_dir, + model_config.load_format, + model_config.revision, + ) return model.eval() diff --git a/ChatTTS/model/velocity/model_runner.py b/ChatTTS/model/velocity/model_runner.py index 86ed4a730..5b0f2c2d8 100644 --- a/ChatTTS/model/velocity/model_runner.py +++ b/ChatTTS/model/velocity/model_runner.py @@ -10,9 +10,17 @@ from ChatTTS.model.velocity.model_loader import get_model from vllm.model_executor import InputMetadata, SamplingMetadata from vllm.model_executor.parallel_utils.communication_op import ( - broadcast, broadcast_object_list) + broadcast, + broadcast_object_list, +) from ChatTTS.model.velocity.sampling_params import SamplingParams, SamplingType -from ChatTTS.model.velocity.sequence import SamplerOutput, SequenceData, SequenceGroupMetadata, SequenceGroupOutput, SequenceOutput +from ChatTTS.model.velocity.sequence import ( + SamplerOutput, + SequenceData, + SequenceGroupMetadata, + SequenceGroupOutput, + SequenceOutput, +) from vllm.utils import in_wsl from ChatTTS.model.velocity.post_model import Post_model, Sampler from safetensors.torch import safe_open @@ -34,18 +42,19 @@ def __init__( parallel_config: ParallelConfig, scheduler_config: SchedulerConfig, is_driver_worker: bool = False, - post_model_path: str = None + post_model_path: str = None, ): self.model_config = model_config self.parallel_config = parallel_config self.scheduler_config = scheduler_config self.is_driver_worker = is_driver_worker self.post_model_path = post_model_path - + # model_config can be None in tests/samplers/test_sampler.py. # FIXME(woosuk): This is a hack to make the tests work. Refactor this. - self.sliding_window = (model_config.get_sliding_window() - if model_config is not None else None) + self.sliding_window = ( + model_config.get_sliding_window() if model_config is not None else None + ) self.model = None self.block_size = None # Set after initial profiling. @@ -54,7 +63,9 @@ def __init__( self.max_context_len_to_capture = ( self.model_config.max_context_len_to_capture - if self.model_config is not None else 0) + if self.model_config is not None + else 0 + ) # When using CUDA graph, the input block tables must be padded to # max_context_len_to_capture. However, creating the block table in # Python can be expensive. To optimize this, we cache the block table @@ -68,28 +79,27 @@ def __init__( def load_model(self) -> None: self.model = get_model(self.model_config) self.post_model = Post_model( - self.model_config.get_hidden_size(), - self.model_config.num_audio_tokens, - self.model_config.num_text_tokens - ) + self.model_config.get_hidden_size(), + self.model_config.num_audio_tokens, + self.model_config.num_text_tokens, + ) state_dict_tensors = {} with safe_open(self.post_model_path, framework="pt", device=0) as f: for k in f.keys(): state_dict_tensors[k] = f.get_tensor(k) self.post_model.load_state_dict(state_dict_tensors) self.post_model.to(next(self.model.parameters())).eval() - self.sampler = Sampler( - self.post_model, - self.model_config.num_audio_tokens, - 4 - ) + self.sampler = Sampler(self.post_model, self.model_config.num_audio_tokens, 4) + def set_block_size(self, block_size: int) -> None: self.block_size = block_size - max_num_blocks = (self.max_context_len_to_capture + block_size - - 1) // block_size + max_num_blocks = ( + self.max_context_len_to_capture + block_size - 1 + ) // block_size self.graph_block_tables = np.zeros( - (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32) + (max(_BATCH_SIZES_TO_CAPTURE), max_num_blocks), dtype=np.int32 + ) def _prepare_prompt( self, @@ -145,18 +155,15 @@ def _prepare_prompt( slot_mapping[-1].append(slot) max_prompt_len = max(prompt_lens) - input_tokens = _make_tensor_with_pad(input_tokens, - max_prompt_len, - pad=0, - dtype=torch.long) - input_positions = _make_tensor_with_pad(input_positions, - max_prompt_len, - pad=0, - dtype=torch.long) - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_prompt_len, - pad=_PAD_SLOT_ID, - dtype=torch.long) + input_tokens = _make_tensor_with_pad( + input_tokens, max_prompt_len, pad=0, dtype=torch.long + ) + input_positions = _make_tensor_with_pad( + input_positions, max_prompt_len, pad=0, dtype=torch.long + ) + slot_mapping = _make_tensor_with_pad( + slot_mapping, max_prompt_len, pad=_PAD_SLOT_ID, dtype=torch.long + ) input_metadata = InputMetadata( is_prompt=True, @@ -192,8 +199,11 @@ def _prepare_decode( position = seq_len - 1 input_positions.append([position]) - context_len = seq_len if self.sliding_window is None else min( - seq_len, self.sliding_window) + context_len = ( + seq_len + if self.sliding_window is None + else min(seq_len, self.sliding_window) + ) context_lens.append(context_len) block_table = seq_group_metadata.block_tables[seq_id] @@ -203,8 +213,7 @@ def _prepare_decode( slot_mapping.append([slot]) if self.sliding_window is not None: - sliding_window_blocks = (self.sliding_window // - self.block_size) + sliding_window_blocks = self.sliding_window // self.block_size block_table = block_table[-sliding_window_blocks:] block_tables.append(block_table) @@ -213,7 +222,8 @@ def _prepare_decode( use_captured_graph = ( not self.model_config.enforce_eager and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] - and max_context_len <= self.max_context_len_to_capture) + and max_context_len <= self.max_context_len_to_capture + ) if use_captured_graph: # Pad the input tokens, positions, and slot mapping to match the # batch size of the captured graph. @@ -227,24 +237,16 @@ def _prepare_decode( block_tables.append([]) batch_size = graph_batch_size - input_tokens = _make_tensor_with_pad(input_tokens, - max_len=1, - pad=0, - dtype=torch.long, - device="cuda") - input_positions = _make_tensor_with_pad(input_positions, - max_len=1, - pad=0, - dtype=torch.long, - device="cuda") - slot_mapping = _make_tensor_with_pad(slot_mapping, - max_len=1, - pad=_PAD_SLOT_ID, - dtype=torch.long, - device="cuda") - context_lens = torch.tensor(context_lens, - dtype=torch.int, - device="cuda") + input_tokens = _make_tensor_with_pad( + input_tokens, max_len=1, pad=0, dtype=torch.long, device="cuda" + ) + input_positions = _make_tensor_with_pad( + input_positions, max_len=1, pad=0, dtype=torch.long, device="cuda" + ) + slot_mapping = _make_tensor_with_pad( + slot_mapping, max_len=1, pad=_PAD_SLOT_ID, dtype=torch.long, device="cuda" + ) + context_lens = torch.tensor(context_lens, dtype=torch.int, device="cuda") if use_captured_graph: # The shape of graph_block_tables is @@ -252,7 +254,7 @@ def _prepare_decode( input_block_tables = self.graph_block_tables[:batch_size] for i, block_table in enumerate(block_tables): if block_table: - input_block_tables[i, :len(block_table)] = block_table + input_block_tables[i, : len(block_table)] = block_table block_tables = torch.tensor(input_block_tables, device="cuda") else: block_tables = _make_tensor_with_pad( @@ -297,34 +299,38 @@ def _prepare_sample( # NOTE: prompt token positions do not need sample, skip categorized_sample_indices_start_idx += prompt_len - 1 - categorized_sample_indices[ - sampling_params.sampling_type].append( - categorized_sample_indices_start_idx) + categorized_sample_indices[sampling_params.sampling_type].append( + categorized_sample_indices_start_idx + ) categorized_sample_indices_start_idx += 1 if sampling_params.prompt_logprobs is not None: selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + prompt_len - 1)) - selected_token_indices.append(selected_token_start_idx + - prompt_len - 1) + range( + selected_token_start_idx, + selected_token_start_idx + prompt_len - 1, + ) + ) + selected_token_indices.append(selected_token_start_idx + prompt_len - 1) selected_token_start_idx += max_prompt_len else: num_seqs = len(seq_ids) selected_token_indices.extend( - range(selected_token_start_idx, - selected_token_start_idx + num_seqs)) + range(selected_token_start_idx, selected_token_start_idx + num_seqs) + ) selected_token_start_idx += num_seqs - categorized_sample_indices[ - sampling_params.sampling_type].extend( - range(categorized_sample_indices_start_idx, - categorized_sample_indices_start_idx + num_seqs)) + categorized_sample_indices[sampling_params.sampling_type].extend( + range( + categorized_sample_indices_start_idx, + categorized_sample_indices_start_idx + num_seqs, + ) + ) categorized_sample_indices_start_idx += num_seqs - selected_token_indices = _async_h2d(selected_token_indices, - dtype=torch.long, - pin_memory=not self.in_wsl) + selected_token_indices = _async_h2d( + selected_token_indices, dtype=torch.long, pin_memory=not self.in_wsl + ) categorized_sample_indices = { t: _async_h2d(seq_ids, dtype=torch.int, pin_memory=not self.in_wsl) for t, seq_ids in categorized_sample_indices.items() @@ -353,14 +359,17 @@ def prepare_input_tensors( is_prompt = seq_group_metadata_list[0].is_prompt # Prepare input tensors. if is_prompt: - (input_tokens, input_positions, input_metadata, - prompt_lens) = self._prepare_prompt(seq_group_metadata_list) + (input_tokens, input_positions, input_metadata, prompt_lens) = ( + self._prepare_prompt(seq_group_metadata_list) + ) else: - (input_tokens, input_positions, input_metadata - ) = self._prepare_decode(seq_group_metadata_list) + (input_tokens, input_positions, input_metadata) = self._prepare_decode( + seq_group_metadata_list + ) prompt_lens = [] - sampling_metadata = self._prepare_sample(seq_group_metadata_list, - prompt_lens) + sampling_metadata = self._prepare_sample( + seq_group_metadata_list, prompt_lens + ) def get_size_or_none(x: Optional[torch.Tensor]): return x.size() if x is not None else None @@ -369,24 +378,15 @@ def get_size_or_none(x: Optional[torch.Tensor]): # its shape and then broadcast the tensor to avoid high # serialization cost. py_data = { - "input_tokens_size": - input_tokens.size(), - "input_positions_size": - input_positions.size(), - "is_prompt": - input_metadata.is_prompt, - "slot_mapping_size": - get_size_or_none(input_metadata.slot_mapping), - "max_context_len": - input_metadata.max_context_len, - "context_lens_size": - get_size_or_none(input_metadata.context_lens), - "block_tables_size": - get_size_or_none(input_metadata.block_tables), - "use_cuda_graph": - input_metadata.use_cuda_graph, - "selected_token_indices_size": - sampling_metadata.selected_token_indices.size(), + "input_tokens_size": input_tokens.size(), + "input_positions_size": input_positions.size(), + "is_prompt": input_metadata.is_prompt, + "slot_mapping_size": get_size_or_none(input_metadata.slot_mapping), + "max_context_len": input_metadata.max_context_len, + "context_lens_size": get_size_or_none(input_metadata.context_lens), + "block_tables_size": get_size_or_none(input_metadata.block_tables), + "use_cuda_graph": input_metadata.use_cuda_graph, + "selected_token_indices_size": sampling_metadata.selected_token_indices.size(), } broadcast_object_list([py_data], src=0) # TODO(zhuohan): Combine the broadcasts or set async_op=True. @@ -403,39 +403,38 @@ def get_size_or_none(x: Optional[torch.Tensor]): receving_list = [None] broadcast_object_list(receving_list, src=0) py_data = receving_list[0] - input_tokens = torch.empty(*py_data["input_tokens_size"], - dtype=torch.long, - device="cuda") + input_tokens = torch.empty( + *py_data["input_tokens_size"], dtype=torch.long, device="cuda" + ) broadcast(input_tokens, src=0) - input_positions = torch.empty(*py_data["input_positions_size"], - dtype=torch.long, - device="cuda") + input_positions = torch.empty( + *py_data["input_positions_size"], dtype=torch.long, device="cuda" + ) broadcast(input_positions, src=0) if py_data["slot_mapping_size"] is not None: - slot_mapping = torch.empty(*py_data["slot_mapping_size"], - dtype=torch.long, - device="cuda") + slot_mapping = torch.empty( + *py_data["slot_mapping_size"], dtype=torch.long, device="cuda" + ) broadcast(slot_mapping, src=0) else: slot_mapping = None if py_data["context_lens_size"] is not None: - context_lens = torch.empty(*py_data["context_lens_size"], - dtype=torch.int, - device="cuda") + context_lens = torch.empty( + *py_data["context_lens_size"], dtype=torch.int, device="cuda" + ) broadcast(context_lens, src=0) else: context_lens = None if py_data["block_tables_size"] is not None: - block_tables = torch.empty(*py_data["block_tables_size"], - dtype=torch.int, - device="cuda") + block_tables = torch.empty( + *py_data["block_tables_size"], dtype=torch.int, device="cuda" + ) broadcast(block_tables, src=0) else: block_tables = None selected_token_indices = torch.empty( - *py_data["selected_token_indices_size"], - dtype=torch.long, - device="cuda") + *py_data["selected_token_indices_size"], dtype=torch.long, device="cuda" + ) broadcast(selected_token_indices, src=0) input_metadata = InputMetadata( is_prompt=py_data["is_prompt"], @@ -463,7 +462,8 @@ def execute_model( kv_caches: List[Tuple[torch.Tensor, torch.Tensor]], ) -> Optional[SamplerOutput]: input_tokens, input_positions, input_metadata, sampling_metadata = ( - self.prepare_input_tensors(seq_group_metadata_list)) + self.prepare_input_tensors(seq_group_metadata_list) + ) # print(sampling_metadata.seq_data) seq_groups = [] input_tokens_history = [] @@ -476,14 +476,16 @@ def execute_model( else: tokens_history = [list(token) for token in tokens_history] input_tokens_history.append(tokens_history) - input_tokens_history = torch.tensor(input_tokens_history).to(input_tokens.device) - # token_ids = rtn.outputs[0].token_ids - # for j, token_id in enumerate(token_ids): - # if len(token_id) == 1: - # token_ids[j] = token_id[0] - # else: - # token_ids[j] = list(token_id) - + input_tokens_history = torch.tensor(input_tokens_history).to( + input_tokens.device + ) + # token_ids = rtn.outputs[0].token_ids + # for j, token_id in enumerate(token_ids): + # if len(token_id) == 1: + # token_ids[j] = token_id[0] + # else: + # token_ids[j] = list(token_id) + # Execute the model. # print("it1",input_tokens) if len(input_tokens.shape) == 2: @@ -494,25 +496,29 @@ def execute_model( # print("it2",input_tokens.shape) text_mask = input_tokens != 0 text_mask = text_mask[:, :, 0] - + if input_metadata.use_cuda_graph: graph_batch_size = input_tokens.shape[0] model_executable = self.graph_runners[graph_batch_size] else: model_executable = self.model - + infer_text = sampling_metadata.seq_groups[0][1].infer_text temperture = sampling_metadata.seq_groups[0][1].temperature if not infer_text: temperture = torch.tensor(temperture).to(input_tokens.device) - logits_processors, logits_warpers = sampling_metadata.seq_groups[0][1].logits_processors + logits_processors, logits_warpers = sampling_metadata.seq_groups[0][ + 1 + ].logits_processors # print(logits_processors, logits_warpers) min_new_token = sampling_metadata.seq_groups[0][1].min_new_token eos_token = sampling_metadata.seq_groups[0][1].eos_token start_idx = sampling_metadata.seq_groups[0][1].start_idx if input_tokens.shape[-2] == 1: if infer_text: - input_emb: torch.Tensor = self.post_model.emb_text(input_tokens[:, :, 0]) + input_emb: torch.Tensor = self.post_model.emb_text( + input_tokens[:, :, 0] + ) else: code_emb = [ self.post_model.emb_code[i](input_tokens[:, :, i]) @@ -531,7 +537,11 @@ def execute_model( # print(hidden_states.shape) # print(input_tokens) idx_next, logprob, finish = self.sampler.sample( - inputs_ids=input_tokens if input_tokens_history.shape[-2] == 0 else input_tokens_history, + inputs_ids=( + input_tokens + if input_tokens_history.shape[-2] == 0 + else input_tokens_history + ), hidden_states=hidden_states, infer_text=infer_text, temperature=temperture, @@ -540,11 +550,11 @@ def execute_model( min_new_token=min_new_token, now_length=1, eos_token=eos_token, - start_idx=start_idx + start_idx=start_idx, ) # print(logprob.shape, idx_next.shape) if len(logprob.shape) == 2: - logprob = logprob[:,None,:] + logprob = logprob[:, None, :] logprob = torch.gather(logprob, -1, idx_next.transpose(-1, -2))[:, :, 0] # print("测试",idx_next.shape, logprob.shape) # Sample the next token. @@ -557,14 +567,16 @@ def execute_model( idx_next_i = idx_next[i, 0, :].cpu().tolist() logprob_i = logprob[i].cpu().tolist() result = SequenceGroupOutput( - samples = [SequenceOutput( - parent_seq_id=seq_groups[i], - logprobs={tuple(idx_next_i):logprob_i}, - output_token=tuple(idx_next_i), - hidden_states=hidden_states[i].cpu(), - finished=finish[i].item(), - ),], - prompt_logprobs = None + samples=[ + SequenceOutput( + parent_seq_id=seq_groups[i], + logprobs={tuple(idx_next_i): logprob_i}, + output_token=tuple(idx_next_i), + hidden_states=hidden_states[i].cpu(), + finished=finish[i].item(), + ), + ], + prompt_logprobs=None, ) results.append(result) # print(results) @@ -575,7 +587,9 @@ def execute_model( def profile_run(self) -> None: # Enable top-k sampling to reflect the accurate memory usage. vocab_size = self.model_config.get_vocab_size() - sampling_params = SamplingParams(top_p=0.99, top_k=vocab_size - 1, infer_text=True) + sampling_params = SamplingParams( + top_p=0.99, top_k=vocab_size - 1, infer_text=True + ) max_num_batched_tokens = self.scheduler_config.max_num_batched_tokens max_num_seqs = self.scheduler_config.max_num_seqs @@ -583,8 +597,9 @@ def profile_run(self) -> None: # number of tokens equal to max_num_batched_tokens. seqs: List[SequenceGroupMetadata] = [] for group_id in range(max_num_seqs): - seq_len = (max_num_batched_tokens // max_num_seqs + - (group_id < max_num_batched_tokens % max_num_seqs)) + seq_len = max_num_batched_tokens // max_num_seqs + ( + group_id < max_num_batched_tokens % max_num_seqs + ) seq_data = SequenceData([0] * seq_len) seq = SequenceGroupMetadata( request_id=str(group_id), @@ -605,20 +620,28 @@ def profile_run(self) -> None: @torch.inference_mode() def capture_model(self, kv_caches: List[KVCache]) -> None: assert not self.model_config.enforce_eager - logger.info("Capturing the model for CUDA graphs. This may lead to " - "unexpected consequences if the model is not static. To " - "run the model in eager mode, set 'enforce_eager=True' or " - "use '--enforce-eager' in the CLI.") - logger.info("CUDA graphs can take additional 1~3 GiB memory per GPU. " - "If you are running out of memory, consider decreasing " - "`gpu_memory_utilization` or enforcing eager mode.") + logger.info( + "Capturing the model for CUDA graphs. This may lead to " + "unexpected consequences if the model is not static. To " + "run the model in eager mode, set 'enforce_eager=True' or " + "use '--enforce-eager' in the CLI." + ) + logger.info( + "CUDA graphs can take additional 1~3 GiB memory per GPU. " + "If you are running out of memory, consider decreasing " + "`gpu_memory_utilization` or enforcing eager mode." + ) start_time = time.perf_counter() # Prepare dummy inputs. These will be reused for all batch sizes. max_batch_size = max(_BATCH_SIZES_TO_CAPTURE) - input_emb = torch.zeros(max_batch_size, 1, self.model_config.get_hidden_size(), dtype=next(self.model.parameters()).dtype).cuda() - input_positions = torch.zeros(max_batch_size, 1, - dtype=torch.long).cuda() + input_emb = torch.zeros( + max_batch_size, + 1, + self.model_config.get_hidden_size(), + dtype=next(self.model.parameters()).dtype, + ).cuda() + input_positions = torch.zeros(max_batch_size, 1, dtype=torch.long).cuda() slot_mapping = torch.empty(max_batch_size, 1, dtype=torch.long).cuda() slot_mapping.fill_(_PAD_SLOT_ID) context_lens = torch.ones(max_batch_size, dtype=torch.int32).cuda() @@ -718,12 +741,15 @@ def forward( # Copy the input tensors to the input buffers. self.input_buffers["input_emb"].copy_(input_emb, non_blocking=True) self.input_buffers["positions"].copy_(positions, non_blocking=True) - self.input_buffers["slot_mapping"].copy_(input_metadata.slot_mapping, - non_blocking=True) - self.input_buffers["context_lens"].copy_(input_metadata.context_lens, - non_blocking=True) - self.input_buffers["block_tables"].copy_(input_metadata.block_tables, - non_blocking=True) + self.input_buffers["slot_mapping"].copy_( + input_metadata.slot_mapping, non_blocking=True + ) + self.input_buffers["context_lens"].copy_( + input_metadata.context_lens, non_blocking=True + ) + self.input_buffers["block_tables"].copy_( + input_metadata.block_tables, non_blocking=True + ) # Run the graph. self.graph.replay() @@ -749,10 +775,12 @@ def _make_tensor_with_pad( pin_memory: bool = False, ) -> torch.Tensor: padded_x = [_pad_to_max(x_i, max_len, pad) for x_i in x] - return torch.tensor(padded_x, - dtype=dtype, - device=device, - pin_memory=pin_memory and str(device) == "cpu") + return torch.tensor( + padded_x, + dtype=dtype, + device=device, + pin_memory=pin_memory and str(device) == "cpu", + ) def _get_graph_batch_size(batch_size: int) -> int: diff --git a/ChatTTS/model/velocity/output.py b/ChatTTS/model/velocity/output.py index ea3c81d80..3413a3e2b 100644 --- a/ChatTTS/model/velocity/output.py +++ b/ChatTTS/model/velocity/output.py @@ -1,8 +1,12 @@ from typing import List, Optional import torch -from ChatTTS.model.velocity.sequence import (PromptLogprobs, SampleLogprobs, SequenceGroup, - SequenceStatus) +from ChatTTS.model.velocity.sequence import ( + PromptLogprobs, + SampleLogprobs, + SequenceGroup, + SequenceStatus, +) class CompletionOutput: @@ -41,13 +45,15 @@ def finished(self) -> bool: return self.finish_reason is not None def __repr__(self) -> str: - return (f"CompletionOutput(index={self.index}, " - f"text={self.text!r}, " - f"token_ids={self.token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}, " - f"logprobs={self.logprobs}, " - f"finish_reason={self.finish_reason}, " - f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None})") + return ( + f"CompletionOutput(index={self.index}, " + f"text={self.text!r}, " + f"token_ids={self.token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}, " + f"logprobs={self.logprobs}, " + f"finish_reason={self.finish_reason}, " + f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None})" + ) class RequestOutput: @@ -85,7 +91,8 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": seqs = seq_group.get_seqs() if seq_group.sampling_params.use_beam_search: sorting_key = lambda seq: seq.get_beam_search_score( - seq_group.sampling_params.length_penalty) + seq_group.sampling_params.length_penalty + ) else: sorting_key = lambda seq: seq.get_cumulative_logprob() sorted_seqs = sorted(seqs, key=sorting_key, reverse=True) @@ -101,12 +108,15 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": # logprobs are not requested. logprobs = None finshed_reason = SequenceStatus.get_finished_reason(seq.status) - output = CompletionOutput(seqs.index(seq), seq.output_text, - seq.get_output_token_ids(), - seq.get_cumulative_logprob(), logprobs, - finshed_reason, - seq.data.hidden_states - ) + output = CompletionOutput( + seqs.index(seq), + seq.output_text, + seq.get_output_token_ids(), + seq.get_cumulative_logprob(), + logprobs, + finshed_reason, + seq.data.hidden_states, + ) outputs.append(output) # Every sequence in the sequence group should have the same prompt. @@ -114,14 +124,21 @@ def from_seq_group(cls, seq_group: SequenceGroup) -> "RequestOutput": prompt_token_ids = seq_group.prompt_token_ids prompt_logprobs = seq_group.prompt_logprobs finished = seq_group.is_finished() - return cls(seq_group.request_id, prompt, prompt_token_ids, - prompt_logprobs, outputs, finished) + return cls( + seq_group.request_id, + prompt, + prompt_token_ids, + prompt_logprobs, + outputs, + finished, + ) def __repr__(self) -> str: - return (f"RequestOutput(request_id={self.request_id}, " - f"prompt={self.prompt!r}, " - f"prompt_token_ids={self.prompt_token_ids}, " - f"prompt_logprobs={self.prompt_logprobs}, " - f"outputs={self.outputs}, " - f"finished={self.finished})" - ) + return ( + f"RequestOutput(request_id={self.request_id}, " + f"prompt={self.prompt!r}, " + f"prompt_token_ids={self.prompt_token_ids}, " + f"prompt_logprobs={self.prompt_logprobs}, " + f"outputs={self.outputs}, " + f"finished={self.finished})" + ) diff --git a/ChatTTS/model/velocity/post_model.py b/ChatTTS/model/velocity/post_model.py index c38853b3a..89bc79dcc 100644 --- a/ChatTTS/model/velocity/post_model.py +++ b/ChatTTS/model/velocity/post_model.py @@ -12,13 +12,11 @@ from torch.functional import F from torch.nn.utils.parametrizations import weight_norm from typing import List, Callable + + class Post_model(nn.Module): def __init__( - self, - hidden_size: int, - num_audio_tokens: int, - num_text_tokens: int, - num_vq=4 + self, hidden_size: int, num_audio_tokens: int, num_text_tokens: int, num_vq=4 ): super().__init__() @@ -27,41 +25,24 @@ def __init__( self.model_dim = hidden_size self.emb_code = nn.ModuleList( - [ - nn.Embedding( - num_audio_tokens, - self.model_dim - ) - for _ in range(num_vq) - ], - ) - self.emb_text = nn.Embedding( - num_text_tokens, self.model_dim + [nn.Embedding(num_audio_tokens, self.model_dim) for _ in range(num_vq)], ) + self.emb_text = nn.Embedding(num_text_tokens, self.model_dim) self.head_text = weight_norm( - nn.Linear( - self.model_dim, - num_text_tokens, - bias=False - ), + nn.Linear(self.model_dim, num_text_tokens, bias=False), name="weight", ) self.head_code = nn.ModuleList( [ weight_norm( - nn.Linear( - self.model_dim, - num_audio_tokens, - bias=False - ), + nn.Linear(self.model_dim, num_audio_tokens, bias=False), name="weight", ) for _ in range(self.num_vq) ], ) - def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Tensor: """ get_emb @@ -90,112 +71,118 @@ def forward(self, input_ids: torch.Tensor, text_mask: torch.Tensor) -> torch.Ten del emb_text, emb_code, text_mask_inv return emb - + + class Sampler: - def __init__(self, - post_model: Post_model, - num_audio_tokens: int, - num_vq: int - ): + def __init__(self, post_model: Post_model, num_audio_tokens: int, num_vq: int): self.post_model = post_model self.device = next(self.post_model.parameters()).device self.num_audio_tokens = num_audio_tokens self.num_vq = num_vq - - def sample(self, - inputs_ids: torch.Tensor, - hidden_states: torch.Tensor, - infer_text: bool = False, - temperature: torch.Tensor = 1.0, - logits_processors: List[Callable] = [lambda logits_token, logits: logits,], - logits_warpers: List[Callable] = [lambda logits_token, logits: logits,], - min_new_token: int = 0, - now_length: int = 0, - eos_token: int = 0, - start_idx: int = 0, - ): - # print(inputs_ids.shape) - B = hidden_states.shape[0] - - end_idx = torch.zeros( - inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long + + def sample( + self, + inputs_ids: torch.Tensor, + hidden_states: torch.Tensor, + infer_text: bool = False, + temperature: torch.Tensor = 1.0, + logits_processors: List[Callable] = [ + lambda logits_token, logits: logits, + ], + logits_warpers: List[Callable] = [ + lambda logits_token, logits: logits, + ], + min_new_token: int = 0, + now_length: int = 0, + eos_token: int = 0, + start_idx: int = 0, + ): + # print(inputs_ids.shape) + B = hidden_states.shape[0] + + end_idx = torch.zeros( + inputs_ids.shape[0], device=inputs_ids.device, dtype=torch.long + ) + finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool() + if not infer_text: + temperature = ( + temperature.unsqueeze(0) + .expand(inputs_ids.shape[0], -1) + .contiguous() + .view(-1, 1) ) - finish = torch.zeros(inputs_ids.shape[0], device=inputs_ids.device).bool() - if not infer_text: - temperature = ( - temperature.unsqueeze(0) - .expand(inputs_ids.shape[0], -1) - .contiguous() - .view(-1, 1) - ) - - if infer_text: - logits: torch.Tensor = self.post_model.head_text(hidden_states) - else: - # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3) - logits = torch.empty( - hidden_states.size(0), - hidden_states.size(1), - self.num_audio_tokens, - self.num_vq, - dtype=torch.float, - device=self.device, - ) - for num_vq_iter in range(self.num_vq): - x: torch.Tensor = self.post_model.head_code[num_vq_iter](hidden_states) - logits[..., num_vq_iter] = x - del x - - del hidden_states - - # logits = logits[:, -1].float() - logits = logits.narrow(1, -1, 1).squeeze_(1).float() - - if not infer_text: - # logits = rearrange(logits, "b c n -> (b n) c") - logits = logits.permute(0, 2, 1) - logits = logits.reshape(-1, logits.size(2)) - # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c") - inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1) - logits_token = inputs_ids_sliced.reshape( - inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1), - -1, - ).to(self.device) - else: - logits_token = inputs_ids[:, start_idx:, 0].to(self.device) - - logits /= temperature - - for logitsProcessors in logits_processors: - logits = logitsProcessors(logits_token, logits) - - for logitsWarpers in logits_warpers: - logits = logitsWarpers(logits_token, logits) - - del logits_token - - if now_length < min_new_token: - logits[:, eos_token] = -torch.inf - - scores = F.softmax(logits, dim=-1) - idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) - if not infer_text: - scores = scores.reshape(B, -1, scores.shape[-1]) - if not infer_text: - # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) - idx_next = idx_next.view(-1, self.num_vq) - finish_or = idx_next.eq(eos_token).any(1) - finish.logical_or_(finish_or) - del finish_or - else: - finish_or = idx_next.eq(eos_token).any(1) - finish.logical_or_(finish_or) - del finish_or - - del inputs_ids - - not_finished = finish.logical_not().to(end_idx.device) - - end_idx.add_(not_finished.int()) - idx_next = idx_next[:, None, :] - return idx_next, torch.log(scores), finish, \ No newline at end of file + + if infer_text: + logits: torch.Tensor = self.post_model.head_text(hidden_states) + else: + # logits = torch.stack([self.head_code[i](hidden_states) for i in range(self.num_vq)], 3) + logits = torch.empty( + hidden_states.size(0), + hidden_states.size(1), + self.num_audio_tokens, + self.num_vq, + dtype=torch.float, + device=self.device, + ) + for num_vq_iter in range(self.num_vq): + x: torch.Tensor = self.post_model.head_code[num_vq_iter](hidden_states) + logits[..., num_vq_iter] = x + del x + + del hidden_states + + # logits = logits[:, -1].float() + logits = logits.narrow(1, -1, 1).squeeze_(1).float() + + if not infer_text: + # logits = rearrange(logits, "b c n -> (b n) c") + logits = logits.permute(0, 2, 1) + logits = logits.reshape(-1, logits.size(2)) + # logits_token = rearrange(inputs_ids[:, start_idx:], "b c n -> (b n) c") + inputs_ids_sliced = inputs_ids[:, start_idx:].permute(0, 2, 1) + logits_token = inputs_ids_sliced.reshape( + inputs_ids_sliced.size(0) * inputs_ids_sliced.size(1), + -1, + ).to(self.device) + else: + logits_token = inputs_ids[:, start_idx:, 0].to(self.device) + + logits /= temperature + + for logitsProcessors in logits_processors: + logits = logitsProcessors(logits_token, logits) + + for logitsWarpers in logits_warpers: + logits = logitsWarpers(logits_token, logits) + + del logits_token + + if now_length < min_new_token: + logits[:, eos_token] = -torch.inf + + scores = F.softmax(logits, dim=-1) + idx_next = torch.multinomial(scores, num_samples=1).to(finish.device) + if not infer_text: + scores = scores.reshape(B, -1, scores.shape[-1]) + if not infer_text: + # idx_next = rearrange(idx_next, "(b n) 1 -> b n", n=self.num_vq) + idx_next = idx_next.view(-1, self.num_vq) + finish_or = idx_next.eq(eos_token).any(1) + finish.logical_or_(finish_or) + del finish_or + else: + finish_or = idx_next.eq(eos_token).any(1) + finish.logical_or_(finish_or) + del finish_or + + del inputs_ids + + not_finished = finish.logical_not().to(end_idx.device) + + end_idx.add_(not_finished.int()) + idx_next = idx_next[:, None, :] + return ( + idx_next, + torch.log(scores), + finish, + ) diff --git a/ChatTTS/model/velocity/sampling_params.py b/ChatTTS/model/velocity/sampling_params.py index be3f9bf7f..e650fc546 100644 --- a/ChatTTS/model/velocity/sampling_params.py +++ b/ChatTTS/model/velocity/sampling_params.py @@ -1,4 +1,5 @@ """Sampling parameters for text generation.""" + from enum import IntEnum from functools import cached_property from typing import Callable, List, Optional, Union @@ -113,13 +114,20 @@ def __init__( prompt_logprobs: Optional[int] = None, skip_special_tokens: bool = True, spaces_between_special_tokens: bool = True, - logits_processors: Optional[List[LogitsProcessor]] = ([lambda logits_token, logits: logits,],[lambda logits_token, logits: logits,]), + logits_processors: Optional[List[LogitsProcessor]] = ( + [ + lambda logits_token, logits: logits, + ], + [ + lambda logits_token, logits: logits, + ], + ), min_new_token: int = 0, max_new_token: int = 8192, infer_text: bool = False, eos_token: int = 0, - spk_emb:str = None, - start_idx:int = 0, + spk_emb: str = None, + start_idx: int = 0, ) -> None: self.n = n self.best_of = best_of if best_of is not None else n @@ -173,42 +181,50 @@ def _verify_args(self) -> None: if self.n < 1: raise ValueError(f"n must be at least 1, got {self.n}.") if self.best_of < self.n: - raise ValueError(f"best_of must be greater than or equal to n, " - f"got n={self.n} and best_of={self.best_of}.") + raise ValueError( + f"best_of must be greater than or equal to n, " + f"got n={self.n} and best_of={self.best_of}." + ) if not -2.0 <= self.presence_penalty <= 2.0: - raise ValueError("presence_penalty must be in [-2, 2], got " - f"{self.presence_penalty}.") + raise ValueError( + "presence_penalty must be in [-2, 2], got " f"{self.presence_penalty}." + ) if not -2.0 <= self.frequency_penalty <= 2.0: - raise ValueError("frequency_penalty must be in [-2, 2], got " - f"{self.frequency_penalty}.") + raise ValueError( + "frequency_penalty must be in [-2, 2], got " + f"{self.frequency_penalty}." + ) if not 0.0 < self.repetition_penalty <= 2.0: - raise ValueError("repetition_penalty must be in (0, 2], got " - f"{self.repetition_penalty}.") + raise ValueError( + "repetition_penalty must be in (0, 2], got " + f"{self.repetition_penalty}." + ) # if self.temperature < 0.0: # raise ValueError( # f"temperature must be non-negative, got {self.temperature}.") if not 0.0 < self.top_p <= 1.0: raise ValueError(f"top_p must be in (0, 1], got {self.top_p}.") if self.top_k < -1 or self.top_k == 0: - raise ValueError(f"top_k must be -1 (disable), or at least 1, " - f"got {self.top_k}.") + raise ValueError( + f"top_k must be -1 (disable), or at least 1, " f"got {self.top_k}." + ) if not 0.0 <= self.min_p <= 1.0: - raise ValueError("min_p must be in [0, 1], got " - f"{self.min_p}.") + raise ValueError("min_p must be in [0, 1], got " f"{self.min_p}.") if self.max_tokens < 1: - raise ValueError( - f"max_tokens must be at least 1, got {self.max_tokens}.") + raise ValueError(f"max_tokens must be at least 1, got {self.max_tokens}.") if self.logprobs is not None and self.logprobs < 0: - raise ValueError( - f"logprobs must be non-negative, got {self.logprobs}.") + raise ValueError(f"logprobs must be non-negative, got {self.logprobs}.") if self.prompt_logprobs is not None and self.prompt_logprobs < 0: - raise ValueError(f"prompt_logprobs must be non-negative, got " - f"{self.prompt_logprobs}.") + raise ValueError( + f"prompt_logprobs must be non-negative, got " f"{self.prompt_logprobs}." + ) def _verify_beam_search(self) -> None: if self.best_of == 1: - raise ValueError("best_of must be greater than 1 when using beam " - f"search. Got {self.best_of}.") + raise ValueError( + "best_of must be greater than 1 when using beam " + f"search. Got {self.best_of}." + ) if self.temperature > _SAMPLING_EPS: raise ValueError("temperature must be 0 when using beam search.") if self.top_p < 1.0 - _SAMPLING_EPS: @@ -218,22 +234,29 @@ def _verify_beam_search(self) -> None: if self.early_stopping not in [True, False, "never"]: raise ValueError( f"early_stopping must be True, False, or 'never', " - f"got {self.early_stopping}.") + f"got {self.early_stopping}." + ) def _verify_non_beam_search(self) -> None: if self.early_stopping is not False: - raise ValueError("early_stopping is not effective and must be " - "False when not using beam search.") - if (self.length_penalty < 1.0 - _SAMPLING_EPS - or self.length_penalty > 1.0 + _SAMPLING_EPS): + raise ValueError( + "early_stopping is not effective and must be " + "False when not using beam search." + ) + if ( + self.length_penalty < 1.0 - _SAMPLING_EPS + or self.length_penalty > 1.0 + _SAMPLING_EPS + ): raise ValueError( "length_penalty is not effective and must be the " - "default value of 1.0 when not using beam search.") + "default value of 1.0 when not using beam search." + ) def _verify_greedy_sampling(self) -> None: if self.best_of > 1: - raise ValueError("best_of must be 1 when using greedy sampling." - f"Got {self.best_of}.") + raise ValueError( + "best_of must be 1 when using greedy sampling." f"Got {self.best_of}." + ) @cached_property def sampling_type(self) -> SamplingType: @@ -270,4 +293,4 @@ def __repr__(self) -> str: f"max_new_token={self.max_new_token}), " f"min_new_token={self.min_new_token}), " f"infer_text={self.infer_text})" - ) + ) diff --git a/ChatTTS/model/velocity/scheduler.py b/ChatTTS/model/velocity/scheduler.py index e93ca7fd6..4eb38d278 100644 --- a/ChatTTS/model/velocity/scheduler.py +++ b/ChatTTS/model/velocity/scheduler.py @@ -6,8 +6,13 @@ from ChatTTS.model.velocity.block_manager import AllocStatus, BlockSpaceManager from vllm.core.policy import PolicyFactory from vllm.logger import init_logger -from ChatTTS.model.velocity.sequence import (Sequence, SequenceData, SequenceGroup, - SequenceGroupMetadata, SequenceStatus) +from ChatTTS.model.velocity.sequence import ( + Sequence, + SequenceData, + SequenceGroup, + SequenceGroupMetadata, + SequenceStatus, +) logger = init_logger(__name__) @@ -21,6 +26,7 @@ class PreemptionMode(enum.Enum): recompute them when the sequences are resumed, treating the sequences as new prompts. """ + SWAP = enum.auto() RECOMPUTE = enum.auto() @@ -49,8 +55,12 @@ def __init__( def is_empty(self) -> bool: # NOTE: We do not consider the ignored sequence groups. - return (not self.scheduled_seq_groups and not self.blocks_to_swap_in - and not self.blocks_to_swap_out and not self.blocks_to_copy) + return ( + not self.scheduled_seq_groups + and not self.blocks_to_swap_in + and not self.blocks_to_swap_out + and not self.blocks_to_copy + ) class Scheduler: @@ -63,8 +73,10 @@ def __init__( self.scheduler_config = scheduler_config self.cache_config = cache_config - self.prompt_limit = min(self.scheduler_config.max_model_len, - self.scheduler_config.max_num_batched_tokens) + self.prompt_limit = min( + self.scheduler_config.max_model_len, + self.scheduler_config.max_num_batched_tokens, + ) # Instantiate the scheduling policy. self.policy = PolicyFactory.get_policy(policy_name="fcfs") @@ -73,7 +85,8 @@ def __init__( block_size=self.cache_config.block_size, num_gpu_blocks=self.cache_config.num_gpu_blocks, num_cpu_blocks=self.cache_config.num_cpu_blocks, - sliding_window=self.cache_config.sliding_window) + sliding_window=self.cache_config.sliding_window, + ) # TODO(zhuohan): Use deque instead of list for better performance. # Sequence groups in the WAITING state. @@ -89,7 +102,7 @@ def add_seq_group(self, seq_group: SequenceGroup) -> None: def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: if isinstance(request_id, str): - request_id = (request_id, ) + request_id = (request_id,) request_ids = set(request_id) for state_queue in [self.waiting, self.running, self.swapped]: # We need to reverse the list as we are removing elements @@ -129,8 +142,9 @@ def _schedule(self) -> SchedulerOutputs: scheduled: List[SequenceGroup] = [] # The total number of sequences on the fly, including the # requests in the generation phase. - num_curr_seqs = sum(seq_group.get_max_num_running_seqs() - for seq_group in self.running) + num_curr_seqs = sum( + seq_group.get_max_num_running_seqs() for seq_group in self.running + ) seq_lens: List[int] = [] # Optimization: We do not sort the waiting queue since the preempted @@ -139,16 +153,16 @@ def _schedule(self) -> SchedulerOutputs: while self.waiting: seq_group = self.waiting[0] - waiting_seqs = seq_group.get_seqs( - status=SequenceStatus.WAITING) + waiting_seqs = seq_group.get_seqs(status=SequenceStatus.WAITING) assert len(waiting_seqs) == 1, ( - "Waiting sequence group should have only one prompt " - "sequence.") + "Waiting sequence group should have only one prompt " "sequence." + ) num_prompt_tokens = waiting_seqs[0].get_len() if num_prompt_tokens > self.prompt_limit: logger.warning( f"Input prompt ({num_prompt_tokens} tokens) is too long" - f" and exceeds limit of {self.prompt_limit}") + f" and exceeds limit of {self.prompt_limit}" + ) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -162,7 +176,8 @@ def _schedule(self) -> SchedulerOutputs: elif can_allocate == AllocStatus.NEVER: logger.warning( f"Input prompt ({num_prompt_tokens} tokens) is too long" - f" and exceeds the capacity of block_manager") + f" and exceeds the capacity of block_manager" + ) for seq in waiting_seqs: seq.status = SequenceStatus.FINISHED_IGNORED ignored_seq_groups.append(seq_group) @@ -172,15 +187,13 @@ def _schedule(self) -> SchedulerOutputs: # If the number of batched tokens exceeds the limit, stop. new_seq_lens = seq_lens + [num_prompt_tokens] num_batched_tokens = len(new_seq_lens) * max(new_seq_lens) - if (num_batched_tokens > - self.scheduler_config.max_num_batched_tokens): + if num_batched_tokens > self.scheduler_config.max_num_batched_tokens: break # The total number of sequences in the RUNNING state should not # exceed the maximum number of sequences. num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_curr_seqs + num_new_seqs > - self.scheduler_config.max_num_seqs): + if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: break num_paddings = num_batched_tokens - sum(new_seq_lens) @@ -198,8 +211,7 @@ def _schedule(self) -> SchedulerOutputs: scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=scheduled, prompt_run=True, - num_batched_tokens=len(seq_lens) * - max(seq_lens) if seq_lens else 0, + num_batched_tokens=len(seq_lens) * max(seq_lens) if seq_lens else 0, blocks_to_swap_in=blocks_to_swap_in, blocks_to_swap_out=blocks_to_swap_out, blocks_to_copy=blocks_to_copy, @@ -239,8 +251,9 @@ def _schedule(self) -> SchedulerOutputs: # Swap in the sequence groups in the SWAPPED state if possible. self.swapped = self.policy.sort_by_priority(now, self.swapped) if not preempted: - num_curr_seqs = sum(seq_group.get_max_num_running_seqs() - for seq_group in self.running) + num_curr_seqs = sum( + seq_group.get_max_num_running_seqs() for seq_group in self.running + ) while self.swapped: seq_group = self.swapped[0] @@ -251,8 +264,7 @@ def _schedule(self) -> SchedulerOutputs: # The total number of sequences in the RUNNING state should not # exceed the maximum number of sequences. num_new_seqs = seq_group.get_max_num_running_seqs() - if (num_curr_seqs + num_new_seqs > - self.scheduler_config.max_num_seqs): + if num_curr_seqs + num_new_seqs > self.scheduler_config.max_num_seqs: break seq_group = self.swapped.pop(0) @@ -266,7 +278,8 @@ def _schedule(self) -> SchedulerOutputs: # sequences in the RUNNING state. num_batched_tokens = sum( seq_group.num_seqs(status=SequenceStatus.RUNNING) - for seq_group in self.running) + for seq_group in self.running + ) scheduler_outputs = SchedulerOutputs( scheduled_seq_groups=self.running, @@ -313,8 +326,7 @@ def free_seq(self, seq: Sequence) -> None: def free_finished_seq_groups(self) -> None: self.running = [ - seq_group for seq_group in self.running - if not seq_group.is_finished() + seq_group for seq_group in self.running if not seq_group.is_finished() ] def _allocate(self, seq_group: SequenceGroup) -> None: @@ -406,7 +418,8 @@ def _swap_out( # entire engine. raise RuntimeError( "Aborted due to the lack of CPU swap space. Please increase " - "the swap space to avoid this error.") + "the swap space to avoid this error." + ) mapping = self.block_manager.swap_out(seq_group) blocks_to_swap_out.update(mapping) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): diff --git a/ChatTTS/model/velocity/sequence.py b/ChatTTS/model/velocity/sequence.py index f5c2a09ad..4bc3f354d 100644 --- a/ChatTTS/model/velocity/sequence.py +++ b/ChatTTS/model/velocity/sequence.py @@ -1,4 +1,5 @@ """Sequence and its related classes.""" + import copy import enum from typing import Dict, List, Optional, Union @@ -12,6 +13,7 @@ class SequenceStatus(enum.Enum): """Status of a sequence.""" + WAITING = enum.auto() RUNNING = enum.auto() SWAPPED = enum.auto() @@ -69,10 +71,12 @@ def __init__( self.cumulative_logprob = 0.0 self.hidden_states: Optional[torch.Tensor] = None self.finished = False - + def append_token_id(self, token_id: int, logprob: float) -> None: if isinstance(self.cumulative_logprob, float): - self.cumulative_logprob = [0.0, ] * len(logprob) + self.cumulative_logprob = [ + 0.0, + ] * len(logprob) self.output_token_ids.append(token_id) for i in range(len(self.cumulative_logprob)): self.cumulative_logprob[i] += logprob[i] @@ -82,7 +86,7 @@ def append_hidden_states(self, hidden_states: torch.Tensor) -> None: self.hidden_states = hidden_states else: self.hidden_states = torch.cat([self.hidden_states, hidden_states], dim=0) - + def get_len(self) -> int: return len(self.output_token_ids) + len(self.prompt_token_ids) @@ -101,12 +105,14 @@ def get_last_token_id(self) -> int: return self.output_token_ids[-1] def __repr__(self) -> str: - return (f"SequenceData(" - f"prompt_token_ids={self.prompt_token_ids}, " - f"output_token_ids={self.output_token_ids}, " - f"cumulative_logprob={self.cumulative_logprob}), " - f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}, " - f"finished={self.finished})") + return ( + f"SequenceData(" + f"prompt_token_ids={self.prompt_token_ids}, " + f"output_token_ids={self.output_token_ids}, " + f"cumulative_logprob={self.cumulative_logprob}), " + f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}, " + f"finished={self.finished})" + ) class Sequence: @@ -165,8 +171,7 @@ def _append_tokens_to_blocks(self, token_ids: List[int]) -> None: last_block = self.logical_token_blocks[-1] num_empty_slots = last_block.get_num_empty_slots() - last_block.append_tokens(token_ids[cursor:cursor + - num_empty_slots]) + last_block.append_tokens(token_ids[cursor : cursor + num_empty_slots]) cursor += num_empty_slots def append_token_id( @@ -174,7 +179,7 @@ def append_token_id( token_id: int, logprobs: Dict[int, float], hidden_states: Optional[torch.Tensor] = None, - finished: bool = False + finished: bool = False, ) -> None: assert token_id in logprobs self._append_tokens_to_blocks([token_id]) @@ -182,7 +187,7 @@ def append_token_id( self.data.append_token_id(token_id, logprobs[token_id]) self.data.append_hidden_states(hidden_states) self.data.finished = finished - + def get_len(self) -> int: return self.data.get_len() @@ -204,10 +209,12 @@ def get_output_token_ids(self) -> List[int]: def get_cumulative_logprob(self) -> float: return self.data.cumulative_logprob - def get_beam_search_score(self, - length_penalty: float = 0.0, - seq_len: Optional[int] = None, - eos_token_id: Optional[int] = None) -> float: + def get_beam_search_score( + self, + length_penalty: float = 0.0, + seq_len: Optional[int] = None, + eos_token_id: Optional[int] = None, + ) -> float: """Calculate the beam search score with length penalty. Adapted from @@ -218,8 +225,7 @@ def get_beam_search_score(self, seq_len = self.get_len() # NOTE: HF implementation does not count the EOS token # towards the length, we align with that here for testing. - if (eos_token_id is not None - and self.get_last_token_id() == eos_token_id): + if eos_token_id is not None and self.get_last_token_id() == eos_token_id: seq_len -= 1 return self.get_cumulative_logprob() / (seq_len**length_penalty) @@ -232,9 +238,11 @@ def fork(self, new_seq_id: int) -> "Sequence": return new_seq def __repr__(self) -> str: - return (f"Sequence(seq_id={self.seq_id}, " - f"status={self.status.name}, " - f"num_blocks={len(self.logical_token_blocks)})") + return ( + f"Sequence(seq_id={self.seq_id}, " + f"status={self.status.name}, " + f"num_blocks={len(self.logical_token_blocks)})" + ) class SequenceGroup: @@ -296,14 +304,10 @@ def get_seqs( if status is None: return list(self.seqs_dict.values()) else: - return [ - seq for seq in self.seqs_dict.values() if seq.status == status - ] + return [seq for seq in self.seqs_dict.values() if seq.status == status] def get_unfinished_seqs(self) -> List[Sequence]: - return [ - seq for seq in self.seqs_dict.values() if not seq.is_finished() - ] + return [seq for seq in self.seqs_dict.values() if not seq.is_finished()] def get_finished_seqs(self) -> List[Sequence]: return [seq for seq in self.seqs_dict.values() if seq.is_finished()] @@ -336,9 +340,11 @@ def is_finished(self) -> bool: return all(seq.is_finished() for seq in self.get_seqs()) def __repr__(self) -> str: - return (f"SequenceGroup(request_id={self.request_id}, " - f"sampling_params={self.sampling_params}, " - f"num_seqs={len(self.seqs_dict)})") + return ( + f"SequenceGroup(request_id={self.request_id}, " + f"sampling_params={self.sampling_params}, " + f"num_seqs={len(self.seqs_dict)})" + ) class SequenceGroupMetadata: @@ -386,27 +392,31 @@ def __init__( output_token: int, logprobs: Dict[int, float], hidden_states: Optional[torch.Tensor] = None, - finished: bool = False + finished: bool = False, ) -> None: self.parent_seq_id = parent_seq_id self.output_token = output_token self.logprobs = logprobs self.finished = finished self.hidden_states = hidden_states + def __repr__(self) -> str: - return (f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " - f"output_token={self.output_token}, " - f"logprobs={self.logprobs})," - f"finished={self.finished})," - f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}" - ) + return ( + f"SequenceOutput(parent_seq_id={self.parent_seq_id}, " + f"output_token={self.output_token}, " + f"logprobs={self.logprobs})," + f"finished={self.finished})," + f"hidden_states={self.hidden_states.shape if self.hidden_states is not None else None}" + ) def __eq__(self, other: object) -> bool: if not isinstance(other, SequenceOutput): raise NotImplementedError() - return (self.parent_seq_id == other.parent_seq_id - and self.output_token == other.output_token - and self.logprobs == other.logprobs) + return ( + self.parent_seq_id == other.parent_seq_id + and self.output_token == other.output_token + and self.logprobs == other.logprobs + ) class SequenceGroupOutput: @@ -421,14 +431,18 @@ def __init__( self.prompt_logprobs = prompt_logprobs def __repr__(self) -> str: - return (f"SequenceGroupOutput(samples={self.samples}, " - f"prompt_logprobs={self.prompt_logprobs})") + return ( + f"SequenceGroupOutput(samples={self.samples}, " + f"prompt_logprobs={self.prompt_logprobs})" + ) def __eq__(self, other: object) -> bool: if not isinstance(other, SequenceGroupOutput): raise NotImplementedError() - return (self.samples == other.samples - and self.prompt_logprobs == other.prompt_logprobs) + return ( + self.samples == other.samples + and self.prompt_logprobs == other.prompt_logprobs + ) # For each sequence group, we generate a list of SequenceOutput object, diff --git a/ChatTTS/model/velocity/worker.py b/ChatTTS/model/velocity/worker.py index 0162302bf..9578551d9 100644 --- a/ChatTTS/model/velocity/worker.py +++ b/ChatTTS/model/velocity/worker.py @@ -1,17 +1,15 @@ """A GPU worker class.""" + import os from typing import Dict, List, Optional, Tuple import torch import torch.distributed -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig from vllm.model_executor import set_random_seed -from vllm.model_executor.parallel_utils.communication_op import ( - broadcast_object_list) -from vllm.model_executor.parallel_utils.parallel_state import ( - initialize_model_parallel) +from vllm.model_executor.parallel_utils.communication_op import broadcast_object_list +from vllm.model_executor.parallel_utils.parallel_state import initialize_model_parallel from vllm.sequence import SamplerOutput, SequenceGroupMetadata from vllm.worker.cache_engine import CacheEngine from ChatTTS.model.velocity.model_runner import ModelRunner @@ -33,7 +31,7 @@ def __init__( local_rank: int, rank: int, distributed_init_method: str, - post_model_path:str, + post_model_path: str, is_driver_worker: bool = False, ) -> None: self.model_config = model_config @@ -44,12 +42,17 @@ def __init__( self.distributed_init_method = distributed_init_method self.is_driver_worker = is_driver_worker self.post_model_path = post_model_path - + if self.is_driver_worker: assert self.rank == 0, "The driver worker must have rank 0." - self.model_runner = ModelRunner(model_config, parallel_config, - scheduler_config, is_driver_worker, post_model_path) + self.model_runner = ModelRunner( + model_config, + parallel_config, + scheduler_config, + is_driver_worker, + post_model_path, + ) # Uninitialized cache engine. Will be initialized by # self.init_cache_engine(). self.cache_config = None @@ -74,8 +77,9 @@ def init_model(self) -> None: _check_if_gpu_supports_dtype(self.model_config.dtype) # Initialize the distributed environment. - _init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method) + _init_distributed_environment( + self.parallel_config, self.rank, self.distributed_init_method + ) # Initialize the model. set_random_seed(self.model_config.seed) @@ -105,10 +109,12 @@ def profile_num_available_blocks( peak_memory = total_gpu_memory - free_gpu_memory cache_block_size = CacheEngine.get_cache_block_size( - block_size, self.model_config, self.parallel_config) + block_size, self.model_config, self.parallel_config + ) num_gpu_blocks = int( - (total_gpu_memory * gpu_memory_utilization - peak_memory) // - cache_block_size) + (total_gpu_memory * gpu_memory_utilization - peak_memory) + // cache_block_size + ) num_cpu_blocks = int(cpu_swap_space // cache_block_size) num_gpu_blocks = max(num_gpu_blocks, 0) num_cpu_blocks = max(num_cpu_blocks, 0) @@ -117,8 +123,9 @@ def profile_num_available_blocks( def init_cache_engine(self, cache_config: CacheConfig) -> None: self.cache_config = cache_config - self.cache_engine = CacheEngine(self.cache_config, self.model_config, - self.parallel_config) + self.cache_engine = CacheEngine( + self.cache_config, self.model_config, self.parallel_config + ) self.cache_events = self.cache_engine.events self.gpu_cache = self.cache_engine.gpu_cache self.model_runner.set_block_size(self.cache_engine.block_size) @@ -171,10 +178,11 @@ def execute_model( assert blocks_to_swap_out is not None assert blocks_to_copy is not None block_swapping_info = [ - blocks_to_swap_in, blocks_to_swap_out, blocks_to_copy + blocks_to_swap_in, + blocks_to_swap_out, + blocks_to_copy, ] - broadcast_object_list([num_seq_groups] + block_swapping_info, - src=0) + broadcast_object_list([num_seq_groups] + block_swapping_info, src=0) else: # num_seq_groups, blocks_to_swap_in, blocks_to_swap_out, # blocks_to_copy (4 elements) @@ -189,8 +197,9 @@ def execute_model( if num_seq_groups == 0: return {} - output = self.model_runner.execute_model(seq_group_metadata_list, - self.gpu_cache) + output = self.model_runner.execute_model( + seq_group_metadata_list, self.gpu_cache + ) return output @@ -206,11 +215,13 @@ def _init_distributed_environment( raise RuntimeError( "torch.distributed is already initialized but the torch world " "size does not match parallel_config.world_size " - f"({torch_world_size} vs. {parallel_config.world_size}).") + f"({torch_world_size} vs. {parallel_config.world_size})." + ) elif not distributed_init_method: raise ValueError( "distributed_init_method must be set if torch.distributed " - "is not already initialized") + "is not already initialized" + ) else: torch.distributed.init_process_group( backend="nccl", @@ -221,8 +232,9 @@ def _init_distributed_environment( # A small all_reduce for warmup. torch.distributed.all_reduce(torch.zeros(1).cuda()) - initialize_model_parallel(parallel_config.tensor_parallel_size, - parallel_config.pipeline_parallel_size) + initialize_model_parallel( + parallel_config.tensor_parallel_size, parallel_config.pipeline_parallel_size + ) def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): @@ -234,4 +246,5 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): raise ValueError( "Bfloat16 is only supported on GPUs with compute capability " f"of at least 8.0. Your {gpu_name} GPU has compute capability " - f"{compute_capability[0]}.{compute_capability[1]}.") + f"{compute_capability[0]}.{compute_capability[1]}." + ) diff --git a/test.py b/test.py index 07e9fae5d..2690b28fd 100644 --- a/test.py +++ b/test.py @@ -2,26 +2,28 @@ import torch import torchaudio import soundfile as sf + chat = ChatTTS.Chat() -chat.load(compile=False) # Set to True for better performance +chat.load(compile=False) # Set to True for better performance rand_spk = chat.sample_random_speaker() -print(rand_spk) # save it for later timbre recovery +print(rand_spk) # save it for later timbre recovery params_infer_code = ChatTTS.Chat.InferCodeParams( - spk_emb = rand_spk, # add sampled speaker - temperature = .3, # using custom temperature - top_P = 0.7, # top P decode - top_K = 20, # top K decode + spk_emb=rand_spk, # add sampled speaker + temperature=0.3, # using custom temperature + top_P=0.7, # top P decode + top_K=20, # top K decode ) params_refine_text = ChatTTS.Chat.RefineTextParams( - prompt='[oral_2][laugh_0][break_6]', + prompt="[oral_2][laugh_0][break_6]", ) texts = ["PUT YOUR 1st TEXT HERE", "PUT YOUR 2nd TEXT HERE"] -wavs = chat.infer(texts, +wavs = chat.infer( + texts, params_refine_text=params_refine_text, params_infer_code=params_infer_code, - ) +) # torchaudio.save("output1.wav", torch.from_numpy(wavs[0]), 24000) -sf.write("output1.wav", wavs[1], 24000) \ No newline at end of file +sf.write("output1.wav", wavs[1], 24000)