From 8567dc372845faa27c841c8edb9c176bc64c2b2c Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Mon, 19 Feb 2024 20:56:09 +0100 Subject: [PATCH] add constrain_with_previous_records parameter to PointerNetworkTaskModuleForEnd2EndRE --- src/pie_modules/taskmodules/pointer_network_for_end2end_re.py | 4 ++++ tests/taskmodules/test_pointer_network_for_end2end_re.py | 2 ++ 2 files changed, 6 insertions(+) diff --git a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py index cdd9b1ef4..a5f66cedf 100644 --- a/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py +++ b/src/pie_modules/taskmodules/pointer_network_for_end2end_re.py @@ -185,6 +185,7 @@ def __init__( none_label: str = "none", loop_dummy_relation_name: str = "loop", constrained_generation: bool = False, + constrain_with_previous_records: bool = True, # generic pointer network label_tokens: Optional[Dict[str, str]] = None, label_representations: Optional[Dict[str, str]] = None, @@ -226,6 +227,7 @@ def __init__( self.none_label = none_label self.loop_dummy_relation_name = loop_dummy_relation_name self.constrained_generation = constrained_generation + self.constrain_with_previous_records = constrain_with_previous_records # will be set in _post_prepare() self.relation_encoder_decoder: BinaryRelationEncoderDecoder @@ -641,6 +643,7 @@ def decode_annotations( encoding=encoding.labels, input_length=self.tokenizer.model_max_length, stop_ids=[self.eos_id], + disrespect_decoded_annotations=not self.constrain_with_previous_records, ) return self.postprocess_decoded_relations(decoded_relations), errors except Exception as e: @@ -676,6 +679,7 @@ def get_follow_up_candidates(self, previous_ids: List[int], input_len: int) -> S input_length=input_len, stop_ids=[self.eos_id], decoded_annotations=decoded_relations, + disrespect_decoded_annotations=not self.constrain_with_previous_records, ) successfully_decoded = previous_ids[: len(previous_ids) - len(remaining)] self.cache_decoded.add(tuple(previous_ids), (decoded_relations, successfully_decoded)) diff --git a/tests/taskmodules/test_pointer_network_for_end2end_re.py b/tests/taskmodules/test_pointer_network_for_end2end_re.py index 9e94dd547..1115903b6 100644 --- a/tests/taskmodules/test_pointer_network_for_end2end_re.py +++ b/tests/taskmodules/test_pointer_network_for_end2end_re.py @@ -198,6 +198,7 @@ def test_prepared_config(taskmodule, config): "entities": "labeled_spans", "relations": "binary_relations", }, + "constrain_with_previous_records": True, "constrained_generation": False, "label_tokens": None, "label_representations": None, @@ -226,6 +227,7 @@ def test_prepared_config(taskmodule, config): "entities": "labeled_spans", "relations": "binary_relations", }, + "constrain_with_previous_records": True, "constrained_generation": False, "label_tokens": None, "label_representations": None,