From 4196245d06098503bcbe6f59b39f63f9946cfaf1 Mon Sep 17 00:00:00 2001 From: Benedikt Fuchs Date: Fri, 16 Aug 2024 16:40:42 +0200 Subject: [PATCH] make onnx export work again --- flair/embeddings/transformer.py | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) diff --git a/flair/embeddings/transformer.py b/flair/embeddings/transformer.py index b614444b3..5499dbed3 100644 --- a/flair/embeddings/transformer.py +++ b/flair/embeddings/transformer.py @@ -65,7 +65,7 @@ def pad_sequence_embeddings(all_hidden_states: List[torch.Tensor]) -> torch.Tens @torch.jit.script_if_tracing def truncate_hidden_states(hidden_states: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor: - return hidden_states[:, :, : input_ids.size()[1]] + return hidden_states[:, :, : input_ids.size(1)] @torch.jit.script_if_tracing @@ -95,14 +95,12 @@ def combine_strided_tensors( if selected_sentences.size(0) > 1: start_part = selected_sentences[0, : half_stride + 1] mid_part = selected_sentences[:, half_stride + 1 : max_length - 1 - half_stride] - mid_part = torch.reshape(mid_part, (mid_part.shape[0] * mid_part.shape[1],) + mid_part.shape[2:]) - end_part = selected_sentences[selected_sentences.shape[0] - 1, max_length - half_stride - 1 :] + mid_part = torch.reshape(mid_part, (mid_part.size(0) * mid_part.size(1),) + mid_part.size()[2:]) + end_part = selected_sentences[selected_sentences.size(0) - 1, max_length - half_stride - 1 :] sentence_hidden_state = torch.cat((start_part, mid_part, end_part), dim=0) - sentence_hidden_states[sentence_id, : sentence_hidden_state.shape[0]] = torch.cat( - (start_part, mid_part, end_part), dim=0 - ) + sentence_hidden_states[sentence_id, : sentence_hidden_state.size(0)] = sentence_hidden_state else: - sentence_hidden_states[sentence_id, : selected_sentences.shape[1]] = selected_sentences[0, :] + sentence_hidden_states[sentence_id, : selected_sentences.size(1)] = selected_sentences[0, :] return sentence_hidden_states