diff --git a/autointent/modules/prediction/_adaptive.py b/autointent/modules/prediction/_adaptive.py index 3eba03b1..ff503376 100644 --- a/autointent/modules/prediction/_adaptive.py +++ b/autointent/modules/prediction/_adaptive.py @@ -146,7 +146,11 @@ def dump(self, path: str) -> None: """ dump_dir = Path(path) - metadata = AdaptivePredictorDumpMetadata(r=self._r, tags=self.tags, n_classes=self.n_classes) + metadata = AdaptivePredictorDumpMetadata( + r=self._r, + tags=[t.model_dump() for t in self.tags] if self.tags else None, # type: ignore[misc] + n_classes=self.n_classes, + ) with (dump_dir / self.metadata_dict_name).open("w") as file: json.dump(metadata, file, indent=4) diff --git a/autointent/modules/prediction/_threshold.py b/autointent/modules/prediction/_threshold.py index 7b6e6a2d..757c7560 100644 --- a/autointent/modules/prediction/_threshold.py +++ b/autointent/modules/prediction/_threshold.py @@ -164,9 +164,11 @@ def dump(self, path: str) -> None: ) dump_dir = Path(path) + metadata_json = self.metadata + metadata_json["tags"] = [tag.model_dump() for tag in metadata_json["tags"]] if metadata_json["tags"] else None # type: ignore[misc] with (dump_dir / self.metadata_dict_name).open("w") as file: - json.dump(self.metadata, file, indent=4) + json.dump(metadata_json, file, indent=4) def load(self, path: str) -> None: """ diff --git a/autointent/modules/prediction/_tunable.py b/autointent/modules/prediction/_tunable.py index db8e3fb8..be07c26f 100644 --- a/autointent/modules/prediction/_tunable.py +++ b/autointent/modules/prediction/_tunable.py @@ -170,9 +170,11 @@ def dump(self, path: str) -> None: ) dump_dir = Path(path) + metadata_json = self.metadata + metadata_json["tags"] = [tag.model_dump() for tag in metadata_json["tags"]] if metadata_json["tags"] else None # type: ignore[misc] with (dump_dir / self.metadata_dict_name).open("w") as file: - json.dump(self.metadata, file, indent=4) + json.dump(metadata_json, file, indent=4) def load(self, path: str) -> None: """ @@ -183,9 +185,11 @@ def load(self, path: str) -> None: dump_dir = Path(path) with (dump_dir / self.metadata_dict_name).open() as file: - metadata: TunablePredictorDumpMetadata = json.load(file) + metadata = json.load(file) - self.metadata = metadata + metadata["tags"] = [Tag(**tag) for tag in metadata["tags"]] if metadata["tags"] else None + + self.metadata: TunablePredictorDumpMetadata = metadata self.thresh = np.array(metadata["thresh"]) self.multilabel = metadata["multilabel"] self.tags = metadata["tags"] diff --git a/autointent/modules/prediction/_utils.py b/autointent/modules/prediction/_utils.py index 539f4db7..edf32be2 100644 --- a/autointent/modules/prediction/_utils.py +++ b/autointent/modules/prediction/_utils.py @@ -21,23 +21,31 @@ def apply_tags(labels: npt.NDArray[Any], scores: npt.NDArray[Any], tags: list[Ta :param tags: List of `Tag` objects, where each tag specifies mutually exclusive intent IDs. :return: Adjusted array of shape (n_samples, n_classes) with binary labels. """ - n_samples, _ = labels.shape - res = np.copy(labels) - - for i in range(n_samples): - sample_labels = labels[i].astype(bool) - sample_scores = scores[i] - - for tag in tags: - if any(sample_labels[idx] for idx in tag.intent_ids): - # Find the index of the class with the highest score among the tagged indices - max_score_index = max(tag.intent_ids, key=lambda idx: sample_scores[idx]) - # Set all other tagged indices to 0 in the result - for idx in tag.intent_ids: - if idx != max_score_index: - res[i, idx] = 0 - - return res + labels = labels.copy() + + for tag in tags: + intent_ids = tag.intent_ids + + labels_sub = labels[:, intent_ids] + scores_sub = scores[:, intent_ids] + + assigned = labels_sub == 1 + num_assigned = assigned.sum(axis=1) + + assigned_scores = np.where(assigned, scores_sub, -np.inf) + + samples_to_adjust = np.where(num_assigned > 1)[0] + + if samples_to_adjust.size > 0: + assigned_scores_adjust = assigned_scores[samples_to_adjust, :] + idx_max_adjust = assigned_scores_adjust.argmax(axis=1) + + labels_sub[samples_to_adjust, :] = 0 + labels_sub[samples_to_adjust, idx_max_adjust] = 1 + + labels[:, intent_ids] = labels_sub + + return labels class WrongClassificationError(Exception): diff --git a/tests/assets/configs/multiclass.yaml b/tests/assets/configs/multiclass.yaml index 385455e0..e295c172 100644 --- a/tests/assets/configs/multiclass.yaml +++ b/tests/assets/configs/multiclass.yaml @@ -29,7 +29,7 @@ metric: prediction_accuracy search_space: - module_type: threshold - thresh: [0.5, [0.5, 0.5, 0.5]] + thresh: [0.5, [0.5, 0.5, 0.5, 0.5]] - module_type: tunable - module_type: argmax - module_type: jinoos \ No newline at end of file diff --git a/tests/assets/configs/multilabel.yaml b/tests/assets/configs/multilabel.yaml index b73ef952..d28e7d68 100644 --- a/tests/assets/configs/multilabel.yaml +++ b/tests/assets/configs/multilabel.yaml @@ -25,6 +25,6 @@ metric: prediction_accuracy search_space: - module_type: threshold - thresh: [0.5, [0.5, 0.5, 0.5]] + thresh: [0.5, [0.5, 0.5, 0.5, 0.5]] - module_type: tunable - module_type: adaptive \ No newline at end of file diff --git a/tests/assets/data/clinc_subset.json b/tests/assets/data/clinc_subset.json index b3967c57..80d90428 100644 --- a/tests/assets/data/clinc_subset.json +++ b/tests/assets/data/clinc_subset.json @@ -14,6 +14,14 @@ "id": 2, "name": "alarm", "description": "User wants to set or manage an alarm." + }, + { + "id": 3, + "name": "alarm reservation", + "tags": ["alarm", "reservation"], + "regexp_full_match": [], + "regexp_partial_match": [], + "description": "User wants to set or manage an alarm second time." } ], "train": [ @@ -138,20 +146,125 @@ "label": 2 }, { - "utterance": "how much is an overdraft fee for bank" + "utterance": "how much is an overdraft fee for bank", + "label": 3 + }, + { + "utterance": "where is the dipstick", + "label": 3 + }, + { + "utterance": "where is the dipstick", + "label": 3 + }, + { + "utterance": "where is the dipstick", + "label": 3 + }, + { + "utterance": "where is the dipstick", + "label": 3 + }, + { + "utterance": "where is the dipstick", + "label": 3 + }, + { + "utterance": "where is the dipstick", + "label": 3 + }, + { + "utterance": "how much is 1 share of aapl" + }, + { + "utterance": "how is glue made" + }, + { + "utterance": "how much is 1 share of aapl" + }, + { + "utterance": "how is glue made" + }, + { + "utterance": "how much is 1 share of aapl" + }, + { + "utterance": "how is glue made" + }, + { + "utterance": "how much is 1 share of aapl" + }, + { + "utterance": "how is glue made" + } + ], + "test": [ + { + "utterance": "can i make a reservation for redrobin", + "label": 0 + }, + { + "utterance": "does redrobin do reservations", + "label": 0 + }, + { + "utterance": "does acero in maplewood allow reservations", + "label": 0 + }, + { + "utterance": "i think my account is blocked", + "label": 1 + }, + { + "utterance": "why is my bank account stopping all transactions from going through", + "label": 1 + }, + { + "utterance": "what would cause me to be locked out of my bank account", + "label": 1 + }, + { + "utterance": "find out the reason why am i locked out of my bank account", + "label": 1 + }, + { + "utterance": "make sure my alarm is set for three thirty in the morning", + "label": 2 + }, + { + "utterance": "please set an alarm for mid day", + "label": 2 + }, + { + "utterance": "have an alarm set for three in the morning", + "label": 2 + }, + { + "utterance": "set an alarm for me for 10:00 and another one set for 4:00", + "label": 2 }, { - "utterance": "why are exponents preformed before multiplication in the order of operations" + "utterance": "set an alarm to go to sleep and another to wake up", + "label": 2 }, { - "utterance": "what size wipers does this car take" + "utterance": "how much is an overdraft fee for bank", + "label": 3 }, { - "utterance": "where is the dipstick" + "utterance": "why are exponents preformed before multiplication in the order of operations", + "label": 3 + }, + { + "utterance": "what size wipers does this car take", + "label": 3 }, { "utterance": "how much is 1 share of aapl" }, + { + "utterance": "how is glue made" + }, { "utterance": "how is glue made" } diff --git a/tests/context/datahandler/test_stratificaiton.py b/tests/context/datahandler/test_stratificaiton.py index 0b859260..2bf4cd0b 100644 --- a/tests/context/datahandler/test_stratificaiton.py +++ b/tests/context/datahandler/test_stratificaiton.py @@ -12,8 +12,8 @@ def test_train_test_split(dataset): assert Split.TRAIN in dataset assert Split.TEST in dataset - assert dataset[Split.TRAIN].num_rows == 24 - assert dataset[Split.TEST].num_rows == 6 + assert dataset[Split.TRAIN].num_rows == 29 + assert dataset[Split.TEST].num_rows == 8 assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST) @@ -28,6 +28,6 @@ def test_multilabel_train_test_split(dataset): assert Split.TRAIN in dataset assert Split.TEST in dataset - assert dataset[Split.TRAIN].num_rows == 24 - assert dataset[Split.TEST].num_rows == 6 + assert dataset[Split.TRAIN].num_rows == 30 + assert dataset[Split.TEST].num_rows == 7 assert dataset.get_n_classes(Split.TRAIN) == dataset.get_n_classes(Split.TEST) diff --git a/tests/modules/prediction/conftest.py b/tests/modules/prediction/conftest.py index 4783a0f3..00b24572 100644 --- a/tests/modules/prediction/conftest.py +++ b/tests/modules/prediction/conftest.py @@ -1,3 +1,4 @@ +import numpy as np import pytest from autointent.context.data_handler import DataHandler @@ -43,3 +44,8 @@ def multilabel_fit_data(dataset): scores = scorer.predict(data_handler.validation_utterances(1) + data_handler.oos_utterances(1)) labels = data_handler.validation_labels(1) + [[0] * data_handler.n_classes] * len(data_handler.oos_utterances(1)) return scores, labels + + +@pytest.fixture +def scores(): + return np.array([[0.05, 0.9, 0, 0.05], [0.8, 0, 0.1, 0.1], [0, 0.2, 0.7, 0.1]]) diff --git a/tests/modules/prediction/test_adaptive.py b/tests/modules/prediction/test_adaptive.py index 258f480e..b77b7a61 100644 --- a/tests/modules/prediction/test_adaptive.py +++ b/tests/modules/prediction/test_adaptive.py @@ -8,9 +8,9 @@ def test_multilabel(multilabel_fit_data): predictor = AdaptivePredictor() predictor.fit(*multilabel_fit_data) - scores = np.array([[0.2, 0.9, 0], [0.8, 0, 0.6], [0, 0.4, 0.7]]) + scores = np.array([[0.2, 0.9, 0, 0], [0.8, 0, 0.6, 0], [0, 0.4, 0.7, 0]]) predictions = predictor.predict(scores) - desired = np.array([[0, 1, 0], [1, 0, 1], [0, 1, 1]]) + desired = np.array([[0, 1, 0, 0], [1, 0, 1, 0], [0, 1, 1, 0]]) np.testing.assert_array_equal(predictions, desired) diff --git a/tests/modules/prediction/test_argmax.py b/tests/modules/prediction/test_argmax.py index 0b39e5f1..ab7bbc3e 100644 --- a/tests/modules/prediction/test_argmax.py +++ b/tests/modules/prediction/test_argmax.py @@ -5,10 +5,9 @@ from autointent.modules.prediction._utils import InvalidNumClassesError, WrongClassificationError -def test_multiclass(multiclass_fit_data): +def test_multiclass(multiclass_fit_data, scores): predictor = ArgmaxPredictor() predictor.fit(*multiclass_fit_data) - scores = np.array([[0.1, 0.9, 0], [0.8, 0, 0.2], [0, 0.3, 0.7]]) predictions = predictor.predict(scores) np.testing.assert_array_equal(predictions, np.array([1, 0, 2])) diff --git a/tests/modules/prediction/test_jinoos.py b/tests/modules/prediction/test_jinoos.py index e2934e06..8dd2b884 100644 --- a/tests/modules/prediction/test_jinoos.py +++ b/tests/modules/prediction/test_jinoos.py @@ -18,11 +18,9 @@ def detect_oos(scores: npt.NDArray[Any], labels: npt.NDArray[Any], thresh: float return labels -def test_predict_returns_correct_indices(multiclass_fit_data): +def test_predict_returns_correct_indices(multiclass_fit_data, scores): predictor = JinoosPredictor() predictor.fit(*multiclass_fit_data) - scores = np.array([[0.1, 0.9, 0], [0.8, 0, 0.2], [0, 0.3, 0.7]]) - # inference predictions = predictor.predict(scores) desired = detect_oos(scores, np.array([1, 0, 2]), predictor.thresh) diff --git a/tests/modules/prediction/test_threshold.py b/tests/modules/prediction/test_threshold.py index a394b321..96c2f803 100644 --- a/tests/modules/prediction/test_threshold.py +++ b/tests/modules/prediction/test_threshold.py @@ -5,36 +5,26 @@ from autointent.modules.prediction._utils import InvalidNumClassesError -def test_multiclass(multiclass_fit_data): - predictor = ThresholdPredictor(0.5) - predictor.fit(*multiclass_fit_data) - scores = np.array([[0.1, 0.9, 0], [0.8, 0, 0.2], [0, 0.3, 0.7]]) - predictions = predictor.predict(scores) - np.testing.assert_array_equal(predictions, np.array([1, 0, 2])) - - -def test_multilabel(multilabel_fit_data): - predictor = ThresholdPredictor(thresh=0.5) - predictor.fit(*multilabel_fit_data) - scores = np.array([[0.2, 0.9, 0], [0.8, 0, 0.6], [0, 0.4, 0.7]]) - predictions = predictor.predict(scores) - np.testing.assert_array_equal(predictions, np.array([[0, 1, 0], [1, 0, 1], [0, 0, 1]])) - - -def test_multiclass_list(multiclass_fit_data): - predictor = ThresholdPredictor(np.array([0.5, 0.5, 0.8])) - predictor.fit(*multiclass_fit_data) - scores = np.array([[0.1, 0.9, 0], [0.8, 0, 0.2], [0, 0.3, 0.7]]) - predictions = predictor.predict(scores) - np.testing.assert_array_equal(predictions, np.array([1, 0, -1])) - - -def test_multilabel_list(multilabel_fit_data): - predictor = ThresholdPredictor(np.array([0.5, 0.5, 0.8])) - predictor.fit(*multilabel_fit_data) - scores = np.array([[0.1, 0.9, 0], [0.8, 0, 0.2], [0, 0.3, 0.7]]) +@pytest.mark.parametrize( + ("fit_fixture", "threshold", "expected"), + [ + # Multiclass with a single scalar threshold + ("multiclass_fit_data", 0.5, np.array([1, 0, 2])), + # Multilabel with a single scalar threshold + ("multilabel_fit_data", 0.5, np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0]])), + # Multiclass with an array of thresholds + ("multiclass_fit_data", np.array([0.5, 0.5, 0.8, 0.5]), np.array([1, 0, -1])), + # Multilabel with an array of thresholds + ("multilabel_fit_data", np.array([0.5, 0.5, 0.8, 0.5]), np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 0, 0]])), + ], +) +def test_predict(fit_fixture, threshold, expected, request, scores): + fit_data = request.getfixturevalue(fit_fixture) + + predictor = ThresholdPredictor(threshold) + predictor.fit(*fit_data) predictions = predictor.predict(scores) - np.testing.assert_array_equal(predictions, np.array([[0, 1, 0], [1, 0, 0], [0, 0, 0]])) + np.testing.assert_array_equal(predictions, expected) def test_fails_on_wrong_n_classes_predict(multiclass_fit_data): diff --git a/tests/modules/prediction/test_tunable.py b/tests/modules/prediction/test_tunable.py index 03308695..c5dedf6f 100644 --- a/tests/modules/prediction/test_tunable.py +++ b/tests/modules/prediction/test_tunable.py @@ -5,23 +5,28 @@ from autointent.modules.prediction._utils import InvalidNumClassesError -def test_multiclass(multiclass_fit_data): - predictor = TunablePredictor() - predictor.fit(*multiclass_fit_data) - scores = np.array([[0.1, 0.9, 0], [0.8, 0, 0.2], [0, 0.3, 0.7]]) - - predictions = predictor.predict(scores) - desired = np.array([1, 0, 2]) - - np.testing.assert_array_equal(predictions, desired) - +@pytest.mark.parametrize( + ("fixture_name", "scores", "desired"), + [ + ( + "multiclass_fit_data", + np.array([[0.1, 0.9, 0, 0.5], [0.8, 0, 0.2, 0.5], [0, 0.3, 0.7, 0.5]]), + np.array([1, 0, 2]), + ), + ( + "multilabel_fit_data", + np.array([[0.1, 0.9, 0, 0.1], [0.8, 0, 0.1, 0.1], [0, 0.2, 0.7, 0.1]]), + np.array([[0, 1, 0, 0], [1, 0, 0, 0], [0, 0, 1, 0]]), + ), + ], +) +def test_predict_scenarios(request, fixture_name, scores, desired): + # Dynamically obtain fixture data + fit_data = request.getfixturevalue(fixture_name) -def test_multilabel(multilabel_fit_data): predictor = TunablePredictor() - predictor.fit(*multilabel_fit_data) - scores = np.array([[0.2, 0.9, 0], [0.8, 0, 0.6], [0, 0.4, 0.7]]) + predictor.fit(*fit_data) predictions = predictor.predict(scores) - desired = np.array([[0, 1, 0], [0, 0, 1], [0, 0, 1]]) np.testing.assert_array_equal(predictions, desired) diff --git a/tests/modules/prediction/test_utils.py b/tests/modules/prediction/test_utils.py new file mode 100644 index 00000000..50971c1b --- /dev/null +++ b/tests/modules/prediction/test_utils.py @@ -0,0 +1,88 @@ +import numpy as np +import pytest + +from autointent.context.data_handler import Tag +from autointent.modules.prediction._utils import apply_tags + + +def sample_data(): + labels = np.array([[1, 1, 0], [0, 1, 1]]) + scores = np.array([[0.9, 0.8, 0.1], [0.1, 0.8, 0.7]]) + tags = [Tag(name="mutual_exclusive", intent_ids=[0, 1])] + expected_labels = np.array([[1, 0, 0], [0, 1, 1]]) + return labels, scores, tags, expected_labels + + +def no_conflict_data(): + labels = np.array([[1, 0, 0], [0, 1, 0]]) + scores = np.array([[0.9, 0.2, 0.1], [0.1, 0.8, 0.3]]) + tags = [Tag(name="mutual_exclusive", intent_ids=[0, 1])] + expected_labels = labels.copy() + return labels, scores, tags, expected_labels + + +def multiple_tags_data(): + labels = np.array([[1, 1, 1, 0], [1, 0, 1, 1]]) + scores = np.array([[0.9, 0.8, 0.7, 0.6], [0.95, 0.85, 0.9, 0.8]]) + tags = [Tag(name="tag1", intent_ids=[0, 1]), Tag(name="tag2", intent_ids=[2, 3])] + expected_labels = np.array([[1, 0, 1, 0], [1, 0, 1, 0]]) + return labels, scores, tags, expected_labels + + +# Parametrized test function +@pytest.mark.parametrize( + ("labels", "scores", "tags", "expected_labels"), + [ + # Test case: No tags provided (no conflict) + ( + np.array([[1, 0, 0], [0, 1, 1]]), + np.array([[0.9, 0.2, 0.1], [0.1, 0.8, 0.7]]), + [], + np.array([[1, 0, 0], [0, 1, 1]]), + ), + # Test case: Single tag, no conflict + no_conflict_data(), + # Test case: Single tag with conflict + sample_data(), + # Test case: Multiple tags with conflicts + multiple_tags_data(), + # Test case: All intents conflict + ( + np.array([[1, 1, 1], [1, 1, 1]]), + np.array([[0.9, 0.85, 0.8], [0.95, 0.9, 0.88]]), + [Tag(name="all_conflict", intent_ids=[0, 1, 2])], + np.array([[1, 0, 0], [1, 0, 0]]), + ), + # Test case: No assigned intents + ( + np.array([[0, 0, 0], [0, 0, 0]]), + np.array([[0.1, 0.2, 0.3], [0.05, 0.1, 0.15]]), + [Tag(name="tag1", intent_ids=[0, 1])], + np.array([[0, 0, 0], [0, 0, 0]]), + ), + # Test case: Partial conflict + ( + np.array([[1, 1, 0], [0, 1, 1]]), + np.array([[0.7, 0.9, 0.1], [0.1, 0.8, 0.85]]), + [Tag(name="tag1", intent_ids=[0, 1]), Tag(name="tag2", intent_ids=[1, 2])], + np.array([[0, 1, 0], [0, 0, 1]]), + ), + # Test case: Overlapping tags + ( + np.array([[1, 1, 1], [1, 1, 0]]), + np.array([[0.9, 0.85, 0.8], [0.95, 0.9, 0.88]]), + [Tag(name="tag1", intent_ids=[0, 1]), Tag(name="tag2", intent_ids=[1, 2])], + np.array([[1, 0, 1], [1, 0, 0]]), + ), + # Test case: Conflict with same scores + ( + np.array([[1, 1], [1, 1]]), + np.array([[0.8, 0.8], [0.9, 0.9]]), + [Tag(name="tag1", intent_ids=[0, 1])], + np.array([[1, 0], [1, 0]]), + ), + ], +) +def test_apply_tags(labels, scores, tags, expected_labels): + adjusted_labels = apply_tags(labels, scores, tags) + np.testing.assert_array_equal(adjusted_labels, expected_labels) diff --git a/tests/modules/scoring/test_description.py b/tests/modules/scoring/test_description.py index be6b2773..40210dbb 100644 --- a/tests/modules/scoring/test_description.py +++ b/tests/modules/scoring/test_description.py @@ -9,8 +9,8 @@ @pytest.mark.parametrize( ("expected_prediction", "multilabel"), [ - ([[0.9, 0.9, 0.9], [0.9, 0.9, 0.9]], True), - ([[0.2, 0.3, 0.2], [0.2, 0.3, 0.2]], False), + ([[0.9, 0.9, 0.9, 0.9], [0.9, 0.9, 0.9, 0.9]], True), + ([[0.2, 0.3, 0.2, 0.2], [0.2, 0.3, 0.2, 0.2]], False), ], ) def test_description_scorer(dataset, expected_prediction, multilabel): diff --git a/tests/modules/scoring/test_dnnc.py b/tests/modules/scoring/test_dnnc.py index be0a1361..a6b50065 100644 --- a/tests/modules/scoring/test_dnnc.py +++ b/tests/modules/scoring/test_dnnc.py @@ -30,7 +30,7 @@ def test_base_dnnc(dataset, train_head, pred_score): "can you tell me why is my bank account frozen", ] predictions = scorer.predict(test_data) - np.testing.assert_almost_equal(np.array([[0.0, pred_score, 0.0]] * len(test_data)), predictions, decimal=0.5) + np.testing.assert_almost_equal(np.array([[0.0, pred_score, 0.0, 0.0]] * len(test_data)), predictions, decimal=0.5) predictions, metadata = scorer.predict_with_metadata(test_data) assert len(predictions) == len(test_data) diff --git a/tests/modules/scoring/test_knn.py b/tests/modules/scoring/test_knn.py index 3617d287..56a7139a 100644 --- a/tests/modules/scoring/test_knn.py +++ b/tests/modules/scoring/test_knn.py @@ -25,7 +25,16 @@ def test_base_knn(dataset): scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) predictions = scorer.predict(test_data) assert ( - predictions == np.array([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]) + predictions + == np.array( + [ + [0.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + ] + ) ).all() predictions, metadata = scorer.predict_with_metadata(test_data) diff --git a/tests/modules/scoring/test_linear.py b/tests/modules/scoring/test_linear.py index 26070019..74ffe41b 100644 --- a/tests/modules/scoring/test_linear.py +++ b/tests/modules/scoring/test_linear.py @@ -24,12 +24,32 @@ def test_base_linear(dataset): np.testing.assert_almost_equal( np.array( [ - [0.33332719, 0.33334283, 0.33332997], - [0.33332507, 0.33334446, 0.33333046], - [0.33332806, 0.33334067, 0.33333127], - [0.33332788, 0.33334159, 0.33333053], - [0.33332806, 0.33334418, 0.33332775], - ], + [ + 0.01828613, + 0.93842264, + 0.02633502, + 0.01695622, + ], + [0.02662749, 0.89566195, 0.05008801, 0.02762255], + [ + 0.08131153, + 0.79191015, + 0.07896874, + 0.04780958, + ], + [ + 0.08382678, + 0.77043132, + 0.0826499, + 0.063092, + ], + [ + 0.01482186, + 0.9699848, + 0.00757169, + 0.00762165, + ], + ] ), predictions, decimal=2, diff --git a/tests/modules/scoring/test_mlknn.py b/tests/modules/scoring/test_mlknn.py index 8b16e991..97b064bb 100644 --- a/tests/modules/scoring/test_mlknn.py +++ b/tests/modules/scoring/test_mlknn.py @@ -18,7 +18,7 @@ def test_base_mlknn(dataset): }, { "utterance": "i am nost sure why my account is blocked", - "label": [0, 2], + "label": [0, 3], }, ], ) @@ -37,7 +37,23 @@ def test_base_mlknn(dataset): ] predictions = scorer.predict_labels(test_data) - assert (predictions == np.array([[0, 1, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0], [0, 1, 0]])).all() + assert ( + predictions + == np.array( + [ + [ + 0, + 1, + 0, + 0, + ], + [0, 1, 0, 0], + [0, 1, 0, 0], + [0, 1, 0, 0], + [0, 1, 0, 0], + ] + ) + ).all() predictions, metadata = scorer.predict_with_metadata(test_data) assert len(predictions) == len(test_data) diff --git a/tests/modules/scoring/test_rerank_scorer.py b/tests/modules/scoring/test_rerank_scorer.py index 5a0e8351..28e93569 100644 --- a/tests/modules/scoring/test_rerank_scorer.py +++ b/tests/modules/scoring/test_rerank_scorer.py @@ -31,7 +31,16 @@ def test_base_rerank_scorer(dataset): scorer.fit(data_handler.train_utterances(0), data_handler.train_labels(0)) predictions = scorer.predict(test_data) assert ( - predictions == np.array([[0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0]]) + predictions + == np.array( + [ + [0.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + [0.0, 1.0, 0.0, 0.0], + ] + ) ).all() predictions, metadata = scorer.predict_with_metadata(test_data) diff --git a/tests/nodes/test_predicton.py b/tests/nodes/test_predicton.py index 357f07ba..76484fde 100644 --- a/tests/nodes/test_predicton.py +++ b/tests/nodes/test_predicton.py @@ -41,7 +41,9 @@ def test_prediction_multiclass(scoring_optimizer_multiclass): load_path=trial.module_dump_dir, ) node = InferenceNode.from_config(config) - node.module.predict(np.array([[0.27486506, 0.31681463, 0.37459106], [0.2769358, 0.31536099, 0.37366978]])) + node.module.predict( + np.array([[0.27486506, 0.31681463, 0.37459106, 0.532], [0.2769358, 0.31536099, 0.37366978, 0.532]]) + ) node.module.clear_cache() gc.collect() torch.cuda.empty_cache() @@ -73,7 +75,11 @@ def test_prediction_multilabel(scoring_optimizer_multilabel): load_path=trial.module_dump_dir, ) node = InferenceNode.from_config(config) - node.module.predict(np.array([[0.27486506, 0.31681463, 0.37459106], [0.2769358, 0.31536099, 0.37366978]])) + node.module.predict( + np.array( + [[0.27486506, 0.31681463, 0.37459106, 0.37459106], [0.2769358, 0.31536099, 0.37366978, 0.37459106]] + ) + ) node.module.clear_cache() gc.collect() torch.cuda.empty_cache()