diff --git a/requirements.txt b/requirements.txt index aa33f5b..a719e5d 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ # --------- pytorch-ie --------- # pytorch-ie>=0.28.0,<0.30.0 pie-datasets>=0.8.1,<0.9.0 -pie-modules>=0.9.0,<=0.10.5 +git+https://github.com/ArneBinder/pie-modules.git # --------- hydra --------- # hydra-core>=1.3.0 diff --git a/tests/unit/serializer/test_brat.py b/tests/unit/serializer/test_brat.py index a50dc2f..486e900 100644 --- a/tests/unit/serializer/test_brat.py +++ b/tests/unit/serializer/test_brat.py @@ -108,31 +108,39 @@ def document(): document = TextDocumentWithLabeledSpansAndBinaryRelations( text="Harry lives in Berlin. He works at DFKI.", id="tmp_1" ) - predicted_labeled_spans = [ - LabeledSpan(start=0, end=5, label="PERSON"), # Harry - LabeledSpan(start=15, end=21, label="LOCATION"), # Berlin - ] - document.labeled_spans.predictions.extend(predicted_labeled_spans) - labeled_spans = [ - LabeledSpan(start=0, end=5, label="PERSON"), # Harry - LabeledSpan(start=15, end=21, label="LOCATION"), # Berlin - LabeledSpan(start=35, end=39, label="ORGANIZATION"), # DFKI - ] - - document.labeled_spans.extend(labeled_spans) - - predicted_binary_relations = [ - BinaryRelation(head=labeled_spans[0], tail=labeled_spans[1], label="lives_in"), - ] - document.binary_relations.predictions.extend(predicted_binary_relations) + document.labeled_spans.predictions.extend( + [ + LabeledSpan(start=0, end=5, label="PERSON"), # Harry + LabeledSpan(start=15, end=21, label="LOCATION"), # Berlin + ] + ) - binary_relations = [ - BinaryRelation(head=labeled_spans[0], tail=labeled_spans[1], label="lives_in"), - BinaryRelation(head=labeled_spans[0], tail=labeled_spans[2], label="works_at"), - ] + document.labeled_spans.extend( + [ + LabeledSpan(start=0, end=5, label="PERSON"), # Harry + LabeledSpan(start=15, end=21, label="LOCATION"), # Berlin + LabeledSpan(start=35, end=39, label="ORGANIZATION"), # DFKI + ] + ) - document.binary_relations.extend(binary_relations) + document.binary_relations.predictions.extend( + [ + BinaryRelation( + head=document.labeled_spans[0], tail=document.labeled_spans[1], label="lives_in" + ), + ] + ) + document.binary_relations.extend( + [ + BinaryRelation( + head=document.labeled_spans[0], tail=document.labeled_spans[1], label="lives_in" + ), + BinaryRelation( + head=document.labeled_spans[0], tail=document.labeled_spans[2], label="works_at" + ), + ] + ) return document @@ -362,35 +370,50 @@ def document_with_multispan(): document = TextDocumentWithLabeledSpansAndBinaryRelations( text="Harry lives in Berlin, Germany. He works at DFKI.", id="tmp" ) - entities = [ - LabeledSpan(start=0, end=5, label="PERSON"), - LabeledSpan(start=15, end=21, label="LOCATION"), - LabeledSpan(start=23, end=30, label="LOCATION"), - LabeledSpan(start=44, end=48, label="ORGANIZATION"), - ] - for ent in entities: - document.labeled_spans.predictions.append(ent) - - relations = [ - BinaryRelation(head=entities[0], tail=entities[1], label="lives_in"), - BinaryRelation( - head=entities[1], tail=entities[2], label="parts_of_same" - ), # should be removed - BinaryRelation( - head=entities[1], tail=entities[3], label="parts_of_same" - ), # should be removed - BinaryRelation(head=entities[0], tail=entities[3], label="works_at"), - BinaryRelation( - head=entities[3], tail=entities[1], label="located_in" - ), # tail should be a new merged entity - BinaryRelation( - head=entities[3], tail=entities[2], label="located_in" - ), # tail should be a new merged entity - ] - + document.labeled_spans.predictions.extend( + [ + LabeledSpan(start=0, end=5, label="PERSON"), + LabeledSpan(start=15, end=21, label="LOCATION"), + LabeledSpan(start=23, end=30, label="LOCATION"), + LabeledSpan(start=44, end=48, label="ORGANIZATION"), + ] + ) # add relations as predictions - for rel in relations: - document.binary_relations.predictions.append(rel) + + document.binary_relations.predictions.extend( + [ + BinaryRelation( + head=document.labeled_spans.predictions[0], + tail=document.labeled_spans.predictions[1], + label="lives_in", + ), + BinaryRelation( + head=document.labeled_spans.predictions[1], + tail=document.labeled_spans.predictions[2], + label="parts_of_same", + ), # should be removed + BinaryRelation( + head=document.labeled_spans.predictions[1], + tail=document.labeled_spans.predictions[3], + label="parts_of_same", + ), # should be removed + BinaryRelation( + head=document.labeled_spans.predictions[0], + tail=document.labeled_spans.predictions[3], + label="works_at", + ), + BinaryRelation( + head=document.labeled_spans.predictions[3], + tail=document.labeled_spans.predictions[1], + label="located_in", + ), # tail should be a new merged entity + BinaryRelation( + head=document.labeled_spans.predictions[3], + tail=document.labeled_spans.predictions[2], + label="located_in", + ), # tail should be a new merged entity + ] + ) return document