Skip to content

Commit

Permalink
remove SpanSimilarityModel in favor of new SequencePairSimilarityMode…
Browse files Browse the repository at this point in the history
…lWithPooler
  • Loading branch information
ArneBinder committed Sep 12, 2024
1 parent 508bac3 commit 2e2d44d
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 229 deletions.
6 changes: 4 additions & 2 deletions src/pie_modules/models/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from .sequence_classification_with_pooler import SequenceClassificationModelWithPooler
from .sequence_classification_with_pooler import (
SequenceClassificationModelWithPooler,
SequencePairSimilarityModelWithPooler,
)
from .simple_extractive_question_answering import SimpleExtractiveQuestionAnsweringModel
from .simple_generative import SimpleGenerativeModel
from .simple_sequence_classification import SimpleSequenceClassificationModel
from .simple_token_classification import SimpleTokenClassificationModel
from .span_similarity import SpanSimilarityModel
from .span_tuple_classification import SpanTupleClassificationModel
from .token_classification_with_seq2seq_encoder_and_crf import (
TokenClassificationModelWithSeq2SeqEncoderAndCrf,
Expand Down
63 changes: 63 additions & 0 deletions src/pie_modules/models/sequence_classification_with_pooler.py
Original file line number Diff line number Diff line change
Expand Up @@ -266,3 +266,66 @@ def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
probabilities = torch.sigmoid(outputs.logits)
labels = (probabilities > self.multi_label_threshold).to(torch.long)
return {"labels": labels, "probabilities": probabilities}


@PyTorchIEModel.register()
class SequencePairSimilarityModelWithPooler(
SequenceClassificationModelWithPoolerBase,
):
"""TODO.
Args:
label_threshold: The threshold for the multi-label classifier, i.e. the probability
above which a class is predicted.
**kwargs
"""

def __init__(self, label_threshold: float = 0.5, **kwargs):
super().__init__(**kwargs)
self.multi_label_threshold = label_threshold

def setup_classifier(self, pooler_output_dim: int) -> Callable:
return torch.nn.functional.cosine_similarity

def setup_loss_fct(self):
return nn.BCELoss()

def forward(
self,
inputs: InputType,
targets: Optional[TargetType] = None,
return_hidden_states: bool = False,
) -> OutputType:
sanitized_inputs = separate_arguments_by_prefix(
# Note that the order of the prefixes is important because one is a prefix of the other,
# so we need to start with the longer!
arguments=inputs,
prefixes=["pooler_pair_", "pooler_"],
)

pooled_output = self.get_pooled_output(
model_inputs=sanitized_inputs["remaining"]["encoding"],
pooler_inputs=sanitized_inputs["pooler_"],
)
pooled_output_pair = self.get_pooled_output(
model_inputs=sanitized_inputs["remaining"]["encoding_pair"],
pooler_inputs=sanitized_inputs["pooler_pair_"],
)

logits = self.classifier(pooled_output, pooled_output_pair)

result = {"logits": logits}
if targets is not None:
labels = targets["labels"]
loss = self.loss_fct(logits, labels)
result["loss"] = loss
if return_hidden_states:
raise NotImplementedError("return_hidden_states is not yet implemented")

return SequenceClassifierOutput(**result)

def decode(self, inputs: InputType, outputs: OutputType) -> TargetType:
# probabilities = torch.sigmoid(outputs.logits)
probabilities = outputs.logits
labels = (probabilities > self.multi_label_threshold).to(torch.long)
return {"labels": labels, "probabilities": probabilities}
217 changes: 0 additions & 217 deletions src/pie_modules/models/span_similarity.py

This file was deleted.

8 changes: 4 additions & 4 deletions src/pie_modules/taskmodules/cross_text_binary_coref.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,8 @@ def encode_input(
"encoding_pair": encoding_pair,
"pooler_start_indices": start,
"pooler_end_indices": end,
"pooler_start_indices_pair": start_pair,
"pooler_end_indices_pair": end_pair,
"pooler_pair_start_indices": start_pair,
"pooler_pair_end_indices": end_pair,
},
metadata={"candidate_annotation": coref_rel},
)
Expand All @@ -201,13 +201,13 @@ def collate(
)

inputs = {
k: self.tokenizer.pad(v, return_tensors="pt")
k: self.tokenizer.pad(v, return_tensors="pt").data
if k in ["encoding", "encoding_pair"]
else torch.tensor(v)
for k, v in inputs_dict.items()
}
for k, v in inputs.items():
if k.startswith("pooler_start_indices") or k.startswith("pooler_end_indices"):
if k.startswith("pooler_") and k.endswith("_indices"):
inputs[k] = v.unsqueeze(-1)

if not task_encodings[0].has_targets:
Expand Down
12 changes: 6 additions & 6 deletions tests/taskmodules/test_cross_text_binary_coref.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,8 +218,8 @@ def test_encode_input(task_encodings_without_target, taskmodule):
toks[start:end]
for toks, start, end in zip(
tokens_pair,
inputs_dict["pooler_start_indices_pair"],
inputs_dict["pooler_end_indices_pair"],
inputs_dict["pooler_pair_start_indices"],
inputs_dict["pooler_pair_end_indices"],
)
]
assert span_tokens == [["she"], ["she"], ["C"], ["C"]]
Expand Down Expand Up @@ -279,10 +279,10 @@ def test_collate(batch, taskmodule):
assert set(inputs) == {
"pooler_end_indices",
"encoding_pair",
"pooler_end_indices_pair",
"pooler_pair_end_indices",
"pooler_start_indices",
"encoding",
"pooler_start_indices_pair",
"pooler_pair_start_indices",
}
torch.testing.assert_close(
inputs["encoding"]["input_ids"],
Expand Down Expand Up @@ -325,10 +325,10 @@ def test_collate(batch, taskmodule):
torch.testing.assert_close(inputs["pooler_start_indices"], torch.tensor([[2], [2], [4], [4]]))
torch.testing.assert_close(inputs["pooler_end_indices"], torch.tensor([[3], [3], [5], [5]]))
torch.testing.assert_close(
inputs["pooler_start_indices_pair"], torch.tensor([[1], [3], [1], [3]])
inputs["pooler_pair_start_indices"], torch.tensor([[1], [3], [1], [3]])
)
torch.testing.assert_close(
inputs["pooler_end_indices_pair"], torch.tensor([[2], [5], [2], [5]])
inputs["pooler_pair_end_indices"], torch.tensor([[2], [5], [2], [5]])
)

torch.testing.assert_close(targets, {"labels": torch.tensor([0.0, 0.0, 0.0, 0.0])})
Expand Down

0 comments on commit 2e2d44d

Please sign in to comment.