Skip to content

Commit

Permalink
refactor/bugfix n_thread definiton
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasmueboe committed Aug 16, 2024
1 parent acf65df commit ac11b9e
Show file tree
Hide file tree
Showing 10 changed files with 85 additions and 72 deletions.
32 changes: 29 additions & 3 deletions sainsc/_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import functools
import os
from typing import NoReturn
from typing import Callable, NoReturn, ParamSpec, TypeVar

import numpy as np
import pandas as pd
Expand All @@ -9,15 +10,40 @@


def _get_n_cpus() -> int:
return len(os.sched_getaffinity(0))
available_cpus = len(os.sched_getaffinity(0))
return min(available_cpus, 32)


P = ParamSpec("P")
T = TypeVar("T")


def _validate_n_threads(n_threads: int | None) -> int:
if n_threads is None:
n_threads = 0
if n_threads < 0:
raise ValueError("`n_threads` must be >= 0.")
else:
return n_threads if n_threads > 0 else _get_n_cpus()


def validate_threads(func: Callable[P, T]) -> Callable[P, T]:
@functools.wraps(func)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> T:
n_threads = kwargs.get("n_threads", 0)
assert n_threads is None or isinstance(n_threads, int)
kwargs["n_threads"] = _validate_n_threads(n_threads)
return func(*args, **kwargs)

return wrapper


def _get_coordinate_index(
x: NDArray[np.integer],
y: NDArray[np.integer],
*,
name: str | None = None,
n_threads: int = 1,
n_threads: int | None = None,
) -> pd.Index:
x_i32: NDArray[np.int32] = x.astype(np.int32, copy=False)
y_i32: NDArray[np.int32] = y.astype(np.int32, copy=False)
Expand Down
10 changes: 5 additions & 5 deletions sainsc/_utils_rust.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,8 @@ class GridCounts:
resolution : float, optional
Resolution as nm / pixel.
n_threads : int, optional
Number of threads used for reading and processing file. If `None` this will
default to the number of logical CPUs.
Number of threads used for processing. If `None` or 0 this will default to
the number of logical CPUs.
Raises
------
Expand Down Expand Up @@ -128,8 +128,8 @@ class GridCounts:
Resolution of each coordinate unit in nm. The default is 1,000 i.e. measurements
are in um.
n_threads : int, optional
Number of threads used for initializing :py:class:`sainsc.LazyKDE`.
If `None` this will default to the number of logical CPUs.
Number of threads used for processing. If `None` or 0 this will default to
the number of logical CPUs.
Returns
-------
Expand Down Expand Up @@ -265,4 +265,4 @@ class GridCounts:
"""

@n_threads.setter
def n_threads(self, n_threads: int): ...
def n_threads(self, n_threads: int | None): ...
30 changes: 11 additions & 19 deletions sainsc/io/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from scipy.sparse import csr_matrix

from .._typealias import _PathLike
from .._utils import _get_coordinate_index, _get_n_cpus, _raise_module_load_error
from .._utils import _get_coordinate_index, _raise_module_load_error, validate_threads
from ..lazykde import LazyKDE
from ._io_utils import (
_bin_coordinates,
Expand Down Expand Up @@ -97,6 +97,7 @@ def _prepare_gem_dataframe(
return df.select(["gene", "x", "y", "count"])


@validate_threads
def read_gem_file(
filepath: _PathLike, *, sep: str = "\t", n_threads: int | None = None, **kwargs
) -> pl.DataFrame:
Expand All @@ -115,7 +116,7 @@ def read_gem_file(
sep : str, optional
Separator used in :py:func:`polars.read_csv`.
n_threads : int, optional
Number of threads used for reading and processing file. If `None` this will
Number of threads used for reading file and processing. If `None` or 0 this will
default to the number of available CPUs.
kwargs
Other keyword arguments will be passed to :py:func:`polars.read_csv`.
Expand All @@ -130,9 +131,6 @@ def read_gem_file(
If count column has an unknown name.
"""

if n_threads is None:
n_threads = _get_n_cpus()

