From 8e2d845f4e32a0f04cf07a68cc15a7cb4cd9262b Mon Sep 17 00:00:00 2001 From: Arne Binder Date: Tue, 7 Nov 2023 23:59:51 +0100 Subject: [PATCH] make mypy happy --- src/pie_datasets/statistics.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pie_datasets/statistics.py b/src/pie_datasets/statistics.py index 3c0c8850..0a1ae0e4 100644 --- a/src/pie_datasets/statistics.py +++ b/src/pie_datasets/statistics.py @@ -1,6 +1,6 @@ import logging from collections import defaultdict -from typing import Any, Dict, List, Optional, Type, Union +from typing import Any, Callable, Dict, List, Optional, Type, Union from pytorch_ie.annotations import Span from pytorch_ie.core import Document, DocumentStatistic @@ -110,7 +110,7 @@ def __init__( f"this collector, but be aware that the results may be wrong for your own aggregation " f"functions that rely on zero values." ) - self.aggregation_functions = { + self.aggregation_functions: Dict[str, Callable[[List], Any]] = { name: func for name, func in self.aggregation_functions.items() if name not in ["mean", "std", "min"] @@ -226,7 +226,7 @@ def __init__( f"this collector, but be aware that the results may be wrong for your own aggregation " f"functions that rely on zero values." ) - self.aggregation_functions = { + self.aggregation_functions: Dict[str, Callable[[List], Any]] = { name: func for name, func in self.aggregation_functions.items() if name not in ["mean", "std", "min"]