diff --git a/autointent/_pipeline/_pipeline.py b/autointent/_pipeline/_pipeline.py index 200eb6f9..be30963b 100644 --- a/autointent/_pipeline/_pipeline.py +++ b/autointent/_pipeline/_pipeline.py @@ -135,7 +135,8 @@ def fit(self, dataset: Dataset, force_multilabel: bool = False) -> Context: predictions = self.predict(context.data_handler.test_utterances()) for metric_name, metric in PREDICTION_METRICS_MULTILABEL.items(): context.optimization_info.pipeline_metrics[metric_name] = metric( - context.data_handler.test_labels(), predictions, + context.data_handler.test_labels(), + predictions, ) return context diff --git a/autointent/metrics/scoring.py b/autointent/metrics/scoring.py index c16f7361..727d75b9 100644 --- a/autointent/metrics/scoring.py +++ b/autointent/metrics/scoring.py @@ -247,7 +247,7 @@ def scoring_neg_coverage(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) - def scoring_neg_ranking_loss(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float: """ - supports multilabel. + Supports multilabel. Compute the average number of label pairs that are incorrectly ordered given y_score weighted by the size of the label set and the number of labels not in the label set. diff --git a/tests/context/datahandler/test_data_handler.py b/tests/context/datahandler/test_data_handler.py index e6cec8b4..0e6811ce 100644 --- a/tests/context/datahandler/test_data_handler.py +++ b/tests/context/datahandler/test_data_handler.py @@ -129,7 +129,7 @@ def test_sample_validation(label): "validation_1": mock_split(), "test": mock_split(), }, - ] + ], ) def test_dataset_initialization(mapping): dataset = Dataset.from_dict(mapping) @@ -151,8 +151,8 @@ def test_dataset_initialization(mapping): {"train": mock_split(), "validation": mock_split(), "validation_0": mock_split()}, {"train": mock_split(), "validation": mock_split(), "validation_1": mock_split()}, {"train": mock_split(), "validation": mock_split(), "validation_0": mock_split(), "validation_1": mock_split()}, - {"train": mock_split(), "oos": mock_split()} - ] + {"train": mock_split(), "oos": mock_split()}, + ], ) def test_dataset_validation(mapping): with pytest.raises(ValueError): @@ -169,7 +169,7 @@ def test_dataset_validation(mapping): "test": [{"utterance": "Hello!", "label": 0}], }, {"train": [{"utterance": "Hello!"}]}, - ] + ], ) def test_intents_validation(mapping): with pytest.raises(ValueError):