Skip to content

Commit

Permalink
lint
Browse files Browse the repository at this point in the history
  • Loading branch information
Samoed committed Dec 20, 2024
1 parent e4985b2 commit 3dcb1c0
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 6 deletions.
3 changes: 2 additions & 1 deletion autointent/_pipeline/_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def fit(self, dataset: Dataset, force_multilabel: bool = False) -> Context:
predictions = self.predict(context.data_handler.test_utterances())
for metric_name, metric in PREDICTION_METRICS_MULTILABEL.items():
context.optimization_info.pipeline_metrics[metric_name] = metric(
context.data_handler.test_labels(), predictions,
context.data_handler.test_labels(),
predictions,
)

return context
Expand Down
2 changes: 1 addition & 1 deletion autointent/metrics/scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def scoring_neg_coverage(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -

def scoring_neg_ranking_loss(labels: LABELS_VALUE_TYPE, scores: SCORES_VALUE_TYPE) -> float:
"""
supports multilabel.
Supports multilabel.
Compute the average number of label pairs that are incorrectly ordered given y_score
weighted by the size of the label set and the number of labels not in the label set.
Expand Down
8 changes: 4 additions & 4 deletions tests/context/datahandler/test_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def test_sample_validation(label):
"validation_1": mock_split(),
"test": mock_split(),
},
]
],
)
def test_dataset_initialization(mapping):
dataset = Dataset.from_dict(mapping)
Expand All @@ -151,8 +151,8 @@ def test_dataset_initialization(mapping):
{"train": mock_split(), "validation": mock_split(), "validation_0": mock_split()},
{"train": mock_split(), "validation": mock_split(), "validation_1": mock_split()},
{"train": mock_split(), "validation": mock_split(), "validation_0": mock_split(), "validation_1": mock_split()},
{"train": mock_split(), "oos": mock_split()}
]
{"train": mock_split(), "oos": mock_split()},
],
)
def test_dataset_validation(mapping):
with pytest.raises(ValueError):
Expand All @@ -169,7 +169,7 @@ def test_dataset_validation(mapping):
"test": [{"utterance": "Hello!", "label": 0}],
},
{"train": [{"utterance": "Hello!"}]},
]
],
)
def test_intents_validation(mapping):
with pytest.raises(ValueError):
Expand Down

0 comments on commit 3dcb1c0

Please sign in to comment.