Skip to content

Commit

Permalink
Make NDict (hopefully) lighter to import (#342)
Browse files Browse the repository at this point in the history
sync

Co-authored-by: Sagi Polaczek <[email protected]>
  • Loading branch information
SagiPolaczek and Sagi Polaczek authored Jan 22, 2024
1 parent e91e0c5 commit 995d85a
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions fuse/utils/ndict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

import copy
import types
import numpy
import torch
from numpy import ndarray
from torch import Tensor
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -273,7 +273,7 @@ def pop(self, key: str) -> Any:
del self[key]
return res

def indices(self, indices: numpy.ndarray) -> dict:
def indices(self, indices: ndarray) -> dict:
"""
Extract the specified indices from each element in the dictionary (if possible)
Expand All @@ -286,7 +286,7 @@ def indices(self, indices: numpy.ndarray) -> dict:
for key in all_keys:
try:
value = self[key]
if isinstance(value, (numpy.ndarray, torch.Tensor)):
if isinstance(value, (ndarray, Tensor)):
new_value = value[indices]
elif isinstance(value, Sequence):
new_value = [item for i, item in enumerate(value) if indices[i]]
Expand Down

0 comments on commit 995d85a

Please sign in to comment.