Skip to content

Commit

Permalink
Merge pull request #49 from ArneBinder/fix_MetricWithArbitraryCounts_…
Browse files Browse the repository at this point in the history
…get_counts

convert values to `LongTensor` in `MetricWithArbitraryCounts.get_counts()`
  • Loading branch information
ArneBinder authored Jan 23, 2024
2 parents b2a09af + a79f71c commit d331d7d
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/pie_modules/taskmodules/metrics/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Dict, Optional

import torch
from torch import LongTensor
from torch import LongTensor, Tensor
from torchmetrics import Metric

logger = logging.getLogger(__name__)
Expand All @@ -27,10 +27,12 @@ def get_counts(self, key_prefix: str = "counts_") -> Dict[Optional[str], LongTen
result = {}
for k, v in self.metric_state.items():
if k.startswith(key_prefix):
if not isinstance(v, LongTensor):
if not isinstance(v, Tensor):
raise ValueError(
f"Expected metric state for key {k} to be a LongTensor, but got {type(v)}."
)
if not isinstance(v, LongTensor):
v = v.long()
key = k[len(key_prefix) :] or None
result[key] = v
return result

0 comments on commit d331d7d

Please sign in to comment.