Skip to content

Commit

Permalink
chore(format): run black on dev (#627)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and github-actions[bot] authored Jul 24, 2024
1 parent a21bafa commit 4991dfd
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 9 deletions.
11 changes: 8 additions & 3 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,7 +549,9 @@ def _infer_code(

return [
GPT.GenerationOutputs(
ids=token_ids, hiddens=hidden_states, attentions=[],
ids=token_ids,
hiddens=hidden_states,
attentions=[],
),
]

Expand Down Expand Up @@ -640,8 +642,11 @@ def _refine_text(
del_all(logits_warpers)
del_all(logits_processors)

return GPT.GenerationOutputs(ids=token_ids, hiddens=hidden_states, attentions=[],
)
return GPT.GenerationOutputs(
ids=token_ids,
hiddens=hidden_states,
attentions=[],
)

emb = gpt(input_ids, text_mask)

Expand Down
14 changes: 8 additions & 6 deletions ChatTTS/model/velocity/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -782,14 +782,16 @@ def _make_tensor_with_pad(
padded_x = []
for x_i in x:
pad_i = pad
if isinstance(x[0][0],tuple):
if isinstance(x[0][0], tuple):
pad_i = (0,) * len(x[0][0])
padded_x.append(_pad_to_max(x_i, max_len, pad_i))

return torch.tensor(padded_x,
dtype=dtype,
device=device,
pin_memory=pin_memory and str(device) == "cpu")

return torch.tensor(
padded_x,
dtype=dtype,
device=device,
pin_memory=pin_memory and str(device) == "cpu",
)


def _get_graph_batch_size(batch_size: int) -> int:
Expand Down

0 comments on commit 4991dfd

Please sign in to comment.