Skip to content

Commit

Permalink
Merge pull request #124 from ArneBinder/sequence_classification_with_…
Browse files Browse the repository at this point in the history
…pooler/outsource_setup_pooler

encapsulate `get_pooler_and_output_size` in `SequenceClassificationModelWithPoolerBase`
  • Loading branch information
ArneBinder authored Oct 5, 2024
2 parents 4c4e0a7 + 4c5e6f6 commit c9be340
Showing 1 changed file with 15 additions and 4 deletions.
19 changes: 15 additions & 4 deletions src/pie_modules/models/sequence_classification_with_pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,7 @@ def __init__(
if isinstance(pooler, str):
pooler = {"type": pooler}
self.pooler_config = pooler or {}
self.pooler, pooler_output_dim = get_pooler_and_output_size(
config=self.pooler_config,
input_dim=self.model.config.hidden_size,
)
self.pooler, pooler_output_dim = self.setup_pooler(input_dim=self.model.config.hidden_size)
self.classifier = self.setup_classifier(pooler_output_dim=pooler_output_dim)
self.loss_fct = self.setup_loss_fct()

Expand All @@ -158,6 +155,20 @@ def setup_classifier(self, pooler_output_dim: int) -> Callable:
def setup_loss_fct(self) -> Callable:
pass

def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]:
"""Set up the pooler. The pooler is used to get a representation of the input sequence(s)
that can be used by the classifier. It is a callable that takes the hidden states of the
base model (and additional model inputs that are prefixed with "pooler_") and returns the
pooled output.
Args:
input_dim: The input dimension of the pooler, i.e. the hidden size of the base model.
Returns:
A tuple with the pooler and the output dimension of the pooler.
"""
return get_pooler_and_output_size(config=self.pooler_config, input_dim=input_dim)

def get_pooled_output(self, model_inputs, pooler_inputs) -> torch.FloatTensor:
output = self.model(**model_inputs)
hidden_state = output.last_hidden_state
Expand Down

0 comments on commit c9be340

Please sign in to comment.