Skip to content
This repository has been archived by the owner on Apr 15, 2022. It is now read-only.

Commit

Permalink
Merge pull request #99 from splicemachine/DBAAS-4971
Browse files Browse the repository at this point in the history
Dbaas 4971
  • Loading branch information
Ben Epstein authored Jan 13, 2021
2 parents b063a9b + bfd2786 commit 90f874f
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 16 deletions.
17 changes: 11 additions & 6 deletions splicemachine/features/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
class FeatureStore:
def __init__(self, splice_ctx: PySpliceContext) -> None:
self.splice_ctx = splice_ctx
self.mlflow_ctx = None
self.feature_sets = [] # Cache of newly created feature sets

def register_splice_context(self, splice_ctx: PySpliceContext) -> None:
Expand Down Expand Up @@ -166,7 +167,7 @@ def get_feature_vector(self, features: List[Union[str, Feature]],
Gets a feature vector given a list of Features and primary key values for their corresponding Feature Sets
:param features: List of str Feature names or Features
:param join_key_values: (dict) join key vals to get the proper Feature values formatted as {join_key_column_name: join_key_value}
:param join_key_values: (dict) join key values to get the proper Feature values formatted as {join_key_column_name: join_key_value}
:param return_sql: Whether to return the SQL needed to get the vector or the values themselves. Default False
:return: Pandas Dataframe or str (SQL statement)
"""
Expand Down Expand Up @@ -333,7 +334,7 @@ def get_training_set(self, features: Union[List[Feature], List[str]], current_va
if current_values_only:
ts.start_time = ts.end_time

if hasattr(self, 'mlflow_ctx'):
if self.mlflow_ctx and not return_sql:
self.mlflow_ctx._active_training_set: TrainingSet = ts
ts._register_metadata(self.mlflow_ctx)
return sql if return_sql else self.splice_ctx.df(sql)
Expand Down Expand Up @@ -394,7 +395,7 @@ def get_training_set_from_view(self, training_view: str, features: Union[List[Fe
sql = _generate_training_set_history_sql(tvw, features, feature_sets, start_time=start_time, end_time=end_time)

# Link this to mlflow for model deployment
if hasattr(self, 'mlflow_ctx') and not return_sql:
if self.mlflow_ctx and not return_sql:
ts = TrainingSet(training_view=tvw, features=features,
start_time=start_time, end_time=end_time)
self.mlflow_ctx._active_training_set: TrainingSet = ts
Expand Down Expand Up @@ -735,12 +736,16 @@ def __log_mlflow_results(self, name, rounds, mlflow_results):
:param name: MLflow run name
:param rounds: Number of rounds of feature elimination that were run
:param mlflow_results: The params / metrics to log
:return:
"""
with self.mlflow_ctx.start_run(run_name=name):
try:
if self.mlflow_ctx.active_run():
self.mlflow_ctx.start_run(run_name=name)
for r in range(rounds):
with self.mlflow_ctx.start_run(run_name=f'Round {r}', nested=True):
self.mlflow_ctx.log_metrics(mlflow_results[r])
finally:
self.mlflow_ctx.end_run()


def __prune_features_for_elimination(self, features) -> List[Feature]:
"""
Expand Down Expand Up @@ -814,7 +819,7 @@ def run_feature_elimination(self, df, features: List[Union[str, Feature]], label
round_metrics[row['name']] = row['score']
mlflow_results.append(round_metrics)

if log_mlflow and hasattr(self, 'mlflow_ctx'):
if log_mlflow and self.mlflow_ctx:
run_name = mlflow_run_name or f'feature_elimination_{label}'
self.__log_mlflow_results(run_name, rnd, mlflow_results)

Expand Down
28 changes: 21 additions & 7 deletions splicemachine/features/training_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from .feature import Feature
from typing import List, Optional
from datetime import datetime
from splicemachine import SpliceMachineException

class TrainingSet:
"""
Expand All @@ -19,7 +20,7 @@ def __init__(self,
):
self.training_view = training_view
self.features = features
self.start_time = start_time or datetime.min
self.start_time = start_time or datetime(year=1900,month=1,day=1) # Saw problems with spark handling datetime.min
self.end_time = end_time or datetime.today()

def _register_metadata(self, mlflow_ctx):
Expand All @@ -32,9 +33,22 @@ def _register_metadata(self, mlflow_ctx):
"""
if mlflow_ctx.active_run():
print("There is an active mlflow run, your training set will be logged to that run.")
mlflow_ctx.lp("splice.feature_store.training_set",self.training_view.name)
mlflow_ctx.lp("splice.feature_store.training_set_start_time",str(self.start_time))
mlflow_ctx.lp("splice.feature_store.training_set_end_time",str(self.end_time))
mlflow_ctx.lp("splice.feature_store.training_set_num_features", len(self.features))
for i,f in enumerate(self.features):
mlflow_ctx.lp(f'splice.feature_store.training_set_feature_{i}',f.name)
try:
mlflow_ctx.lp("splice.feature_store.training_set",self.training_view.name)
mlflow_ctx.lp("splice.feature_store.training_set_start_time",str(self.start_time))
mlflow_ctx.lp("splice.feature_store.training_set_end_time",str(self.end_time))
mlflow_ctx.lp("splice.feature_store.training_set_num_features", len(self.features))
for i,f in enumerate(self.features):
mlflow_ctx.lp(f'splice.feature_store.training_set_feature_{i}',f.name)
except:
raise SpliceMachineException("It looks like your active run already has a Training Set logged to it. "
"You cannot get a new active Training Set during an active run if you "
"already have an active Training Set. If you've called fs.get_training_set "
"or fs.get_training_set_from_view before starting this run, then that "
"Training Set was logged to the current active run. If you call "
"fs.get_training_set or fs.get_training_set_from_view before starting an "
"mlflow run, all following runs will assume that Training Set to be the "
"active Training Set, and will log the Training Set as metadata. For more "
"information, refer to the documentation. If you'd like to use a new "
"Training Set, end the current run, call one of the mentioned functions, "
"and start your new run.") from None
7 changes: 4 additions & 3 deletions splicemachine/spark/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,16 +124,17 @@ def replaceDataframeSchema(self, dataframe, schema_table_name):

def fileToTable(self, file_path, schema_table_name, primary_keys=None, drop_table=False, **pandas_args):
"""
Load a file from the local filesystem and create a new table (or recreate an existing table), and load the data
from the file into the new table
Load a file from the local filesystem or from a remote location and create a new table
(or recreate an existing table), and load the data from the file into the new table. Any file_path that can be
read by pandas should work here.
:param file_path: The local file to load
:param schema_table_name: The schema.table name
:param primary_keys: List[str] of primary keys for the table. Default None
:param drop_table: Whether or not to drop the table. If this is False and the table already exists, the
function will fail. Default False
:param pandas_args: Extra parameters to be passed into the pd.read_csv function. Any parameters accepted
in pd.read_csv will work here
in pd.read_csv will work here
:return: None
"""
import pandas as pd
Expand Down

0 comments on commit 90f874f

Please sign in to comment.