Skip to content

Commit

Permalink
add constrain_with_previous_records parameter to PointerNetworkTaskMo…
Browse files Browse the repository at this point in the history
…duleForEnd2EndRE
  • Loading branch information
ArneBinder committed Feb 20, 2024
1 parent 3fd0233 commit 8567dc3
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 0 deletions.
4 changes: 4 additions & 0 deletions src/pie_modules/taskmodules/pointer_network_for_end2end_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ def __init__(
none_label: str = "none",
loop_dummy_relation_name: str = "loop",
constrained_generation: bool = False,
constrain_with_previous_records: bool = True,
# generic pointer network
label_tokens: Optional[Dict[str, str]] = None,
label_representations: Optional[Dict[str, str]] = None,
Expand Down Expand Up @@ -226,6 +227,7 @@ def __init__(
self.none_label = none_label
self.loop_dummy_relation_name = loop_dummy_relation_name
self.constrained_generation = constrained_generation
self.constrain_with_previous_records = constrain_with_previous_records
# will be set in _post_prepare()
self.relation_encoder_decoder: BinaryRelationEncoderDecoder

Expand Down Expand Up @@ -641,6 +643,7 @@ def decode_annotations(
encoding=encoding.labels,
input_length=self.tokenizer.model_max_length,
stop_ids=[self.eos_id],
disrespect_decoded_annotations=not self.constrain_with_previous_records,
)
return self.postprocess_decoded_relations(decoded_relations), errors
except Exception as e:
Expand Down Expand Up @@ -676,6 +679,7 @@ def get_follow_up_candidates(self, previous_ids: List[int], input_len: int) -> S
input_length=input_len,
stop_ids=[self.eos_id],
decoded_annotations=decoded_relations,
disrespect_decoded_annotations=not self.constrain_with_previous_records,
)
successfully_decoded = previous_ids[: len(previous_ids) - len(remaining)]
self.cache_decoded.add(tuple(previous_ids), (decoded_relations, successfully_decoded))
Expand Down
2 changes: 2 additions & 0 deletions tests/taskmodules/test_pointer_network_for_end2end_re.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ def test_prepared_config(taskmodule, config):
"entities": "labeled_spans",
"relations": "binary_relations",
},
"constrain_with_previous_records": True,
"constrained_generation": False,
"label_tokens": None,
"label_representations": None,
Expand Down Expand Up @@ -226,6 +227,7 @@ def test_prepared_config(taskmodule, config):
"entities": "labeled_spans",
"relations": "binary_relations",
},
"constrain_with_previous_records": True,
"constrained_generation": False,
"label_tokens": None,
"label_representations": None,
Expand Down

0 comments on commit 8567dc3

Please sign in to comment.