From c1d31ea8fcdc655dce53b11fd955c82a9c1abbd8 Mon Sep 17 00:00:00 2001 From: Fangjun Kuang Date: Wed, 20 Sep 2023 19:12:37 +0800 Subject: [PATCH] Reduce whisper decoder file size with onnx export --- scripts/whisper/export-onnx.py | 38 +++++++++++++++++++++++++++++----- 1 file changed, 33 insertions(+), 5 deletions(-) diff --git a/scripts/whisper/export-onnx.py b/scripts/whisper/export-onnx.py index fbb0b132b..46594d12f 100755 --- a/scripts/whisper/export-onnx.py +++ b/scripts/whisper/export-onnx.py @@ -200,10 +200,25 @@ def forward( x = self.textDecoder.ln(x) - logits = ( - x - @ torch.transpose(self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1) - ).float() + if False: + # x.shape (1, 3, 384) + # weight.shape (51684, 384) + + logits = ( + x + @ torch.transpose( + self.textDecoder.token_embedding.weight.to(x.dtype), 0, 1 + ) + ).float() + else: + logits = ( + torch.matmul( + self.textDecoder.token_embedding.weight.to(x.dtype), + x.permute(0, 2, 1), + ) + .permute(0, 2, 1) + .float() + ) return logits, n_layer_self_k_cache, n_layer_self_v_cache @@ -246,6 +261,19 @@ def main(): opset_version = 13 model = whisper.load_model(name) + print( + f"number of model parameters: {name}", + sum(p.numel() for p in model.parameters()), + ) + print( + f"number of encoder parameters: {name}", + sum(p.numel() for p in model.encoder.parameters()), + ) + print( + f"number of decoder parameters: {name}", + sum(p.numel() for p in model.decoder.parameters()), + ) + convert_tokens(name=name, model=model) # write tokens @@ -419,7 +447,7 @@ def main(): }, ) - if 'large' in args.model: + if "large" in args.model: # it causes errors for large models, so skip it. return # Generate int8 quantization models