From 31e83013625aa743c82f82e26bb876399580c936 Mon Sep 17 00:00:00 2001 From: Sebastian Niehus Date: Wed, 3 Apr 2024 10:30:55 +0200 Subject: [PATCH] WIP - Prepare class for label counts --- .../use_cases/classify/classify.py | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/src/intelligence_layer/use_cases/classify/classify.py b/src/intelligence_layer/use_cases/classify/classify.py index dfef894b4..3ec5ab2f4 100644 --- a/src/intelligence_layer/use_cases/classify/classify.py +++ b/src/intelligence_layer/use_cases/classify/classify.py @@ -63,14 +63,22 @@ class SingleLabelClassifyEvaluation(BaseModel): correct: bool +class SingleLabelClassifyLabelCounts(BaseModel): + ## TODO: DocString + count_correctly_assigned: int = 0 + count_incorrectly_assigned: int = 0 + + class AggregatedSingleLabelClassifyEvaluation(BaseModel): """The aggregated evaluation of a single label classify implementation against a dataset. Attributes: percentage_correct: Percentage of answers that were considered to be correct + assignments_per_label: Mapping stating how often the respective label was assigned correctly or incorrectly """ percentage_correct: float + assignments_per_label: Mapping[str, SingleLabelClassifyLabelCounts] class SingleLabelClassifyAggregationLogic( @@ -78,12 +86,16 @@ class SingleLabelClassifyAggregationLogic( SingleLabelClassifyEvaluation, AggregatedSingleLabelClassifyEvaluation ] ): - def aggregate( + def aggregate( # TODO: Need the specific labels here self, evaluations: Iterable[SingleLabelClassifyEvaluation] ) -> AggregatedSingleLabelClassifyEvaluation: acc = MeanAccumulator() + label_counts = Mapping[str, SingleLabelClassifyLabelCounts] for evaluation in evaluations: acc.add(1.0 if evaluation.correct else 0.0) + ## TODO: Add info to label counts + + return AggregatedSingleLabelClassifyEvaluation(percentage_correct=acc.extract())