Skip to content

Commit

Permalink
Merge branch 'op_replace_any' of github.com:BiomedSciAI/fuse-med-ml i…
Browse files Browse the repository at this point in the history
…nto op_replace_any
  • Loading branch information
Sivan Ravid committed Jan 28, 2024
2 parents d8935cb + 7e0dd96 commit e7b6de2
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 4 deletions.
6 changes: 6 additions & 0 deletions fuse/utils/file_io/file_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def save_dataframe(df: pd.DataFrame, filename: str, **kwargs: dict) -> None:

assert file_type in [
"csv",
"tsv",
"hd5",
"hdf5",
"hdf",
Expand All @@ -304,6 +305,8 @@ def save_dataframe(df: pd.DataFrame, filename: str, **kwargs: dict) -> None:
df.to_pickle(filename, **kwargs)
elif file_type == "csv":
df.to_csv(filename, **kwargs)
elif file_type == "tsv":
df.to_csv(filename, sep="\t" ** kwargs)
elif file_type in ["hd5", "hdf5", "hdf"]:
df.to_hdf(filename, **kwargs)
elif file_type == "xslx":
Expand All @@ -322,6 +325,7 @@ def read_dataframe(filename: str) -> pd.DataFrame:

assert file_type in [
"csv",
"tsv",
"hd5",
"hdf5",
"hdf",
Expand All @@ -334,6 +338,8 @@ def read_dataframe(filename: str) -> pd.DataFrame:
df = pd.read_pickle(filename)
elif file_type == "csv":
df = pd.read_csv(filename)
elif file_type == "tsv":
df = pd.read_csv(filename, sep="\t")
elif file_type in ["hd5", "hdf5", "hdf"]:
df = pd.read_hdf(filename)
elif file_type == "xslx":
Expand Down
8 changes: 4 additions & 4 deletions fuse/utils/ndict.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@

import copy
import types
import numpy
import torch
from numpy import ndarray
from torch import Tensor
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -273,7 +273,7 @@ def pop(self, key: str) -> Any:
del self[key]
return res

def indices(self, indices: numpy.ndarray) -> dict:
def indices(self, indices: ndarray) -> dict:
"""
Extract the specified indices from each element in the dictionary (if possible)
Expand All @@ -286,7 +286,7 @@ def indices(self, indices: numpy.ndarray) -> dict:
for key in all_keys:
try:
value = self[key]
if isinstance(value, (numpy.ndarray, torch.Tensor)):
if isinstance(value, (ndarray, Tensor)):
new_value = value[indices]
elif isinstance(value, Sequence):
new_value = [item for i, item in enumerate(value) if indices[i]]
Expand Down

0 comments on commit e7b6de2

Please sign in to comment.