Skip to content

Commit

Permalink
Merge branch 'master' into op_replace_any
Browse files Browse the repository at this point in the history
  • Loading branch information
sivanravidos authored Jan 22, 2024
2 parents f6ee01d + 995d85a commit f40df75
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions fuse/utils/ndict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

import copy
import types
import numpy
import torch
from numpy import ndarray
from torch import Tensor
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -273,7 +273,7 @@ def pop(self, key: str) -> Any:
del self[key]
return res

def indices(self, indices: numpy.ndarray) -> dict:
def indices(self, indices: ndarray) -> dict:
"""
Extract the specified indices from each element in the dictionary (if possible)
Expand All @@ -286,7 +286,7 @@ def indices(self, indices: numpy.ndarray) -> dict:
for key in all_keys:
try:
value = self[key]
if isinstance(value, (numpy.ndarray, torch.Tensor)):
if isinstance(value, (ndarray, Tensor)):
new_value = value[indices]
elif isinstance(value, Sequence):
new_value = [item for i, item in enumerate(value) if indices[i]]
Expand Down

0 comments on commit f40df75

Please sign in to comment.