diff --git a/src/pie_modules/models/__init__.py b/src/pie_modules/models/__init__.py index f64038f80..e454a38e1 100644 --- a/src/pie_modules/models/__init__.py +++ b/src/pie_modules/models/__init__.py @@ -3,6 +3,7 @@ from .simple_generative import SimpleGenerativeModel from .simple_sequence_classification import SimpleSequenceClassificationModel from .simple_token_classification import SimpleTokenClassificationModel +from .span_similarity import SpanSimilarityModel from .span_tuple_classification import SpanTupleClassificationModel from .token_classification_with_seq2seq_encoder_and_crf import ( TokenClassificationModelWithSeq2SeqEncoderAndCrf, diff --git a/src/pie_modules/models/simple_similarity.py b/src/pie_modules/models/span_similarity.py similarity index 94% rename from src/pie_modules/models/simple_similarity.py rename to src/pie_modules/models/span_similarity.py index 6d3766f2c..ee436c1c0 100644 --- a/src/pie_modules/models/simple_similarity.py +++ b/src/pie_modules/models/span_similarity.py @@ -32,7 +32,7 @@ @PyTorchIEModel.register() -class SimpleSimilarityModel( +class SpanSimilarityModel( ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType], RequiresModelNameOrPath, ): @@ -63,11 +63,13 @@ def __init__( self, model_name_or_path: str, tokenizer_vocab_size: Optional[int] = None, + similarity_threshold: float = 0.5, classifier_dropout: Optional[float] = None, learning_rate: float = 1e-5, task_learning_rate: Optional[float] = None, warmup_proportion: float = 0.1, - # TODO: use "mention_pooling" per default? + # TODO: use "mention_pooling" per default? But this requires + # to also set num_indices=1 in the pooler_config pooler: Optional[Union[Dict[str, Any], str]] = None, freeze_base_model: bool = False, hidden_dim: Optional[int] = None, @@ -77,6 +79,7 @@ def __init__( self.save_hyperparameters() + self.similarity_threshold = similarity_threshold self.learning_rate = learning_rate self.task_learning_rate = task_learning_rate self.warmup_proportion = warmup_proportion @@ -108,6 +111,10 @@ def __init__( if isinstance(pooler, str): pooler = {"type": pooler} + if pooler is not None: + if pooler["type"] == "mention_pooling": + # we have only one index (span) per input to pool + pooler["num_indices"] = 1 self.pooler_config = pooler or {} self.pooler, pooler_output_dim = get_pooler_and_output_size( config=self.pooler_config, @@ -170,7 +177,7 @@ def forward( return SequenceClassifierOutput(**result) def decode(self, inputs: InputType, outputs: OutputType) -> TargetType: - labels = (outputs.logits > 0.5).to(torch.long) + labels = (outputs.logits >= self.similarity_threshold).to(torch.long) return {"labels": labels, "probabilities": outputs.logits}