diff --git a/splicemachine/features/feature_store.py b/splicemachine/features/feature_store.py index 90bb081..47923d2 100644 --- a/splicemachine/features/feature_store.py +++ b/splicemachine/features/feature_store.py @@ -25,6 +25,7 @@ class FeatureStore: def __init__(self, splice_ctx: PySpliceContext) -> None: self.splice_ctx = splice_ctx + self.mlflow_ctx = None self.feature_sets = [] # Cache of newly created feature sets def register_splice_context(self, splice_ctx: PySpliceContext) -> None: @@ -166,7 +167,7 @@ def get_feature_vector(self, features: List[Union[str, Feature]], Gets a feature vector given a list of Features and primary key values for their corresponding Feature Sets :param features: List of str Feature names or Features - :param join_key_values: (dict) join key vals to get the proper Feature values formatted as {join_key_column_name: join_key_value} + :param join_key_values: (dict) join key values to get the proper Feature values formatted as {join_key_column_name: join_key_value} :param return_sql: Whether to return the SQL needed to get the vector or the values themselves. Default False :return: Pandas Dataframe or str (SQL statement) """ @@ -333,7 +334,7 @@ def get_training_set(self, features: Union[List[Feature], List[str]], current_va if current_values_only: ts.start_time = ts.end_time - if hasattr(self, 'mlflow_ctx'): + if self.mlflow_ctx and not return_sql: self.mlflow_ctx._active_training_set: TrainingSet = ts ts._register_metadata(self.mlflow_ctx) return sql if return_sql else self.splice_ctx.df(sql) @@ -394,7 +395,7 @@ def get_training_set_from_view(self, training_view: str, features: Union[List[Fe sql = _generate_training_set_history_sql(tvw, features, feature_sets, start_time=start_time, end_time=end_time) # Link this to mlflow for model deployment - if hasattr(self, 'mlflow_ctx') and not return_sql: + if self.mlflow_ctx and not return_sql: ts = TrainingSet(training_view=tvw, features=features, start_time=start_time, end_time=end_time) self.mlflow_ctx._active_training_set: TrainingSet = ts @@ -735,12 +736,16 @@ def __log_mlflow_results(self, name, rounds, mlflow_results): :param name: MLflow run name :param rounds: Number of rounds of feature elimination that were run :param mlflow_results: The params / metrics to log - :return: """ - with self.mlflow_ctx.start_run(run_name=name): + try: + if self.mlflow_ctx.active_run(): + self.mlflow_ctx.start_run(run_name=name) for r in range(rounds): with self.mlflow_ctx.start_run(run_name=f'Round {r}', nested=True): self.mlflow_ctx.log_metrics(mlflow_results[r]) + finally: + self.mlflow_ctx.end_run() + def __prune_features_for_elimination(self, features) -> List[Feature]: """ @@ -814,7 +819,7 @@ def run_feature_elimination(self, df, features: List[Union[str, Feature]], label round_metrics[row['name']] = row['score'] mlflow_results.append(round_metrics) - if log_mlflow and hasattr(self, 'mlflow_ctx'): + if log_mlflow and self.mlflow_ctx: run_name = mlflow_run_name or f'feature_elimination_{label}' self.__log_mlflow_results(run_name, rnd, mlflow_results) diff --git a/splicemachine/features/training_set.py b/splicemachine/features/training_set.py index 80bda62..a317282 100644 --- a/splicemachine/features/training_set.py +++ b/splicemachine/features/training_set.py @@ -2,6 +2,7 @@ from .feature import Feature from typing import List, Optional from datetime import datetime +from splicemachine import SpliceMachineException class TrainingSet: """ @@ -19,7 +20,7 @@ def __init__(self, ): self.training_view = training_view self.features = features - self.start_time = start_time or datetime.min + self.start_time = start_time or datetime(year=1900,month=1,day=1) # Saw problems with spark handling datetime.min self.end_time = end_time or datetime.today() def _register_metadata(self, mlflow_ctx): @@ -32,9 +33,22 @@ def _register_metadata(self, mlflow_ctx): """ if mlflow_ctx.active_run(): print("There is an active mlflow run, your training set will be logged to that run.") - mlflow_ctx.lp("splice.feature_store.training_set",self.training_view.name) - mlflow_ctx.lp("splice.feature_store.training_set_start_time",str(self.start_time)) - mlflow_ctx.lp("splice.feature_store.training_set_end_time",str(self.end_time)) - mlflow_ctx.lp("splice.feature_store.training_set_num_features", len(self.features)) - for i,f in enumerate(self.features): - mlflow_ctx.lp(f'splice.feature_store.training_set_feature_{i}',f.name) + try: + mlflow_ctx.lp("splice.feature_store.training_set",self.training_view.name) + mlflow_ctx.lp("splice.feature_store.training_set_start_time",str(self.start_time)) + mlflow_ctx.lp("splice.feature_store.training_set_end_time",str(self.end_time)) + mlflow_ctx.lp("splice.feature_store.training_set_num_features", len(self.features)) + for i,f in enumerate(self.features): + mlflow_ctx.lp(f'splice.feature_store.training_set_feature_{i}',f.name) + except: + raise SpliceMachineException("It looks like your active run already has a Training Set logged to it. " + "You cannot get a new active Training Set during an active run if you " + "already have an active Training Set. If you've called fs.get_training_set " + "or fs.get_training_set_from_view before starting this run, then that " + "Training Set was logged to the current active run. If you call " + "fs.get_training_set or fs.get_training_set_from_view before starting an " + "mlflow run, all following runs will assume that Training Set to be the " + "active Training Set, and will log the Training Set as metadata. For more " + "information, refer to the documentation. If you'd like to use a new " + "Training Set, end the current run, call one of the mentioned functions, " + "and start your new run.") from None diff --git a/splicemachine/spark/context.py b/splicemachine/spark/context.py index a715afe..c704872 100755 --- a/splicemachine/spark/context.py +++ b/splicemachine/spark/context.py @@ -124,8 +124,9 @@ def replaceDataframeSchema(self, dataframe, schema_table_name): def fileToTable(self, file_path, schema_table_name, primary_keys=None, drop_table=False, **pandas_args): """ - Load a file from the local filesystem and create a new table (or recreate an existing table), and load the data - from the file into the new table + Load a file from the local filesystem or from a remote location and create a new table + (or recreate an existing table), and load the data from the file into the new table. Any file_path that can be + read by pandas should work here. :param file_path: The local file to load :param schema_table_name: The schema.table name @@ -133,7 +134,7 @@ def fileToTable(self, file_path, schema_table_name, primary_keys=None, drop_tabl :param drop_table: Whether or not to drop the table. If this is False and the table already exists, the function will fail. Default False :param pandas_args: Extra parameters to be passed into the pd.read_csv function. Any parameters accepted - in pd.read_csv will work here + in pd.read_csv will work here :return: None """ import pandas as pd