Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split: Refactor for discrete values, add tests, rename to "Text to Columns" #253

Merged
merged 1 commit into from
Nov 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 0 additions & 184 deletions orangecontrib/prototypes/widgets/icons/Split.svg

This file was deleted.

33 changes: 33 additions & 0 deletions orangecontrib/prototypes/widgets/icons/TextToColumns.svg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from functools import partial

import numpy as np

from AnyQt.QtCore import Qt
Expand All @@ -13,63 +15,81 @@
from orangewidget.settings import Setting


def get_substrings(values, delimiter):
return sorted({ss.strip() for s in values for ss in s.split(delimiter)}
- {""})


class SplitColumn:
def __init__(self, data, attr, delimiter):
self.attr = attr
self.delimiter = delimiter
column = set(data.get_column(self.attr))
self.new_values = tuple(get_substrings(column, self.delimiter))

column = self.get_string_values(data, self.attr)
values = [s.split(self.delimiter) for s in column]
self.new_values = tuple(sorted({val if val else "?" for vals in
values for val in vals}))
def __call__(self, data):
column = data.get_column(self.attr)
values = [{ss.strip() for ss in s.split(self.delimiter)}
for s in column]
return {v: np.array([i for i, xs in enumerate(values) if v in xs])
for v in self.new_values}

def __eq__(self, other):
return self.attr == other.attr and self.delimiter == \
other.delimiter and self.new_values == other.new_values
return self.attr == other.attr \
and self.delimiter == other.delimiter \
and self.new_values == other.new_values

def __hash__(self):
return hash((self.attr, self.delimiter, self.new_values))

def __call__(self, data):
column = self.get_string_values(data, self.attr)
values = [set(s.split(self.delimiter)) for s in column]
shared_data = {v: [i for i, xs in enumerate(values) if v in xs] for v
in self.new_values}
return shared_data

@staticmethod
def get_string_values(data, var):
# turn discrete to string variable
column = data.get_column(var)
if var.is_discrete:
return [var.str_val(x) for x in column]
return column


class OneHotStrings(SharedComputeValue):

def __init__(self, fn, new_feature):
super().__init__(fn)
self.new_feature = new_feature

def __eq__(self, other):
return self.compute_shared == other.compute_shared \
and self.new_feature == other.new_feature

def __hash__(self):
return hash((self.compute_shared, self.new_feature))

def compute(self, data, shared_data):
indices = shared_data[self.new_feature]
col = np.zeros(len(data))
col[indices] = 1
return col

def __eq__(self, other):
return super().__eq__(other) and self.new_feature == other.new_feature

class OWSplit(OWWidget):
name = "Split"
description = "Split string variables to create discrete."
icon = "icons/Split.svg"
def __hash__(self):
return super().__hash__() ^ hash(self.new_feature)


class OneHotDiscrete:
def __init__(self, variable, delimiter, value):
self.variable = variable
self.value = value
self.delimiter = delimiter

def __call__(self, data):
column = data.get_column(self.variable).astype(float)
col = np.zeros(len(column))
col[np.isnan(column)] = np.nan
for val_idx, value in enumerate(self.variable.values):
if self.value in value.split(self.delimiter):
col[column == val_idx] = 1
return col

def __eq__(self, other):
return self.variable == other.variable \
and self.value == other.value \
and self.delimiter == other.delimiter

def __hash__(self):
return hash((self.variable, self.value, self.delimiter))


class OWTextToColumns(OWWidget):
name = "Text to Columns"
description = "Split text or categorical variables into binary indicators"
icon = "icons/TextToColumns.svg"
keywords = ["split"]
priority = 700

class Inputs:
Expand Down Expand Up @@ -129,12 +149,18 @@ def apply(self):
return
var = self.data.domain[self.attribute]

sc = SplitColumn(self.data, var, self.delimiter)
if var.is_discrete:
values = get_substrings(var.values, self.delimiter)
computer = partial(OneHotDiscrete, var, self.delimiter)
else:
sc = SplitColumn(self.data, var, self.delimiter)
values = sc.new_values
computer = partial(OneHotStrings, sc)
names = get_unique_names(self.data.domain, values, equal_numbers=False)

new_columns = tuple(DiscreteVariable(
get_unique_names(self.data.domain, v), values=("0", "1"),
compute_value=OneHotStrings(sc, v)
) for v in sc.new_values)
name, values=("0", "1"), compute_value=computer(value)
) for value, name in zip(values, names))

new_domain = Domain(
self.data.domain.attributes + new_columns,
Expand All @@ -145,5 +171,5 @@ def apply(self):


if __name__ == "__main__": # pragma: no cover
WidgetPreview(OWSplit).run(Table.from_file(
WidgetPreview(OWTextToColumns).run(Table.from_file(
"tests/orange-in-education.tab"))
Loading
Loading