Skip to content

Commit

Permalink
fix fitted values for sparse models (#775)
Browse files Browse the repository at this point in the history
  • Loading branch information
jmoralez authored Feb 5, 2024
1 parent 9d75d22 commit 28f567b
Show file tree
Hide file tree
Showing 7 changed files with 364 additions and 108 deletions.
254 changes: 193 additions & 61 deletions nbs/src/core/models.ipynb

Large diffs are not rendered by default.

7 changes: 6 additions & 1 deletion nbs/src/utils.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,12 @@
" fitted_vals = np.full(y.size, np.nan, np.float32)\n",
" fitted_vals[1:] = np.roll(y, 1)[1:]\n",
" fcst['fitted'] = fitted_vals\n",
" return fcst"
" return fcst\n",
"\n",
"def _ensure_float(x: np.ndarray) -> np.ndarray:\n",
" if x.dtype not in (np.float32, np.float64):\n",
" x = x.astype(np.float32)\n",
" return x"
]
},
{
Expand Down
2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ author = Nixtla
author_email = [email protected]
copyright = Nixtla Inc.
branch = main
version = 1.7.2
version = 1.7.3
min_python = 3.8
audience = Developers
language = English
Expand Down
2 changes: 1 addition & 1 deletion statsforecast/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "1.7.2"
__version__ = "1.7.3"
__all__ = ["StatsForecast"]
from .core import StatsForecast
from .distributed import fugue # noqa
7 changes: 7 additions & 0 deletions statsforecast/_modidx.py
Original file line number Diff line number Diff line change
Expand Up @@ -672,13 +672,19 @@
'statsforecast.models._add_fitted_pi': ( 'src/core/models.html#_add_fitted_pi',
'statsforecast/models.py'),
'statsforecast.models._adida': ('src/core/models.html#_adida', 'statsforecast/models.py'),
'statsforecast.models._chunk_forecast': ( 'src/core/models.html#_chunk_forecast',
'statsforecast/models.py'),
'statsforecast.models._chunk_sums': ('src/core/models.html#_chunk_sums', 'statsforecast/models.py'),
'statsforecast.models._croston_classic': ( 'src/core/models.html#_croston_classic',
'statsforecast/models.py'),
'statsforecast.models._croston_optimized': ( 'src/core/models.html#_croston_optimized',
'statsforecast/models.py'),
'statsforecast.models._croston_sba': ('src/core/models.html#_croston_sba', 'statsforecast/models.py'),
'statsforecast.models._demand': ('src/core/models.html#_demand', 'statsforecast/models.py'),
'statsforecast.models._expand_fitted_demand': ( 'src/core/models.html#_expand_fitted_demand',
'statsforecast/models.py'),
'statsforecast.models._expand_fitted_intervals': ( 'src/core/models.html#_expand_fitted_intervals',
'statsforecast/models.py'),
'statsforecast.models._get_conformal_method': ( 'src/core/models.html#_get_conformal_method',
'statsforecast/models.py'),
'statsforecast.models._historic_average': ( 'src/core/models.html#_historic_average',
Expand Down Expand Up @@ -764,6 +770,7 @@
'statsforecast.utils._calculate_intervals': ( 'src/utils.html#_calculate_intervals',
'statsforecast/utils.py'),
'statsforecast.utils._calculate_sigma': ('src/utils.html#_calculate_sigma', 'statsforecast/utils.py'),
'statsforecast.utils._ensure_float': ('src/utils.html#_ensure_float', 'statsforecast/utils.py'),
'statsforecast.utils._naive': ('src/utils.html#_naive', 'statsforecast/utils.py'),
'statsforecast.utils._quantiles': ('src/utils.html#_quantiles', 'statsforecast/utils.py'),
'statsforecast.utils._repeat_val': ('src/utils.html#_repeat_val', 'statsforecast/utils.py'),
Expand Down
Loading

0 comments on commit 28f567b

Please sign in to comment.