Skip to content

Commit

Permalink
chore(format): run black on dev (#643)
Browse files Browse the repository at this point in the history
Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
  • Loading branch information
github-actions[bot] and github-actions[bot] authored Jul 29, 2024
1 parent 8f49418 commit e4cd66e
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 42 deletions.
83 changes: 42 additions & 41 deletions examples/onnx/exporter.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,25 +11,27 @@
from examples.onnx.gpt import GPT

# disable cuda
torch.cuda.is_available = lambda : False
torch.cuda.is_available = lambda: False

# add args to control which modules to export
parser = argparse.ArgumentParser()
parser.add_argument("--gpt", action="store_true", help="trace gpt")
parser.add_argument("--decoder", action="store_true", help="trace decoder")
parser.add_argument("--vocos", action="store_true", help="trace vocos")
parser.add_argument("--pth_dir", default="./assets", type=str, help="path to the pth model directory")
parser.add_argument("--out_dir", default="./tmp", type=str, help="path to output directory")
parser.add_argument(
"--pth_dir", default="./assets", type=str, help="path to the pth model directory"
)
parser.add_argument(
"--out_dir", default="./tmp", type=str, help="path to output directory"
)

args = parser.parse_args()
chattts_config = Config()


def export_gpt():
gpt_model = GPT(
gpt_config=asdict(chattts_config.gpt),
use_flash_attn=False
).eval()
gpt_model.from_pretrained(asdict(chattts_config.path)['gpt_ckpt_path'])
gpt_model = GPT(gpt_config=asdict(chattts_config.gpt), use_flash_attn=False).eval()
gpt_model.from_pretrained(asdict(chattts_config.path)["gpt_ckpt_path"])
gpt_model = gpt_model.eval()
for param in gpt_model.parameters():
param.requires_grad = False
Expand All @@ -49,13 +51,13 @@ def export_gpt():

folder = os.path.join(args.out_dir, "gpt")
os.makedirs(folder, exist_ok=True)

for param in gpt_model.emb_text.parameters():
param.requires_grad = False

for param in gpt_model.emb_code.parameters():
param.requires_grad = False

for param in gpt_model.head_code.parameters():
param.requires_grad = False

Expand All @@ -68,7 +70,7 @@ def __init__(self, *args, **kwargs) -> None:

def forward(self, input_ids):
return gpt_model.emb_text(input_ids)

def convert_embedding_text():
model = EmbeddingText()
input_ids = torch.tensor([range(SEQ_LENGTH)])
Expand All @@ -84,7 +86,6 @@ def convert_embedding_text():
opset_version=15,
)


class EmbeddingCode(torch.nn.Module):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
Expand All @@ -94,11 +95,11 @@ def forward(self, input_ids):
-1, -1, gpt_model.num_vq
) # for forward_first_code
code_emb = [
gpt_model.emb_code[i](input_ids[:, :, i]) for i in range(gpt_model.num_vq)
gpt_model.emb_code[i](input_ids[:, :, i])
for i in range(gpt_model.num_vq)
]
return torch.stack(code_emb, 2).sum(2)


def convert_embedding_code():
model = EmbeddingCode()
input_ids = torch.tensor([range(SEQ_LENGTH)])
Expand All @@ -114,18 +115,17 @@ def convert_embedding_code():
opset_version=15,
)


class EmbeddingCodeCache(torch.nn.Module): # for forward_next_code
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)

def forward(self, input_ids):
code_emb = [
gpt_model.emb_code[i](input_ids[:, :, i]) for i in range(gpt_model.num_vq)
gpt_model.emb_code[i](input_ids[:, :, i])
for i in range(gpt_model.num_vq)
]
return torch.stack(code_emb, 2).sum(2)


def convert_embedding_code_cache():
model = EmbeddingCodeCache()
input_ids = torch.tensor(
Expand All @@ -142,7 +142,6 @@ def convert_embedding_code_cache():
opset_version=15,
)


class Block(torch.nn.Module):
def __init__(self, layer_id):
super().__init__()
Expand All @@ -162,7 +161,6 @@ def forward(self, hidden_states, position_ids, attention_mask):
hidden_states = self.norm(hidden_states)
return hidden_states, present_k, present_v


def convert_block(layer_id):
model = Block(layer_id)
hidden_states = torch.randn((1, SEQ_LENGTH, HIDDEN_SIZE))
Expand All @@ -182,7 +180,6 @@ def convert_block(layer_id):
opset_version=15,
)


class BlockCache(torch.nn.Module):

