Skip to content

Commit

Permalink
Merge branch 'master' into optimize-mean-subtoken-pooling
Browse files Browse the repository at this point in the history
  • Loading branch information
helpmefindaname authored Aug 23, 2024
2 parents 4d5ede7 + 3d8f078 commit 0bfd4f8
Showing 1 changed file with 5 additions and 7 deletions.
12 changes: 5 additions & 7 deletions flair/embeddings/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ def pad_sequence_embeddings(all_hidden_states: List[torch.Tensor]) -> torch.Tens

@torch.jit.script_if_tracing
def truncate_hidden_states(hidden_states: torch.Tensor, input_ids: torch.Tensor) -> torch.Tensor:
return hidden_states[:, :, : input_ids.size()[1]]
return hidden_states[:, :, : input_ids.size(1)]


@torch.jit.script_if_tracing
Expand Down Expand Up @@ -95,14 +95,12 @@ def combine_strided_tensors(
if selected_sentences.size(0) > 1:
start_part = selected_sentences[0, : half_stride + 1]
mid_part = selected_sentences[:, half_stride + 1 : max_length - 1 - half_stride]
mid_part = torch.reshape(mid_part, (mid_part.shape[0] * mid_part.shape[1],) + mid_part.shape[2:])
end_part = selected_sentences[selected_sentences.shape[0] - 1, max_length - half_stride - 1 :]
mid_part = torch.reshape(mid_part, (mid_part.size(0) * mid_part.size(1),) + mid_part.size()[2:])
end_part = selected_sentences[selected_sentences.size(0) - 1, max_length - half_stride - 1 :]
sentence_hidden_state = torch.cat((start_part, mid_part, end_part), dim=0)
sentence_hidden_states[sentence_id, : sentence_hidden_state.shape[0]] = torch.cat(
(start_part, mid_part, end_part), dim=0
)
sentence_hidden_states[sentence_id, : sentence_hidden_state.size(0)] = sentence_hidden_state
else:
sentence_hidden_states[sentence_id, : selected_sentences.shape[1]] = selected_sentences[0, :]
sentence_hidden_states[sentence_id, : selected_sentences.size(1)] = selected_sentences[0, :]

return sentence_hidden_states

Expand Down

0 comments on commit 0bfd4f8

Please sign in to comment.