Skip to content

Commit

Permalink
Save serialized keras surrogate model (#25)
Browse files Browse the repository at this point in the history
Co-authored-by: mikivee <mikivee>
  • Loading branch information
mikivee authored Nov 6, 2024
1 parent 73f2ee8 commit 322a26a
Show file tree
Hide file tree
Showing 7 changed files with 50 additions and 22 deletions.
5 changes: 2 additions & 3 deletions install-db-requirements.sh
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,9 @@ cat >"$TEMP_REQUIREMENTS" <<'EOF'
appnope==0.1.4 ; python_version >= "3.10" and python_version < "3.11" and (platform_system == "Darwin" or sys_platform == "darwin")
colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.11" and (platform_system == "Windows" or sys_platform == "win32")
cython==0.29.32 ; python_version >= "3.10" and python_version < "3.11"
databricks-connect==13.1.0 ; python_version >= "3.10" and python_version < "3.11"
distro==1.9.0 ; python_version >= "3.10" and python_version < "3.11"
dmlbootstrap==0.1.3 ; python_version >= "3.10" and python_version < "3.11"
dmlutils==0.7.0 ; python_version >= "3.10" and python_version < "3.11"
dmlbootstrap==0.2.0 ; python_version >= "3.10" and python_version < "3.11"
dmlutils==0.8.0 ; python_version >= "3.10" and python_version < "3.11"
duckdb==1.1.1 ; python_version >= "3.10" and python_version < "3.11"
et-xmlfile==1.1.0 ; python_version >= "3.10" and python_version < "3.11"
flask==2.2.5 ; python_version >= "3.10" and python_version < "3.11"
Expand Down
5 changes: 2 additions & 3 deletions requirements-db-14.3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,9 @@
appnope==0.1.4 ; python_version >= "3.10" and python_version < "3.11" and (platform_system == "Darwin" or sys_platform == "darwin")
colorama==0.4.6 ; python_version >= "3.10" and python_version < "3.11" and (platform_system == "Windows" or sys_platform == "win32")
cython==0.29.32 ; python_version >= "3.10" and python_version < "3.11"
databricks-connect==13.1.0 ; python_version >= "3.10" and python_version < "3.11"
distro==1.9.0 ; python_version >= "3.10" and python_version < "3.11"
dmlbootstrap==0.1.3 ; python_version >= "3.10" and python_version < "3.11"
dmlutils==0.7.0 ; python_version >= "3.10" and python_version < "3.11"
dmlbootstrap==0.2.0 ; python_version >= "3.10" and python_version < "3.11"
dmlutils==0.8.0 ; python_version >= "3.10" and python_version < "3.11"
duckdb==1.1.1 ; python_version >= "3.10" and python_version < "3.11"
et-xmlfile==1.1.0 ; python_version >= "3.10" and python_version < "3.11"
flask==2.2.5 ; python_version >= "3.10" and python_version < "3.11"
Expand Down
5 changes: 2 additions & 3 deletions requirements-test-14.3.txt
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ cython==0.29.32 ; python_version >= "3.10" and python_version < "3.11"
dacite==1.8.1 ; python_version >= "3.10" and python_version < "3.11"
databricks-automl-runtime==0.2.20 ; python_version >= "3.10" and python_version < "3.11"
databricks-cli==0.18.0 ; python_version >= "3.10" and python_version < "3.11"
databricks-connect==13.1.0 ; python_version >= "3.10" and python_version < "3.11"
databricks-feature-engineering==0.2.0 ; python_version >= "3.10" and python_version < "3.11"
databricks-sdk==0.1.6 ; python_version >= "3.10" and python_version < "3.11"
dataclasses-json==0.6.3 ; python_version >= "3.10" and python_version < "3.11"
Expand All @@ -66,8 +65,8 @@ dill==0.3.6 ; python_version >= "3.10" and python_version < "3.11"
diskcache==5.6.3 ; python_version >= "3.10" and python_version < "3.11"
distlib==0.3.7 ; python_version >= "3.10" and python_version < "3.11"
distro==1.9.0 ; python_version >= "3.10" and python_version < "3.11"
dmlbootstrap==0.1.3 ; python_version >= "3.10" and python_version < "3.11"
dmlutils==0.7.0 ; python_version >= "3.10" and python_version < "3.11"
dmlbootstrap==0.2.0 ; python_version >= "3.10" and python_version < "3.11"
dmlutils==0.8.0 ; python_version >= "3.10" and python_version < "3.11"
docstring-to-markdown==0.11 ; python_version >= "3.10" and python_version < "3.11"
duckdb==1.1.1 ; python_version >= "3.10" and python_version < "3.11"
entrypoints==0.4 ; python_version >= "3.10" and python_version < "3.11"
Expand Down
2 changes: 1 addition & 1 deletion scripts/model_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# COMMAND ----------

# MAGIC %pip install seaborn==v0.13.0
# MAGIC %pip install mlflow==2.13.0 seaborn==v0.13.0
# MAGIC dbutils.library.restartPython()

# COMMAND ----------
Expand Down
19 changes: 13 additions & 6 deletions scripts/model_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,21 @@
# MAGIC
# MAGIC ---
# MAGIC #### Cluster/ User Requirements
# MAGIC - Access Mode: Single User or Shared (Not No Isolation Shared)
# MAGIC - Runtime: >= Databricks Runtime 14.3 ML (or >= Databricks Runtime 14.3 + `%pip install databricks-feature-engineering`)
# MAGIC - Node type: Single Node. Because of [this issue](https://kb.databricks.com/en_US/libraries/apache-spark-jobs-fail-with-environment-directory-not-found-error), worker nodes cannot access the directory needed to run inference on a keras trained model, meaning that the `score_batch()` function throws and OSError.
# MAGIC - Can be run on CPU or GPU, with 2x speedup on GPU
# MAGIC - Cluster-level packages: `gcsfs==2023.5.0`, `mlflow==2.13.0` (newer than default, which is required to pass a `code_paths` in logging)
# MAGIC - `USE CATALOG`, `CREATE SCHEMA` privleges on the `ml` Unity Catalog (Ask Miki if for access)
# MAGIC

# COMMAND ----------

# we need a newer version of MLFlow in order to use a custom loss
%pip install mlflow==2.13.0

# COMMAND ----------

dbutils.library.restartPython()

# COMMAND ----------

# DBTITLE 1,Set debug mode
# this controls the training parameters, with test mode on a much smaller training set for fewer epochs
dbutils.widgets.dropdown("mode", "test", ["test", "production"])
Expand Down Expand Up @@ -208,8 +213,7 @@ def convert_feature_dataframe_to_dict(
for weather features contain len 8760 arrays.
Returns:
- The preprocessed feature data in format {feature_name (str) :
np.array of shape [N] for building model features and shape [N,8760] for weather features}
- The preprocessed feature data in format {feature_name (str) : np.array of shape [N]
"""
return {
col: np.array(feature_df[col])
Expand Down Expand Up @@ -283,6 +287,9 @@ def convert_feature_dataframe_to_dict(
# skip registering model for now..
# mlflow.register_model(f"runs:/{run_id}/{sm.artifact_path}", str(sm))

# serialize the keras model and save to GCP
sm.save_keras_model(run_id = run_id)

# COMMAND ----------

# MAGIC %md ## Evaluate Model
Expand Down
7 changes: 2 additions & 5 deletions src/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,9 +352,7 @@ def convert_dataframe_to_dict(self, feature_df: pd.DataFrame) -> Dict[str, np.nd
Returns
-------
Dict[str,np.ndarray]: The preprocessed feature data in format {feature_name (str):
np.array of shape [len(feature_df)] for building model features
and shape [len(feature_df), 8760] for weather features}
Dict[str,np.ndarray]: preprocessed feature data in format {feature_name (str): np.array of shape [N]}
"""
return {col: np.array(feature_df[col]) for col in self.building_features + ["weather_file_city_index"]}

Expand All @@ -378,8 +376,7 @@ def __getitem__(self, index: int) -> Tuple[Dict[str, np.ndarray], Dict[str, np.n
Returns
-------
- X (dict): features for batch in format {feature_name (str):
np.array of shape [batch_size] for building model features and shape [batch_size, 8760] for weather features}
- X (dict): features for batch in format {feature_name (str): np.array of shape [batch_size]}
- y (dict) : targets for the batch in format {target_name (str): np.array of shape [batch_size]}
"""
# subset rows of targets and building features to batch
Expand Down
29 changes: 28 additions & 1 deletion src/surrogate_model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import numpy as np
from typing import Any, Dict, List, Tuple
import os
from typing import Any, Dict, List, Tuple, Optional

import mlflow
import pyspark.sql.functions as F
Expand All @@ -9,6 +10,7 @@
from pyspark.sql.types import ArrayType, DoubleType
from tensorflow import keras
from tensorflow.keras import layers, models
from tensorflow.python.lib.io import file_io

from src.datagen import DataGenerator

Expand Down Expand Up @@ -279,6 +281,31 @@ def get_model_uri(self, run_id: str = None, version: int = None, verbose: bool =
else:
return f"runs:/{run_id}/{self.artifact_path}"

def save_keras_model(self, run_id):
"""
Saves the keras model for the given run ID to Google Cloud Storage.
Parameters:
- run_id (str): The unique identifier for the MLflow run associated with the model to be saved.
"""
fname = f"sumo_{self.name}_{run_id}.keras"
gcp_model_dir = "gs://the-cube/export/surrogate_model/"

# load mlflow model
mlflow_model = mlflow.pyfunc.load_model(model_uri=self.get_model_uri(run_id=run_id))
# extract keras model
keras_model = mlflow_model.unwrap_python_model().model

# save locally
keras_model.save(fname)
# then copy to gcp
with file_io.FileIO(fname, mode="rb") as f_local:
with file_io.FileIO(os.path.join(gcp_model_dir, fname), mode="wb+") as f_gcp:
f_gcp.write(f_local.read())
# delete local file
os.remove(fname)

def score_batch(
self,
test_data: DataFrame,
Expand Down

0 comments on commit 322a26a

Please sign in to comment.