Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix RESpanPairClassificationTaskModule #101

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 34 additions & 14 deletions src/pie_modules/taskmodules/re_span_pair_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,25 @@ def construct_argument_marker(pos: str, label: Optional[str] = None, role: str =

def inject_markers_into_text(
text: str, positions_and_markers: List[Tuple[int, str]]
) -> Tuple[str, Dict[int, int]]:
) -> Tuple[str, Dict[int, Tuple[int, List[str]]]]:
"""Inject markers into the text at the given positions.

Args:
text: The text to inject the markers into.
positions_and_markers: A list of tuples where each tuple contains the position in the text
where the marker should be injected and the marker text itself.

Returns:
A tuple containing the text with the markers injected and a dictionary mapping the original
positions to the new positions and the markers that were injected at that position.
"""
offset = 0
original2new_pos = dict()
original2new_pos: Dict[int, Tuple[int, List[str]]] = dict()
for original_pos, marker in sorted(positions_and_markers):
text = text[: original_pos + offset] + marker + text[original_pos + offset :]
previous_markers = original2new_pos.get(original_pos, (-1, []))[1]
original2new_pos[original_pos] = (original_pos + offset, previous_markers + [marker])
offset += len(marker)
original2new_pos[original_pos] = original_pos + offset
return text, original2new_pos


