Skip to content

Commit

Permalink
Add test for tags (#67)
Browse files Browse the repository at this point in the history
  • Loading branch information
Samoed authored Dec 8, 2024
1 parent 0a70364 commit ad69779
Show file tree
Hide file tree
Showing 22 changed files with 374 additions and 97 deletions.
6 changes: 5 additions & 1 deletion autointent/modules/prediction/_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,11 @@ def dump(self, path: str) -> None:
"""
dump_dir = Path(path)

metadata = AdaptivePredictorDumpMetadata(r=self._r, tags=self.tags, n_classes=self.n_classes)
metadata = AdaptivePredictorDumpMetadata(
r=self._r,
tags=[t.model_dump() for t in self.tags] if self.tags else None, # type: ignore[misc]
n_classes=self.n_classes,
)

with (dump_dir / self.metadata_dict_name).open("w") as file:
json.dump(metadata, file, indent=4)
Expand Down
4 changes: 3 additions & 1 deletion autointent/modules/prediction/_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,11 @@ def dump(self, path: str) -> None:
)

dump_dir = Path(path)
metadata_json = self.metadata
metadata_json["tags"] = [tag.model_dump() for tag in metadata_json["tags"]] if metadata_json["tags"] else None # type: ignore[misc]

with (dump_dir / self.metadata_dict_name).open("w") as file:
json.dump(self.metadata, file, indent=4)
json.dump(metadata_json, file, indent=4)

def load(self, path: str) -> None:
"""
Expand Down
10 changes: 7 additions & 3 deletions autointent/modules/prediction/_tunable.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,11 @@ def dump(self, path: str) -> None:
)

dump_dir = Path(path)
metadata_json = self.metadata
metadata_json["tags"] = [tag.model_dump() for tag in metadata_json["tags"]] if metadata_json["tags"] else None # type: ignore[misc]

with (dump_dir / self.metadata_dict_name).open("w") as file:
json.dump(self.metadata, file, indent=4)
json.dump(metadata_json, file, indent=4)

def load(self, path: str) -> None:
"""
Expand All @@ -183,9 +185,11 @@ def load(self, path: str) -> None:
dump_dir = Path(path)

with (dump_dir / self.metadata_dict_name).open() as file:
metadata: TunablePredictorDumpMetadata = json.load(file)
metadata = json.load(file)

self.metadata = metadata
metadata["tags"] = [Tag(**tag) for tag in metadata["tags"]] if metadata["tags"] else None

self.metadata: TunablePredictorDumpMetadata = metadata
self.thresh = np.array(metadata["thresh"])
self.multilabel = metadata["multilabel"]
self.tags = metadata["tags"]
Expand Down
42 changes: 25 additions & 17 deletions autointent/modules/prediction/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,31 @@ def apply_tags(labels: npt.NDArray[Any], scores: npt.NDArray[Any], tags: list[Ta
:param tags: List of `Tag` objects, where each tag specifies mutually exclusive intent IDs.
:return: Adjusted array of shape (n_samples, n_classes) with binary labels.
"""
n_samples, _ = labels.shape
res = np.copy(labels)

for i in range(n_samples):
sample_labels = labels[i].astype(bool)
sample_scores = scores[i]

for tag in tags:
if any(sample_labels[idx] for idx in tag.intent_ids):
# Find the index of the class with the highest score among the tagged indices
max_score_index = max(tag.intent_ids, key=lambda idx: sample_scores[idx])
# Set all other tagged indices to 0 in the result
for idx in tag.intent_ids:
if idx != max_score_index:
res[i, idx] = 0

return res
labels = labels.copy()

for tag in tags:
intent_ids = tag.intent_ids

labels_sub = labels[:, intent_ids]
scores_sub = scores[:, intent_ids]

assigned = labels_sub == 1
num_assigned = assigned.sum(axis=1)

assigned_scores = np.where(assigned, scores_sub, -np.inf)

samples_to_adjust = np.where(num_assigned > 1)[0]

if samples_to_adjust.size > 0:
assigned_scores_adjust = assigned_scores[samples_to_adjust, :]
idx_max_adjust = assigned_scores_adjust.argmax(axis=1)

labels_sub[samples_to_adjust, :] = 0
labels_sub[samples_to_adjust, idx_max_adjust] = 1

labels[:, intent_ids] = labels_sub

return labels


