From 4c5e6f6ea2d571ef8a501b8950d059e2123b9e03 Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Sat, 5 Oct 2024 15:44:33 +0200 Subject: [PATCH] implement SequenceClassificationModelWithPoolerBase.setup_pooler() --- .../sequence_classification_with_pooler.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/src/pie_modules/models/sequence_classification_with_pooler.py b/src/pie_modules/models/sequence_classification_with_pooler.py index 400011064..719dfd8a5 100644 --- a/src/pie_modules/models/sequence_classification_with_pooler.py +++ b/src/pie_modules/models/sequence_classification_with_pooler.py @@ -136,10 +136,7 @@ def __init__( if isinstance(pooler, str): pooler = {"type": pooler} self.pooler_config = pooler or {} - self.pooler, pooler_output_dim = get_pooler_and_output_size( - config=self.pooler_config, - input_dim=self.model.config.hidden_size, - ) + self.pooler, pooler_output_dim = self.setup_pooler(input_dim=self.model.config.hidden_size) self.classifier = self.setup_classifier(pooler_output_dim=pooler_output_dim) self.loss_fct = self.setup_loss_fct() @@ -158,6 +155,20 @@ def setup_classifier(self, pooler_output_dim: int) -> Callable: def setup_loss_fct(self) -> Callable: pass + def setup_pooler(self, input_dim: int) -> Tuple[Callable, int]: + """Set up the pooler. The pooler is used to get a representation of the input sequence(s) + that can be used by the classifier. It is a callable that takes the hidden states of the + base model (and additional model inputs that are prefixed with "pooler_") and returns the + pooled output. + + Args: + input_dim: The input dimension of the pooler, i.e. the hidden size of the base model. + + Returns: + A tuple with the pooler and the output dimension of the pooler. + """ + return get_pooler_and_output_size(config=self.pooler_config, input_dim=input_dim) + def get_pooled_output(self, model_inputs, pooler_inputs) -> torch.FloatTensor: output = self.model(**model_inputs) hidden_state = output.last_hidden_state