From 78b35652a3ada2a08ea058c97a34d6e9ad786a39 Mon Sep 17 00:00:00 2001 From: Jack Date: Fri, 18 Aug 2023 14:14:22 +0800 Subject: [PATCH] fix issue with obtaining the decoder layer number when converting the T5 model. (#17185) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description fix issue with obtaining the decoder layer number when converting the T5 model. ### Motivation and Context fix issue: https://github.com/microsoft/onnxruntime/issues/17072 Test with [byt5-small](https://huggingface.co/google/byt5-small/tree/main) model, which has 12 encoder layers and 4 decoder layers. Here is the log. ![image](https://github.com/microsoft/onnxruntime/assets/3481539/ff1b69c5-f485-4301-a333-9ee2a984df07) --- .../transformers/models/t5/t5_decoder.py | 20 +++++++++++-------- .../models/t5/t5_encoder_decoder_init.py | 10 ++++++---- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py b/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py index 0b8f4919b8eb5..fe415aa7680fc 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_decoder.py @@ -100,7 +100,8 @@ def __init__(self, decoder, lm_head, config): ) def forward(self, decoder_input_ids, encoder_attention_mask, *past): - past_key_values = PastKeyValuesHelper.group_by_layer(past, self.config.num_layers) + num_decoder_layers = self.config.num_decoder_layers + past_key_values = PastKeyValuesHelper.group_by_layer(past, num_decoder_layers) # This is a hack since only the third dimension of encoder_hidden_states is used here dummy_encoder_hidden_states = encoder_attention_mask.unsqueeze(2) @@ -162,7 +163,7 @@ def create_dummy( T5DecoderInputs: dummy inputs for decoder """ num_attention_heads: int = config.num_heads - num_layers: int = config.num_layers + num_layers: int = config.num_decoder_layers vocab_size: int = config.vocab_size # Do not use head_size = hidden_size / num_attention_heads here. @@ -263,9 +264,11 @@ def export_onnx( ) input_list = inputs.to_list() - past_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=False) - present_names = PastKeyValuesHelper.get_past_names(decoder.config.num_layers, present=True) - present_self_names = present_names[: 2 * decoder.config.num_layers] + num_decoder_layers = decoder.config.num_decoder_layers + + past_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=False) + present_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=True) + present_self_names = present_names[: 2 * num_decoder_layers] input_past_names = past_names if isinstance(decoder, T5Decoder) else [] output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names @@ -407,20 +410,21 @@ def verify_onnx( torch_outputs = model(*input_list) ort_outputs = T5DecoderHelper.onnxruntime_inference(ort_session, inputs) + num_decoder_layers = model.config.num_decoder_layers max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0])) max_diff_all = max_diff logger.debug(f"logits max_diff={max_diff}") - for i in range(2 * model.config.num_layers): + for i in range(2 * num_decoder_layers): max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i])) logger.debug(f"self attention past state {i} max_diff={max_diff}") max_diff_all = max(max_diff_all, max_diff) if isinstance(model, T5DecoderInit): - for i in range(2 * model.config.num_layers): + for i in range(2 * num_decoder_layers): max_diff = numpy.amax( - numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * model.config.num_layers + i]) + numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * num_decoder_layers + i]) ) logger.debug(f"cross attention past state {i} max_diff={max_diff}") max_diff_all = max(max_diff_all, max_diff) diff --git a/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py b/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py index e3d600981ef0e..8870ca6f34780 100644 --- a/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py +++ b/onnxruntime/python/tools/transformers/models/t5/t5_encoder_decoder_init.py @@ -125,7 +125,7 @@ def export_onnx( ) input_list = inputs.to_list() - present_names = PastKeyValuesHelper.get_past_names(model.config.num_layers, present=True) + present_names = PastKeyValuesHelper.get_past_names(model.config.num_decoder_layers, present=True) output_names = ["logits", "encoder_hidden_states", *present_names] @@ -271,6 +271,8 @@ def verify_onnx( input_list = inputs.to_list() torch_outputs = model(*input_list) + num_decoder_layers = model.config.num_decoder_layers + assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0])) logger.debug(f"logits max_diff={max_diff}") @@ -281,13 +283,13 @@ def verify_onnx( logger.debug(f"encoder_hidden_states max_diff={max_diff}") max_diff_all = max(max_diff_all, max_diff) - for i in range(2 * model.config.num_layers): + for i in range(2 * num_decoder_layers): max_diff = numpy.amax(numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i])) logger.debug(f"self attention past state {i} max_diff={max_diff}") - for i in range(2 * model.config.num_layers): + for i in range(2 * num_decoder_layers): max_diff = numpy.amax( - numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * model.config.num_layers + i]) + numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * num_decoder_layers + i]) ) logger.debug(f"cross attention past state {i} max_diff={max_diff}") max_diff_all = max(max_diff_all, max_diff)