class WrongClassificationError(Exception):
Expand Down
2 changes: 1 addition & 1 deletion tests/assets/configs/multiclass.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
metric: prediction_accuracy
search_space:
- module_type: threshold
thresh: [0.5, [0.5, 0.5, 0.5]]
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]
- module_type: tunable
- module_type: argmax
- module_type: jinoos
2 changes: 1 addition & 1 deletion tests/assets/configs/multilabel.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
metric: prediction_accuracy
search_space:
- module_type: threshold
thresh: [0.5, [0.5, 0.5, 0.5]]
thresh: [0.5, [0.5, 0.5, 0.5, 0.5]]
- module_type: tunable
- module_type: adaptive
121 changes: 117 additions & 4 deletions tests/assets/data/clinc_subset.json
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,14 @@
"id": 2,
"name": "alarm",
"description": "User wants to set or manage an alarm."
},
{
"id": 3,
"name": "alarm reservation",
"tags": ["alarm", "reservation"],
"regexp_full_match": [],
"regexp_partial_match": [],
"description": "User wants to set or manage an alarm second time."
}
],
"train": [
Expand Down Expand Up @@ -138,20 +146,125 @@
"label": 2
},
{
"utterance": "how much is an overdraft fee for bank"
"utterance": "how much is an overdraft fee for bank",
"label": 3
},
{
"utterance": "where is the dipstick",
"label": 3
},
{
"utterance": "where is the dipstick",
"label": 3
},
{
"utterance": "where is the dipstick",
"label": 3
},
{
"utterance": "where is the dipstick",
"label": 3
},
{
"utterance": "where is the dipstick",
"label": 3
},
{
"utterance": "where is the dipstick",
"label": 3
},
{
"utterance": "how much is 1 share of aapl"
},
{
"utterance": "how is glue made"
},
{
"utterance": "how much is 1 share of aapl"
},
{
"utterance": "how is glue made"
},
{
"utterance": "how much is 1 share of aapl"
},
{
"utterance": "how is glue made"
},
{
"utterance": "how much is 1 share of aapl"
},
{
"utterance": "how is glue made"
}
],
"test": [
{
"utterance": "can i make a reservation for redrobin",
"label": 0
},
{
"utterance": "does redrobin do reservations",
"label": 0
},
{
"utterance": "does acero in maplewood allow reservations",
"label": 0
},
{
"utterance": "i think my account is blocked",
"label": 1
},
{
"utterance": "why is my bank account stopping all transactions from going through",
"label": 1
},
{
"utterance": "what would cause me to be locked out of my bank account",
"label": 1
},
{
"utterance": "find out the reason why am i locked out of my bank account",
"label": 1
},
{
"utterance": "make sure my alarm is set for three thirty in the morning",
"label": 2
},
{
"utterance": "please set an alarm for mid day",
"label": 2
},
{
"utterance": "have an alarm set for three in the morning",
"label": 2
},
{
"utterance": "set an alarm for me for 10:00 and another one set for 4:00",
"label": 2
},
{
"utterance": "why are exponents preformed before multiplication in the order of operations"
"utterance": "set an alarm to go to sleep and another to wake up",
"label": 2
},
{
"utterance": "what size wipers does this car take"
"utterance": "how much is an overdraft fee for bank",
"label": 3
},
{
"utterance": "where is the dipstick"
"utterance": "why are exponents preformed before multiplication in the order of operations",
"label": 3
},
{
"utterance": "what size wipers does this car take",
"label": 3
},
{
"utterance": "how much is 1 share of aapl"
},
{
"utterance": "how is glue made"
},
{
"utterance": "how is glue made"
}
Expand Down
8 changes: 4 additions & 4 deletions tests/context/datahandler/test_stratificaiton.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ def test_train_test_split(dataset):

assert Split.TRAIN in dataset
assert Split.TEST in dataset
assert dataset[Split.TRAIN].num_rows == 24
assert dataset[Split.TEST].num_rows == 6
assert dataset[Split.TRAIN].num_rows == 29
assert dataset[Split.TEST].num_rows == 8
assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST)


Expand All @@ -28,6 +28,6 @@ def test_multilabel_train_test_split(dataset):

assert Split.TRAIN in dataset
assert Split.TEST in dataset
assert dataset[Split.TRAIN].num_rows == 24
assert dataset[Split.TEST].num_rows == 6
assert dataset[Split.TRAIN].num_rows == 30
assert dataset[Split.TEST].num_rows == 7
assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST)
6 changes: 6 additions & 0 deletions tests/modules/prediction/conftest.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest

from autointent.context.data_handler import DataHandler
Expand Down Expand Up @@ -43,3 +44,8 @@ def multilabel_fit_data(dataset):
scores = scorer.predict(data_handler.validation_utterances(1) + data_handler.oos_utterances(1))
labels = data_handler.validation_labels(1) + [[0] * data_handler.n_classes] * len(data_handler.oos_utterances(1))
return scores, labels


