Skip to content

Commit

Permalink
add tests for example code (dataset, taskmodule, metric) (#149)
Browse files Browse the repository at this point in the history
  • Loading branch information
ArneBinder authored Nov 24, 2023
1 parent a91c6e0 commit 46bacb9
Show file tree
Hide file tree
Showing 8 changed files with 1,274 additions and 0 deletions.
1 change: 1 addition & 0 deletions src/metrics/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .f1 import F1Metric
5 changes: 5 additions & 0 deletions tests/dataset_builders/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from pathlib import Path

DATASET_BUILDER_BASE_PATH = Path("dataset_builders")
PIE_BASE_PATH = DATASET_BUILDER_BASE_PATH / "pie"
HF_BASE_PATH = DATASET_BUILDER_BASE_PATH / "hf"
118 changes: 118 additions & 0 deletions tests/dataset_builders/pie/test_conll2003.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import datasets
import pytest
from pie_datasets import DatasetDict
from pytorch_ie.core import Document
from pytorch_ie.documents import TextDocumentWithLabeledSpans

from dataset_builders.pie.conll2003.conll2003 import Conll2003
from tests.dataset_builders import PIE_BASE_PATH

DATASET_NAME = "conll2003"
PIE_DATASET_PATH = PIE_BASE_PATH / DATASET_NAME
HF_DATASET_PATH = Conll2003.BASE_DATASET_PATH
SPLIT_NAMES = {"train", "validation", "test"}
SPLIT_SIZES = {"train": 14041, "validation": 3250, "test": 3453}


@pytest.fixture(params=[config.name for config in Conll2003.BUILDER_CONFIGS], scope="module")
def dataset_name(request):
return request.param


@pytest.fixture(scope="module")
def hf_dataset(dataset_name):
return datasets.load_dataset(str(HF_DATASET_PATH), name=dataset_name)


def test_hf_dataset(hf_dataset):
assert set(hf_dataset) == SPLIT_NAMES
split_sizes = {split_name: len(ds) for split_name, ds in hf_dataset.items()}
assert split_sizes == SPLIT_SIZES


@pytest.fixture(scope="module")
def hf_example(hf_dataset):
return hf_dataset["train"][0]


def test_hf_example(hf_example, dataset_name):
if dataset_name == "conll2003":
assert hf_example == {
"chunk_tags": [11, 21, 11, 12, 21, 22, 11, 12, 0],
"id": "0",
"ner_tags": [3, 0, 7, 0, 0, 0, 7, 0, 0],
"pos_tags": [22, 42, 16, 21, 35, 37, 16, 21, 7],
"tokens": ["EU", "rejects", "German", "call", "to", "boycott", "British", "lamb", "."],
}
else:
raise ValueError(f"Unknown dataset name: {dataset_name}")


@pytest.fixture(scope="module")
def document(hf_example, hf_dataset):
conll2003 = Conll2003()
generate_document_kwargs = conll2003._generate_document_kwargs(hf_dataset["train"])
document = conll2003._generate_document(example=hf_example, **generate_document_kwargs)
return document


def test_document(document, dataset_name):
assert isinstance(document, Document)
if dataset_name == "conll2003":
assert document.text == "EU rejects German call to boycott British lamb ."
entities = list(document.entities)
assert len(entities) == 3
assert str(entities[0]) == "EU"
assert str(entities[1]) == "German"
assert str(entities[2]) == "British"
else:
raise ValueError(f"Unknown dataset name: {dataset_name}")


@pytest.fixture(scope="module")
def pie_dataset(dataset_name):
return DatasetDict.load_dataset(str(PIE_DATASET_PATH), name=dataset_name)


def test_pie_dataset(pie_dataset):
assert set(pie_dataset) == SPLIT_NAMES
split_sizes = {split_name: len(ds) for split_name, ds in pie_dataset.items()}
assert split_sizes == SPLIT_SIZES


@pytest.fixture(scope="module", params=list(Conll2003.DOCUMENT_CONVERTERS))
def converter_document_type(request):
return request.param


@pytest.fixture(scope="module")
def converted_pie_dataset(pie_dataset, converter_document_type):
pie_dataset_converted = pie_dataset.to_document_type(document_type=converter_document_type)
return pie_dataset_converted


def test_converted_pie_dataset(converted_pie_dataset, converter_document_type):
assert set(converted_pie_dataset) == SPLIT_NAMES
split_sizes = {split_name: len(ds) for split_name, ds in converted_pie_dataset.items()}
assert split_sizes == SPLIT_SIZES
for ds in converted_pie_dataset.values():
for document in ds:
assert isinstance(document, converter_document_type)


@pytest.fixture(scope="module")
def converted_document(converted_pie_dataset):
return converted_pie_dataset["train"][0]


def test_converted_document(converted_document, converter_document_type):
assert isinstance(converted_document, converter_document_type)
if converter_document_type == TextDocumentWithLabeledSpans:
assert converted_document.text == "EU rejects German call to boycott British lamb ."
entities = list(converted_document.labeled_spans)
assert len(entities) == 3
assert str(entities[0]) == "EU"
assert str(entities[1]) == "German"
assert str(entities[2]) == "British"
else:
raise ValueError(f"Unknown converter document type: {converter_document_type}")
Empty file added tests/unit/document/__init__.py
Empty file.
Empty file added tests/unit/metrics/__init__.py
Empty file.
109 changes: 109 additions & 0 deletions tests/unit/metrics/test_f1.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from dataclasses import dataclass

import pytest
from pytorch_ie.annotations import LabeledSpan
from pytorch_ie.core import AnnotationLayer, annotation_field
from pytorch_ie.documents import TextBasedDocument

from src.metrics import F1Metric


@pytest.fixture
def documents():
@dataclass
class TextDocumentWithEntities(TextBasedDocument):
entities: AnnotationLayer[LabeledSpan] = annotation_field(target="text")

# a test sentence with two entities
doc1 = TextDocumentWithEntities(
text="The quick brown fox jumps over the lazy dog.",
)
doc1.entities.append(LabeledSpan(start=4, end=19, label="animal"))
doc1.entities.append(LabeledSpan(start=35, end=43, label="animal"))
assert str(doc1.entities[0]) == "quick brown fox"
assert str(doc1.entities[1]) == "lazy dog"

# a second test sentence with a different text and a single entity (a company)
doc2 = TextDocumentWithEntities(text="Apple is a great company.")
doc2.entities.append(LabeledSpan(start=0, end=5, label="company"))
assert str(doc2.entities[0]) == "Apple"

documents = [doc1, doc2]

# add predictions
# correct
documents[0].entities.predictions.append(LabeledSpan(start=4, end=19, label="animal"))
# correct, but duplicate, this should not be counted
documents[0].entities.predictions.append(LabeledSpan(start=4, end=19, label="animal"))
# correct
documents[0].entities.predictions.append(LabeledSpan(start=35, end=43, label="animal"))
# wrong label
documents[0].entities.predictions.append(LabeledSpan(start=35, end=43, label="cat"))
# correct
documents[1].entities.predictions.append(LabeledSpan(start=0, end=5, label="company"))
# wrong span
documents[1].entities.predictions.append(LabeledSpan(start=10, end=15, label="company"))

return documents


def test_f1(documents):
metric = F1Metric(layer="entities")
metric(documents)
# tp, fp, fn for micro
assert dict(metric.counts) == {"MICRO": (3, 2, 0)}
assert metric.compute() == {"MICRO": {"f1": 0.7499999999999999, "p": 0.6, "r": 1.0}}


def test_f1_per_label(documents):
metric = F1Metric(layer="entities", labels=["animal", "company", "cat"])
metric(documents)
# tp, fp, fn for micro and per label
assert dict(metric.counts) == {
"MICRO": (3, 2, 0),
"cat": (0, 1, 0),
"company": (1, 1, 0),
"animal": (2, 0, 0),
}
assert metric.compute() == {
"MACRO": {"f1": 0.5555555555555556, "p": 0.5, "r": 0.6666666666666666},
"MICRO": {"f1": 0.7499999999999999, "p": 0.6, "r": 1.0},
"cat": {"f1": 0.0, "p": 0.0, "r": 0.0},
"company": {"f1": 0.6666666666666666, "p": 0.5, "r": 1.0},
"animal": {"f1": 1.0, "p": 1.0, "r": 1.0},
}


def test_f1_per_label_no_labels(documents):
with pytest.raises(ValueError) as excinfo:
F1Metric(layer="entities", labels=[])
assert str(excinfo.value) == "labels cannot be empty"


def test_f1_per_label_not_allowed():
with pytest.raises(ValueError) as excinfo:
F1Metric(layer="entities", labels=["animal", "MICRO"])
assert (
str(excinfo.value)
== "labels cannot contain 'MICRO' or 'MACRO' because they are used to capture aggregated metrics"
)


# def test_f1_show_as_markdown(documents, caplog):
# metric = F1Metric(layer="entities", labels=["animal", "company", "cat"], show_as_markdown=True)
# metric(documents)
# caplog.set_level(logging.INFO)
# caplog.clear()
# metric.compute()
# assert len(caplog.records) == 1
# assert (
# caplog.records[0].message == "\n"
# "entities:\n"
# "| | f1 | p | r |\n"
# "|:--------|------:|----:|------:|\n"
# "| MACRO | 0.556 | 0.5 | 0.667 |\n"
# "| MICRO | 0.75 | 0.6 | 1 |\n"
# "| animal | 1 | 1 | 1 |\n"
# "| company | 0.667 | 0.5 | 1 |\n"
# "| cat | 0 | 0 | 0 |"
# )
Empty file.
Loading

0 comments on commit 46bacb9

Please sign in to comment.