From 995d85a2c411148d95d323140fa2966d2caea8b3 Mon Sep 17 00:00:00 2001 From: Sagi Polaczek <56922146+SagiPolaczek@users.noreply.github.com> Date: Mon, 22 Jan 2024 11:16:14 +0200 Subject: [PATCH 1/2] Make `NDict` (hopefully) lighter to import (#342) sync Co-authored-by: Sagi Polaczek --- fuse/utils/ndict.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/fuse/utils/ndict.py b/fuse/utils/ndict.py index a145f711..b849ca14 100644 --- a/fuse/utils/ndict.py +++ b/fuse/utils/ndict.py @@ -22,8 +22,8 @@ import copy import types -import numpy -import torch +from numpy import ndarray +from torch import Tensor from typing import ( Any, Callable, @@ -273,7 +273,7 @@ def pop(self, key: str) -> Any: del self[key] return res - def indices(self, indices: numpy.ndarray) -> dict: + def indices(self, indices: ndarray) -> dict: """ Extract the specified indices from each element in the dictionary (if possible) @@ -286,7 +286,7 @@ def indices(self, indices: numpy.ndarray) -> dict: for key in all_keys: try: value = self[key] - if isinstance(value, (numpy.ndarray, torch.Tensor)): + if isinstance(value, (ndarray, Tensor)): new_value = value[indices] elif isinstance(value, Sequence): new_value = [item for i, item in enumerate(value) if indices[i]] From a58d86983f217f3f23200a8cd032ca79b89901fe Mon Sep 17 00:00:00 2001 From: Sagi Polaczek <56922146+SagiPolaczek@users.noreply.github.com> Date: Sun, 28 Jan 2024 10:07:40 +0200 Subject: [PATCH 2/2] Support TSV files in `file_io` (#341) * done * done * support save tsv --------- Co-authored-by: Sagi Polaczek --- fuse/utils/file_io/file_io.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/fuse/utils/file_io/file_io.py b/fuse/utils/file_io/file_io.py index 99f349c4..8e28c531 100644 --- a/fuse/utils/file_io/file_io.py +++ b/fuse/utils/file_io/file_io.py @@ -291,6 +291,7 @@ def save_dataframe(df: pd.DataFrame, filename: str, **kwargs: dict) -> None: assert file_type in [ "csv", + "tsv", "hd5", "hdf5", "hdf", @@ -304,6 +305,8 @@ def save_dataframe(df: pd.DataFrame, filename: str, **kwargs: dict) -> None: df.to_pickle(filename, **kwargs) elif file_type == "csv": df.to_csv(filename, **kwargs) + elif file_type == "tsv": + df.to_csv(filename, sep="\t" ** kwargs) elif file_type in ["hd5", "hdf5", "hdf"]: df.to_hdf(filename, **kwargs) elif file_type == "xslx": @@ -322,6 +325,7 @@ def read_dataframe(filename: str) -> pd.DataFrame: assert file_type in [ "csv", + "tsv", "hd5", "hdf5", "hdf", @@ -334,6 +338,8 @@ def read_dataframe(filename: str) -> pd.DataFrame: df = pd.read_pickle(filename) elif file_type == "csv": df = pd.read_csv(filename) + elif file_type == "tsv": + df = pd.read_csv(filename, sep="\t") elif file_type in ["hd5", "hdf5", "hdf"]: df = pd.read_hdf(filename) elif file_type == "xslx":