Skip to content

Commit

Permalink
Parameter Fitter: Tests
Browse files Browse the repository at this point in the history
  • Loading branch information
VesnaT committed Oct 23, 2024
1 parent 42c9de4 commit 566bc51
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 19 deletions.
21 changes: 5 additions & 16 deletions Orange/widgets/evaluate/owparameterfitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,11 @@ def update_setters(self):
def update_grid(**settings):
self.grid_settings.update(**settings)
self.master.showGrid(
x=self.grid_settings[self.SHOW_GRID_LABEL],
y=self.grid_settings[self.SHOW_GRID_LABEL],
x=False, y=self.grid_settings[self.SHOW_GRID_LABEL],
alpha=self.grid_settings[Updater.ALPHA_LABEL] / 255)

self._setters[self.PLOT_BOX] = {self.GRID_LABEL: update_grid}

@property
def title_item(self):
return self.master.getPlotItem().titleLabel

@property
def axis_items(self):
return [value["item"] for value in
Expand All @@ -159,8 +154,7 @@ def __init__(self):
self.setMouseEnabled(False, False)
self.hideButtons()

self.showGrid(False, True)
self.showGrid(y=self.parameter_setter.DEFAULT_SHOW_GRID,
self.showGrid(x=False, y=self.parameter_setter.DEFAULT_SHOW_GRID,
alpha=self.parameter_setter.DEFAULT_ALPHA_GRID / 255)

self.tooltip_delegate = HelpEventDelegate(self.help_event)
Expand Down Expand Up @@ -284,7 +278,6 @@ class Inputs:
auto_commit = Setting(True)

class Error(OWWidget.Error):
domain_transform_err = Msg("{}")
unknown_err = Msg("{}")
not_enough_data = Msg(f"At least {N_FOLD} instances are needed.")
incompatible_learner = Msg("{}")
Expand Down Expand Up @@ -409,7 +402,6 @@ def handleNewSignals(self):
self.Warning.no_parameters.clear()
self.Error.incompatible_learner.clear()
self.Error.unknown_err.clear()
self.Error.domain_transform_err.clear()
self.clear()
if self._data is None or self._learner is None:
return
Expand Down Expand Up @@ -454,8 +446,8 @@ def _set_range_controls(self):
self.__spin_max.setMinimum(-MIN_MAX_SPIN)
self.minimum = self.initial_parameters[param.parameter_name]
if param.max is not None:
self.__spin_min.setMaximum(param.min)
self.__spin_max.setMaximum(param.min)
self.__spin_min.setMaximum(param.max)
self.__spin_max.setMaximum(param.max)
self.maximum = param.max
else:
self.__spin_min.setMaximum(MIN_MAX_SPIN)
Expand Down Expand Up @@ -484,10 +476,7 @@ def on_done(self, result: FitterResults):
self.graph.set_data(*result)

def on_exception(self, ex: Exception):
if isinstance(ex, DomainTransformationError):
self.Error.domain_transform_err(ex)
else:
self.Error.unknown_err(ex)
self.Error.unknown_err(ex)

def on_partial_result(self, _):
pass
Expand Down
188 changes: 185 additions & 3 deletions Orange/widgets/evaluate/tests/test_owparameterfitter.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,31 @@
# pylint: disable=missing-docstring,protected-access
import unittest
from unittest.mock import patch, Mock

import pyqtgraph as pg

from AnyQt.QtCore import QPointF
from AnyQt.QtGui import QFont
from AnyQt.QtWidgets import QToolTip

from Orange.classification import NaiveBayesLearner
from Orange.data import Table
from Orange.modelling import RandomForestLearner
from Orange.regression import PLSRegressionLearner
from Orange.widgets.evaluate.owparameterfitter import OWParameterFitter
from Orange.widgets.model.owrandomforest import OWRandomForest
from Orange.widgets.tests.base import WidgetTest
from Orange.widgets.tests.utils import simulate


class DummyLearner(PLSRegressionLearner):
def fitted_parameters(self):
return [
self.FittedParameter("n_components", "Foo", "foo", int, 1, None),
self.FittedParameter("n_components", "Bar", "bar", int, 1, 10),
self.FittedParameter("n_components", "Baz", "baz", int, None, 10),
self.FittedParameter("n_components", "Qux", "qux", int, None, None)
]


class TestOWParameterFitter(WidgetTest):
Expand All @@ -17,6 +36,8 @@ def setUpClass(cls):
cls._housing = Table("housing")
cls._naive_bayes = NaiveBayesLearner()
cls._pls = PLSRegressionLearner()
cls._rf = RandomForestLearner()
cls._dummy = DummyLearner()

def setUp(self):
self.widget = self.create_widget(OWParameterFitter)
Expand Down Expand Up @@ -48,21 +69,18 @@ def test_random_forest(self):

self.send_signal(self.widget.Inputs.learner, learner)
self.assertFalse(self.widget.Warning.no_parameters.is_shown())
self.assertFalse(self.widget.Error.domain_transform_err.is_shown())
self.assertFalse(self.widget.Error.unknown_err.is_shown())
self.assertFalse(self.widget.Error.not_enough_data.is_shown())
self.assertFalse(self.widget.Error.incompatible_learner.is_shown())

self.send_signal(self.widget.Inputs.data, self._heart)
self.assertFalse(self.widget.Warning.no_parameters.is_shown())
self.assertFalse(self.widget.Error.domain_transform_err.is_shown())
self.assertFalse(self.widget.Error.unknown_err.is_shown())
self.assertFalse(self.widget.Error.not_enough_data.is_shown())
self.assertFalse(self.widget.Error.incompatible_learner.is_shown())

self.send_signal(self.widget.Inputs.data, self._housing)
self.assertFalse(self.widget.Warning.no_parameters.is_shown())
self.assertFalse(self.widget.Error.domain_transform_err.is_shown())
self.assertFalse(self.widget.Error.unknown_err.is_shown())
self.assertFalse(self.widget.Error.not_enough_data.is_shown())
self.assertFalse(self.widget.Error.incompatible_learner.is_shown())
Expand All @@ -77,6 +95,31 @@ def test_plot(self):
x = self.widget.graph._FitterPlot__bar_item_cv.opts["x"]
self.assertEqual(list(x), [0.2, 1.2])

@patch.object(QToolTip, "showText")
def test_tooltip(self, show_text):
graph = self.widget.graph

self.assertFalse(self.widget.graph.help_event(Mock()))
self.assertIsNone(show_text.call_args)

self.send_signal(self.widget.Inputs.data, self._housing)
self.send_signal(self.widget.Inputs.learner, self._pls)
self.wait_until_finished()

for item in graph.items():
if isinstance(item, pg.BarGraphItem):
item.mapFromScene = Mock(return_value=QPointF(0.2, 0.2))

self.assertTrue(self.widget.graph.help_event(Mock()))
self.assertIn("Train:", show_text.call_args[0][1])
self.assertIn("CV:", show_text.call_args[0][1])

for item in graph.items():
if isinstance(item, pg.BarGraphItem):
item.mapFromScene = Mock(return_value=QPointF(0.5, 0.5))
self.assertFalse(self.widget.graph.help_event(Mock()))


def test_manual_steps(self):
self.send_signal(self.widget.Inputs.data, self._housing)
self.send_signal(self.widget.Inputs.learner, self._pls)
Expand Down Expand Up @@ -106,6 +149,145 @@ def test_steps_preview(self):
self.wait_until_finished()
self.assertEqual(self.widget.preview, "[10, 15, 20, 25]")

def test_on_parameter_changed(self):
self.send_signal(self.widget.Inputs.data, self._housing)
self.send_signal(self.widget.Inputs.learner, self._dummy)
self.wait_until_finished()

self.widget.commit.deferred = Mock()

for i in range(1, 4):
self.widget.commit.deferred.reset_mock()
simulate.combobox_activate_index(
self.widget.controls.parameter_index, i)
self.wait_until_finished()
self.widget.commit.deferred.assert_called_once()

def test_not_enough_data(self):
self.send_signal(self.widget.Inputs.data, self._housing[:5])
self.send_signal(self.widget.Inputs.learner, self._pls)
self.wait_until_finished()
self.assertTrue(self.widget.Error.not_enough_data.is_shown())
self.send_signal(self.widget.Inputs.data, None)
self.assertFalse(self.widget.Error.not_enough_data.is_shown())

def test_unknown_err(self):
self.send_signal(self.widget.Inputs.data, Table("iris")[:50])
self.send_signal(self.widget.Inputs.learner, self._rf)
self.wait_until_finished()
self.assertTrue(self.widget.Error.unknown_err.is_shown())
self.send_signal(self.widget.Inputs.data, None)
self.assertFalse(self.widget.Error.unknown_err.is_shown())

def test_fitted_parameters(self):
self.assertEqual(self.widget.fitted_parameters, [])

self.send_signal(self.widget.Inputs.data, self._housing)
self.assertEqual(self.widget.fitted_parameters, [])

self.send_signal(self.widget.Inputs.learner, self._pls)
self.assertEqual(len(self.widget.fitted_parameters), 1)
self.wait_until_finished()

self.send_signal(self.widget.Inputs.data, None)
self.assertEqual(self.widget.fitted_parameters, [])

def test_initial_parameters(self):
self.assertEqual(self.widget.initial_parameters, {})

self.send_signal(self.widget.Inputs.data, self._housing)
self.assertEqual(self.widget.initial_parameters, {})

self.send_signal(self.widget.Inputs.learner, self._pls)
self.assertEqual(len(self.widget.initial_parameters), 3)
self.wait_until_finished()

self.send_signal(self.widget.Inputs.learner, self._rf)
self.assertEqual(len(self.widget.initial_parameters), 13)
self.wait_until_finished()

self.send_signal(self.widget.Inputs.data, None)
self.assertEqual(self.widget.initial_parameters, {})

def test_saved_workflow(self):
self.send_signal(self.widget.Inputs.data, self._housing)
self.send_signal(self.widget.Inputs.learner, self._dummy)
self.wait_until_finished()
simulate.combobox_activate_index(
self.widget.controls.parameter_index, 2)
self.widget.controls.minimum.setValue(3)
self.widget.controls.maximum.setValue(6)
self.wait_until_finished()

settings = self.widget.settingsHandler.pack_data(self.widget)
widget = self.create_widget(OWParameterFitter,
stored_settings=settings)
self.send_signal(widget.Inputs.data, self._housing, widget=widget)
self.send_signal(widget.Inputs.learner, self._dummy, widget=widget)
self.wait_until_finished(widget=widget)
self.assertEqual(widget.controls.parameter_index.currentText(), "Baz")
self.assertEqual(widget.minimum, 3)
self.assertEqual(widget.maximum, 6)

def test_visual_settings(self):
graph = self.widget.graph

def test_settings():
font = QFont("Helvetica", italic=True, pointSize=20)
for item in graph.parameter_setter.axis_items:
self.assertFontEqual(item.label.font(), font)
font.setPointSize(15)
for item in graph.parameter_setter.axis_items:
self.assertFontEqual(item.style["tickFont"], font)
font.setPointSize(17)
for legend_item in graph.parameter_setter.legend_items:
self.assertFontEqual(legend_item[1].item.font(), font)
self.assertFalse(graph.getAxis("left").grid)

key, value = ("Fonts", "Font family", "Font family"), "Helvetica"
self.widget.set_visual_settings(key, value)

key, value = ("Fonts", "Axis title", "Font size"), 20
self.widget.set_visual_settings(key, value)
key, value = ("Fonts", "Axis title", "Italic"), True
self.widget.set_visual_settings(key, value)

key, value = ("Fonts", "Axis ticks", "Font size"), 15
self.widget.set_visual_settings(key, value)
key, value = ("Fonts", "Axis ticks", "Italic"), True
self.widget.set_visual_settings(key, value)

key, value = ("Fonts", "Legend", "Font size"), 17
self.widget.set_visual_settings(key, value)
key, value = ("Fonts", "Legend", "Italic"), True
self.widget.set_visual_settings(key, value)

key, value = ("Figure", "Gridlines", "Show"), False
self.widget.set_visual_settings(key, value)
key, value = ("Figure", "Gridlines", "Opacity"), 20
self.widget.set_visual_settings(key, value)

test_settings()

self.send_signal(self.widget.Inputs.learner, self._pls)
self.send_signal(self.widget.Inputs.data, self._heart[:10])
test_settings()

self.send_signal(self.widget.Inputs.data, None)
self.send_signal(self.widget.Inputs.learner, None)

self.send_signal(self.widget.Inputs.learner, self._pls)
self.send_signal(self.widget.Inputs.data, self._heart[:10])
test_settings()

def assertFontEqual(self, font1: QFont, font2: QFont):
self.assertEqual(font1.family(), font2.family())
self.assertEqual(font1.pointSize(), font2.pointSize())
self.assertEqual(font1.italic(), font2.italic())

def test_send_report(self):
self.assertEqual(1, 2)


if __name__ == "__main__":
unittest.main()

0 comments on commit 566bc51

Please sign in to comment.