Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: suiyoubi <[email protected]>
  • Loading branch information
suiyoubi committed Jan 5, 2025
1 parent e497e5b commit d5ba8f3
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions nemo/collections/llm/bert/model/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,12 @@ class BertEmbeddingMiniConfig(BertEmbeddingConfig):


class BertEmbeddingHead(nn.Module):
"""Performs mean pooling on the token embeddings.
"""
"""Performs mean pooling on the token embeddings."""

def __init__(
self, word_embedding_dimension: int, pooling_mode_mean_tokens: bool = True,
self,
word_embedding_dimension: int,
pooling_mode_mean_tokens: bool = True,
):
super(BertEmbeddingHead, self).__init__()

Expand All @@ -128,8 +129,7 @@ def __init__(
self.pooling_mode_mean_tokens = pooling_mode_mean_tokens

def forward(self, token_embeddings: Tensor, attention_mask: Tensor):
""" Forward function for embedding head. Performs mean pooling.
"""
"""Forward function for embedding head. Performs mean pooling."""
token_embeddings = token_embeddings.permute(1, 0, 2)
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
sum_embeddings = torch.sum(token_embeddings * input_mask_expanded, 1)
Expand Down

0 comments on commit d5ba8f3

Please sign in to comment.