diff --git a/splicemachine/features/feature_store.py b/splicemachine/features/feature_store.py index 89d4bc6..28781e2 100644 --- a/splicemachine/features/feature_store.py +++ b/splicemachine/features/feature_store.py @@ -25,6 +25,7 @@ class FeatureStore: def __init__(self, splice_ctx: PySpliceContext) -> None: self.splice_ctx = splice_ctx + self.mlflow_ctx = None self.feature_sets = [] # Cache of newly created feature sets def register_splice_context(self, splice_ctx: PySpliceContext) -> None: @@ -166,7 +167,7 @@ def get_feature_vector(self, features: List[Union[str, Feature]], Gets a feature vector given a list of Features and primary key values for their corresponding Feature Sets :param features: List of str Feature names or Features - :param join_key_values: (dict) join key vals to get the proper Feature values formatted as {join_key_column_name: join_key_value} + :param join_key_values: (dict) join key values to get the proper Feature values formatted as {join_key_column_name: join_key_value} :param return_sql: Whether to return the SQL needed to get the vector or the values themselves. Default False :return: Pandas Dataframe or str (SQL statement) """ @@ -735,12 +736,16 @@ def __log_mlflow_results(self, name, rounds, mlflow_results): :param name: MLflow run name :param rounds: Number of rounds of feature elimination that were run :param mlflow_results: The params / metrics to log - :return: """ - with self.mlflow_ctx.start_run(run_name=name): + try: + if self.mlflow_ctx.active_run(): + self.mlflow_ctx.start_run(run_name=name) for r in range(rounds): with self.mlflow_ctx.start_run(run_name=f'Round {r}', nested=True): self.mlflow_ctx.log_metrics(mlflow_results[r]) + finally: + self.mlflow_ctx.end_run() + def __prune_features_for_elimination(self, features) -> List[Feature]: """ diff --git a/splicemachine/spark/context.py b/splicemachine/spark/context.py index a715afe..c704872 100755 --- a/splicemachine/spark/context.py +++ b/splicemachine/spark/context.py @@ -124,8 +124,9 @@ def replaceDataframeSchema(self, dataframe, schema_table_name): def fileToTable(self, file_path, schema_table_name, primary_keys=None, drop_table=False, **pandas_args): """ - Load a file from the local filesystem and create a new table (or recreate an existing table), and load the data - from the file into the new table + Load a file from the local filesystem or from a remote location and create a new table + (or recreate an existing table), and load the data from the file into the new table. Any file_path that can be + read by pandas should work here. :param file_path: The local file to load :param schema_table_name: The schema.table name @@ -133,7 +134,7 @@ def fileToTable(self, file_path, schema_table_name, primary_keys=None, drop_tabl :param drop_table: Whether or not to drop the table. If this is False and the table already exists, the function will fail. Default False :param pandas_args: Extra parameters to be passed into the pd.read_csv function. Any parameters accepted - in pd.read_csv will work here + in pd.read_csv will work here :return: None """ import pandas as pd