Skip to content

Commit

Permalink
Adding the 2nd set of functionality and many more widgets
Browse files Browse the repository at this point in the history
  • Loading branch information
jamartinh committed Nov 26, 2015
1 parent e64cb69 commit 0efff45
Show file tree
Hide file tree
Showing 9 changed files with 61 additions and 43 deletions.
6 changes: 4 additions & 2 deletions orangecontrib/spark/base/spark_ml_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ class OWSparkEstimator(OWSparkTransformer):

get_modules = get_estimators

def fit(self):
self.out_model = self.method.fit(self.in_df)
def apply(self):
method_instance = self.method()
paramMap = self.build_param_map()
self.out_model = method_instance.fit(self.in_df, params = paramMap)
self.send("Model", self.out_model)
16 changes: 13 additions & 3 deletions orangecontrib/spark/base/spark_ml_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def __init__(self):

self.action_box = gui.widgetBox(self.box)
# Action Button
self.create_sc_btn = gui.button(self.action_box, self, label = 'Apply', callback = self.transform)
self.create_sc_btn = gui.button(self.action_box, self, label = 'Apply', callback = self.apply)

def refresh_method(self, text):

Expand Down Expand Up @@ -108,6 +108,16 @@ def get_input(self, obj):
self.in_df = obj
self.refresh_method(self.gui_parameters['method'].get_value())

def transform(self):
self.out_df = self.method.transform(self.in_df)
def build_param_map(self):
paramMap = dict()
for k in self.method_parameters:
value = self.gui_parameters[k].get_usable_value()
name = self.gui_parameters[k].get_param_name(self.method.__name__, k)
paramMap[name] = value
return paramMap

