From 03c860a91eb7d59c6078658091fa312739ee08be Mon Sep 17 00:00:00 2001 From: Sagi Polaczek Date: Thu, 18 Jan 2024 15:26:58 +0200 Subject: [PATCH] sync --- fuse/utils/ndict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fuse/utils/ndict.py b/fuse/utils/ndict.py index 565ce2d6..ddc861da 100644 --- a/fuse/utils/ndict.py +++ b/fuse/utils/ndict.py @@ -22,8 +22,8 @@ import copy import types -import numpy -import torch +from numpy import ndarray +from torch import Tensor from typing import ( Any, Callable, @@ -273,7 +273,7 @@ def pop(self, key: str) -> Any: del self[key] return res - def indices(self, indices: numpy.ndarray) -> dict: + def indices(self, indices: ndarray) -> dict: """ Extract the specified indices from each element in the dictionary (if possible) @@ -286,7 +286,7 @@ def indices(self, indices: numpy.ndarray) -> dict: for key in all_keys: try: value = self[key] - if isinstance(value, (numpy.ndarray, torch.Tensor)): + if isinstance(value, (ndarray, Tensor)): new_value = value[indices] elif isinstance(value, Sequence): new_value = [item for i, item in enumerate(value) if indices[i]]