From d5ba8f377c932d9d7cf03142c292a686209cb628 Mon Sep 17 00:00:00 2001 From: suiyoubi Date: Sun, 5 Jan 2025 04:47:24 +0000 Subject: [PATCH] Apply isort and black reformatting Signed-off-by: suiyoubi --- nemo/collections/llm/bert/model/embedding.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/nemo/collections/llm/bert/model/embedding.py b/nemo/collections/llm/bert/model/embedding.py index 89a6df1c2930..289aedbd68d3 100644 --- a/nemo/collections/llm/bert/model/embedding.py +++ b/nemo/collections/llm/bert/model/embedding.py @@ -112,11 +112,12 @@ class BertEmbeddingMiniConfig(BertEmbeddingConfig): class BertEmbeddingHead(nn.Module): - """Performs mean pooling on the token embeddings. - """ + """Performs mean pooling on the token embeddings.""" def __init__( - self, word_embedding_dimension: int, pooling_mode_mean_tokens: bool = True, + self, + word_embedding_dimension: int, + pooling_mode_mean_tokens: bool = True, ): super(BertEmbeddingHead, self).__init__() @@ -128,8 +129,7 @@ def __init__( self.pooling_mode_mean_tokens = pooling_mode_mean_tokens def forward(self, token_embeddings: Tensor, attention_mask: Tensor): - """ Forward function for embedding head. Performs mean pooling. - """ + """Forward function for embedding head. Performs mean pooling.""" token_embeddings = token_embeddings.permute(1, 0, 2) input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)