Skip to content

Commit

Permalink
Split: Refactor for discrete values, add tests, rename
Browse files Browse the repository at this point in the history
  • Loading branch information
janezd committed Oct 23, 2022
1 parent 6b5c959 commit 1cee531
Show file tree
Hide file tree
Showing 4 changed files with 301 additions and 222 deletions.
184 changes: 0 additions & 184 deletions orangecontrib/prototypes/widgets/icons/Split.svg

This file was deleted.

33 changes: 33 additions & 0 deletions orangecontrib/prototypes/widgets/icons/TextToColumns.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import numpy as np

from AnyQt.QtCore import Qt
Expand All @@ -13,63 +15,88 @@
from orangewidget.settings import Setting


def get_substrings(values, delimiter):
return sorted({ss.strip() for s in values for ss in s.split(delimiter)}
- {""})


# TODO: Replace with Table.get_column after merging and releasing
# https://github.com/biolab/orange3/pull/6058
def get_column(data, attr):
if attr not in data.domain and attr.compute_value is not None:
return attr.compute_value(data)
return data.get_column_view(attr)[0]


class SplitColumn:
def __init__(self, data, attr, delimiter):
self.attr = attr
self.delimiter = delimiter
column = set(get_column(data, self.attr))
self.new_values = tuple(get_substrings(column, self.delimiter))

column = self.get_string_values(data, self.attr)
values = [s.split(self.delimiter) for s in column]
self.new_values = tuple(sorted({val if val else "?" for vals in
values for val in vals}))
def __call__(self, data):
column = get_column(data, self.attr)
values = [{ss.strip() for ss in s.split(self.delimiter)}
for s in column]
return {v: np.array([i for i, xs in enumerate(values) if v in xs])
for v in self.new_values}

def __eq__(self, other):
return self.attr == other.attr and self.delimiter == \
other.delimiter and self.new_values == other.new_values
return self.attr == other.attr \
and self.delimiter == other.delimiter \
and self.new_values == other.new_values

def __hash__(self):
return hash((self.attr, self.delimiter, self.new_values))

def __call__(self, data):
column = self.get_string_values(data, self.attr)
values = [set(s.split(self.delimiter)) for s in column]
shared_data = {v: [i for i, xs in enumerate(values) if v in xs] for v
in self.new_values}
return shared_data

@staticmethod
def get_string_values(data, var):
# turn discrete to string variable
column = data.get_column_view(var)[0]
if var.is_discrete:
return [var.str_val(x) for x in column]
return column


class OneHotStrings(SharedComputeValue):

def __init__(self, fn, new_feature):
super().__init__(fn)
self.new_feature = new_feature

def __eq__(self, other):
return self.compute_shared == other.compute_shared \
and self.new_feature == other.new_feature

def __hash__(self):
return hash((self.compute_shared, self.new_feature))

def compute(self, data, shared_data):
indices = shared_data[self.new_feature]
col = np.zeros(len(data))
col[indices] = 1
return col

def __eq__(self, other):
return super().__eq__(other) and self.new_feature == other.new_feature

class OWSplit(OWWidget):
name = "Split"
description = "Split string variables to create discrete."
icon = "icons/Split.svg"
def __hash__(self):
return super().__hash__() ^ hash(self.new_feature)


class OneHotDiscrete:
def __init__(self, variable, delimiter, value):
self.variable = variable
self.value = value
self.delimiter = delimiter

def __call__(self, data):
column = get_column(data, self.variable).astype(float)
col = np.zeros(len(column))
col[np.isnan(column)] = np.nan
for val_idx, value in enumerate(self.variable.values):
if self.value in value.split(self.delimiter):
col[column == val_idx] = 1
return col

def __eq__(self, other):
return self.variable == other.variable \
and self.value == other.value \
and self.delimiter == other.delimiter

def __hash__(self):
return hash((self.variable, self.value, self.delimiter))


class OWTextToColumns(OWWidget):
name = "Text to Columns"
description = "Split text or categorical variables into binary indicators"
icon = "icons/TextToColumns.svg"
priority = 700

class Inputs:
Expand Down Expand Up @@ -129,12 +156,18 @@ def apply(self):
return
var = self.data.domain[self.attribute]

sc = SplitColumn(self.data, var, self.delimiter)
if var.is_discrete:
values = get_substrings(var.values, self.delimiter)
computer = partial(OneHotDiscrete, var, self.delimiter)
else:
sc = SplitColumn(self.data, var, self.delimiter)
values = sc.new_values
computer = partial(OneHotStrings, sc)
names = get_unique_names(self.data.domain, values, equal_numbers=False)

new_columns = tuple(DiscreteVariable(
get_unique_names(self.data.domain, v), values=("0", "1"),
compute_value=OneHotStrings(sc, v)
) for v in sc.new_values)
name, values=("0", "1"), compute_value=computer(value)
) for value, name in zip(values, names))

new_domain = Domain(
self.data.domain.attributes + new_columns,
Expand All @@ -145,5 +178,5 @@ def apply(self):


if __name__ == "__main__": # pragma: no cover
WidgetPreview(OWSplit).run(Table.from_file(
WidgetPreview(OWTextToColumns).run(Table.from_file(
"tests/orange-in-education.tab"))
Loading

0 comments on commit 1cee531

Please sign in to comment.