From 1cee531ffaca496098bdde3e23f60bf15afd5aaf Mon Sep 17 00:00:00 2001 From: janezd Date: Mon, 11 Jul 2022 17:02:19 +0200 Subject: [PATCH] Split: Refactor for discrete values, add tests, rename --- .../prototypes/widgets/icons/Split.svg | 184 ---------------- .../widgets/icons/TextToColumns.svg | 33 +++ .../{owsplit.py => owtexttocolumns.py} | 109 ++++++---- .../widgets/tests/test_owtexttocolumns.py | 197 ++++++++++++++++++ 4 files changed, 301 insertions(+), 222 deletions(-) delete mode 100644 orangecontrib/prototypes/widgets/icons/Split.svg create mode 100644 orangecontrib/prototypes/widgets/icons/TextToColumns.svg rename orangecontrib/prototypes/widgets/{owsplit.py => owtexttocolumns.py} (56%) create mode 100644 orangecontrib/prototypes/widgets/tests/test_owtexttocolumns.py diff --git a/orangecontrib/prototypes/widgets/icons/Split.svg b/orangecontrib/prototypes/widgets/icons/Split.svg deleted file mode 100644 index 5594fb76..00000000 --- a/orangecontrib/prototypes/widgets/icons/Split.svg +++ /dev/null @@ -1,184 +0,0 @@ - - - - - - - - - - - - - - - - - - - - - - - - - diff --git a/orangecontrib/prototypes/widgets/icons/TextToColumns.svg b/orangecontrib/prototypes/widgets/icons/TextToColumns.svg new file mode 100644 index 00000000..bc42545f --- /dev/null +++ b/orangecontrib/prototypes/widgets/icons/TextToColumns.svg @@ -0,0 +1,33 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/orangecontrib/prototypes/widgets/owsplit.py b/orangecontrib/prototypes/widgets/owtexttocolumns.py similarity index 56% rename from orangecontrib/prototypes/widgets/owsplit.py rename to orangecontrib/prototypes/widgets/owtexttocolumns.py index ae8220b4..4097e506 100644 --- a/orangecontrib/prototypes/widgets/owsplit.py +++ b/orangecontrib/prototypes/widgets/owtexttocolumns.py @@ -1,3 +1,5 @@ +from functools import partial + import numpy as np from AnyQt.QtCore import Qt @@ -13,63 +15,88 @@ from orangewidget.settings import Setting +def get_substrings(values, delimiter): + return sorted({ss.strip() for s in values for ss in s.split(delimiter)} + - {""}) + + +# TODO: Replace with Table.get_column after merging and releasing +# https://github.com/biolab/orange3/pull/6058 +def get_column(data, attr): + if attr not in data.domain and attr.compute_value is not None: + return attr.compute_value(data) + return data.get_column_view(attr)[0] + + class SplitColumn: def __init__(self, data, attr, delimiter): self.attr = attr self.delimiter = delimiter + column = set(get_column(data, self.attr)) + self.new_values = tuple(get_substrings(column, self.delimiter)) - column = self.get_string_values(data, self.attr) - values = [s.split(self.delimiter) for s in column] - self.new_values = tuple(sorted({val if val else "?" for vals in - values for val in vals})) + def __call__(self, data): + column = get_column(data, self.attr) + values = [{ss.strip() for ss in s.split(self.delimiter)} + for s in column] + return {v: np.array([i for i, xs in enumerate(values) if v in xs]) + for v in self.new_values} def __eq__(self, other): - return self.attr == other.attr and self.delimiter == \ - other.delimiter and self.new_values == other.new_values + return self.attr == other.attr \ + and self.delimiter == other.delimiter \ + and self.new_values == other.new_values def __hash__(self): return hash((self.attr, self.delimiter, self.new_values)) - def __call__(self, data): - column = self.get_string_values(data, self.attr) - values = [set(s.split(self.delimiter)) for s in column] - shared_data = {v: [i for i, xs in enumerate(values) if v in xs] for v - in self.new_values} - return shared_data - - @staticmethod - def get_string_values(data, var): - # turn discrete to string variable - column = data.get_column_view(var)[0] - if var.is_discrete: - return [var.str_val(x) for x in column] - return column - class OneHotStrings(SharedComputeValue): - def __init__(self, fn, new_feature): super().__init__(fn) self.new_feature = new_feature - def __eq__(self, other): - return self.compute_shared == other.compute_shared \ - and self.new_feature == other.new_feature - - def __hash__(self): - return hash((self.compute_shared, self.new_feature)) - def compute(self, data, shared_data): indices = shared_data[self.new_feature] col = np.zeros(len(data)) col[indices] = 1 return col + def __eq__(self, other): + return super().__eq__(other) and self.new_feature == other.new_feature -class OWSplit(OWWidget): - name = "Split" - description = "Split string variables to create discrete." - icon = "icons/Split.svg" + def __hash__(self): + return super().__hash__() ^ hash(self.new_feature) + + +class OneHotDiscrete: + def __init__(self, variable, delimiter, value): + self.variable = variable + self.value = value + self.delimiter = delimiter + + def __call__(self, data): + column = get_column(data, self.variable).astype(float) + col = np.zeros(len(column)) + col[np.isnan(column)] = np.nan + for val_idx, value in enumerate(self.variable.values): + if self.value in value.split(self.delimiter): + col[column == val_idx] = 1 + return col + + def __eq__(self, other): + return self.variable == other.variable \ + and self.value == other.value \ + and self.delimiter == other.delimiter + + def __hash__(self): + return hash((self.variable, self.value, self.delimiter)) + + +class OWTextToColumns(OWWidget): + name = "Text to Columns" + description = "Split text or categorical variables into binary indicators" + icon = "icons/TextToColumns.svg" priority = 700 class Inputs: @@ -129,12 +156,18 @@ def apply(self): return var = self.data.domain[self.attribute] - sc = SplitColumn(self.data, var, self.delimiter) + if var.is_discrete: + values = get_substrings(var.values, self.delimiter) + computer = partial(OneHotDiscrete, var, self.delimiter) + else: + sc = SplitColumn(self.data, var, self.delimiter) + values = sc.new_values + computer = partial(OneHotStrings, sc) + names = get_unique_names(self.data.domain, values, equal_numbers=False) new_columns = tuple(DiscreteVariable( - get_unique_names(self.data.domain, v), values=("0", "1"), - compute_value=OneHotStrings(sc, v) - ) for v in sc.new_values) + name, values=("0", "1"), compute_value=computer(value) + ) for value, name in zip(values, names)) new_domain = Domain( self.data.domain.attributes + new_columns, @@ -145,5 +178,5 @@ def apply(self): if __name__ == "__main__": # pragma: no cover - WidgetPreview(OWSplit).run(Table.from_file( + WidgetPreview(OWTextToColumns).run(Table.from_file( "tests/orange-in-education.tab")) diff --git a/orangecontrib/prototypes/widgets/tests/test_owtexttocolumns.py b/orangecontrib/prototypes/widgets/tests/test_owtexttocolumns.py new file mode 100644 index 00000000..3fca2264 --- /dev/null +++ b/orangecontrib/prototypes/widgets/tests/test_owtexttocolumns.py @@ -0,0 +1,197 @@ +# pylint: disable=missing-docstring,unsubscriptable-object +import os +import unittest + +import numpy as np + +from Orange.data import Table, StringVariable, Domain, DiscreteVariable +from Orange.widgets.tests.base import WidgetTest + +from orangecontrib.prototypes.widgets.owtexttocolumns import \ + OWTextToColumns, SplitColumn, get_substrings, OneHotStrings, OneHotDiscrete + + +class TestComputation(unittest.TestCase): + def setUp(self): + domain = Domain([DiscreteVariable("x", values=("a c d", "bb d"))], None, + [StringVariable("foo"), StringVariable("bar")]) + self.data = Table.from_numpy( + domain, + np.array([1, 0, np.nan])[:, None], None, + [["a,bbb,d", "e;f o"], ["", "f o"], ["bbb,d", "e;a;o"]] + ) + + def test_get_string_values(self): + np.testing.assert_equal( + set(get_substrings({"a bc", "d,e", "", "f,a t", "t"}, " ")), + {"a", "bc", "d,e", "f,a", "t"}) + np.testing.assert_equal( + set(get_substrings({"a bc", "d,e", "", "f,a t", "t"}, ",")), + {"a bc", "d", "e", "f", "a t", "t"}) + + def test_split_column(self): + sc = SplitColumn(self.data, self.data.domain.metas[0], ",") + shared = sc(self.data) + self.assertEqual(set(sc.new_values), {"a", "bbb", "d"}) + self.assertEqual(set(shared), set(sc.new_values)) + np.testing.assert_equal(shared["a"], [0]) + np.testing.assert_equal(shared["bbb"], [0, 2]) + np.testing.assert_equal(shared["d"], [0, 2]) + + sc = SplitColumn(self.data, self.data.domain.metas[1], ";") + shared = sc(self.data) + self.assertEqual(set(sc.new_values), {"a", "e", "f o", "o"}) + self.assertEqual(set(shared), set(sc.new_values)) + np.testing.assert_equal(shared["a"], [2]) + np.testing.assert_equal(shared["e"], [0, 2]) + np.testing.assert_equal(shared["f o"], [0, 1]) + np.testing.assert_equal(shared["o"], [2]) + + def test_one_hot_strings(self): + attr = self.data.domain.metas[0] + sc = SplitColumn(self.data, attr, ",") + + oh = OneHotStrings(sc, "a") + np.testing.assert_equal(oh(self.data), [1, 0, 0]) + + oh = OneHotStrings(sc, "bbb") + np.testing.assert_equal(oh(self.data), [1, 0, 1]) + + data = Table.from_numpy( + Domain([], None, [attr]), + np.zeros((5, 0)), None, + np.array(["bbb,x,y", "", "bbb", "bbb,a", "foo"])[:, None]) + np.testing.assert_equal(oh(data), [1, 0, 1, 1, 0]) + + def test_one_hot_discrete(self): + attr = self.data.domain.attributes[0] + + oh = OneHotDiscrete(attr, " ", "a") + np.testing.assert_equal(oh(self.data), [0, 1, np.nan]) + + oh = OneHotDiscrete(attr, " ", "d") + np.testing.assert_equal(oh(self.data), [1, 1, np.nan]) + + data = Table.from_numpy( + Domain([attr], None), + np.array([1, 0, 1, 0, np.nan])[:, None]) + + oh = OneHotDiscrete(attr, " ", "a") + np.testing.assert_equal(oh(data), [0, 1, 0, 1, np.nan]) + + oh = OneHotDiscrete(attr, " ", "d") + np.testing.assert_equal(oh(data), [1, 1, 1, 1, np.nan]) + + def test_discrete_metas(self): + attr = DiscreteVariable("x", values=("a c d", "bb d")) + domain = Domain([], None, [attr]) + data = Table.from_numpy(domain, np.zeros((3, 0)), None, + np.array([1, 0, np.nan])[:, None]) + oh = OneHotDiscrete(attr, " ", "a") + np.testing.assert_equal(oh(data), [0, 1, np.nan]) + + +class TestOWTextToColumns(WidgetTest): + def setUp(self): + self.widget = self.create_widget(OWTextToColumns) + test_path = os.path.dirname(os.path.abspath(__file__)) + self.data = Table.from_file(os.path.join(test_path, "orange-in-education.tab")) + self._create_simple_corpus() + + def _set_attr(self, attr, widget=None): + if widget is None: + widget = self.widget + attr_combo = widget.controls.attribute + idx = attr_combo.model().indexOf(attr) + attr_combo.setCurrentIndex(idx) + attr_combo.activated.emit(idx) + + def _create_simple_corpus(self) -> None: + """ + Create a simple dataset with 4 documents. + """ + metas = np.array( + [ + ["foo,"], + ["bar,baz "], + ["foo,bar"], + [""], + ] + ) + text_var = StringVariable("foo") + domain = Domain([], metas=[text_var]) + self.small_table = Table.from_numpy( + domain, + X=np.empty((len(metas), 0)), + metas=metas, + ) + + def test_data(self): + """Basic functionality""" + self.send_signal(self.widget.Inputs.data, self.data) + self._set_attr(self.data.domain.attributes[1]) + output = self.get_output(self.widget.Outputs.data) + self.assertEqual(len(output.domain.attributes), + len(self.data.domain.attributes) + 3) + + def test_empty_data(self): + """Do not crash on empty data""" + self.send_signal(self.widget.Inputs.data, None) + + def test_discrete(self): + """No crash on data attributes of different types""" + self.send_signal(self.widget.Inputs.data, self.data) + self.assertEqual(self.widget.attribute, self.data.domain.metas[1]) + self._set_attr(self.data.domain.attributes[1]) + self.assertEqual(self.widget.attribute, self.data.domain.attributes[1]) + + def test_numeric_only(self): + """Error raised when only numeric variables given""" + housing = Table.from_file("housing") + self.send_signal(self.widget.Inputs.data, housing) + self.assertTrue(self.widget.Warning.no_disc.is_shown()) + + def test_split_nonexisting(self): + """Test splitting when delimiter doesn't exist""" + self.widget.delimiter = "|" + self.send_signal(self.widget.Inputs.data, self.data) + new_cols = set(self.data.get_column_view("Country")[0]) + self.assertFalse(any(self.widget.delimiter in v for v in new_cols)) + self.assertEqual(len(self.get_output( + self.widget.Outputs.data).domain.attributes), + len(self.data.domain.attributes) + len(new_cols)) + + def test_output_string(self): + "Test outputs; at the same time, test for duplicate variables" + self.widget.delimiter = "," + self.send_signal(self.widget.Inputs.data, self.small_table) + out = self.get_output(self.widget.Outputs.data) + self.assertEqual([attr.name for attr in out.domain.attributes], + ["bar", "baz", "foo (1)"]) + np.testing.assert_equal(out.X, + [[0, 0, 1], + [1, 1, 0], + [1, 0, 1], + [0, 0, 0]]) + + def test_output_discrete(self): + self.widget.delimiter = " " + attr = DiscreteVariable("x", values=("bar foo", "bar baz", "crux")) + data = Table.from_numpy( + Domain([attr], None), + np.array([1, 1, 0, 1, 2, np.nan])[:, None], None) + self.send_signal(self.widget.Inputs.data, data) + out = self.get_output(self.widget.Outputs.data) + self.assertEqual([attr.name for attr in out.domain.attributes], + ["x", "bar", "baz", "crux", "foo"]) + np.testing.assert_equal(out.X, + [[1, 1, 1, 0, 0], + [1, 1, 1, 0, 0], + [0, 1, 0, 0, 1], + [1, 1, 1, 0, 0], + [2, 0, 0, 1, 0], + [np.nan, np.nan, np.nan, np.nan, np.nan]]) + + +if __name__ == "__main__": + unittest.main()