Expand Down Expand Up @@ -505,21 +517,23 @@ def inject_markers_for_labeled_spans(

if isinstance(document, TextDocumentWithLabeledPartitions):
# create "dummy" markers for the partitions so that entries for these positions are created
# in original2new_pos
# in original_pos2new_pos_and_markers
for labeled_partition in document.labeled_partitions:
positions_and_markers.append((labeled_partition.start, ""))
positions_and_markers.append((labeled_partition.end, ""))

# inject markers into the text
marked_text, original2new_pos = inject_markers_into_text(
marked_text, original_pos2new_pos_and_markers = inject_markers_into_text(
document.text, positions_and_markers
)

# construct new spans
old2new_spans = dict()
for labeled_span in document.labeled_spans:
start = original2new_pos[labeled_span.start]
end = original2new_pos[labeled_span.end]
start_before_markers, markers = original_pos2new_pos_and_markers[labeled_span.start]
# we use just the span *without* the markers as new span
start = start_before_markers + sum(len(marker) for marker in markers)
end = original_pos2new_pos_and_markers[labeled_span.end][0]
new_span = LabeledSpan(start=start, end=end, label=labeled_span.label)
old2new_spans[labeled_span] = new_span

Expand All @@ -546,9 +560,13 @@ def inject_markers_for_labeled_spans(
new_document.binary_relations.extend(old2new_relations.values())
if isinstance(document, TextDocumentWithLabeledPartitions):
for labeled_partition in document.labeled_partitions:
new_start = original2new_pos[labeled_partition.start]
new_end = original2new_pos[labeled_partition.end]
new_labeled_partitions = labeled_partition.copy(start=new_start, end=new_end)
# we use the span *including* the markers as new span
start, _ = original_pos2new_pos_and_markers[labeled_partition.start]
end_before_markers, markers = original_pos2new_pos_and_markers[
labeled_partition.end
]
end = end_before_markers + sum(len(marker) for marker in markers)
new_labeled_partitions = labeled_partition.copy(start=start, end=end)
new_document.labeled_partitions.append(new_labeled_partitions)

new2old_spans = {new_span: old_span for old_span, new_span in old2new_spans.items()}
Expand Down Expand Up @@ -657,7 +675,7 @@ def encode_target(
get_relation_argument_spans_and_roles(relation)
].append(relation)
label_indices = [] # list of label indices
candidate_relations = []
# candidate_relations = []
for candidate_relation in task_encoding.metadata["candidate_relations"]:
candidate_roles_and_args = get_relation_argument_spans_and_roles(candidate_relation)
gold_relations = gold_roles_and_args2relation.get(candidate_roles_and_args, [])
Expand All @@ -678,9 +696,9 @@ def encode_target(
label_idx = PAD_VALUES["labels"]

label_indices.append(label_idx)
candidate_relations.append(candidate_relation)
# candidate_relations.append(candidate_relation)

task_encoding.metadata["candidate_relations"] = candidate_relations
# task_encoding.metadata["candidate_relations"] = candidate_relations
target: TargetEncodingType = {"labels": to_tensor("labels", label_indices)}

self._maybe_log_example(task_encoding=task_encoding, target=target)
Expand Down Expand Up @@ -711,7 +729,9 @@ def _maybe_log_example(
):
logger.info(f"relation {i}: {label}")
for j, arg_idx in enumerate(tuple_indices):
arg_tokens = tokens[span_start_indices[arg_idx] : span_end_indices[arg_idx]]
arg_tokens = tokens[
span_start_indices[arg_idx] : span_end_indices[arg_idx] + 1
]
logger.info(f"\targ {j}: {' '.join([str(x) for x in arg_tokens])}")

self._logged_examples_counter += 1
Expand Down
81 changes: 38 additions & 43 deletions tests/models/test_span_tuple_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def batch():
]
),
"span_start_indices": tensor([[1, 9, 0, 0], [4, 12, 18, 0], [4, 12, 17, 21]]),
"span_end_indices": tensor([[7, 12, 0, 0], [10, 15, 21, 0], [10, 15, 20, 24]]),
"span_end_indices": tensor([[6, 11, 0, 0], [9, 14, 20, 0], [9, 14, 19, 23]]),
"tuple_indices": tensor(
[[[0, 1], [-1, -1], [-1, -1]], [[0, 1], [0, 2], [2, 1]], [[0, 1], [2, 3], [3, 2]]]
),
Expand Down Expand Up @@ -351,46 +351,41 @@ def test_forward_logits(batch, model):
tensor(
[
[
-0.23075447976589203,
0.08129829168319702,
-0.26441076397895813,
0.3208361268043518,
-0.3551301658153534,
0.09493370354175568,
-0.15801358222961426,
0.5679908990859985,
],
[
-0.2247302085161209,
0.21453489363193512,
-0.20609508454799652,
0.2984844148159027,
-0.266460657119751,
0.16119083762168884,
-0.10706772655248642,
0.5230874419212341,
],
[
-0.0552724152803421,
0.18319237232208252,
-0.14115819334983826,
0.23137536644935608,
-0.11953088641166687,
0.1623934805393219,
-0.04825110733509064,
0.43645235896110535,
],
[-0.2047966569662094, 0.17388933897018433, -0.06319254636764526, 0.4306640625],
[
-0.2897184491157532,
0.17462071776390076,
-0.12004873156547546,
0.1817789375782013,
-0.3208402395248413,
0.09282125532627106,
-0.05495951324701309,
0.4880615472793579,
],
[
-0.3101494312286377,
0.18245069682598114,
-0.13525372743606567,
0.28625163435935974,
-0.4020463228225708,
0.2283128798007965,
0.013205204159021378,
0.3972089886665344,
],
[
-0.33728304505348206,
0.22038179636001587,
-0.0482308566570282,
0.25237396359443665,
],
[
-0.3835912048816681,
0.20549766719341278,
0.15333643555641174,
0.23370930552482605,
-0.2575981616973877,
0.0700659453868866,
-0.010283984243869781,
0.4580671489238739,
],
]
),
Expand All @@ -402,7 +397,7 @@ def test_step(batch, model, config):
loss = model._step("train", batch)
assert loss is not None
if config == {}:
torch.testing.assert_close(loss, torch.tensor(1.3911350965499878))
torch.testing.assert_close(loss, torch.tensor(1.3872407674789429))
else:
raise ValueError(f"Unknown config: {config}")

Expand All @@ -413,7 +408,7 @@ def test_training_step_and_on_epoch_end(batch, model, config):
loss = model.training_step(batch, batch_idx=0)
assert loss is not None
if config == {}:
torch.testing.assert_close(loss, torch.tensor(1.3911350965499878))
torch.testing.assert_close(loss, torch.tensor(1.3872407674789429))
else:
raise ValueError(f"Unknown config: {config}")

Expand All @@ -427,7 +422,7 @@ def test_validation_step_and_on_epoch_end(batch, model, config):
assert loss is not None
metric_values = {k: v.item() for k, v in metric.compute().items()}
if config == {}:
torch.testing.assert_close(loss, torch.tensor(1.3911350965499878))
torch.testing.assert_close(loss, torch.tensor(1.3872407674789429))
assert metric_values == {
"macro/f1": 0.14814814925193787,
"micro/f1": 0.2857142984867096,
Expand All @@ -449,7 +444,7 @@ def test_test_step_and_on_epoch_end(batch, model, config):
assert loss is not None
metric_values = {k: v.item() for k, v in metric.compute().items()}
if config == {}:
torch.testing.assert_close(loss, torch.tensor(1.3911350965499878))
torch.testing.assert_close(loss, torch.tensor(1.3872407674789429))
assert metric_values == {
"macro/f1": 0.14814814925193787,
"micro/f1": 0.2857142984867096,
Expand Down Expand Up @@ -483,21 +478,21 @@ def test_predict_and_predict_step(model, batch, config, test_step):
tensor(
[
[
[0.1973, 0.2695, 0.1907, 0.3425],
[0.1586, 0.2488, 0.1932, 0.3993],
[-1.0000, -1.0000, -1.0000, -1.0000],
[-1.0000, -1.0000, -1.0000, -1.0000],
],
[
[0.1902, 0.2951, 0.1938, 0.3209],
[0.2213, 0.2809, 0.2031, 0.2947],
[0.1859, 0.2958, 0.2203, 0.2979],
[0.1692, 0.2596, 0.1985, 0.3727],
[0.1944, 0.2578, 0.2088, 0.3390],
[0.1818, 0.2655, 0.2095, 0.3432],
],
[
[0.1772, 0.2900, 0.2111, 0.3217],
[0.1699, 0.2968, 0.2269, 0.3064],
[0.1571, 0.2831, 0.2687, 0.2912],
[0.1650, 0.2495, 0.2152, 0.3704],
[0.1511, 0.2839, 0.2289, 0.3361],
[0.1750, 0.2429, 0.2241, 0.3580],
],
],
]
),
)
else:
Expand Down
Loading
Loading