From e4370a00839e1a725424d106aa86779f7c2b2254 Mon Sep 17 00:00:00 2001 From: RichieHakim Date: Fri, 31 May 2024 13:34:15 -0400 Subject: [PATCH] Update `estimate_array_size` function to handle sparse arrays --- bnpm/misc.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/bnpm/misc.py b/bnpm/misc.py index cd6105b..9231643 100644 --- a/bnpm/misc.py +++ b/bnpm/misc.py @@ -7,6 +7,7 @@ from contextlib import contextmanager, ExitStack import numpy as np +import scipy.sparse def estimate_array_size( @@ -22,6 +23,9 @@ def estimate_array_size( RH 2021 Args: + array (np.ndarray or torch.Tensor or scipy.sparse): + array to estimate size of. If supplied, then 'numel' and + 'input_shape' are ignored numel (int): number of elements in the array. If None, then 'input_shape' is used instead @@ -48,7 +52,11 @@ def estimate_array_size( ## Either array supplied or numel or input_shape assert sum([(array is not None), (numel is not None), (input_shape is not None)]) == 1, 'Exactly one of array, numel, or input_shape must be supplied' - if array is not None: + if scipy.sparse.issparse(array): + assert hasattr(array, 'shape'), f'array must have a shape attribute' + numel = array.nnz + bitsize = array.dtype.itemsize * 8 + elif array is not None: assert hasattr(array, 'shape'), f'array must have a shape attribute' input_shape = array.shape