df = pl.read_csv(
Path(filepath),
separator=sep,
Expand All @@ -150,6 +148,7 @@ def read_gem_file(
return df


@validate_threads
def read_StereoSeq(
filepath: _PathLike,
*,
Expand Down Expand Up @@ -180,7 +179,7 @@ def read_StereoSeq(
sep : str, optional
Separator used in :py:func:`polars.read_csv`.
n_threads : int, optional
Number of threads used for reading and processing file. If `None` this will
Number of threads used for reading file and processing. If `None` or 0 this will
default to the number of available CPUs.
kwargs
Other keyword arguments will be passed to :py:func:`polars.read_csv`.
Expand All @@ -189,8 +188,6 @@ def read_StereoSeq(
-------
sainsc.LazyKDE
"""
if n_threads is None:
n_threads = _get_n_cpus()

df = read_gem_file(filepath, sep=sep, n_threads=n_threads, **kwargs)
df = _prepare_gem_dataframe(df, exon_count=exon_count, gene_name=gene_name)
Expand All @@ -213,6 +210,7 @@ def read_StereoSeq(
"""Patterns for Xenium controls"""


@validate_threads
def read_Xenium(
filepath: _PathLike,
*,
Expand All @@ -236,16 +234,13 @@ def read_Xenium(
'is_gene' column, as well.
column.
n_threads : int | None, optional
Number of threads used for reading and processing file. If `None` this will
Number of threads used for reading file and processing. If `None` or 0 this will
default to the number of available CPUs.
Returns
-------
sainsc.LazyKDE
"""
if n_threads is None:
n_threads = _get_n_cpus()

filepath = Path(filepath)
columns = list(_XENIUM_COLUMNS.keys())

Expand Down Expand Up @@ -284,6 +279,7 @@ def read_Xenium(
"""Patterns for Vizgen controls"""


@validate_threads
def read_Vizgen(
filepath: _PathLike,
*,
Expand All @@ -304,15 +300,13 @@ def read_Vizgen(
List of regex patterns to filter the 'gene' column,
:py:attr:`sainsc.io.VIZGEN_CTRLS` by default.
n_threads : int | None, optional
Number of threads used for reading and processing file. If `None` this will
Number of threads used for reading file and processing. If `None` or 0 this will
default to the number of available CPUs.
Returns
-------
sainsc.LazyKDE
"""
if n_threads is None:
n_threads = _get_n_cpus()

transcripts = pl.read_csv(
Path(filepath),
Expand All @@ -331,6 +325,7 @@ def read_Vizgen(
# Binned data


@validate_threads
def read_StereoSeq_bins(
filepath: _PathLike,
bin_size: int = 50,
Expand Down Expand Up @@ -368,7 +363,7 @@ def read_StereoSeq_bins(
sep : str, optional
Separator used in :py:func:`polars.read_csv`.
n_threads : int, optional
Number of threads used for reading and processing file. If `None` this will
Number of threads used for reading file and processing. If `None` or 0 this will
default to the number of available CPUs.
kwargs
Other keyword arguments will be passed to :py:func:`polars.read_csv`.
Expand All @@ -384,9 +379,6 @@ def read_StereoSeq_bins(
ModuleNotFoundError
If `spatialdata` is set to `True` but the package is not installed.
"""
if n_threads is None:
n_threads = _get_n_cpus()

df = read_gem_file(filepath, sep=sep, n_threads=n_threads, **kwargs)
df = _prepare_gem_dataframe(df, exon_count=exon_count, gene_name=gene_name)
df = _bin_coordinates(df, bin_size)
Expand Down
2 changes: 1 addition & 1 deletion sainsc/io/_io_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _bin_coordinates(df: pl.DataFrame, bin_size: float) -> pl.DataFrame:


def _categorical_coordinate(
x: NDArray[np.int32], y: NDArray[np.int32], *, n_threads: int = 1
x: NDArray[np.int32], y: NDArray[np.int32], *, n_threads: int | None = None
) -> tuple[NDArray[np.int32], NDArray[np.int32]]:
assert len(x) == len(y)

Expand Down
38 changes: 16 additions & 22 deletions sainsc/lazykde/_LazyKDE.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing_extensions import Self

from .._typealias import _Cmap, _Csx, _CsxArray, _Local_Max, _RangeTuple2D
from .._utils import _get_n_cpus, _raise_module_load_error
from .._utils import _raise_module_load_error, _validate_n_threads, validate_threads
from .._utils_rust import (
GridCounts,
cosinef32_and_celltypei8,
Expand Down Expand Up @@ -53,31 +53,27 @@ class LazyKDE:
of data in memory.
"""

def __init__(
self,
counts: GridCounts,
*,
n_threads: int | None = None,
):
@validate_threads
def __init__(self, counts: GridCounts, *, n_threads: int | None = None):
"""
Parameters
----------
counts : sainsc.GridCounts
Gene counts.
n_threads : int, optional
Number of threads used for reading and processing file. If `None` this will
default to the number of available CPUs.
Number of threads used for processing. If `None` or 0 this will default to
the number of available CPUs.
"""
if n_threads is None:
n_threads = _get_n_cpus()

self.counts: GridCounts = counts
"""
sainsc.GridCounts : Spatial gene counts.
"""

# n_threads is validated (decorator) and will be int
# but this can currently not be reflected in the type checker
assert isinstance(n_threads, int)
self.counts.n_threads = n_threads

self._threads = n_threads

self._kernel: NDArray[np.float32] | None = None
Expand All @@ -91,6 +87,7 @@ def __init__(
self._celltypes: list[str] | None = None

@classmethod
@validate_threads
def from_dataframe(
cls, df: pl.DataFrame | pd.DataFrame, *, n_threads: int | None = None, **kwargs
) -> Self:
Expand All @@ -105,8 +102,8 @@ def from_dataframe(
----------
df : polars.DataFrame | pandas.DataFrame
n_threads : int, optional
Number of threads used for reading and processing file. If `None` this will
default to the number of available CPUs.
Number of threads used for processing. If `None` or 0 this will default to
the number of available CPUs.
kwargs
Other keyword arguments are passed to
:py:meth:`sainsc.GridCounts.from_dataframe`.
Expand Down Expand Up @@ -1119,18 +1116,15 @@ def n_threads(self) -> int:
Raises
------
TypeError
If setting with a type other than `int` or less than 0.
ValueError
If setting with an `int` less than 0.
"""
return self._threads

@n_threads.setter
def n_threads(self, n_threads: int):
if isinstance(n_threads, int) and n_threads >= 0:
self._threads = n_threads if n_threads > 0 else _get_n_cpus()
self.counts.n_threads = self._threads
else:
raise TypeError("`n_threads` must be an `int` >= 0.")
def n_threads(self, n_threads: int | None):
self._threads = _validate_n_threads(n_threads)
self.counts.n_threads = self._threads

@property
def shape(self) -> tuple[int, int]:
Expand Down
12 changes: 6 additions & 6 deletions src/coordinates.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ pub fn coordinate_as_string<'py>(
y: PyReadonlyArray1<'py, CoordInt>,
n_threads: Option<usize>,
) -> PyResult<Bound<'py, PyArray1<PyFixedString<12>>>> {
match string_coordinate_index_(x.as_array(), y.as_array(), n_threads.unwrap_or(0)) {
match string_coordinate_index_(x.as_array(), y.as_array(), n_threads) {
Ok(string_coordinates) => Ok(string_coordinates
.map(|s| (*s).into())
.into_pyarray_bound(py)),
Expand All @@ -39,7 +39,7 @@ pub fn categorical_coordinate<'py>(
Bound<'py, PyArray1<CodeInt>>,
Bound<'py, PyArray2<CoordInt>>,
)> {
match categorical_coordinate_(x.as_array(), y.as_array(), n_threads.unwrap_or(0)) {
match categorical_coordinate_(x.as_array(), y.as_array(), n_threads) {
Ok((codes, coordinates)) => Ok((
codes.into_pyarray_bound(py),
coordinates.into_pyarray_bound(py),
Expand All @@ -54,7 +54,7 @@ pub fn categorical_coordinate<'py>(
fn string_coordinate_index_<'a, X, const N: usize>(
x: ArrayView1<'a, X>,
y: ArrayView1<'a, X>,
n_threads: usize,
n_threads: Option<usize>,
) -> Result<Array1<[u8; N]>, ThreadPoolBuildError>
where
X: Display,
Expand Down Expand Up @@ -82,7 +82,7 @@ where
fn categorical_coordinate_<'a, C, X>(
x: ArrayView1<'a, X>,
y: ArrayView1<'a, X>,
n_threads: usize,
n_threads: Option<usize>,
) -> Result<(Array1<C>, Array2<X>), ThreadPoolBuildError>
where
C: PrimInt + Sync + Send + AddAssign,
Expand Down Expand Up @@ -138,7 +138,7 @@ mod tests {
let a = array![0, 1, 99_999];
let b = array![0, 20, 99_999];

let a_b: Array1<[u8; 12]> = string_coordinate_index_(a.view(), b.view(), 1).unwrap();
let a_b: Array1<[u8; 12]> = string_coordinate_index_(a.view(), b.view(), None).unwrap();

let a_b_string: Vec<[u8; 12]> = vec![
[48, 95, 48, 0, 0, 0, 0, 0, 0, 0, 0, 0],
Expand All @@ -154,7 +154,7 @@ mod tests {
let b = array![0, 1, 0, 0, 1];

let (codes, coord): (Array1<i32>, Array2<i32>) =
categorical_coordinate_(a.view(), b.view(), 1).unwrap();
categorical_coordinate_(a.view(), b.view(), None).unwrap();

let codes_test = array![0, 1, 2, 0, 3];
let coord_test = array![[0, 0], [0, 1], [1, 0], [1, 1]];
Expand Down
4 changes: 2 additions & 2 deletions src/cosine.rs
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ macro_rules! build_cos_ct_fn {
counts.shape,
log,
chunk_size,
n_threads.unwrap_or(0),
n_threads
);

match cos_ct {
Expand All @@ -76,7 +76,7 @@ fn chunk_and_calculate_cosine<C, I, F, U>(
shape: (usize, usize),
log: bool,
chunk_size: (usize, usize),
n_threads: usize,
n_threads: Option<usize>,
) -> Result<(Array2<F>, Array2<F>, Array2<U>), Box<dyn Error>>
where
C: NumCast + Copy + Sync + Send + Default,
Expand Down
Loading

0 comments on commit ac11b9e

Please sign in to comment.