def __init__(self, layer_id):
Expand All @@ -204,7 +201,6 @@ def forward(self, hidden_states, position_ids, attention_mask, past_k, past_v):
hidden_states = self.norm(hidden_states)
return hidden_states, present_k, present_v


def convert_block_cache(layer_id):
model = BlockCache(layer_id)
hidden_states = torch.randn((1, 1, HIDDEN_SIZE))
Expand Down Expand Up @@ -232,7 +228,6 @@ def convert_block_cache(layer_id):
opset_version=15,
)


class GreedyHead(torch.nn.Module):

def __init__(self):
Expand All @@ -242,7 +237,6 @@ def forward(self, m_logits):
_, token = torch.topk(m_logits.float(), 1)
return token


def convert_greedy_head_text():
model = GreedyHead()
m_logits = torch.randn(1, TEXT_VOCAB_SIZE)
Expand All @@ -258,7 +252,6 @@ def convert_greedy_head_text():
opset_version=15,
)


def convert_greedy_head_code():
model = GreedyHead()
m_logits = torch.randn(1, AUDIO_VOCAB_SIZE, gpt_model.num_vq)
Expand All @@ -282,18 +275,20 @@ def forward(self, hidden_states):
m_logits = gpt_model.head_text(hidden_states)
return m_logits


class LmHead_infer_code(torch.nn.Module):
def __init__(self):
super().__init__()

def forward(self, hidden_states):
m_logits = torch.stack(
[gpt_model.head_code[i](hidden_states) for i in range(gpt_model.num_vq)], 2
[
gpt_model.head_code[i](hidden_states)
for i in range(gpt_model.num_vq)
],
2,
)
return m_logits


def convert_lm_head_text():
model = LmHead_infer_text()
input = torch.randn(1, HIDDEN_SIZE)
Expand All @@ -309,7 +304,6 @@ def convert_lm_head_text():
opset_version=15,
)


def convert_lm_head_code():
model = LmHead_infer_code()
input = torch.randn(1, HIDDEN_SIZE)
Expand Down Expand Up @@ -343,23 +337,26 @@ def convert_lm_head_code():
convert_greedy_head_text()
convert_greedy_head_code()


def export_decoder():
decoder = (
DVAE(
decoder_config=asdict(chattts_config.decoder),
dim=chattts_config.decoder.idim,
).eval()
)
decoder = DVAE(
decoder_config=asdict(chattts_config.decoder),
dim=chattts_config.decoder.idim,
).eval()
decoder.load_state_dict(
torch.load(asdict(chattts_config.path)['decoder_ckpt_path'], weights_only=True, mmap=True)
torch.load(
asdict(chattts_config.path)["decoder_ckpt_path"],
weights_only=True,
mmap=True,
)
)

for param in decoder.parameters():
param.requires_grad = False
rand_input = torch.rand([1, 768, 1024], requires_grad=False)

def mydec(_inp):
return decoder(_inp, mode='decode')
return decoder(_inp, mode="decode")

jitmodel = jit.trace(mydec, [rand_input])
jit.save(jitmodel, f"{args.out_dir}/decoder_jit.pt")
Expand All @@ -371,11 +368,15 @@ def export_vocos():
)
backbone = instantiate_class(args=(), init=asdict(chattts_config.vocos.backbone))
head = instantiate_class(args=(), init=asdict(chattts_config.vocos.head))
vocos = (
Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head).eval()
vocos = Vocos(
feature_extractor=feature_extractor, backbone=backbone, head=head
).eval()
vocos.load_state_dict(
torch.load(
asdict(chattts_config.path)["vocos_ckpt_path"], weights_only=True, mmap=True
)
)
vocos.load_state_dict(torch.load(asdict(chattts_config.path)['vocos_ckpt_path'], weights_only=True, mmap=True))


for param in vocos.parameters():
param.requires_grad = False
rand_input = torch.rand([1, 100, 2048], requires_grad=False)
Expand Down Expand Up @@ -414,4 +415,4 @@ def myvocos(_inp):
if args.vocos:
export_vocos()

print("Done. Please check the files in", args.out_dir)
print("Done. Please check the files in", args.out_dir)
5 changes: 4 additions & 1 deletion examples/onnx/gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from torch.nn.utils.parametrizations import weight_norm
from .modeling_llama import LlamaModel, LlamaConfig


class GPT(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -71,7 +72,9 @@ def __init__(
)

def from_pretrained(self, file_path: str):
self.load_state_dict(torch.load(file_path, weights_only=True, mmap=True), strict=False)
self.load_state_dict(
torch.load(file_path, weights_only=True, mmap=True), strict=False
)

def _build_llama(
self,
Expand Down

0 comments on commit e4cd66e

Please sign in to comment.