@pytest.fixture
def scores():
return np.array([[0.05, 0.9, 0, 0.05], [0.8, 0, 0.1, 0.1], [0, 0.2, 0.7, 0.1]])
4 changes: 2 additions & 2 deletions tests/modules/prediction/test_adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@
def test_multilabel(multilabel_fit_data):
predictor = AdaptivePredictor()
predictor.fit(*multilabel_fit_data)
scores = np.array([[0.2, 0.9, 0], [0.8, 0, 0.6], [0, 0.4, 0.7]])
scores = np.array([[0.2, 0.9, 0, 0], [0.8, 0, 0.6, 0], [0, 0.4, 0.7, 0]])
predictions = predictor.predict(scores)
desired = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 1]])
desired = np.array([[0, 1, 0, 0], [1, 0, 1, 0], [0, 1, 1, 0]])

np.testing.assert_array_equal(predictions, desired)

Expand Down
3 changes: 1 addition & 2 deletions tests/modules/prediction/test_argmax.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from autointent.modules.prediction._utils import InvalidNumClassesError, WrongClassificationError


def test_multiclass(multiclass_fit_data):
def test_multiclass(multiclass_fit_data, scores):
predictor = ArgmaxPredictor()
predictor.fit(*multiclass_fit_data)
scores = np.array([[0.1, 0.9, 0], [0.8, 0, 0.2], [0, 0.3, 0.7]])
predictions = predictor.predict(scores)
np.testing.assert_array_equal(predictions, np.array([1, 0, 2]))

Expand Down
4 changes: 1 addition & 3 deletions tests/modules/prediction/test_jinoos.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@ def detect_oos(scores: npt.NDArray[Any], labels: npt.NDArray[Any], thresh: float
return labels


def test_predict_returns_correct_indices(multiclass_fit_data):
def test_predict_returns_correct_indices(multiclass_fit_data, scores):
predictor = JinoosPredictor()
predictor.fit(*multiclass_fit_data)
scores = np.array([[0.1, 0.9, 0], [0.8, 0, 0.2], [0, 0.3, 0.7]])

# inference
predictions = predictor.predict(scores)
desired = detect_oos(scores, np.array([1, 0, 2]), predictor.thresh)
Expand Down
48 changes: 19 additions & 29 deletions tests/modules/prediction/test_threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,26 @@
from autointent.modules.prediction._utils import InvalidNumClassesError


def test_multiclass(multiclass_fit_data):
predictor = ThresholdPredictor(0.5)
predictor.fit(*multiclass_fit_data)
scores = np.array([[0.1, 0.9, 0], [0.8, 0, 0.2], [0, 0.3, 0.7]])
predictions = predictor.predict(scores)
np.testing.assert_array_equal(predictions, np.array([1, 0, 2]))


def test_multilabel(multilabel_fit_data):
predictor = ThresholdPredictor(thresh=0.5)
predictor.fit(*multilabel_fit_data)
scores = np.array([[0.2, 0.9, 0], [0.8, 0, 0.6], [0, 0.4, 0.7]])
predictions = predictor.predict(scores)
np.testing.assert_array_equal(predictions, np.array([[0, 1, 0], [1, 0, 1], [0, 0, 1]]))


def test_multiclass_list(multiclass_fit_data):
predictor = ThresholdPredictor(np.array([0.5, 0.5, 0.8]))
predictor.fit(*multiclass_fit_data)
scores = np.array([[0.1, 0.9, 0], [0.8, 0, 0.2], [0, 0.3, 0.7]])
predictions = predictor.predict(scores)
np.testing.assert_array_equal(predictions, np.array([1, 0, -1]))


def test_multilabel_list(multilabel_fit_data):
predictor = ThresholdPredictor(np.array([0.5, 0.5, 0.8]))
predictor.fit(*multilabel_fit_data)
scores = np.array([[0.1, 0.9, 0], [0.8, 0, 0.2], [0, 0.3, 0.7]])
@pytest.mark.parametrize(
("fit_fixture", "threshold", "expected"),
[
# Multiclass with a single scalar threshold
("multiclass_fit_data", 0.5, np.array([1, 0, 2])),
# Multilabel with a single scalar threshold
("multilabel_fit_data", 0.5, np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0]])),
# Multiclass with an array of thresholds
("multiclass_fit_data", np.array([0.5, 0.5, 0.8, 0.5]), np.array([1, 0, -1])),
# Multilabel with an array of thresholds
("multilabel_fit_data", np.array([0.5, 0.5, 0.8, 0.5]), np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 0]])),
],
)
def test_predict(fit_fixture, threshold, expected, request, scores):
fit_data = request.getfixturevalue(fit_fixture)

predictor = ThresholdPredictor(threshold)
predictor.fit(*fit_data)
predictions = predictor.predict(scores)
np.testing.assert_array_equal(predictions, np.array([[0, 1, 0], [1, 0, 0], [0, 0, 0]]))
np.testing.assert_array_equal(predictions, expected)


def test_fails_on_wrong_n_classes_predict(multiclass_fit_data):
Expand Down
Loading

0 comments on commit ad69779

Please sign in to comment.