def apply(self):
method_instance = self.method()
paramMap = self.build_param_map()
self.out_df = method_instance.transform(self.in_df, params = paramMap)
self.send("DataFrame", self.out_df)
11 changes: 11 additions & 0 deletions orangecontrib/spark/utils/gui_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ def __init__(self, parent_widget, label = None, default_value = None, place_hold
self.gui_type = 'multiple'
callback_func = dummy_func if not callback_func else callback_func
self.widget = create_auto_combobox(parent_widget, self.list_values, callback_func)
self.widget.setStyleSheet("background-color: rgb(255, 255, 255);")
else:
self.gui_type = 'single'
self.widget = QtGui.QLineEdit(parent_widget)
Expand Down Expand Up @@ -67,6 +68,16 @@ def update(self, values):
else:
self.widget.setText(values)

def get_usable_value(self):
val = self.get_value()
try:
return float(val)
except ValueError:
return val

def get_param_name(self, parent, name):
return str(parent) + "__" + name


def create_auto_combobox(parent_widget, values, callback_func = None):
combo = QtGui.QComboBox(parent_widget)
Expand Down
42 changes: 18 additions & 24 deletions orangecontrib/spark/widgets/pyspark_script_console.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import sys
import unicodedata

import Orange.data
from Orange.base import Learner, Model
from Orange.widgets import widget, gui
from Orange.widgets.settings import Setting
from Orange.widgets.utils import itemmodels
Expand Down Expand Up @@ -124,10 +122,11 @@ def keyPressEvent(self, event):
super().keyPressEvent(event)


class PythonConsole(QtGui.QPlainTextEdit, code.InteractiveConsole):
def __init__(self, locals = None, parent = None):
class PySparkConsole(QtGui.QPlainTextEdit, code.InteractiveConsole):
def __init__(self, locals = None, parent = None, sc = None):
QtGui.QPlainTextEdit.__init__(self, parent)
code.InteractiveConsole.__init__(self, locals)
self.sc = sc
self.history, self.historyInd = [""], 0
self.loop = self.interact()
next(self.loop)
Expand All @@ -146,10 +145,22 @@ def interact(self, banner = None):
sys.ps2 = "... "
cprt = ('Type "help", "copyright", "credits" or "license" '
'for more information.')

spark_logo = """
____ __
/ __/__ ___ _____/ /__
_\ \/ _ \/ _ `/ __/ '_/
/__ / .__/\_,_/_/ /_/\_\ version {version}
/_/
""".format(version = self.sc.version)
if banner is None:
self.write("Python %s on %s\n%s\n(%s)\n" %
(sys.version, sys.platform, cprt,
self.__class__.__name__))
self.write(spark_logo)
self.write("SparkContext available as sc, HiveContext available as sqlContext.")

else:
self.write("%s\n" % str(banner))
more = 0
Expand Down Expand Up @@ -344,7 +355,6 @@ def select_row(view, row):


from ..base.shared_spark_context import SharedSparkContext
import pyspark


class OWPySparkScript(SharedSparkContext, widget.OWWidget):
Expand All @@ -353,24 +363,8 @@ class OWPySparkScript(SharedSparkContext, widget.OWWidget):
icon = "icons/PythonScript.svg"
priority = 3150

inputs = [("in_data", Orange.data.Table, "setExampleTable",
widget.Default),
# ("in_distance", Orange.misc.SymMatrix, "setDistanceMatrix",
# widget.Default),
("in_learner", Learner, "setLearner",
widget.Default),
("in_classifier", Model, "setClassifier",
widget.Default),
("in_object", object, "setObject"),
("sc", pyspark.SparkContext, "setObject", widget.Default),
("hc", pyspark.SparkContext, "setObject", widget.Default),
]

outputs = [("out_data", Orange.data.Table,),
# ("out_distance", Orange.misc.SymMatrix, ),
("out_learner", Learner,),
("out_classifier", Model, widget.Dynamic),
("out_object", object, widget.Dynamic)]
inputs = [("in_object", object, "setObject")]
outputs = [("out_object", object, widget.Dynamic)]

libraryListSource = \
Setting([Script("Hello world", "print('Hello world')\n")])
Expand Down Expand Up @@ -494,7 +488,7 @@ def __init__(self):
self.__dict__['sc'] = self._sc
self.__dict__['hc'] = self._hc

self.console = PythonConsole(self.__dict__, self)
self.console = PySparkConsole(self.__dict__, self, sc = self.sc)
self.consoleBox.layout().addWidget(self.console)
self.console.document().setDefaultFont(QFont(defaultFont))
self.consoleBox.setAlignment(Qt.AlignBottom)
Expand Down
4 changes: 0 additions & 4 deletions orangecontrib/spark/widgets/spark_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@ class OWSparkContext(SharedSparkContext, widget.OWWidget):
name = "Context"
description = "Create a shared Spark (sc) and Hive (hc) Contexts"
icon = "icons/spark.png"
inputs = []
# outputs = [("SparkContext", SparkContext, widget.Default),
# ("HiveContext", HiveContext, widget.Default)]
# settingsHandler = settings.DomainContextHandler()

want_main_area = False
resizing_enabled = True
Expand Down
5 changes: 3 additions & 2 deletions orangecontrib/spark/widgets/spark_ml_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,8 +362,8 @@ def __init__(self):
layout.addWidget(box, 0, 2, 1, 1)

box = gui.widgetBox(self.controlArea, "label", addToLayout = False)
self.class_attrs = ClassVarListItemModel()
self.class_attrs_view = ClassVariableItemView(acceptedType = str)
self.class_attrs = VariablesListItemModel()
self.class_attrs_view = VariablesListItemView(acceptedType = str)
self.class_attrs_view.setModel(self.class_attrs)
self.class_attrs_view.selectionModel().selectionChanged.connect(partial(self.update_interface_state, self.class_attrs_view))
self.class_attrs_view.setMaximumHeight(24)
Expand Down Expand Up @@ -575,6 +575,7 @@ def commit(self):
metas = list(self.meta_attrs)
VA = VectorAssembler(inputCols = attributes, outputCol = 'features')
self.out_df = VA.transform(self.in_df)
print(class_var, type(class_var))
if len(class_var):
self.out_df = self.out_df.withColumn('label', self.out_df[class_var[0]])

Expand Down
12 changes: 7 additions & 5 deletions orangecontrib/spark/widgets/spark_ml_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,21 @@ def refresh_method(self, text):
if hasattr(self, 'values_box'):
self.values_box.hide()

def transform(self):
def apply(self):
metric_names = self.gui_parameters['metricName'].doc_text.split('(')[-1].replace(')', '').split('|')
values = { }
print(metric_names)
method_instance = self.method()
paramMap = self.build_param_map()

if self.in_df:
for metric in metric_names:
values[metric] = self.method.transform(self.in_df)
metricName = self.gui_parameters['metricName'].get_param_name()
paramMap[metricName] = metric
values[metric] = method_instance.apply(self.in_df, paramMap)
else:
for k in metric_names:
values[k] = round(5 * random.random() - 2.5, 2)

print(values.items())

# self.send("DataFrame", self.out_df)
self.table.clear()
self.table.resize(500, 500)
Expand Down
4 changes: 3 additions & 1 deletion orangecontrib/spark/widgets/spark_ml_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,7 @@ def get_input_model(self, obj):

def transform(self):
if self.in_df and self.model:
self.out_df = self.model.transform(self.in_df)
model_instance = self.model()
#paramMap = self.build_param_map()
self.out_df = model_instance.transform(self.in_df)
self.send("DataFrame", self.out_df)
4 changes: 2 additions & 2 deletions orangecontrib/spark/widgets/spark_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@


class OWSparkSQLTableContext(SharedSparkContext, widget.OWWidget):
name = "SQL Table"
name = "Hive Table"
description = "Create a Spark DataFrame from a Hive Table"
icon = "icons/Table.svg"
outputs = [("DataFrame", pyspark.sql.DataFrame, widget.Dynamic)]
Expand All @@ -35,7 +35,7 @@ def __init__(self):
self.gui_parameters = OrderedDict()

if self.hc:
self.databases = self.hc.sql("show databases").toPandas()['result'].tolist()
self.databases = [i.result for i in self.hc.sql("show databases").collect()]

self.gui_parameters['database'] = GuiParam(parent_widget = box, list_values = self.databases, label = 'Database:', default_value = 'default',
callback_func = self.refresh_database)
Expand Down

0 comments on commit 0efff45

Please sign in to comment.