Skip to content

Commit

Permalink
tqdm.auto, Dataset_numpy, setting allow_overwrite to False by default
Browse files Browse the repository at this point in the history
  • Loading branch information
RichieHakim committed Aug 28, 2024
1 parent 687106c commit b7eafe7
Show file tree
Hide file tree
Showing 9 changed files with 100 additions and 17 deletions.
2 changes: 1 addition & 1 deletion bnpm/ca2p_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import scipy.signal
import matplotlib.pyplot as plt
import torch
from tqdm import tqdm
from tqdm.auto import tqdm

import time
import gc
Expand Down
5 changes: 3 additions & 2 deletions bnpm/decomposition.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,11 +762,12 @@ def fit(self, X):
X.moveaxis(self.batch_dimension, 0),
torch.arange(X.shape[self.batch_dimension]),
)
self.kwargs_dataloader.pop('batch_size', None)
kwargs_dataloader_tmp = copy.deepcopy(self.kwargs_dataloader)
kwargs_dataloader_tmp.pop('batch_size', None)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=self.batch_size,
**self.kwargs_dataloader,
**kwargs_dataloader_tmp,
)

kwargs_tmp = copy.deepcopy(self.kwargs_CP_NN_HALS)
Expand Down
14 changes: 7 additions & 7 deletions bnpm/file_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pathlib import Path
import zipfile

from tqdm import tqdm
from tqdm.auto import tqdm

from . import path_helpers

Expand Down Expand Up @@ -92,7 +92,7 @@ def prepare_path(path, mkdir=False, exist_ok=True):
return str(path_obj)

### Custom functions for preparing paths for saving and loading files and directories
def prepare_filepath_for_saving(filepath, mkdir=False, allow_overwrite=True):
def prepare_filepath_for_saving(filepath, mkdir=False, allow_overwrite=False):
return prepare_path(filepath, mkdir=mkdir, exist_ok=allow_overwrite)
def prepare_filepath_for_loading(filepath, must_exist=True):
path = prepare_path(filepath, mkdir=False, exist_ok=must_exist)
Expand All @@ -116,7 +116,7 @@ def pickle_save(
mode='wb',
zipCompress=False,
mkdir=False,
allow_overwrite=True,
allow_overwrite=False,
library='pickle',
**kwargs_zipfile,
):
Expand Down Expand Up @@ -221,7 +221,7 @@ def pickle_load(
return pickle.load(f)


def json_save(obj, filepath, indent=4, mode='w', mkdir=False, allow_overwrite=True):
def json_save(obj, filepath, indent=4, mode='w', mkdir=False, allow_overwrite=False):
"""
Saves an object to a json file.
Uses json.dump.
Expand Down Expand Up @@ -267,7 +267,7 @@ def json_load(filepath, mode='r'):
return json.load(f)


def yaml_save(obj, filepath, indent=4, mode='w', mkdir=False, allow_overwrite=True):
def yaml_save(obj, filepath, indent=4, mode='w', mkdir=False, allow_overwrite=False):
"""
Saves an object to a yaml file.
Uses yaml.dump.
Expand Down Expand Up @@ -363,7 +363,7 @@ def matlab_save(
obj,
filepath,
mkdir=False,
allow_overwrite=True,
allow_overwrite=False,
clean_string=True,
list_to_objArray=True,
none_to_nan=True,
Expand Down Expand Up @@ -470,7 +470,7 @@ def download_file(
hash_type='MD5',
hash_hex=None,
mkdir=False,
allow_overwrite=True,
allow_overwrite=False,
write_mode='wb',
verbose=True,
chunk_size=1024,
Expand Down
2 changes: 1 addition & 1 deletion bnpm/neural_networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import torch
import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm

class RegressionRNN(torch.nn.Module):
"""
Expand Down
2 changes: 1 addition & 1 deletion bnpm/parallel_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import multiprocessing as mp
from functools import partial
import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm

class ParallelExecutionError(Exception):
"""
Expand Down
2 changes: 1 addition & 1 deletion bnpm/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,7 +1261,7 @@ def make_rsync_command(
Implemented by casperdcl here: https://github.com/tqdm/tqdm/issues/311#issuecomment-387066847
"""
try:
from tqdm import tqdm
from tqdm.auto import tqdm
except ImportError:
class _TqdmWrap(object):
# tqdm not installed - construct and return dummy/basic versions
Expand Down
2 changes: 1 addition & 1 deletion bnpm/similarity.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import scipy.optimize
from numba import njit, prange, jit
import torch
from tqdm import tqdm
from tqdm.auto import tqdm

from . import indexing, torch_helpers

Expand Down
86 changes: 84 additions & 2 deletions bnpm/torch_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import torch
from torch.utils.data import Dataset
import numpy as np
from tqdm import tqdm
from tqdm.auto import tqdm

from . import indexing
from . import misc
Expand Down Expand Up @@ -523,7 +523,89 @@ def __getitem__(
idx (int):
The index of the requested sample.
"""
return self.X[idx], idx
return self.X[idx]


class Dataset_numpy(Dataset):
"""
Creates a PyTorch dataset from a numpy array.
RH 2024
Args:
X (np.ndarray):
The data from which to create the dataset.
axis (int):
The dimension along which to sample the data.
device (str):
The device where the tensors will be stored.
dtype (torch.dtype):
The data type to use for the tensor.
Attributes:
X (np.ndarray or np.memmap):
The data from the numpy file.
n_samples (int):
The number of samples in the dataset.
Returns:
(torch.utils.data.Dataset):
A PyTorch dataset.
"""
def __init__(
self,
X: Union[np.ndarray, np.memmap],
axis: int = 0,
device: str = 'cpu',
dtype: torch.dtype = torch.float32,
):
"""
Initializes the Dataset_NumpyFile with the provided parameters.
"""
assert isinstance(X, (np.ndarray, np.memmap)), 'X must be a numpy array or memmap.'
self.X = X
self.n_samples = self.X.shape[axis]
self.is_memmap = isinstance(self.X, np.memmap)
self.axis = axis
self.device = device
self.dtype = dtype

def __len__(self) -> int:
"""
Returns the number of samples in the dataset.
Returns:
(int):
n_samples (int):
The number of samples in the dataset.
"""
return self.n_samples

def __getitem__(
self,
idx: int,
) -> Tuple[torch.Tensor, int]:
"""
Returns a single sample and its index from the dataset.
Args:
idx (int):
The index of the sample to return.
Returns:
sample (torch.Tensor):
The requested sample from the dataset.
"""
arr = np.take(self.X, idx, axis=self.axis)
if self.is_memmap:
arr = np.array(arr)
return torch.as_tensor(arr, dtype=self.dtype, device=self.device)

def close(self):
"""
Closes the numpy file.
"""
if self.is_memmap:
self.X.close()


class BatchRandomSampler(torch.utils.data.Sampler):
Expand Down
2 changes: 1 addition & 1 deletion bnpm/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torchvision
import numpy as np
import cv2
from tqdm import tqdm
from tqdm.auto import tqdm


###############################################################################
Expand Down

0 comments on commit b7eafe7

Please sign in to comment.