From 90917948d3c097249e3a7c773eb3c014c247c273 Mon Sep 17 00:00:00 2001 From: "Ido Amos Ido.Amos@ibm.com" Date: Wed, 26 Jun 2024 10:17:23 -0400 Subject: [PATCH] fixed type annotation errors and formatting --- .../classification/metrics_classification_common.py | 12 ++++++------ fuse/eval/metrics/metrics_common.py | 5 +---- 2 files changed, 7 insertions(+), 10 deletions(-) diff --git a/fuse/eval/metrics/classification/metrics_classification_common.py b/fuse/eval/metrics/classification/metrics_classification_common.py index f1220d8c..33abb01b 100644 --- a/fuse/eval/metrics/classification/metrics_classification_common.py +++ b/fuse/eval/metrics/classification/metrics_classification_common.py @@ -365,9 +365,9 @@ def __init__( def mcc_wrapper( self, - pred: Optional[str] = None, - target: Optional[str] = None, - sample_weight: Optional[str] = None, + pred: Union[List, np.ndarray], + target: Union[List, np.ndarray], + sample_weight: Optional[Union[List, np.ndarray, None]] = None, **kwargs: dict, ) -> float: """ @@ -404,9 +404,9 @@ def __init__( def balanced_acc_wrapper( self, - pred: Optional[str] = None, - target: Optional[str] = None, - sample_weight: Optional[str] = None, + pred: Union[List, np.ndarray], + target: Union[List, np.ndarray], + sample_weight: Optional[Union[List, np.ndarray, None]] = None, **kwargs: dict, ) -> float: """ diff --git a/fuse/eval/metrics/metrics_common.py b/fuse/eval/metrics/metrics_common.py index 7893e1d2..464503ab 100644 --- a/fuse/eval/metrics/metrics_common.py +++ b/fuse/eval/metrics/metrics_common.py @@ -160,10 +160,7 @@ def collect(self, batch: Dict) -> None: batch_to_collect = {} for name, key in self._keys_to_collect.items(): - try: - value = batch[key] - except: - print(self._keys_to_collect) + value = batch[key] # collect distributed if dist.is_initialized() and self._collect_distributed: