diff --git a/fuse/utils/ndict.py b/fuse/utils/ndict.py index a145f711..b849ca14 100644 --- a/fuse/utils/ndict.py +++ b/fuse/utils/ndict.py @@ -22,8 +22,8 @@ import copy import types -import numpy -import torch +from numpy import ndarray +from torch import Tensor from typing import ( Any, Callable, @@ -273,7 +273,7 @@ def pop(self, key: str) -> Any: del self[key] return res - def indices(self, indices: numpy.ndarray) -> dict: + def indices(self, indices: ndarray) -> dict: """ Extract the specified indices from each element in the dictionary (if possible) @@ -286,7 +286,7 @@ def indices(self, indices: numpy.ndarray) -> dict: for key in all_keys: try: value = self[key] - if isinstance(value, (numpy.ndarray, torch.Tensor)): + if isinstance(value, (ndarray, Tensor)): new_value = value[indices] elif isinstance(value, Sequence): new_value = [item for i, item in enumerate(value) if indices[i]]