Skip to content

Commit

Permalink
Make use_cupy an argument, not global function
Browse files Browse the repository at this point in the history
  • Loading branch information
sevagh committed Sep 4, 2021
1 parent d527421 commit aff41a9
Showing 1 changed file with 12 additions and 12 deletions.
24 changes: 12 additions & 12 deletions museval/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,12 +65,6 @@
MAX_SOURCES = 100


# allows one to disable cupy even if its available
def disable_cupy():
global use_cupy
use_cupy = False


# fft plans take up space, you might need to call this between large tracks
def clear_cupy_cache():
# cupy disable fft caching to free blocks
Expand Down Expand Up @@ -159,7 +153,8 @@ def bss_eval(reference_sources, estimated_sources,
compute_permutation=False,
filters_len=512,
framewise_filters=False,
bsseval_sources_version=False
bsseval_sources_version=False,
use_cupy=False
):
"""BSS_EVAL version 4.
Expand Down Expand Up @@ -220,6 +215,10 @@ def bss_eval(reference_sources, estimated_sources,
those to also be zeroed in the references, and hence not evaluated,
artificially boosting results. For this reason, SiSEC always uses
the `bss_eval_images` version, corresponding to ``False``.
use_cupy: bool, optional
if ``True``, uses CuPy to speed up BSS evaluation. You may need to call
``clear_cupy_cache()`` between evaluations of large signals, which fill
up the GPU memory with some FFT caches in CuPy.
Returns
-------
Expand Down Expand Up @@ -286,13 +285,13 @@ def bss_eval(reference_sources, estimated_sources,
def compute_GsfC(win=slice(0, nsampl)):
# First compute the references correlations
G, sf = _compute_reference_correlations(
reference_sources[:, win], filters_len
reference_sources[:, win], filters_len, use_cupy=use_cupy
)
# compute the interference distortion filters
C = np.zeros((nsrc, nsrc, nchan, filters_len, nchan))
for jtrue in range(nsrc):
C[jtrue] = _compute_projection_filters(
G, sf, estimated_sources[jtrue, win]
G, sf, estimated_sources[jtrue, win], use_cupy=use_cupy
)
return (G, sf, C)

Expand All @@ -304,7 +303,8 @@ def compute_Cj(win=slice(0, nsampl)):
Cj[jtrue, jest] = _compute_projection_filters(
G[jtrue, jtrue],
sf[jtrue],
estimated_sources[jest, win]
estimated_sources[jest, win],
use_cupy=use_cupy
)
return Cj

Expand Down Expand Up @@ -540,7 +540,7 @@ def _reshape_G(G):
return G


def _compute_reference_correlations(reference_sources, filters_len):
def _compute_reference_correlations(reference_sources, filters_len, use_cupy=False):
"""Compute the inner products between delayed versions of reference_sources
reference is nsrc X nsamp X nchan.
Returns
Expand Down Expand Up @@ -592,7 +592,7 @@ def _compute_reference_correlations(reference_sources, filters_len):
return G, sf


def _compute_projection_filters(G, sf, estimated_source):
def _compute_projection_filters(G, sf, estimated_source, use_cupy=False):
"""Least-squares projection of estimated source on the subspace spanned by
delayed versions of reference sources, with delays between 0 and
filters_len-1
Expand Down

0 comments on commit aff41a9

Please sign in to comment.