Skip to content

Commit

Permalink
[ODSC-47260] Global Explainer Class : Feature/forecast explain (#367)
Browse files Browse the repository at this point in the history
  • Loading branch information
codeloop authored Oct 11, 2023
2 parents a03cd9c + c802c18 commit e006d73
Show file tree
Hide file tree
Showing 6 changed files with 141 additions and 4 deletions.
1 change: 1 addition & 0 deletions ads/opctl/operator/lowcode/forecast/environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ dependencies:
- datapane
- cerberus
- sktime
- shap
- autots[additional]
- optuna==2.9.0
- oracle-automlx==23.2.3
Expand Down
3 changes: 3 additions & 0 deletions ads/opctl/operator/lowcode/forecast/model/arima.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,6 @@ def _generate_report(self):
ds_forecast_col,
ci_col_names,
)

def explain_model(self) -> dict:
pass
127 changes: 123 additions & 4 deletions ads/opctl/operator/lowcode/forecast/model/automlx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#!/usr/bin/env python
# -*- coding: utf-8 -*--
import traceback

# Copyright (c) 2023 Oracle and/or its affiliates.
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
Expand All @@ -13,12 +14,18 @@

from .. import utils
from .base_model import ForecastOperatorBaseModel
from ..operator_config import ForecastOperatorConfig



# TODO: ODSC-44785 Fix the error message, before GA.
class AutoMLXOperatorModel(ForecastOperatorBaseModel):
"""Class representing AutoMLX operator model."""

def __init__(self, config: ForecastOperatorConfig):
super().__init__(config)
self.global_explanation = {}

@runtime_dependency(
module="automl",
err_msg=(
Expand Down Expand Up @@ -134,8 +141,30 @@ def _build_model(self) -> pd.DataFrame:
self.outputs = outputs_legacy
self.data = data_merged
return outputs_merged

@runtime_dependency(
module="datapane",
err_msg=(
"Please run `pip3 install datapane` to install the required dependencies for report generation."
),
)
def _generate_report(self):
"""
Generate the report for the automlx model.
Parameters
----------
None
Returns
-------
- model_description (datapane.Text): A Text component containing the description of the automlx model.
- other_sections (List[Union[datapane.Text, datapane.Blocks]]): A list of Text and Blocks components representing various sections of the report.
- forecast_col_name (str): The name of the forecasted column.
- train_metrics (bool): A boolean value indicating whether to include train metrics in the report.
- ds_column_series (pd.Series): The pd.Series object representing the datetime column of the dataset.
- ds_forecast_col (pd.Series): The pd.Series object representing the forecasted column.
- ci_col_names (List[str]): A list of column names for the confidence interval in the report.
"""
import datapane as dp

"""The method that needs to be implemented on the particular model level."""
Expand All @@ -159,12 +188,35 @@ def _generate_report(self):
selected_models_section = dp.Blocks(
"### Best Selected model ", dp.Table(selected_df)
)

all_sections = [selected_models_text, selected_models_section]

# Check if the "explain_model" key is present in the "model_kwargs" dictionary of the "self.spec" object
if self.spec.model_kwargs.get("explain_model"):
# If the key is present, call the "explain_model" method
self.explain_model()

# Create a markdown text block for the global explanation section
global_explanation_text = dp.Text(
f"## Global Explanation of Models \n "
"The following tables provide the feature attribution for the global explainability."
)

# Convert the global explanation data to a DataFrame
global_explanation_df = pd.DataFrame(self.global_explanation)

# Create a markdown section for the global explainability
global_explanation_section = dp.Blocks(
"### Global Explainability ",
dp.Table(global_explanation_df/global_explanation_df.sum(axis=0) * 100)
)

# Append the global explanation text and section to the "all_sections" list
all_sections = all_sections + [global_explanation_text, global_explanation_section]

model_description = dp.Text(
"The automlx model automatically preprocesses, selects and engineers "
"high-quality features in your dataset, which then given to an automatically "
"chosen and optimized machine learning model.."
"The AutoMLx model automatically preprocesses, selects and engineers "
"high-quality features in your dataset, which are then provided for further processing."
)
other_sections = all_sections
forecast_col_name = "yhat"
Expand All @@ -182,3 +234,70 @@ def _generate_report(self):
ds_forecast_col,
ci_col_names,
)

def _custom_predict_automlx(self, data):
"""
Predicts the future values of a time series using the AutoMLX model.
Parameters
----------
data (numpy.ndarray): The input data to be used for prediction.
Returns
-------
numpy.ndarray: The predicted future values of the time series.
"""
temp = 0
data_temp = pd.DataFrame(
data,
columns=[col for col in self.dataset_cols],
)

return self.models.get(self.series_id).forecast(
X=data_temp,
periods=data_temp.shape[0]
)[self.series_id]

@runtime_dependency(
module="shap",
err_msg=(
"Please run `pip3 install shap` to install the required dependencies for model explanation."
),
)
def explain_model(self) -> dict:
"""
Generates an explanation for the model by using the SHAP (Shapley Additive exPlanations) library.
This function calculates the SHAP values for each feature in the dataset and stores the results in the `global_explanation` dictionary.
Returns:
dict: A dictionary containing the global explanation for each feature in the dataset.
The keys are the feature names and the values are the average absolute SHAP values.
"""
from shap import KernelExplainer

for series_id in self.target_columns:
self.series_id = series_id
self.dataset_cols = (
self.full_data_dict.get(self.series_id)
.set_index(self.spec.datetime_column.name).drop(self.series_id, axis=1)
.columns
)

kernel_explnr = KernelExplainer(
model=self._custom_predict_automlx,
data=self.full_data_dict.get(self.series_id).set_index(
self.spec.datetime_column.name
)[: -self.spec.horizon.periods][list(self.dataset_cols)],
)

kernel_explnr_vals = kernel_explnr.shap_values(
self.full_data_dict.get(self.series_id).set_index(
self.spec.datetime_column.name
)[: -self.spec.horizon.periods][list(self.dataset_cols)],
nsamples=50,
)
print(kernel_explnr)
self.global_explanation[self.series_id] = dict(
zip(
self.dataset_cols,
np.average(np.absolute(kernel_explnr_vals), axis=0),
)
)
8 changes: 8 additions & 0 deletions ads/opctl/operator/lowcode/forecast/model/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ def generate_report(self):
ci_col_names,
) = self._generate_report()


report_sections = []
title_text = dp.Text("# Forecast Report")

Expand Down Expand Up @@ -440,3 +441,10 @@ def _build_model(self) -> pd.DataFrame:
Build the model.
The method that needs to be implemented on the particular model level.
"""

@abstractmethod
def explain_model(self) -> dict:
"""
explain model using global & local explanations
"""
raise NotImplementedError()
3 changes: 3 additions & 0 deletions ads/opctl/operator/lowcode/forecast/model/neuralprophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,6 @@ def _generate_report(self):
ds_forecast_col,
ci_col_names,
)

def explain_model(self) -> dict:
pass
3 changes: 3 additions & 0 deletions ads/opctl/operator/lowcode/forecast/model/prophet.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,3 +321,6 @@ def _generate_report(self):
ds_forecast_col,
ci_col_names,
)

def explain_model(self) -> dict:
pass

0 comments on commit e006d73

Please sign in to comment.