Skip to content

Commit

Permalink
Merge pull request #6528 from markotoplak/concurrent-pca
Browse files Browse the repository at this point in the history
[ENH] PCA widget runs a separate thread
  • Loading branch information
markotoplak authored Aug 11, 2023
2 parents 94b28dc + e94858f commit 9eb6490
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 46 deletions.
34 changes: 28 additions & 6 deletions Orange/projection/base.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import warnings

import copy
import inspect
import threading
Expand All @@ -9,6 +11,7 @@
from Orange.data.util import SharedComputeValue, get_unique_names
from Orange.misc.wrapper_meta import WrapperMeta
from Orange.preprocess import RemoveNaNRows
from Orange.util import dummy_callback, wrap_callback, OrangeDeprecationWarning
import Orange.preprocess

__all__ = ["LinearCombinationSql", "Projector", "Projection", "SklProjector",
Expand Down Expand Up @@ -44,17 +47,36 @@ def fit(self, X, Y=None):
raise NotImplementedError(
"Classes derived from Projector must overload method fit")

def __call__(self, data):
data = self.preprocess(data)
def __call__(self, data, progress_callback=None):
if progress_callback is None:
progress_callback = dummy_callback
progress_callback(0, "Preprocessing...")
try:
cb = wrap_callback(progress_callback, end=0.1)
data = self.preprocess(data, progress_callback=cb)
except TypeError:
data = self.preprocess(data)
warnings.warn("A keyword argument 'progress_callback' has been "
"added to the preprocess() signature. Implementing "
"the method without the argument is deprecated and "
"will result in an error in the future.",
OrangeDeprecationWarning, stacklevel=2)
self.domain = data.domain
progress_callback(0.1, "Fitting...")
clf = self.fit(data.X, data.Y)
clf.pre_domain = data.domain
clf.name = self.name
progress_callback(1)
return clf

def preprocess(self, data):
for pp in self.preprocessors:
def preprocess(self, data, progress_callback=None):
if progress_callback is None:
progress_callback = dummy_callback
n_pps = len(self.preprocessors)
for i, pp in enumerate(self.preprocessors):
progress_callback(i / n_pps)
data = pp(data)
progress_callback(1)
return data

# Projectors implemented using `fit` access the `domain` through the
Expand Down Expand Up @@ -208,8 +230,8 @@ def _get_sklparams(self, values):
raise TypeError("Wrapper does not define '__wraps__'")
return params

def preprocess(self, data):
data = super().preprocess(data)
def preprocess(self, data, progress_callback=None):
data = super().preprocess(data, progress_callback)
if any(v.is_discrete and len(v.values) > 2
for v in data.domain.attributes):
raise ValueError("Wrapped scikit-learn methods do not support "
Expand Down
81 changes: 51 additions & 30 deletions Orange/widgets/unsupervised/owpca.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from Orange.preprocess import preprocess
from Orange.projection import PCA
from Orange.widgets import widget, gui, settings
from Orange.widgets.utils.concurrent import ConcurrentWidgetMixin
from Orange.widgets.utils.slidergraph import SliderGraph
from Orange.widgets.utils.widgetpreview import WidgetPreview
from Orange.widgets.widget import Input, Output
Expand All @@ -21,7 +22,7 @@
LINE_NAMES = ["component variance", "cumulative variance"]


class OWPCA(widget.OWWidget):
class OWPCA(widget.OWWidget, ConcurrentWidgetMixin):
name = "PCA"
description = "Principal component analysis with a scree-diagram."
icon = "icons/PCA.svg"
Expand Down Expand Up @@ -57,13 +58,13 @@ class Error(widget.OWWidget.Error):

def __init__(self):
super().__init__()
self.data = None
ConcurrentWidgetMixin.__init__(self)

self.data = None
self._pca = None
self._transformed = None
self._variance_ratio = None
self._cumulative = None
self._init_projector()

# Components Selection
form = QFormLayout()
Expand Down Expand Up @@ -114,6 +115,7 @@ def __init__(self):

@Inputs.data
def set_data(self, data):
self.cancel()
self.clear_messages()
self.clear()
self.information()
Expand All @@ -138,40 +140,57 @@ def set_data(self, data):
self.clear_outputs()
return

self._init_projector()

self.data = data
self.fit()

def fit(self):
self.cancel()
self.clear()
self.Warning.trivial_components.clear()
if self.data is None:
return

data = self.data

if self.normalize:
self._pca_projector.preprocessors = \
self._pca_preprocessors + [preprocess.Normalize(center=False)]
else:
self._pca_projector.preprocessors = self._pca_preprocessors
projector = self._create_projector()

if not isinstance(data, SqlTable):
pca = self._pca_projector(data)
variance_ratio = pca.explained_variance_ratio_
cumulative = numpy.cumsum(variance_ratio)

if numpy.isfinite(cumulative[-1]):
self.components_spin.setRange(0, len(cumulative))
self._pca = pca
self._variance_ratio = variance_ratio
self._cumulative = cumulative
self._setup_plot()
else:
self.Warning.trivial_components()
self.start(self._call_projector, data, projector)

@staticmethod
def _call_projector(data: Table, projector, state):

def callback(i: float, status=""):
state.set_progress_value(i * 100)
if status:
state.set_status(status)
if state.is_interruption_requested():
raise Exception # pylint: disable=broad-exception-raised

return projector(data, progress_callback=callback)

def on_done(self, result):
pca = result
variance_ratio = pca.explained_variance_ratio_
cumulative = numpy.cumsum(variance_ratio)

if numpy.isfinite(cumulative[-1]):
self.components_spin.setRange(0, len(cumulative))
self._pca = pca
self._variance_ratio = variance_ratio
self._cumulative = cumulative
self._setup_plot()
else:
self.Warning.trivial_components()

self.commit.now()

self.commit.now()
def on_partial_result(self, result):
pass

def onDeleteWidget(self):
self.shutdown()
super().onDeleteWidget()

def clear(self):
self._pca = None
Expand All @@ -184,7 +203,7 @@ def clear_outputs(self):
self.Outputs.transformed_data.send(None)
self.Outputs.data.send(None)
self.Outputs.components.send(None)
self.Outputs.pca.send(self._pca_projector)
self.Outputs.pca.send(self._create_projector())

def _setup_plot(self):
if self._pca is None:
Expand Down Expand Up @@ -251,10 +270,13 @@ def _update_normalize(self):
if self.data is None:
self._invalidate_selection()

def _init_projector(self):
self._pca_projector = PCA(n_components=MAX_COMPONENTS, random_state=0)
self._pca_projector.component = self.ncomponents
self._pca_preprocessors = PCA.preprocessors
def _create_projector(self):
projector = PCA(n_components=MAX_COMPONENTS, random_state=0)
projector.component = self.ncomponents # for use as a Scorer
if self.normalize:
projector.preprocessors = \
PCA.preprocessors + [preprocess.Normalize(center=False)]
return projector

def _nselected_components(self):
"""Return the number of selected components."""
Expand Down Expand Up @@ -338,11 +360,10 @@ def commit(self):
numpy.hstack((self.data.metas, transformed.X)),
ids=self.data.ids)

self._pca_projector.component = self.ncomponents
self.Outputs.transformed_data.send(transformed)
self.Outputs.components.send(components)
self.Outputs.data.send(data)
self.Outputs.pca.send(self._pca_projector)
self.Outputs.pca.send(self._create_projector())

def send_report(self):
if self.data is None:
Expand Down
23 changes: 13 additions & 10 deletions Orange/widgets/unsupervised/tests/test_owpca.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

from Orange.data import Table, Domain, ContinuousVariable, TimeVariable
from Orange.preprocess import preprocess
from Orange.preprocess.preprocess import Normalize
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.tests.utils import table_dense_sparse, possible_duplicate_table
from Orange.widgets.unsupervised.owpca import OWPCA
Expand All @@ -33,7 +32,7 @@ def test_constant_data(self):
# Ignore the warning: the test checks whether the widget shows
# Warning.trivial_components when this happens
with np.errstate(invalid="ignore"):
self.send_signal(self.widget.Inputs.data, data)
self.send_signal(self.widget.Inputs.data, data, wait=5000)
self.assertTrue(self.widget.Warning.trivial_components.is_shown())
self.assertIsNone(self.get_output(self.widget.Outputs.transformed_data))
self.assertIsNone(self.get_output(self.widget.Outputs.components))
Expand All @@ -56,12 +55,12 @@ def test_limit_components(self):
X = np.random.RandomState(0).rand(101, 101)
data = Table.from_numpy(None, X)
self.widget.ncomponents = 100
self.send_signal(self.widget.Inputs.data, data)
self.send_signal(self.widget.Inputs.data, data, wait=5000)
tran = self.get_output(self.widget.Outputs.transformed_data)
self.assertEqual(len(tran.domain.attributes), 100)
self.widget.ncomponents = 101 # should not be accesible
with self.assertRaises(IndexError):
self.send_signal(self.widget.Inputs.data, data)
self.widget._setup_plot() # pylint: disable=protected-access

def test_migrate_settings_limits_components(self):
settings = dict(ncomponents=10)
Expand All @@ -84,9 +83,11 @@ def test_variance_shown(self):
self.send_signal(self.widget.Inputs.data, self.iris)
self.widget.maxp = 2
self.widget._setup_plot()
self.wait_until_finished()
var2 = self.widget.variance_covered
self.widget.ncomponents = 3
self.widget._update_selection_component_spin()
self.wait_until_finished()
var3 = self.widget.variance_covered
self.assertGreater(var3, var2)

Expand All @@ -98,10 +99,11 @@ def test_unique_domain_components(self):

def test_variance_attr(self):
self.widget.ncomponents = 2
self.send_signal(self.widget.Inputs.data, self.iris)
self.send_signal(self.widget.Inputs.data, self.iris, wait=5000)
self.wait_until_stop_blocking()
self.widget._variance_ratio = np.array([0.5, 0.25, 0.2, 0.05])
self.widget.commit.now()
self.wait_until_finished()

result = self.get_output(self.widget.Outputs.transformed_data)
pc1, pc2 = result.domain.attributes
Expand Down Expand Up @@ -162,17 +164,17 @@ def test_normalize_data(self, prepare_table):
# Enable checkbox
self.widget.controls.normalize.setChecked(True)
self.assertTrue(self.widget.controls.normalize.isChecked())
with patch.object(preprocess, "Normalize", wraps=Normalize) as normalize:
self.send_signal(self.widget.Inputs.data, data)
with patch.object(preprocess.Normalize, "__call__", wraps=lambda x: x) as normalize:
self.send_signal(self.widget.Inputs.data, data, wait=5000)
self.wait_until_stop_blocking()
self.assertTrue(self.widget.controls.normalize.isEnabled())
normalize.assert_called_once()

# Disable checkbox
self.widget.controls.normalize.setChecked(False)
self.assertFalse(self.widget.controls.normalize.isChecked())
with patch.object(preprocess, "Normalize", wraps=Normalize) as normalize:
self.send_signal(self.widget.Inputs.data, data)
with patch.object(preprocess.Normalize, "__call__", wraps=lambda x: x) as normalize:
self.send_signal(self.widget.Inputs.data, data, wait=5000)
self.wait_until_stop_blocking()
self.assertTrue(self.widget.controls.normalize.isEnabled())
normalize.assert_not_called()
Expand All @@ -185,13 +187,14 @@ def test_normalization_variance(self, prepare_table):
# Enable normalization
self.widget.controls.normalize.setChecked(True)
self.assertTrue(self.widget.normalize)
self.send_signal(self.widget.Inputs.data, data)
self.send_signal(self.widget.Inputs.data, data, wait=5000)
self.wait_until_stop_blocking()
variance_normalized = self.widget.variance_covered

# Disable normalization
self.widget.controls.normalize.setChecked(False)
self.assertFalse(self.widget.normalize)
self.wait_until_finished()
self.wait_until_stop_blocking()
variance_unnormalized = self.widget.variance_covered

Expand Down

0 comments on commit 9eb6490

Please sign in to comment.