diff --git a/orangecontrib/spectroscopy/widgets/owspectra.py b/orangecontrib/spectroscopy/widgets/owspectra.py index 06a877259..4558d0bc0 100644 --- a/orangecontrib/spectroscopy/widgets/owspectra.py +++ b/orangecontrib/spectroscopy/widgets/owspectra.py @@ -2,12 +2,16 @@ import sys from collections import defaultdict import random +import time import warnings from xml.sax.saxutils import escape try: import dask import dask.array as da + import dask.distributed + dask_client = dask.distributed.Client(processes=False, n_workers=2, + threads_per_worker=4, dashboard_address=None) except ImportError: dask = None @@ -304,6 +308,8 @@ def compute_averages(data: Orange.data.Table, color_var, subset_indices, def progress_interrupt(i: float): if state.is_interruption_requested(): + if future: + future.cancel() raise InterruptException def _split_by_color_value(data, color_var): @@ -324,7 +330,12 @@ def _split_by_color_value(data, color_var): results = [] + future = None + + is_dask = dask and isinstance(data.X, dask.array.Array) + dsplit = _split_by_color_value(data, color_var) + compute_dask = [] for colorv, indices in dsplit.items(): for part in [None, "subset", "selection"]: progress_interrupt(0) @@ -335,11 +346,11 @@ def _split_by_color_value(data, color_var): elif part == "subset": part_selection = indices & subset_indices if np.any(part_selection): - if dask and isinstance(data.X, da.Array): + if is_dask: subset = data.X[part_selection] - std = da.nanstd(subset, axis=0) - mean = da.nanmean(subset, axis=0) - std, mean = dask.compute(std, mean) + compute_dask.extend([da.nanstd(subset, axis=0), + da.nanmean(subset, axis=0)]) + std, mean = None, None else: std = apply_columns_numpy(data.X, lambda x: bottleneck.nanstd(x, axis=0), @@ -349,7 +360,19 @@ def _split_by_color_value(data, color_var): lambda x: bottleneck.nanmean(x, axis=0), part_selection, callback=progress_interrupt) - results.append((colorv, part, mean, std, part_selection)) + results.append([colorv, part, mean, std, part_selection]) + + if is_dask: + future = dask_client.compute(dask.array.vstack(compute_dask)) + while not future.done(): + progress_interrupt(0) + time.sleep(0.1) + if future.cancelled(): + return + computed = future.result() + for i, lr in enumerate(results): + lr[2] = computed[i*2] + lr[3] = computed[i*2+1] progress_interrupt(0) return results @@ -412,13 +435,28 @@ def show(self): @staticmethod def compute_curves(x, ys, sampled_indices, state: TaskState): + is_dask = dask and isinstance(ys, dask.array.Array) + def progress_interrupt(i: float): if state.is_interruption_requested(): + if future: + future.cancel() raise InterruptException + future = None + progress_interrupt(0) - ys = np.asarray(ys[sampled_indices]) + ys = ys[sampled_indices] + if is_dask: + future = dask_client.compute(ys) + while not future.done(): + progress_interrupt(0) + time.sleep(0.1) + if future.cancelled(): + return + ys = future.result() ys[np.isinf(ys)] = np.nan # remove infs that could ruin display + progress_interrupt(0) return x, ys, sampled_indices