Skip to content

Commit

Permalink
rename model to SpanSimilarityModel; add similarity_threshold paramet…
Browse files Browse the repository at this point in the history
…er; set num_indices when "mention_pooling" is used
  • Loading branch information
ArneBinder committed Sep 12, 2024
1 parent 7b3759f commit 508bac3
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 3 deletions.
1 change: 1 addition & 0 deletions src/pie_modules/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from .simple_generative import SimpleGenerativeModel
from .simple_sequence_classification import SimpleSequenceClassificationModel
from .simple_token_classification import SimpleTokenClassificationModel
from .span_similarity import SpanSimilarityModel
from .span_tuple_classification import SpanTupleClassificationModel
from .token_classification_with_seq2seq_encoder_and_crf import (
TokenClassificationModelWithSeq2SeqEncoderAndCrf,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


@PyTorchIEModel.register()
class SimpleSimilarityModel(
class SpanSimilarityModel(
ModelWithBoilerplate[InputType, OutputType, TargetType, StepOutputType],
RequiresModelNameOrPath,
):
Expand Down Expand Up @@ -63,11 +63,13 @@ def __init__(
self,
model_name_or_path: str,
tokenizer_vocab_size: Optional[int] = None,
similarity_threshold: float = 0.5,
classifier_dropout: Optional[float] = None,
learning_rate: float = 1e-5,
task_learning_rate: Optional[float] = None,
warmup_proportion: float = 0.1,
# TODO: use "mention_pooling" per default?
# TODO: use "mention_pooling" per default? But this requires
# to also set num_indices=1 in the pooler_config
pooler: Optional[Union[Dict[str, Any], str]] = None,
freeze_base_model: bool = False,
hidden_dim: Optional[int] = None,
Expand All @@ -77,6 +79,7 @@ def __init__(

self.save_hyperparameters()

self.similarity_threshold = similarity_threshold
self.learning_rate = learning_rate
self.task_learning_rate = task_learning_rate
self.warmup_proportion = warmup_proportion
Expand Down Expand Up @@ -108,6 +111,10 @@ def __init__(

if isinstance(pooler, str):
pooler = {"type": pooler}
if pooler is not None:
if pooler["type"] == "mention_pooling":
# we have only one index (span) per input to pool
pooler["num_indices"] = 1
self.pooler_config = pooler or {}
self.pooler, pooler_output_dim = get_pooler_and_output_size(
config=self.pooler_config,
Expand Down Expand Up @@ -170,7 +177,7 @@ def forward(
return SequenceClassifierOutput(**result)

def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
labels = (outputs.logits > 0.5).to(torch.long)
labels = (outputs.logits >= self.similarity_threshold).to(torch.long)

return {"labels": labels, "probabilities": outputs.logits}

Expand Down

0 comments on commit 508bac3

Please sign in to comment.