Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Minor] Add function to retrieve regressor coefficients #1597

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "2ee6d94d",
"metadata": {},
"source": [
"# Retrieving regressor coefficients"
]
},
{
"cell_type": "markdown",
"id": "61d31237-c428-483a-bac1-419dddad3000",
"metadata": {},
"source": [
"Understanding the coefficients of various components in a forecasting model is crucial as it provides insights into how different factors influence the predicted values. We will demonstrate how to retrieve these coefficients using specific functions provided in NeuralProphet.\n",
"\n",
"The following functions are available:\n",
"- get_future_regressor_coefficients: Retrieves the coefficients for future regressors.\n",
"- get_event_coefficients: Retrieves the coefficients for events and holidays.\n",
"- get_lagged_regressor_coefficients: Retrieves the coefficients for lagged regressors.\n",
"- get_ar_coefficients: Retrieves the coefficients for autoregressive lags.\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "6575cb59",
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from neuralprophet import NeuralProphet\n",
"\n",
"# Load tutorial datasets \n",
"df = pd.read_csv(\"https://github.com/ourownstory/neuralprophet-data/raw/main/kaggle-energy/datasets/tutorial04.csv\")\n",
"\n",
"df1 = pd.read_csv(\"https://github.com/ourownstory/neuralprophet-data/raw/main/kaggle-energy/datasets/tutorial01.csv\")\n"
]
},
{
"cell_type": "markdown",
"id": "0d2ae750",
"metadata": {},
"source": [
"## Future regressors\n",
"\n",
"Useful for understanding the impact of external variables that are known in advance, such as temperature in this example. "
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95511f2b",
"metadata": {},
"outputs": [],
"source": [
"m = NeuralProphet(epochs=10)\n",
"\n",
"# Add the new future regressor\n",
"m.add_future_regressor(\"temperature\")\n",
"\n",
"\n",
"# Continue training the model and making a prediction\n",
"metrics = m.fit(df)\n",
"\n",
"print(\"Future regressor coefficients:\", m.model.get_future_regressor_coefficients())"
]
},
{
"cell_type": "markdown",
"id": "455b60e1",
"metadata": {},
"source": [
"## Events\n",
"\n",
"Helps in assessing the effect of specific events or holidays on the forecasted values."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "ffd52d2b",
"metadata": {},
"outputs": [],
"source": [
"m = NeuralProphet(epochs=10)\n",
"\n",
"# Add holidays for the US as events \n",
"m.add_country_holidays(\"US\")\n",
"\n",
"metrics = m.fit(df1)\n",
"\n",
"print(\"Event coefficients:\", m.model.get_event_coefficients())"
]
},
{
"cell_type": "markdown",
"id": "757056b4",
"metadata": {},
"source": [
"## Lagged regressors\n",
"\n",
"Lagged regressor coefficients are useful for understanding the influence of past values of external variables on the forecast."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c61347cb-bea9-4732-a7f6-4c05aa496354",
"metadata": {},
"outputs": [],
"source": [
"m = NeuralProphet(epochs=10)\n",
"\n",
"# Add temperature of last three days as lagged regressor\n",
"m.add_lagged_regressor(\"temperature\", n_lags=3)\n",
"\n",
"metrics = m.fit(df)\n",
"print(m.model.get_lagged_regressor_coefficients())"
]
},
{
"cell_type": "markdown",
"id": "a9440659",
"metadata": {},
"source": [
"## Autoregressive\n",
"\n",
"Useful for understanding how past values of the time series itself influence future predictions."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "feff9910",
"metadata": {},
"outputs": [],
"source": [
"m = NeuralProphet(n_lags=5, epochs=10)\n",
"\n",
"metrics = m.fit(df1)\n",
"\n",
"print(\"AR coefficients:\", m.model.get_ar_coefficients())"
]
},
{
"cell_type": "markdown",
"id": "bc77b042",
"metadata": {},
"source": [
"## Visualizing coefficients\n",
"\n",
"With the Neuralprophet plotting features it is easy to automatically create plots for model parameters that visulize the previously discussed coefficients."
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "8f90dd1b",
"metadata": {},
"outputs": [],
"source": [
"m = NeuralProphet(\n",
" n_lags=10, # Autogression\n",
" epochs=10\n",
")\n",
"\n",
"# Add the new future regressor\n",
"m.add_future_regressor(\"temperature\")\n",
"\n",
"# Add holidays for the US as events\n",
"m.add_country_holidays(\"US\")\n",
"\n",
"metrics = m.fit(df)\n",
"\n",
"print(m.model.get_future_regressor_coefficients())\n",
"print(m.model.get_event_coefficients())\n",
"print(m.model.get_ar_coefficients())\n",
"\n",
"m.plot_parameters()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.4"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
72 changes: 34 additions & 38 deletions neuralprophet/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,8 +324,6 @@ def get_valid_configuration( # move to utils
# Identify components to be plotted
# as dict, minimum: {plot_name}
plot_components = []
if validator == "plot_parameters":
quantile_index = m.model.quantiles.index(quantile)

# Plot trend
if "trend" in components:
Expand Down Expand Up @@ -418,38 +416,32 @@ def get_valid_configuration( # move to utils
multiplicative_events = []
if "events" in components:
additive_events_flag = False
muliplicative_events_flag = False
multiplicative_events_flag = False
event_configs = {}
if m.config_events is not None:
for event, configs in m.config_events.items():
if validator == "plot_components" and configs.mode == "additive":
additive_events_flag = True
elif validator == "plot_components" and configs.mode == "multiplicative":
muliplicative_events_flag = True
elif validator == "plot_parameters":
event_params = m.model.get_event_weights(event)
weight_list = [
(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()
]
if configs.mode == "additive":
additive_events = additive_events + weight_list
elif configs.mode == "multiplicative":
multiplicative_events = multiplicative_events + weight_list

event_configs.update(m.config_events)
if m.config_country_holidays is not None:
for country_holiday in m.config_country_holidays.holiday_names:
if validator == "plot_components" and m.config_country_holidays.mode == "additive":
additive_events_flag = True
elif validator == "plot_components" and m.config_country_holidays.mode == "multiplicative":
muliplicative_events_flag = True
elif validator == "plot_parameters":
event_params = m.model.get_event_weights(country_holiday)
weight_list = [
(key, param.detach().numpy()[quantile_index, :]) for key, param in event_params.items()
]
if m.config_country_holidays.mode == "additive":
additive_events = additive_events + weight_list
elif m.config_country_holidays.mode == "multiplicative":
multiplicative_events = multiplicative_events + weight_list
event_configs.update(
{holiday: m.config_country_holidays for holiday in m.config_country_holidays.holiday_names}
)

if event_configs:
if validator == "plot_components":
additive_events_flag = any(config.mode == "additive" for config in event_configs.values())
multiplicative_events_flag = any(config.mode == "multiplicative" for config in event_configs.values())

elif validator == "plot_parameters":
event_coefficients = m.model.get_event_coefficients()
for _, row in event_coefficients.iterrows():
event = row["regressor"]
mode = row["regressor_mode"]
coef = row["coef"]
weight_tuple = (event, coef)

if mode == "additive":
additive_events.append(weight_tuple)
elif mode == "multiplicative":
multiplicative_events.append(weight_tuple)

if additive_events_flag:
plot_components.append(
Expand All @@ -458,7 +450,7 @@ def get_valid_configuration( # move to utils
"comp_name": "events_additive",
}
)
if muliplicative_events_flag:
if multiplicative_events_flag:
plot_components.append(
{
"plot_name": "Multiplicative Events",
Expand Down Expand Up @@ -488,11 +480,15 @@ def get_valid_configuration( # move to utils
}
)
elif validator == "plot_parameters":
regressor_param = m.model.future_regressors.get_reg_weights(regressor)[quantile_index, :]
if configs.mode == "additive":
additive_future_regressors.append((regressor, regressor_param.detach().numpy()))
elif configs.mode == "multiplicative":
multiplicative_future_regressors.append((regressor, regressor_param.detach().numpy()))
future_regressor_coefficients = m.model.get_future_regressor_coefficients()
for _, row in future_regressor_coefficients.iterrows():
regressor = row["regressor"]
mode = row["regressor_mode"]
coef = row["coef"]
if mode == "additive":
additive_future_regressors.append((regressor, coef))
elif mode == "multiplicative":
multiplicative_future_regressors.append((regressor, coef))

# Plot quantiles as a separate component, if present
# If multiple steps in the future are predicted, only plot quantiles if highlight_forecast_step_n is set
Expand Down
Loading
Loading