Skip to content

Commit

Permalink
chore(format): run black on dev (#614)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and github-actions[bot] authored Jul 21, 2024
1 parent c817ae9 commit 776f2c4
Show file tree
Hide file tree
Showing 15 changed files with 1,197 additions and 968 deletions.
97 changes: 51 additions & 46 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,33 +291,44 @@ def _load(
self.dvae = dvae
self.logger.log(logging.INFO, "dvae loaded.")


if gpt_config_path:
cfg = OmegaConf.load(gpt_config_path)
self.num_vq = 4
if not os.path.exists("asset/vllm_model"):
gpt = GPT(
**cfg, use_flash_attn=use_flash_attn, device=device, logger=self.logger
**cfg,
use_flash_attn=use_flash_attn,
device=device,
logger=self.logger,
).eval()
assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
gpt.load_state_dict(torch.load(gpt_ckpt_path, weights_only=True, mmap=True))
gpt.load_state_dict(
torch.load(gpt_ckpt_path, weights_only=True, mmap=True)
)
gpt.prepare(compile=compile and "cuda" in str(device))
self.gpt = gpt
pathlib.Path("asset/vllm_model").mkdir(parents=True, exist_ok=True)
self.gpt.gpt.save_pretrained("asset/vllm_model/gpt")
self.post_model = Post_model(
cfg.gpt_config.hidden_size,
cfg.num_audio_tokens,
cfg.num_text_tokens,
device = device
).to(device).eval()

self.post_model = (
Post_model(
cfg.gpt_config.hidden_size,
cfg.num_audio_tokens,
cfg.num_text_tokens,
device=device,
)
.to(device)
.eval()
)

self.post_model.emb_code = self.gpt.emb_code
self.post_model.emb_text = self.gpt.emb_text
self.post_model.head_text = self.gpt.head_text
self.post_model.head_code = self.gpt.head_code
save_file(self.post_model.state_dict(), "asset/vllm_model/post_model.safetensors")

save_file(
self.post_model.state_dict(),
"asset/vllm_model/post_model.safetensors",
)

self.num_audio_tokens = cfg.num_audio_tokens
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt")
assert os.path.exists(
Expand All @@ -331,15 +342,15 @@ def _load(
)
self.std, self.mean = spk_stat.requires_grad_(False).chunk(2)
self.logger.log(logging.INFO, "gpt loaded.")

self.hidden_size = cfg.gpt_config.hidden_size
self.gpt = LLM(
model="asset/vllm_model/gpt",
num_audio_tokens = cfg.num_audio_tokens,
num_text_tokens = cfg.num_text_tokens,
num_audio_tokens=cfg.num_audio_tokens,
num_text_tokens=cfg.num_text_tokens,
post_model_path="asset/vllm_model/post_model.safetensors",
)

if dvae_config_path:
cfg = OmegaConf.load(dvae_config_path)
dvae = DVAE(**cfg, coef=coef).to(device).eval()
Expand Down Expand Up @@ -369,7 +380,7 @@ def _load(
self.coef = coef

return self.has_loaded()

def _infer(
self,
text,
Expand Down Expand Up @@ -506,7 +517,7 @@ def destroy(self):
del_all(self.ids)
# del_all(self.attentions)
# del_all(self.hiddens)

@torch.no_grad()
def _infer_code(
self,
Expand Down Expand Up @@ -548,15 +559,15 @@ def _infer_code(
text = [f"[Stts][spk_emb]{txt_smp}{i}[Ptts]" for i in text]
else:
text = [f"[Stts][empty_spk]{txt_smp}{i}[Ptts]" for i in text]

input_ids, attention_mask, text_mask = self.tokenizer.encode(
text,
self.num_vq,
prompt_str=params.spk_smp,
device=self.device,
)
start_idx = input_ids.shape[-2]

num_code = self.num_audio_tokens - 1

logits_warpers, logits_processors = gen_logits(
Expand All @@ -565,34 +576,35 @@ def _infer_code(
top_K=params.top_K,
repetition_penalty=params.repetition_penalty,
)

sample_params = SamplingParams(
temperature=temperature,
max_new_token=params.max_new_token,
max_tokens = 8192,
max_tokens=8192,
min_new_token=params.min_new_token,
logits_processors=(logits_warpers, logits_processors),
eos_token = num_code,
eos_token=num_code,
infer_text=False,
start_idx=start_idx
start_idx=start_idx,
)
input_ids = [i.tolist() for i in input_ids]

result = gpt.generate(
None,
sample_params,
input_ids,
)

token_ids = []
hidden_states = []
for i in result:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(i.outputs[0].hidden_states.to(torch.float32).to(self.device))
return [self.GenerationOutputs(
ids=token_ids,
hiddens=hidden_states
),]
hidden_states.append(
i.outputs[0].hidden_states.to(torch.float32).to(self.device)
)
return [
self.GenerationOutputs(ids=token_ids, hiddens=hidden_states),
]

@torch.no_grad()
def _refine_text(
Expand All @@ -602,7 +614,7 @@ def _refine_text(
params: RefineTextParams,
):

gpt:LLM = self.gpt
gpt: LLM = self.gpt

if not isinstance(text, list):
text = [text]
Expand All @@ -614,7 +626,7 @@ def _refine_text(
self.num_vq,
device=self.device,
)

start_idx = input_ids.shape[-2]
# print(start_idx)
logits_warpers, logits_processors = gen_logits(
Expand All @@ -627,26 +639,19 @@ def _refine_text(
sample_params = SamplingParams(
temperature=params.temperature,
max_new_token=params.max_new_token,
max_tokens = 8192,
max_tokens=8192,
min_new_token=params.min_new_token,
logits_processors=(logits_warpers, logits_processors),
eos_token = self.tokenizer.eos_token,
eos_token=self.tokenizer.eos_token,
infer_text=True,
start_idx=start_idx
start_idx=start_idx,
)
input_ids = [i.tolist() for i in input_ids]

result = gpt.generate(
None,
sample_params,
input_ids
)

result = gpt.generate(None, sample_params, input_ids)
token_ids = []
hidden_states = []
for i in result:
token_ids.append(torch.tensor(i.outputs[0].token_ids))
hidden_states.append(i.outputs[0].hidden_states)
return self.GenerationOutputs(
ids=token_ids,
hiddens=hidden_states
)
return self.GenerationOutputs(ids=token_ids, hiddens=hidden_states)
43 changes: 23 additions & 20 deletions ChatTTS/model/velocity/block_manager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A block manager that manages token blocks."""

import enum
from typing import Dict, List, Optional, Set, Tuple

Expand Down Expand Up @@ -31,9 +32,9 @@ def __init__(
# Initialize the free blocks.
self.free_blocks: BlockTable = []
for i in range(num_blocks):
block = PhysicalTokenBlock(device=device,
block_number=i,
block_size=block_size)
block = PhysicalTokenBlock(
device=device, block_number=i, block_size=block_size
)
self.free_blocks.append(block)

def allocate(self) -> PhysicalTokenBlock:
Expand Down Expand Up @@ -63,6 +64,7 @@ class AllocStatus(enum.Enum):
3. Never: seq_group can never be allocated.
The seq_group is too large to allocated in GPU.
"""

OK = enum.auto()
LATER = enum.auto()
NEVER = enum.auto()
Expand All @@ -85,18 +87,15 @@ def __init__(

self.block_sliding_window = None
if sliding_window is not None:
assert sliding_window % block_size == 0, (sliding_window,
block_size)
assert sliding_window % block_size == 0, (sliding_window, block_size)
self.block_sliding_window = sliding_window // block_size

self.watermark = watermark
assert watermark >= 0.0

self.watermark_blocks = int(watermark * num_gpu_blocks)
self.gpu_allocator = BlockAllocator(Device.GPU, block_size,
num_gpu_blocks)
self.cpu_allocator = BlockAllocator(Device.CPU, block_size,
num_cpu_blocks)
self.gpu_allocator = BlockAllocator(Device.GPU, block_size, num_gpu_blocks)
self.cpu_allocator = BlockAllocator(Device.CPU, block_size, num_cpu_blocks)
# Mapping: seq_id -> BlockTable.
self.block_tables: Dict[int, BlockTable] = {}

Expand All @@ -106,13 +105,11 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus:
seq = seq_group.get_seqs(status=SequenceStatus.WAITING)[0]
num_required_blocks = len(seq.logical_token_blocks)
if self.block_sliding_window is not None:
num_required_blocks = min(num_required_blocks,
self.block_sliding_window)
num_required_blocks = min(num_required_blocks, self.block_sliding_window)
num_free_gpu_blocks = self.gpu_allocator.get_num_free_blocks()

# Use watermark to avoid frequent cache eviction.
if (self.num_total_gpu_blocks - num_required_blocks <
self.watermark_blocks):
if self.num_total_gpu_blocks - num_required_blocks < self.watermark_blocks:
return AllocStatus.NEVER
if num_free_gpu_blocks - num_required_blocks >= self.watermark_blocks:
return AllocStatus.OK
Expand All @@ -127,8 +124,10 @@ def allocate(self, seq_group: SequenceGroup) -> None:
# Allocate new physical token blocks that will store the prompt tokens.
block_table: BlockTable = []
for logical_idx in range(len(seq.logical_token_blocks)):
if (self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window):
if (
self.block_sliding_window is not None
and logical_idx >= self.block_sliding_window
):
block = block_table[logical_idx % self.block_sliding_window]
else:
block = self.gpu_allocator.allocate()
Expand All @@ -153,11 +152,14 @@ def append_slot(self, seq: Sequence) -> Optional[Tuple[int, int]]:
block_table = self.block_tables[seq.seq_id]

if len(block_table) < len(logical_blocks):
if (self.block_sliding_window
and len(block_table) >= self.block_sliding_window):
if (
self.block_sliding_window
and len(block_table) >= self.block_sliding_window
):
# re-use a block
block_table.append(block_table[len(block_table) %
self.block_sliding_window])
block_table.append(
block_table[len(block_table) % self.block_sliding_window]
)
else:
# The sequence has a new logical block.
# Allocate a new physical block.
Expand Down Expand Up @@ -188,7 +190,8 @@ def fork(self, parent_seq: Sequence, child_seq: Sequence) -> None:
block.ref_count += 1

def _get_physical_blocks(
self, seq_group: SequenceGroup) -> List[PhysicalTokenBlock]:
self, seq_group: SequenceGroup
) -> List[PhysicalTokenBlock]:
# NOTE: Here, we assume that the physical blocks are only shared by
# the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set()
Expand Down
Loading

0 comments on commit 776f2c4

Please sign in to comment.