Skip to content

Commit

Permalink
add test case
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 12, 2024
1 parent fa3cc45 commit e746d06
Showing 1 changed file with 14 additions and 3 deletions.
17 changes: 14 additions & 3 deletions tests/taskmodules/test_re_text_classification_with_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,6 @@ def test_encode_with_allow_discontinuous_text_and_binary_relations():
tokenizer_name_or_path=tokenizer_name_or_path,
max_window=128,
allow_discontinuous_text=True,
add_argument_indices_to_input=True,
add_global_attention_mask_to_input=True,
)
texts = [
"Loren ipsun dolor sit anet, consectetur adipisci elit, sed eiusnod tenpor incidunt ut labore et dolore nagna aliqua.",
Expand Down Expand Up @@ -879,7 +877,15 @@ def test_encode_with_allow_discontinuous_text_and_binary_relations():
)
doc.binary_relations.append(rel_consecutive)

taskmodule.prepare([doc])
# test document where everything is already included in one argument frame
doc2 = TextDocumentWithLabeledSpansAndBinaryRelations("A founded B.", id="123")
doc2.labeled_spans.append(LabeledSpan(start=0, end=1, label="PER"))
doc2.labeled_spans.append(LabeledSpan(start=10, end=11, label="PER"))
assert doc2.labeled_spans.resolve() == [("PER", "A"), ("PER", "B")]
rel = BinaryRelation(head=doc2.labeled_spans[0], tail=doc2.labeled_spans[1], label="relation")
doc2.binary_relations.append(rel)

taskmodule.prepare([doc, doc2])
encoded = taskmodule.encode_input(doc)

decoded_arg_start = taskmodule.tokenizer.decode(encoded[0].inputs["input_ids"])
Expand All @@ -901,6 +907,11 @@ def test_encode_with_allow_discontinuous_text_and_binary_relations():
== "[CLS] ex ea connodi consequatur. [H] Quis aute iure reprehenderit in voluptate velit esse cillun dolore eu fugiat nulla pariatur. [/H] [T] Excepteur sint obcaecat cupiditat non proident, sunt in culpa qui officia deserunt nollit anin id est laborun. [/T] [SEP]"
)

encoded2 = taskmodule.encode_input(doc2)
assert len(encoded2) == 1
decoded2 = taskmodule.tokenizer.decode(encoded2[0].inputs["input_ids"])
assert decoded2 == "[CLS] [H] A [/H] founded [T] B [/T]. [SEP]"


@pytest.fixture(scope="module")
def taskmodule_with_add_argument_indices(documents):
Expand Down

0 comments on commit e746d06

Please sign in to comment.