Skip to content

Commit

Permalink
[FEAT] Add polars support (#305)
Browse files Browse the repository at this point in the history
  • Loading branch information
elephaint authored Nov 27, 2024
1 parent 8af6b55 commit e7e25ad
Show file tree
Hide file tree
Showing 30 changed files with 7,823 additions and 9,203 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,4 +31,4 @@ jobs:
run: pip install uv && uv pip install --system ".[dev]"

- name: Tests
run: nbdev_test --do_print --timing
run: nbdev_test --do_print --timing --n_workers 0
18 changes: 5 additions & 13 deletions action_files/test_models/src/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
import pickle
import pandas as pd

from statsforecast.models import ETS
os.environ['NIXTLA_ID_AS_COL'] = '1'

from statsforecast.models import AutoETS
from statsforecast.core import StatsForecast

from hierarchicalforecast.utils import aggregate
Expand Down Expand Up @@ -35,31 +37,21 @@ def get_data():
spec = [
['Country'],
['Country', 'State'],
# ['Country', 'Purpose'],
['Country', 'State', 'Region'],
# ['Country', 'State', 'Purpose'],
['Country', 'State', 'Region', 'Purpose']
]

Y_df, S_df, tags = aggregate(Y_df, spec)
Y_df = Y_df.reset_index()

# Train/Test Splits
Y_test_df = Y_df.groupby('unique_id').tail(8)
Y_train_df = Y_df.drop(Y_test_df.index)

Y_test_df = Y_test_df.set_index('unique_id')
Y_train_df = Y_train_df.set_index('unique_id')

sf = StatsForecast(df=Y_train_df,
models=[ETS(season_length=4, model='ZZA')],
sf = StatsForecast(models=[AutoETS(season_length=4, model='ZZA')],
freq='QS', n_jobs=-1)
Y_hat_df = sf.forecast(h=8, fitted=True)
Y_hat_df = sf.forecast(df=Y_train_df, h=8, fitted=True)
Y_fitted_df = sf.forecast_fitted_values()

Y_test_df = Y_test_df.reset_index()
Y_train_df = Y_train_df.reset_index()

# Save Data
if not os.path.exists('./data'):
os.makedirs('./data')
Expand Down
11 changes: 4 additions & 7 deletions action_files/test_models/src/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import os
import pickle
import numpy as np
import pandas as pd

from hierarchicalforecast.evaluation import HierarchicalEvaluation

os.environ['NIXTLA_ID_AS_COL'] = '1'

def rmse(y, y_hat):
return np.mean(np.sqrt(np.mean((y-y_hat)**2, axis=1)))
Expand All @@ -22,16 +24,11 @@ def evaluate():
Y_test_df = pd.read_csv('data/Y_test.csv')
Y_train_df = pd.read_csv('data/Y_train.csv')

Y_rec_df = Y_rec_df.set_index('unique_id')
Y_test_df = Y_test_df.set_index('unique_id')
Y_train_df = Y_train_df.set_index('unique_id')

with open('data/tags.pickle', 'rb') as handle:
tags = pickle.load(handle)

eval_tags = {}
eval_tags['Total'] = tags['Country']
# eval_tags['Purpose'] = tags['Country/Purpose']
eval_tags['State'] = tags['Country/State']
eval_tags['Regions'] = tags['Country/State/Region']
eval_tags['Bottom'] = tags['Country/State/Region/Purpose']
Expand All @@ -42,10 +39,10 @@ def evaluate():
Y_hat_df=Y_rec_df, Y_test_df=Y_test_df,
tags=eval_tags, Y_df=Y_train_df
)
evaluation = evaluation.drop('Overall')
evaluation = evaluation.query("level != 'Overall'").set_index(['level', 'metric'])

evaluation.columns = ['Base'] + models
evaluation = evaluation.applymap('{:.2f}'.format)
evaluation = evaluation.map('{:.2f}'.format)
return evaluation


Expand Down
5 changes: 2 additions & 3 deletions action_files/test_models/src/models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

os.environ['NIXTLA_ID_AS_COL'] = '1'

import fire
import pandas as pd

Expand Down Expand Up @@ -37,14 +39,11 @@ def main():
OptimalCombination(method='ols'),
OptimalCombination(method='wls_struct'),
ERM(method='closed'),
# ERM(method='reg'), # This is so insanely slow that we don't run it
# ERM(method='reg_bu'), # This is so insanely slow that we don't run it
]
hrec = HierarchicalReconciliation(reconcilers=reconcilers)
Y_rec_df = hrec.reconcile(Y_hat_df=Y_hat_df,
Y_df=Y_fitted_df, S=S_df, tags=tags)

Y_rec_df = Y_rec_df.reset_index()
execution_times = pd.Series(hrec.execution_times).reset_index()

if not os.path.exists('./data'):
Expand Down
6 changes: 2 additions & 4 deletions hierarchicalforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
'hierarchicalforecast/core.py'),
'hierarchicalforecast.core.HierarchicalReconciliation.__init__': ( 'src/core.html#hierarchicalreconciliation.__init__',
'hierarchicalforecast/core.py'),
'hierarchicalforecast.core.HierarchicalReconciliation._prepare_Y': ( 'src/core.html#hierarchicalreconciliation._prepare_y',
'hierarchicalforecast/core.py'),
'hierarchicalforecast.core.HierarchicalReconciliation._prepare_fit': ( 'src/core.html#hierarchicalreconciliation._prepare_fit',
'hierarchicalforecast/core.py'),
'hierarchicalforecast.core.HierarchicalReconciliation.bootstrap_reconcile': ( 'src/core.html#hierarchicalreconciliation.bootstrap_reconcile',
Expand Down Expand Up @@ -196,14 +198,10 @@
'hierarchicalforecast/utils.py'),
'hierarchicalforecast.utils._shrunk_covariance_schaferstrimmer_with_nans': ( 'src/utils.html#_shrunk_covariance_schaferstrimmer_with_nans',
'hierarchicalforecast/utils.py'),
'hierarchicalforecast.utils._to_summing_matrix': ( 'src/utils.html#_to_summing_matrix',
'hierarchicalforecast/utils.py'),
'hierarchicalforecast.utils._to_upper_hierarchy': ( 'src/utils.html#_to_upper_hierarchy',
'hierarchicalforecast/utils.py'),
'hierarchicalforecast.utils.aggregate': ( 'src/utils.html#aggregate',
'hierarchicalforecast/utils.py'),
'hierarchicalforecast.utils.aggregate_before': ( 'src/utils.html#aggregate_before',
'hierarchicalforecast/utils.py'),
'hierarchicalforecast.utils.cov2corr': ( 'src/utils.html#cov2corr',
'hierarchicalforecast/utils.py'),
'hierarchicalforecast.utils.is_strictly_hierarchical': ( 'src/utils.html#is_strictly_hierarchical',
Expand Down
Loading

0 comments on commit e7e25ad

Please sign in to comment.