diff --git a/src/pie_modules/taskmodules/metrics/common.py b/src/pie_modules/taskmodules/metrics/common.py index 7d36868e9..18c8361eb 100644 --- a/src/pie_modules/taskmodules/metrics/common.py +++ b/src/pie_modules/taskmodules/metrics/common.py @@ -3,7 +3,7 @@ from typing import Dict, Optional import torch -from torch import LongTensor +from torch import LongTensor, Tensor from torchmetrics import Metric logger = logging.getLogger(__name__) @@ -27,10 +27,12 @@ def get_counts(self, key_prefix: str = "counts_") -> Dict[Optional[str], LongTen result = {} for k, v in self.metric_state.items(): if k.startswith(key_prefix): - if not isinstance(v, LongTensor): + if not isinstance(v, Tensor): raise ValueError( f"Expected metric state for key {k} to be a LongTensor, but got {type(v)}." ) + if not isinstance(v, LongTensor): + v = v.long() key = k[len(key_prefix) :] or None result[key] = v return result