Skip to content

Commit

Permalink
remove simple_ prefix from tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder committed Nov 29, 2023
1 parent c329620 commit ecf0c73
Showing 1 changed file with 74 additions and 76 deletions.
150 changes: 74 additions & 76 deletions tests/taskmodules/test_pointer_network_for_joint_taskmodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

logger = logging.getLogger(__name__)

FIXTURES_DIR = FIXTURES_ROOT / "taskmodules" / "gmam_taskmodule"
# FIXTURES_DIR = FIXTURES_ROOT / "taskmodules" / "gmam_taskmodule"


def _config_to_str(cfg: Dict[str, str]) -> str:
Expand Down Expand Up @@ -41,7 +41,7 @@ def config_str(config):


@pytest.fixture(scope="module")
def simple_document():
def document():
@dataclass
class ExampleDocument(TextBasedDocument):
entities: AnnotationList[LabeledSpan] = annotation_field(target="text")
Expand All @@ -63,13 +63,13 @@ class ExampleDocument(TextBasedDocument):
return doc


def test_simple_document(simple_document):
spans = simple_document.entities
def test_document(document):
spans = document.entities
assert len(spans) == 3
assert (str(spans[0]), spans[0].label) == ("dummy text", "content")
assert (str(spans[1]), spans[1].label) == ("nothing", "topic")
assert (str(spans[2]), spans[2].label) == ("me", "person")
relations = simple_document.relations
relations = document.relations
assert len(relations) == 2
assert (str(relations[0].head), relations[0].label, str(relations[0].tail)) == (
"dummy text",
Expand All @@ -81,42 +81,42 @@ def test_simple_document(simple_document):
"no_relation",
"me",
)
sentences = simple_document.sentences
sentences = document.sentences
assert len(sentences) == 2
assert str(sentences[0]) == "This is a dummy text about nothing."
assert str(sentences[1]) == "Trust me."


SIMPLE_CONFIGS = [{}, {"partition_layer_name": "sentences"}]
SIMPLE_CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in SIMPLE_CONFIGS}
CONFIGS = [{}, {"partition_layer_name": "sentences"}]
CONFIG_DICT = {_config_to_str(cfg): cfg for cfg in CONFIGS}


@pytest.fixture(scope="module", params=SIMPLE_CONFIG_DICT.keys())
def simple_config_str(request):
@pytest.fixture(scope="module", params=CONFIG_DICT.keys())
def config_str(request):
return request.param


@pytest.fixture(scope="module")
def simple_config(simple_config_str):
return SIMPLE_CONFIG_DICT[simple_config_str]
def config(config_str):
return CONFIG_DICT[config_str]


@pytest.fixture(scope="module")
def simple_taskmodule(simple_document, simple_config):
def taskmodule(document, config):
taskmodule = PointerNetworkForJointTaskModule(
text_field_name="text",
span_layer_name="entities",
relation_layer_name="relations",
exclude_annotation_names={"relations": ["no_relation"]},
**simple_config,
**config,
)

taskmodule.prepare(documents=[simple_document])
taskmodule.prepare(documents=[document])
return taskmodule


def test_simple_taskmodule(simple_taskmodule):
tm = simple_taskmodule
def test_taskmodule(taskmodule):
tm = taskmodule
assert tm.prepared_attributes == {
"span_labels": ["content", "person", "topic"],
"relation_labels": ["is_about"],
Expand All @@ -143,24 +143,24 @@ def test_simple_taskmodule(simple_taskmodule):


@pytest.fixture()
def simple_encoded_inputs(simple_taskmodule, simple_document):
return simple_taskmodule.encode_input(simple_document)
def encoded_inputs(taskmodule, document):
return taskmodule.encode_input(document)


@pytest.fixture()
def simple_encoded_input(simple_encoded_inputs):
return simple_encoded_inputs[0]
def encoded_input(encoded_inputs):
return encoded_inputs[0]


def test_simple_encoded_input(simple_encoded_input, simple_taskmodule):
assert simple_encoded_input is not None
if simple_taskmodule.partition_layer_name is None:
assert simple_encoded_input.inputs == {
def test_encoded_input(encoded_input, taskmodule):
assert encoded_input is not None
if taskmodule.partition_layer_name is None:
assert encoded_input.inputs == {
"src_tokens": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2],
"src_seq_len": 13,
}
assert set(simple_encoded_input.metadata) == {"token2char", "char2token", "tokenized_span"}
token2char = simple_encoded_input.metadata["token2char"]
assert set(encoded_input.metadata) == {"token2char", "char2token", "tokenized_span"}
token2char = encoded_input.metadata["token2char"]
assert token2char == [
(0, 0),
(0, 4),
Expand All @@ -176,7 +176,7 @@ def test_simple_encoded_input(simple_encoded_input, simple_taskmodule):
(44, 45),
(0, 0),
]
char2token = simple_encoded_input.metadata["char2token"]
char2token = encoded_input.metadata["char2token"]
assert char2token == {
0: [1],
1: [1],
Expand Down Expand Up @@ -216,25 +216,25 @@ def test_simple_encoded_input(simple_encoded_input, simple_taskmodule):
43: [10],
44: [11],
}
assert simple_encoded_input.metadata.get("partition") is None
tokenized_span = simple_encoded_input.metadata["tokenized_span"]
text = simple_encoded_input.document.text
assert encoded_input.metadata.get("partition") is None
tokenized_span = encoded_input.metadata["tokenized_span"]
text = encoded_input.document.text
assert (
text[tokenized_span.start : tokenized_span.end]
== "This is a dummy text about nothing. Trust me."
)
elif simple_taskmodule.partition_layer_name == "sentences":
assert simple_encoded_input.inputs == {
elif taskmodule.partition_layer_name == "sentences":
assert encoded_input.inputs == {
"src_tokens": [0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 2],
"src_seq_len": 10,
}
assert set(simple_encoded_input.metadata) == {
assert set(encoded_input.metadata) == {
"token2char",
"char2token",
"partition",
"tokenized_span",
}
token2char = simple_encoded_input.metadata["token2char"]
token2char = encoded_input.metadata["token2char"]
assert token2char == [
(0, 0),
(0, 4),
Expand All @@ -247,7 +247,7 @@ def test_simple_encoded_input(simple_encoded_input, simple_taskmodule):
(34, 35),
(0, 0),
]
char2token = simple_encoded_input.metadata["char2token"]
char2token = encoded_input.metadata["char2token"]
assert char2token == {
0: [1],
1: [1],
Expand Down Expand Up @@ -279,34 +279,34 @@ def test_simple_encoded_input(simple_encoded_input, simple_taskmodule):
33: [7],
34: [8],
}
partition = simple_encoded_input.metadata.get("partition")
partition = encoded_input.metadata.get("partition")
assert (partition.start, partition.end) == (0, 35)
tokenized_span = simple_encoded_input.metadata["tokenized_span"]
text = simple_encoded_input.document.text
tokenized_span = encoded_input.metadata["tokenized_span"]
text = encoded_input.document.text
assert (
text[tokenized_span.start : tokenized_span.end]
== "This is a dummy text about nothing."
)
else:
raise Exception(f"unknown partition_layer_name: {simple_taskmodule.partition_layer_name}")
raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}")


@pytest.fixture()
def simple_task_encodings(simple_taskmodule, simple_encoded_inputs):
for encoded_input in simple_encoded_inputs:
targets = simple_taskmodule.encode_target(encoded_input)
def task_encodings(taskmodule, encoded_inputs):
for encoded_input in encoded_inputs:
targets = taskmodule.encode_target(encoded_input)
encoded_input.targets = targets
return simple_encoded_inputs
return encoded_inputs


@pytest.fixture()
def simple_task_encoding(simple_taskmodule, simple_task_encodings):
return simple_task_encodings[0]
def task_encoding(taskmodule, task_encodings):
return task_encodings[0]


def test_encode_target_with_dummy_relations(simple_task_encoding, simple_taskmodule):
targets = simple_task_encoding.targets
if simple_taskmodule.partition_layer_name is None:
def test_encode_target_with_dummy_relations(task_encoding, taskmodule):
targets = task_encoding.targets
if taskmodule.partition_layer_name is None:
assert targets["tgt_tokens"] == [0, 14, 14, 2, 11, 12, 6, 4, 17, 17, 5, 3, 3, 3, 3, 1]
assert targets["tgt_seq_len"] == 16
assert targets["CPM_tag"] == [
Expand All @@ -326,7 +326,7 @@ def test_encode_target_with_dummy_relations(simple_task_encoding, simple_taskmod
[0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]
elif simple_taskmodule.partition_layer_name == "sentences":
elif taskmodule.partition_layer_name == "sentences":
assert targets["tgt_tokens"] == [0, 14, 14, 2, 11, 12, 6, 4, 1]
assert targets["tgt_seq_len"] == 9
assert targets["CPM_tag"] == [
Expand All @@ -340,16 +340,16 @@ def test_encode_target_with_dummy_relations(simple_task_encoding, simple_taskmod
[0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
]
else:
raise Exception(f"unknown partition_layer_name: {simple_taskmodule.partition_layer_name}")
raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}")


@pytest.fixture()
def simple_batch(simple_taskmodule, simple_task_encodings):
return simple_taskmodule.collate(simple_task_encodings)
def batch(taskmodule, task_encodings):
return taskmodule.collate(task_encodings)


def test_simple_collate(simple_batch, simple_taskmodule):
inputs, targets = simple_batch
def test_collate(batch, taskmodule):
inputs, targets = batch
for tensor in inputs.values():
assert isinstance(tensor, torch.Tensor)
assert tensor.dtype == torch.int64
Expand All @@ -358,7 +358,7 @@ def test_simple_collate(simple_batch, simple_taskmodule):
assert tensor.dtype == torch.int64
inputs_lists = {k: inputs[k].tolist() for k in sorted(inputs)}
targets_lists = {k: targets[k].tolist() for k in sorted(targets)}
if simple_taskmodule.partition_layer_name is None:
if taskmodule.partition_layer_name is None:
assert inputs_lists == {
"src_seq_len": [13],
"src_tokens": [[0, 713, 16, 10, 34759, 2788, 59, 1085, 4, 3101, 162, 4, 2]],
Expand Down Expand Up @@ -386,7 +386,7 @@ def test_simple_collate(simple_batch, simple_taskmodule):
"tgt_seq_len": [16],
"tgt_tokens": [[0, 14, 14, 2, 11, 12, 6, 4, 17, 17, 5, 3, 3, 3, 3, 1]],
}
elif simple_taskmodule.partition_layer_name == "sentences":
elif taskmodule.partition_layer_name == "sentences":
assert inputs_lists == {
"src_seq_len": [10, 5],
"src_tokens": [
Expand Down Expand Up @@ -421,35 +421,35 @@ def test_simple_collate(simple_batch, simple_taskmodule):
"tgt_tokens": [[0, 14, 14, 2, 11, 12, 6, 4, 1], [0, 9, 9, 5, 3, 3, 3, 3, 1]],
}
else:
raise Exception(f"unknown partition_layer_name: {simple_taskmodule.partition_layer_name}")
raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}")


@pytest.fixture()
def simple_unbatched_output(simple_taskmodule, simple_batch):
inputs, targets = simple_batch
def unbatched_output(taskmodule, batch):
inputs, targets = batch
# because the model is trained to reproduce the target tokens, we can just use them as model prediction
model_output = {"pred": targets["tgt_tokens"]}
return simple_taskmodule.unbatch_output(model_output)
return taskmodule.unbatch_output(model_output)


@pytest.fixture()
def simple_task_outputs(simple_unbatched_output):
return simple_unbatched_output
def task_outputs(unbatched_output):
return unbatched_output


@pytest.fixture()
def simple_task_output(simple_task_outputs):
return simple_task_outputs[0]
def task_output(task_outputs):
return task_outputs[0]


def test_simple_task_output(simple_task_output, simple_taskmodule):
output_list = simple_task_output.tolist()
if simple_taskmodule.partition_layer_name is None:
def test_task_output(task_output, taskmodule):
output_list = task_output.tolist()
if taskmodule.partition_layer_name is None:
assert output_list == [0, 14, 14, 2, 11, 12, 6, 4, 17, 17, 5, 3, 3, 3, 3, 1]
elif simple_taskmodule.partition_layer_name == "sentences":
elif taskmodule.partition_layer_name == "sentences":
assert output_list == [0, 14, 14, 2, 11, 12, 6, 4, 1]
else:
raise Exception(f"unknown partition_layer_name: {simple_taskmodule.partition_layer_name}")
raise Exception(f"unknown partition_layer_name: {taskmodule.partition_layer_name}")


def _test_annotations_from_output(task_encodings, task_outputs, taskmodule, layer_names_expected):
Expand Down Expand Up @@ -496,12 +496,10 @@ def _test_annotations_from_output(task_encodings, task_outputs, taskmodule, laye
document[layer_name].predictions.clear()


def test_simple_annotations_from_output(
simple_task_encodings, simple_task_outputs, simple_taskmodule
):
def test_annotations_from_output(task_encodings, task_outputs, taskmodule):
_test_annotations_from_output(
taskmodule=simple_taskmodule,
task_encodings=simple_task_encodings,
task_outputs=simple_task_outputs,
taskmodule=taskmodule,
task_encodings=task_encodings,
task_outputs=task_outputs,
layer_names_expected={"entities", "relations"},
)

0 comments on commit ecf0c73

Please sign in to comment.