diff --git a/nbs/src/core/models.ipynb b/nbs/src/core/models.ipynb index 0f84e10e8..7d363f530 100644 --- a/nbs/src/core/models.ipynb +++ b/nbs/src/core/models.ipynb @@ -103,6 +103,7 @@ "from statsforecast.utils import (\n", " _calculate_sigma,\n", " _calculate_intervals,\n", + " _ensure_float,\n", " _naive,\n", " _quantiles,\n", " _repeat_val,\n", @@ -898,7 +899,24 @@ " raise Exception(\n", " 'predict and forward methods are not equal with ' \n", " 'levels for fitted values '\n", - " )" + " )\n", + "\n", + "def _test_fitted_sparse(model_factory):\n", + " y1 = np.array([2, 5, 0, 1, 3, 0, 1, 1, 0], dtype=np.float32)\n", + " y2 = np.array([0, 0, 1, 0, 0, 7, 1, 0, 1], dtype=np.float32)\n", + " y3 = np.array([0, 0, 1, 0, 0, 7, 1, 0, 0], dtype=np.float32)\n", + " y4 = np.zeros(9, dtype=np.float32)\n", + " for y in [y1, y2, y3, y4]:\n", + " expected_fitted = np.hstack(\n", + " [\n", + " model_factory().forecast(y=y[:i + 1], h=1)['mean']\n", + " for i in range(y.size - 1)]\n", + " )\n", + " np.testing.assert_allclose(\n", + " model_factory().forecast(y=y, h=1, fitted=True)['fitted'],\n", + " np.append(np.nan, expected_fitted),\n", + " atol=1e-6,\n", + " )" ] }, { @@ -7138,6 +7156,51 @@ "outputs": [], "source": [ "#| exporti\n", + "def _chunk_forecast(y, aggregation_level):\n", + " lost_remainder_data = len(y) % aggregation_level\n", + " y_cut = y[lost_remainder_data:]\n", + " aggregation_sums = _chunk_sums(y_cut, aggregation_level)\n", + " sums_forecast, _ = _optimized_ses_forecast(aggregation_sums)\n", + " return sums_forecast\n", + "\n", + "@njit(nogil=NOGIL, cache=CACHE)\n", + "def _expand_fitted_demand(fitted: np.ndarray, y: np.ndarray) -> np.ndarray:\n", + " out = np.empty_like(y)\n", + " out[0] = np.nan\n", + " fitted_idx = 0\n", + " for i in range(1, y.size):\n", + " if y[i - 1] > 0:\n", + " fitted_idx += 1\n", + " out[i] = fitted[fitted_idx]\n", + " elif fitted_idx > 0:\n", + " # if this entry is zero, the model didn't change\n", + " out[i] = out[i - 1]\n", + " else:\n", + " # if we haven't seen any demand, use naive\n", + " out[i] = y[i - 1]\n", + " return out\n", + "\n", + "@njit(nogil=NOGIL, cache=CACHE)\n", + "def _expand_fitted_intervals(fitted: np.ndarray, y: np.ndarray) -> np.ndarray:\n", + " out = np.empty_like(y)\n", + " out[0] = np.nan\n", + " fitted_idx = 0\n", + " for i in range(1, y.size):\n", + " if y[i - 1] != 0:\n", + " fitted_idx += 1\n", + " if fitted[fitted_idx] == 0:\n", + " # to avoid division by zero\n", + " out[i] = 1\n", + " else:\n", + " out[i] = fitted[fitted_idx]\n", + " elif fitted_idx > 0:\n", + " # if this entry is zero, the model didn't change\n", + " out[i] = out[i - 1]\n", + " else:\n", + " # if we haven't seen any intervals, use 1 to avoid division by zero\n", + " out[i] = 1\n", + " return out\n", + " \n", "def _adida(\n", " y: np.ndarray, # time series\n", " h: int, # forecasting horizon\n", @@ -7146,19 +7209,30 @@ " if (y == 0).all():\n", " res = {'mean': np.zeros(h, dtype=np.float32)}\n", " if fitted:\n", - " res['fitted'] = y.copy()\n", + " res['fitted'] = np.zeros(y.size, dtype=np.float32)\n", + " res['fitted'][0] = np.nan\n", " return res\n", + " y = _ensure_float(y)\n", " y_intervals = _intervals(y)\n", " mean_interval = y_intervals.mean()\n", " aggregation_level = round(mean_interval)\n", - " lost_remainder_data = len(y) % aggregation_level\n", - " y_cut = y[lost_remainder_data:]\n", - " aggregation_sums = _chunk_sums(y_cut, aggregation_level)\n", - " sums_forecast, sums_fitted = _optimized_ses_forecast(aggregation_sums)\n", + " sums_forecast = _chunk_forecast(y, aggregation_level)\n", " forecast = sums_forecast / aggregation_level\n", " res = {'mean': _repeat_val(val=forecast, h=h)}\n", " if fitted:\n", - " res['fitted'] = sums_fitted / aggregation_level\n", + " warnings.warn(\"Computing fitted values for ADIDA is very expensive\")\n", + " fitted_aggregation_levels = np.round(\n", + " y_intervals.cumsum() / np.arange(1, y_intervals.size + 1)\n", + " )\n", + " fitted_aggregation_levels = _expand_fitted_intervals(\n", + " np.append(np.nan, fitted_aggregation_levels), y\n", + " )[1:].astype(np.int32)\n", + "\n", + " sums_fitted = np.empty(y.size - 1, dtype=y.dtype)\n", + " for i, agg_lvl in enumerate(fitted_aggregation_levels):\n", + " sums_fitted[i] = _chunk_forecast(y[:i+1], agg_lvl)\n", + "\n", + " res['fitted'] = np.append(np.nan, sums_fitted / fitted_aggregation_levels)\n", " return res" ] }, @@ -7222,8 +7296,8 @@ " self :\n", " ADIDA fitted model.\n", " \"\"\" \n", - " self.model_ = _adida(y=y, h=1, fitted=True)\n", - " self.model_['sigma'] = _calculate_sigma(y - self.model_['fitted'], y.size)\n", + " self.model_ = _adida(y=y, h=1, fitted=False)\n", + " self._y = y\n", " self._store_cs(y=y, X=X)\n", " return self\n", "\n", @@ -7276,11 +7350,13 @@ " forecasts : dict \n", " Dictionary with entries `fitted` for point predictions and `level_*` for probabilistic predictions.\n", " \"\"\"\n", - " res = {'fitted': self.model_['fitted']}\n", + " fitted = _adida(y=self._y, h=1, fitted=True)['fitted']\n", + " res = {'fitted': fitted}\n", " if level is not None:\n", - " res = _add_fitted_pi(res=res, se=self.model_['sigma'], level=level)\n", + " sigma = _calculate_sigma(self._y - fitted, self._y.size) \n", + " res = _add_fitted_pi(res=res, se=sigma, level=level)\n", " return res\n", - " \n", + "\n", " def forecast(\n", " self, \n", " y: np.ndarray,\n", @@ -7328,7 +7404,8 @@ " \"to calculate them\"\n", " )\n", " if fitted:\n", - " res = _add_fitted_pi(res=res, se=self.model_['sigma'], level=level)\n", + " sigma = _calculate_sigma(y - res['fitted'], y.size)\n", + " res = _add_fitted_pi(res=res, se=sigma, level=level)\n", " return res" ] }, @@ -7352,7 +7429,9 @@ "adida = ADIDA()\n", "test_class(adida, x=ap, h=12, skip_insample=False)\n", "test_class(adida, x=deg_ts, h=12, skip_insample=False)\n", - "fcst_adida = adida.forecast(ap, 12)" + "fcst_adida = adida.forecast(ap, 12)\n", + "\n", + "_test_fitted_sparse(ADIDA)" ] }, { @@ -7471,19 +7550,25 @@ " h: int, # forecasting horizon\n", " fitted: bool, # fitted values\n", "): \n", + " y = _ensure_float(y)\n", + " # demand\n", " yd = _demand(y)\n", - " yi = _intervals(y)\n", " if not yd.size: #no demand\n", - " return _naive(y=y, h=h, fitted=fitted)\n", + " return _naive(y=y, h=h, fitted=fitted) \n", " ydp, ydf = _ses_forecast(yd, 0.1)\n", + "\n", + " # intervals\n", + " yi = _intervals(y)\n", " yip, yif = _ses_forecast(yi, 0.1)\n", + "\n", " if yip != 0.0:\n", " mean = ydp / yip\n", " else:\n", " mean = ydp\n", " out = {'mean': _repeat_val(val=mean, h=h)}\n", " if fitted:\n", - " yif[yif == 0.0] = 1.0\n", + " ydf = _expand_fitted_demand(np.append(ydf, ydp), y)\n", + " yif = _expand_fitted_intervals(np.append(yif, yip), y) \n", " out['fitted'] = ydf / yif\n", " return out" ] @@ -7651,7 +7736,8 @@ " \"You have to instantiate the class with `prediction_intervals` to calculate them\"\n", " )\n", " if fitted:\n", - " res = _add_fitted_pi(res=res, se=self.model_['sigma'], level=level)\n", + " sigma = _calculate_sigma(y - res['fitted'], y.size)\n", + " res = _add_fitted_pi(res=res, se=sigma, level=level)\n", " return res" ] }, @@ -7662,10 +7748,12 @@ "outputs": [], "source": [ "#| hide\n", - "croston = CrostonClassic()\n", - "test_class(croston, x=ap, h=12, skip_insample=False)\n", - "test_class(croston, x=deg_ts, h=12, skip_insample=False)\n", - "fcst_croston = croston.forecast(ap, 12)\n" + "croston = CrostonClassic(prediction_intervals=ConformalIntervals(2, 1))\n", + "test_class(croston, x=ap, h=12, skip_insample=False, level=[80])\n", + "test_class(croston, x=deg_ts, h=12, skip_insample=False, level=[80])\n", + "fcst_croston = croston.forecast(ap, 12)\n", + "\n", + "_test_fitted_sparse(CrostonClassic)" ] }, { @@ -7784,19 +7872,39 @@ " h: int, # forecasting horizon\n", " fitted: bool, # fitted values\n", " ):\n", + " y = _ensure_float(y)\n", + " # demand\n", " yd = _demand(y)\n", - " yi = _intervals(y)\n", " if not yd.size:\n", " return _naive(y=y, h=h, fitted=fitted)\n", - " ydp, ydf = _optimized_ses_forecast(yd)\n", - " yip, yif = _optimized_ses_forecast(yi)\n", + " ydp, _ = _optimized_ses_forecast(yd)\n", + "\n", + " # intervals\n", + " yi = _intervals(y)\n", + " yip, _ = _optimized_ses_forecast(yi)\n", + "\n", " if yip != 0.0:\n", " mean = ydp / yip\n", " else:\n", " mean = ydp\n", " out = {'mean': _repeat_val(val=mean, h=h)}\n", " if fitted:\n", - " yif[yif == 0.0] = 1.0\n", + " warnings.warn(\"Computing fitted values for CrostonOptimized is very expensive\")\n", + " ydf = np.empty(yd.size + 1, dtype=y.dtype)\n", + " ydf[0] = np.nan\n", + " for i in range(yd.size):\n", + " ydf[i + 1] = _optimized_ses_forecast(yd[:i + 1])[0]\n", + "\n", + " yif = np.empty(yi.size + 1, dtype=y.dtype)\n", + " yif[0] = np.nan\n", + " for i in range(yi.size):\n", + " yiff = _optimized_ses_forecast(yi[:i + 1])[0]\n", + " if yiff == 0:\n", + " yiff = 1.0\n", + " yif[i + 1] = yiff\n", + "\n", + " ydf = _expand_fitted_demand(ydf, y)\n", + " yif = _expand_fitted_intervals(yif, y)\n", " out['fitted'] = ydf / yif\n", " return out" ] @@ -7862,8 +7970,8 @@ " self : \n", " CrostonOptimized fitted model.\n", " \"\"\" \n", - " self.model_ = _croston_optimized(y=y, h=1, fitted=True)\n", - " self.model_['sigma'] = _calculate_sigma(y - self.model_['fitted'], y.size)\n", + " self.model_ = _croston_optimized(y=y, h=1, fitted=False)\n", + " self._y = y\n", " self._store_cs(y=y, X=X)\n", " return self\n", " \n", @@ -7912,10 +8020,12 @@ " -------\n", " forecasts : dict \n", " Dictionary with entries `fitted` for point predictions and `level_*` for probabilistic predictions.\n", - " \"\"\" \n", - " res = {'fitted': self.model_['fitted']}\n", + " \"\"\"\n", + " fitted = _croston_optimized(y=self._y, h=1, fitted=True)['fitted']\n", + " res = {'fitted': fitted}\n", " if level is not None:\n", - " res = _add_fitted_pi(res=res, se=self.model_['sigma'], level=level)\n", + " sigma = _calculate_sigma(self._y - fitted, self._y.size) \n", + " res = _add_fitted_pi(res=res, se=sigma, level=level)\n", " return res\n", "\n", " def forecast(\n", @@ -7961,7 +8071,8 @@ " else:\n", " raise Exception(\"You must pass `prediction_intervals` to compute them.\")\n", " if fitted:\n", - " res = _add_fitted_pi(res=res, se=self.model_['sigma'], level=level)\n", + " sigma = _calculate_sigma(y - res['fitted'], y.size)\n", + " res = _add_fitted_pi(res=res, se=sigma, level=level)\n", " return res" ] }, @@ -7972,10 +8083,12 @@ "outputs": [], "source": [ "#| hide\n", - "croston_op = CrostonOptimized()\n", - "test_class(croston_op, x=ap, h=12, skip_insample=False)\n", - "test_class(croston_op, x=deg_ts, h=12, skip_insample=False)\n", - "fcst_croston_op = croston_op.forecast(ap, 12)" + "croston_op = CrostonOptimized(prediction_intervals=ConformalIntervals(2, 1))\n", + "test_class(croston_op, x=ap, h=12, skip_insample=False, level=[80])\n", + "test_class(croston_op, x=deg_ts, h=12, skip_insample=False, level=[80])\n", + "fcst_croston_op = croston_op.forecast(ap, 12)\n", + "\n", + "_test_fitted_sparse(CrostonOptimized)" ] }, { @@ -8267,7 +8380,8 @@ " \"to calculate them\"\n", " )\n", " if fitted:\n", - " res = _add_fitted_pi(res=res, se=self.model_['sigma'], level=level)\n", + " sigma = _calculate_sigma(y - res['fitted'], y.size)\n", + " res = _add_fitted_pi(res=res, se=sigma, level=level)\n", " return res" ] }, @@ -8278,10 +8392,12 @@ "outputs": [], "source": [ "#| hide\n", - "croston_sba = CrostonSBA()\n", - "test_class(croston_sba, x=ap, h=12, skip_insample=False)\n", - "test_class(croston_sba, x=deg_ts, h=12, skip_insample=False)\n", - "fcst_croston_sba = croston_sba.forecast(ap, 12)" + "croston_sba = CrostonSBA(prediction_intervals=ConformalIntervals(2, 1))\n", + "test_class(croston_sba, x=ap, h=12, skip_insample=False, level=[80])\n", + "test_class(croston_sba, x=deg_ts, h=12, skip_insample=False, level=[80])\n", + "fcst_croston_sba = croston_sba.forecast(ap, 12)\n", + "\n", + "_test_fitted_sparse(CrostonSBA)" ] }, { @@ -8403,24 +8519,29 @@ " if (y == 0).all():\n", " res = {'mean': np.zeros(h, dtype=np.float32)}\n", " if fitted:\n", - " res['fitted'] = y.copy()\n", + " res['fitted'] = np.zeros(y.size, dtype=np.float32)\n", + " res['fitted'][0] = np.nan\n", " return res\n", + " y = _ensure_float(y) \n", " y_intervals = _intervals(y)\n", " mean_interval = y_intervals.mean().item()\n", " max_aggregation_level = round(mean_interval)\n", " forecasts = np.empty(max_aggregation_level, np.float32)\n", - " fitted_vals = np.empty((y.size, max_aggregation_level), dtype=np.float32)\n", " for aggregation_level in range(1, max_aggregation_level + 1):\n", " lost_remainder_data = len(y) % aggregation_level\n", " y_cut = y[lost_remainder_data:]\n", " aggregation_sums = _chunk_sums(y_cut, aggregation_level)\n", - " forecast, fit = _optimized_ses_forecast(aggregation_sums)\n", + " forecast, _ = _optimized_ses_forecast(aggregation_sums)\n", " forecasts[aggregation_level - 1] = forecast / aggregation_level\n", - " fitted_vals[:, aggregation_level - 1] = fit / aggregation_level\n", " forecast = forecasts.mean()\n", " res = {'mean': _repeat_val(val=forecast, h=h)}\n", " if fitted:\n", - " res['fitted'] = fitted_vals.mean(axis=1)\n", + " warnings.warn(\"Computing fitted values for IMAPA is very expensive.\")\n", + " fitted_vals = np.empty_like(y)\n", + " fitted_vals[0] = np.nan\n", + " for i in range(y.size - 1):\n", + " fitted_vals[i + 1] = _imapa(y[:i+1], h=1, fitted=False)['mean'].item()\n", + " res['fitted'] = fitted_vals\n", " return res" ] }, @@ -8480,8 +8601,8 @@ " self : \n", " IMAPA fitted model.\n", " \"\"\"\n", - " self.model_ = _imapa(y=y, h=1, fitted=True)\n", - " self.model_['sigma'] = _calculate_sigma(y - self.model_['fitted'], y.size)\n", + " self.model_ = _imapa(y=y, h=1, fitted=False)\n", + " self._y = y\n", " self._store_cs(y=y, X=X)\n", " return self\n", " \n", @@ -8534,9 +8655,11 @@ " forecasts : dict \n", " Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions.\n", " \"\"\"\n", - " res = {'fitted': self.model_['fitted']}\n", + " fitted = _imapa(y=self._y, h=1, fitted=True)['fitted']\n", + " res = {'fitted': fitted}\n", " if level is not None:\n", - " res = _add_fitted_pi(res=res, se=self.model_['sigma'], level=level)\n", + " sigma = _calculate_sigma(self._y - fitted, self._y.size)\n", + " res = _add_fitted_pi(res=res, se=sigma, level=level)\n", " return res\n", " \n", " def forecast(\n", @@ -8586,7 +8709,8 @@ " \"to calculate them\"\n", " )\n", " if fitted:\n", - " res = _add_fitted_pi(res=res, se=self.model_['sigma'], level=level)\n", + " sigma = _calculate_sigma(y - res['fitted'], y.size)\n", + " res = _add_fitted_pi(res=res, se=sigma, level=level)\n", " return res" ] }, @@ -8597,10 +8721,12 @@ "outputs": [], "source": [ "#| hide\n", - "imapa = IMAPA()\n", - "test_class(imapa, x=ap, h=12, skip_insample=False)\n", - "test_class(imapa, x=deg_ts, h=12, skip_insample=False)\n", - "fcst_imapa = imapa.forecast(ap, 12)" + "imapa = IMAPA(prediction_intervals=ConformalIntervals(2, 1))\n", + "test_class(imapa, x=ap, h=12, skip_insample=False, level=[80])\n", + "test_class(imapa, x=deg_ts, h=12, skip_insample=False, level=[80])\n", + "fcst_imapa = imapa.forecast(ap, 12)\n", + "\n", + "_test_fitted_sparse(IMAPA)" ] }, { @@ -8724,14 +8850,17 @@ " if (y == 0).all():\n", " res = {'mean': np.zeros(h, dtype=np.float32)}\n", " if fitted:\n", - " res['fitted'] = y.copy()\n", + " res['fitted'] = np.zeros(y.size, dtype=np.float32)\n", + " res['fitted'][0] = np.nan\n", " return res\n", + " y = _ensure_float(y)\n", " yd = _demand(y)\n", " yp = _probability(y)\n", " ypf, ypft = _ses_forecast(yp, alpha_p)\n", " ydf, ydft = _ses_forecast(yd, alpha_d)\n", " res = {'mean': _repeat_val(val=ypf * ydf, h=h)}\n", " if fitted:\n", + " ydft = _expand_fitted_demand(np.append(ydft, ydf), y)\n", " res['fitted'] = ypft * ydft\n", " return res" ] @@ -8924,7 +9053,8 @@ " else:\n", " raise Exception(\"You must pass `prediction_intervals` to compute them.\")\n", " if fitted:\n", - " res = _add_fitted_pi(res=res, se=self.model_['sigma'], level=level)\n", + " sigma = _calculate_sigma(y - res['fitted'], y.size)\n", + " res = _add_fitted_pi(res=res, se=sigma, level=level)\n", " return res " ] }, @@ -8935,10 +9065,12 @@ "outputs": [], "source": [ "#| hide\n", - "tsb = TSB(alpha_d=0.9, alpha_p=0.1)\n", - "test_class(tsb, x=ap, h=12, skip_insample=False)\n", - "test_class(tsb, x=deg_ts, h=12, skip_insample=False)\n", - "fcst_tsb = tsb.forecast(ap, 12)" + "tsb = TSB(alpha_d=0.9, alpha_p=0.1, prediction_intervals=ConformalIntervals(2, 1))\n", + "test_class(tsb, x=ap, h=12, skip_insample=False, level=[80])\n", + "test_class(tsb, x=deg_ts, h=12, skip_insample=False, level=[80])\n", + "fcst_tsb = tsb.forecast(ap, 12)\n", + "\n", + "_test_fitted_sparse(lambda: TSB(alpha_d=0.9, alpha_p=0.1))" ] }, { diff --git a/nbs/src/utils.ipynb b/nbs/src/utils.ipynb index 277260183..5f0bb2fc8 100644 --- a/nbs/src/utils.ipynb +++ b/nbs/src/utils.ipynb @@ -301,7 +301,12 @@ " fitted_vals = np.full(y.size, np.nan, np.float32)\n", " fitted_vals[1:] = np.roll(y, 1)[1:]\n", " fcst['fitted'] = fitted_vals\n", - " return fcst" + " return fcst\n", + "\n", + "def _ensure_float(x: np.ndarray) -> np.ndarray:\n", + " if x.dtype not in (np.float32, np.float64):\n", + " x = x.astype(np.float32)\n", + " return x" ] }, { diff --git a/settings.ini b/settings.ini index cdf3a7806..6b231f72d 100644 --- a/settings.ini +++ b/settings.ini @@ -8,7 +8,7 @@ author = Nixtla author_email = business@nixtla.io copyright = Nixtla Inc. branch = main -version = 1.7.2 +version = 1.7.3 min_python = 3.8 audience = Developers language = English diff --git a/statsforecast/__init__.py b/statsforecast/__init__.py index 72a28c666..f7e604890 100644 --- a/statsforecast/__init__.py +++ b/statsforecast/__init__.py @@ -1,4 +1,4 @@ -__version__ = "1.7.2" +__version__ = "1.7.3" __all__ = ["StatsForecast"] from .core import StatsForecast from .distributed import fugue # noqa diff --git a/statsforecast/_modidx.py b/statsforecast/_modidx.py index a8057e91b..d74e8acd3 100644 --- a/statsforecast/_modidx.py +++ b/statsforecast/_modidx.py @@ -672,6 +672,8 @@ 'statsforecast.models._add_fitted_pi': ( 'src/core/models.html#_add_fitted_pi', 'statsforecast/models.py'), 'statsforecast.models._adida': ('src/core/models.html#_adida', 'statsforecast/models.py'), + 'statsforecast.models._chunk_forecast': ( 'src/core/models.html#_chunk_forecast', + 'statsforecast/models.py'), 'statsforecast.models._chunk_sums': ('src/core/models.html#_chunk_sums', 'statsforecast/models.py'), 'statsforecast.models._croston_classic': ( 'src/core/models.html#_croston_classic', 'statsforecast/models.py'), @@ -679,6 +681,10 @@ 'statsforecast/models.py'), 'statsforecast.models._croston_sba': ('src/core/models.html#_croston_sba', 'statsforecast/models.py'), 'statsforecast.models._demand': ('src/core/models.html#_demand', 'statsforecast/models.py'), + 'statsforecast.models._expand_fitted_demand': ( 'src/core/models.html#_expand_fitted_demand', + 'statsforecast/models.py'), + 'statsforecast.models._expand_fitted_intervals': ( 'src/core/models.html#_expand_fitted_intervals', + 'statsforecast/models.py'), 'statsforecast.models._get_conformal_method': ( 'src/core/models.html#_get_conformal_method', 'statsforecast/models.py'), 'statsforecast.models._historic_average': ( 'src/core/models.html#_historic_average', @@ -764,6 +770,7 @@ 'statsforecast.utils._calculate_intervals': ( 'src/utils.html#_calculate_intervals', 'statsforecast/utils.py'), 'statsforecast.utils._calculate_sigma': ('src/utils.html#_calculate_sigma', 'statsforecast/utils.py'), + 'statsforecast.utils._ensure_float': ('src/utils.html#_ensure_float', 'statsforecast/utils.py'), 'statsforecast.utils._naive': ('src/utils.html#_naive', 'statsforecast/utils.py'), 'statsforecast.utils._quantiles': ('src/utils.html#_quantiles', 'statsforecast/utils.py'), 'statsforecast.utils._repeat_val': ('src/utils.html#_repeat_val', 'statsforecast/utils.py'), diff --git a/statsforecast/models.py b/statsforecast/models.py index 32f4396b6..ff50a64ca 100644 --- a/statsforecast/models.py +++ b/statsforecast/models.py @@ -40,6 +40,7 @@ from statsforecast.utils import ( _calculate_sigma, _calculate_intervals, + _ensure_float, _naive, _quantiles, _repeat_val, @@ -3783,6 +3784,54 @@ def forecast( return res # %% ../nbs/src/core/models.ipynb 289 +def _chunk_forecast(y, aggregation_level): + lost_remainder_data = len(y) % aggregation_level + y_cut = y[lost_remainder_data:] + aggregation_sums = _chunk_sums(y_cut, aggregation_level) + sums_forecast, _ = _optimized_ses_forecast(aggregation_sums) + return sums_forecast + + +@njit(nogil=NOGIL, cache=CACHE) +def _expand_fitted_demand(fitted: np.ndarray, y: np.ndarray) -> np.ndarray: + out = np.empty_like(y) + out[0] = np.nan + fitted_idx = 0 + for i in range(1, y.size): + if y[i - 1] > 0: + fitted_idx += 1 + out[i] = fitted[fitted_idx] + elif fitted_idx > 0: + # if this entry is zero, the model didn't change + out[i] = out[i - 1] + else: + # if we haven't seen any demand, use naive + out[i] = y[i - 1] + return out + + +@njit(nogil=NOGIL, cache=CACHE) +def _expand_fitted_intervals(fitted: np.ndarray, y: np.ndarray) -> np.ndarray: + out = np.empty_like(y) + out[0] = np.nan + fitted_idx = 0 + for i in range(1, y.size): + if y[i - 1] != 0: + fitted_idx += 1 + if fitted[fitted_idx] == 0: + # to avoid division by zero + out[i] = 1 + else: + out[i] = fitted[fitted_idx] + elif fitted_idx > 0: + # if this entry is zero, the model didn't change + out[i] = out[i - 1] + else: + # if we haven't seen any intervals, use 1 to avoid division by zero + out[i] = 1 + return out + + def _adida( y: np.ndarray, # time series h: int, # forecasting horizon @@ -3791,19 +3840,30 @@ def _adida( if (y == 0).all(): res = {"mean": np.zeros(h, dtype=np.float32)} if fitted: - res["fitted"] = y.copy() + res["fitted"] = np.zeros(y.size, dtype=np.float32) + res["fitted"][0] = np.nan return res + y = _ensure_float(y) y_intervals = _intervals(y) mean_interval = y_intervals.mean() aggregation_level = round(mean_interval) - lost_remainder_data = len(y) % aggregation_level - y_cut = y[lost_remainder_data:] - aggregation_sums = _chunk_sums(y_cut, aggregation_level) - sums_forecast, sums_fitted = _optimized_ses_forecast(aggregation_sums) + sums_forecast = _chunk_forecast(y, aggregation_level) forecast = sums_forecast / aggregation_level res = {"mean": _repeat_val(val=forecast, h=h)} if fitted: - res["fitted"] = sums_fitted / aggregation_level + warnings.warn("Computing fitted values for ADIDA is very expensive") + fitted_aggregation_levels = np.round( + y_intervals.cumsum() / np.arange(1, y_intervals.size + 1) + ) + fitted_aggregation_levels = _expand_fitted_intervals( + np.append(np.nan, fitted_aggregation_levels), y + )[1:].astype(np.int32) + + sums_fitted = np.empty(y.size - 1, dtype=y.dtype) + for i, agg_lvl in enumerate(fitted_aggregation_levels): + sums_fitted[i] = _chunk_forecast(y[: i + 1], agg_lvl) + + res["fitted"] = np.append(np.nan, sums_fitted / fitted_aggregation_levels) return res # %% ../nbs/src/core/models.ipynb 290 @@ -3863,8 +3923,8 @@ def fit( self : ADIDA fitted model. """ - self.model_ = _adida(y=y, h=1, fitted=True) - self.model_["sigma"] = _calculate_sigma(y - self.model_["fitted"], y.size) + self.model_ = _adida(y=y, h=1, fitted=False) + self._y = y self._store_cs(y=y, X=X) return self @@ -3917,9 +3977,11 @@ def predict_in_sample(self, level: Optional[List[int]] = None): forecasts : dict Dictionary with entries `fitted` for point predictions and `level_*` for probabilistic predictions. """ - res = {"fitted": self.model_["fitted"]} + fitted = _adida(y=self._y, h=1, fitted=True)["fitted"] + res = {"fitted": fitted} if level is not None: - res = _add_fitted_pi(res=res, se=self.model_["sigma"], level=level) + sigma = _calculate_sigma(self._y - fitted, self._y.size) + res = _add_fitted_pi(res=res, se=sigma, level=level) return res def forecast( @@ -3969,7 +4031,8 @@ def forecast( "to calculate them" ) if fitted: - res = _add_fitted_pi(res=res, se=self.model_["sigma"], level=level) + sigma = _calculate_sigma(y - res["fitted"], y.size) + res = _add_fitted_pi(res=res, se=sigma, level=level) return res # %% ../nbs/src/core/models.ipynb 302 @@ -3978,19 +4041,25 @@ def _croston_classic( h: int, # forecasting horizon fitted: bool, # fitted values ): + y = _ensure_float(y) + # demand yd = _demand(y) - yi = _intervals(y) if not yd.size: # no demand return _naive(y=y, h=h, fitted=fitted) ydp, ydf = _ses_forecast(yd, 0.1) + + # intervals + yi = _intervals(y) yip, yif = _ses_forecast(yi, 0.1) + if yip != 0.0: mean = ydp / yip else: mean = ydp out = {"mean": _repeat_val(val=mean, h=h)} if fitted: - yif[yif == 0.0] = 1.0 + ydf = _expand_fitted_demand(np.append(ydf, ydp), y) + yif = _expand_fitted_intervals(np.append(yif, yip), y) out["fitted"] = ydf / yif return out @@ -4154,7 +4223,8 @@ def forecast( "You have to instantiate the class with `prediction_intervals` to calculate them" ) if fitted: - res = _add_fitted_pi(res=res, se=self.model_["sigma"], level=level) + sigma = _calculate_sigma(y - res["fitted"], y.size) + res = _add_fitted_pi(res=res, se=sigma, level=level) return res # %% ../nbs/src/core/models.ipynb 314 @@ -4163,19 +4233,39 @@ def _croston_optimized( h: int, # forecasting horizon fitted: bool, # fitted values ): + y = _ensure_float(y) + # demand yd = _demand(y) - yi = _intervals(y) if not yd.size: return _naive(y=y, h=h, fitted=fitted) - ydp, ydf = _optimized_ses_forecast(yd) - yip, yif = _optimized_ses_forecast(yi) + ydp, _ = _optimized_ses_forecast(yd) + + # intervals + yi = _intervals(y) + yip, _ = _optimized_ses_forecast(yi) + if yip != 0.0: mean = ydp / yip else: mean = ydp out = {"mean": _repeat_val(val=mean, h=h)} if fitted: - yif[yif == 0.0] = 1.0 + warnings.warn("Computing fitted values for CrostonOptimized is very expensive") + ydf = np.empty(yd.size + 1, dtype=y.dtype) + ydf[0] = np.nan + for i in range(yd.size): + ydf[i + 1] = _optimized_ses_forecast(yd[: i + 1])[0] + + yif = np.empty(yi.size + 1, dtype=y.dtype) + yif[0] = np.nan + for i in range(yi.size): + yiff = _optimized_ses_forecast(yi[: i + 1])[0] + if yiff == 0: + yiff = 1.0 + yif[i + 1] = yiff + + ydf = _expand_fitted_demand(ydf, y) + yif = _expand_fitted_intervals(yif, y) out["fitted"] = ydf / yif return out @@ -4236,8 +4326,8 @@ def fit( self : CrostonOptimized fitted model. """ - self.model_ = _croston_optimized(y=y, h=1, fitted=True) - self.model_["sigma"] = _calculate_sigma(y - self.model_["fitted"], y.size) + self.model_ = _croston_optimized(y=y, h=1, fitted=False) + self._y = y self._store_cs(y=y, X=X) return self @@ -4287,9 +4377,11 @@ def predict_in_sample(self, level: Optional[List[int]] = None): forecasts : dict Dictionary with entries `fitted` for point predictions and `level_*` for probabilistic predictions. """ - res = {"fitted": self.model_["fitted"]} + fitted = _croston_optimized(y=self._y, h=1, fitted=True)["fitted"] + res = {"fitted": fitted} if level is not None: - res = _add_fitted_pi(res=res, se=self.model_["sigma"], level=level) + sigma = _calculate_sigma(self._y - fitted, self._y.size) + res = _add_fitted_pi(res=res, se=sigma, level=level) return res def forecast( @@ -4335,7 +4427,8 @@ def forecast( else: raise Exception("You must pass `prediction_intervals` to compute them.") if fitted: - res = _add_fitted_pi(res=res, se=self.model_["sigma"], level=level) + sigma = _calculate_sigma(y - res["fitted"], y.size) + res = _add_fitted_pi(res=res, se=sigma, level=level) return res # %% ../nbs/src/core/models.ipynb 326 @@ -4512,7 +4605,8 @@ def forecast( "to calculate them" ) if fitted: - res = _add_fitted_pi(res=res, se=self.model_["sigma"], level=level) + sigma = _calculate_sigma(y - res["fitted"], y.size) + res = _add_fitted_pi(res=res, se=sigma, level=level) return res # %% ../nbs/src/core/models.ipynb 338 @@ -4524,24 +4618,29 @@ def _imapa( if (y == 0).all(): res = {"mean": np.zeros(h, dtype=np.float32)} if fitted: - res["fitted"] = y.copy() + res["fitted"] = np.zeros(y.size, dtype=np.float32) + res["fitted"][0] = np.nan return res + y = _ensure_float(y) y_intervals = _intervals(y) mean_interval = y_intervals.mean().item() max_aggregation_level = round(mean_interval) forecasts = np.empty(max_aggregation_level, np.float32) - fitted_vals = np.empty((y.size, max_aggregation_level), dtype=np.float32) for aggregation_level in range(1, max_aggregation_level + 1): lost_remainder_data = len(y) % aggregation_level y_cut = y[lost_remainder_data:] aggregation_sums = _chunk_sums(y_cut, aggregation_level) - forecast, fit = _optimized_ses_forecast(aggregation_sums) + forecast, _ = _optimized_ses_forecast(aggregation_sums) forecasts[aggregation_level - 1] = forecast / aggregation_level - fitted_vals[:, aggregation_level - 1] = fit / aggregation_level forecast = forecasts.mean() res = {"mean": _repeat_val(val=forecast, h=h)} if fitted: - res["fitted"] = fitted_vals.mean(axis=1) + warnings.warn("Computing fitted values for IMAPA is very expensive.") + fitted_vals = np.empty_like(y) + fitted_vals[0] = np.nan + for i in range(y.size - 1): + fitted_vals[i + 1] = _imapa(y[: i + 1], h=1, fitted=False)["mean"].item() + res["fitted"] = fitted_vals return res # %% ../nbs/src/core/models.ipynb 339 @@ -4597,8 +4696,8 @@ def fit( self : IMAPA fitted model. """ - self.model_ = _imapa(y=y, h=1, fitted=True) - self.model_["sigma"] = _calculate_sigma(y - self.model_["fitted"], y.size) + self.model_ = _imapa(y=y, h=1, fitted=False) + self._y = y self._store_cs(y=y, X=X) return self @@ -4651,9 +4750,11 @@ def predict_in_sample(self, level: Optional[List[int]] = None): forecasts : dict Dictionary with entries `mean` for point predictions and `level_*` for probabilistic predictions. """ - res = {"fitted": self.model_["fitted"]} + fitted = _imapa(y=self._y, h=1, fitted=True)["fitted"] + res = {"fitted": fitted} if level is not None: - res = _add_fitted_pi(res=res, se=self.model_["sigma"], level=level) + sigma = _calculate_sigma(self._y - fitted, self._y.size) + res = _add_fitted_pi(res=res, se=sigma, level=level) return res def forecast( @@ -4703,7 +4804,8 @@ def forecast( "to calculate them" ) if fitted: - res = _add_fitted_pi(res=res, se=self.model_["sigma"], level=level) + sigma = _calculate_sigma(y - res["fitted"], y.size) + res = _add_fitted_pi(res=res, se=sigma, level=level) return res # %% ../nbs/src/core/models.ipynb 350 @@ -4717,14 +4819,17 @@ def _tsb( if (y == 0).all(): res = {"mean": np.zeros(h, dtype=np.float32)} if fitted: - res["fitted"] = y.copy() + res["fitted"] = np.zeros(y.size, dtype=np.float32) + res["fitted"][0] = np.nan return res + y = _ensure_float(y) yd = _demand(y) yp = _probability(y) ypf, ypft = _ses_forecast(yp, alpha_p) ydf, ydft = _ses_forecast(yd, alpha_d) res = {"mean": _repeat_val(val=ypf * ydf, h=h)} if fitted: + ydft = _expand_fitted_demand(np.append(ydft, ydf), y) res["fitted"] = ypft * ydft return res @@ -4739,21 +4844,21 @@ def __init__( ): """TSB model. - Teunter-Syntetos-Babai: A modification of Croston's method that replaces the inter-demand + Teunter-Syntetos-Babai: A modification of Croston's method that replaces the inter-demand intervals with the demand probability $d_t$, which is defined as follows. $$ d_t = \\begin{cases} - 1 & \\text{if demand occurs at time t} \\\ + 1 & \\text{if demand occurs at time t} \\\ 0 & \\text{otherwise.} \\end{cases} $$ - Hence, the forecast is given by + Hence, the forecast is given by $$\hat{y}_t= \hat{d}_t\hat{z_t}$$ - Both $d_t$ and $z_t$ are forecasted using SES. The smooting paramaters of each may differ, + Both $d_t$ and $z_t$ are forecasted using SES. The smooting paramaters of each may differ, like in the optimized Croston's method. References @@ -4763,11 +4868,11 @@ def __init__( Parameters ---------- alpha_d : float - Smoothing parameter for demand. + Smoothing parameter for demand. alpha_p : float - Smoothing parameter for probability. - alias : str - Custom name of the model. + Smoothing parameter for probability. + alias : str + Custom name of the model. prediction_intervals : Optional[ConformalIntervals] Information to compute conformal prediction intervals. By default, the model will compute the native prediction @@ -4900,7 +5005,8 @@ def forecast( else: raise Exception("You must pass `prediction_intervals` to compute them.") if fitted: - res = _add_fitted_pi(res=res, se=self.model_["sigma"], level=level) + sigma = _calculate_sigma(y - res["fitted"], y.size) + res = _add_fitted_pi(res=res, se=sigma, level=level) return res # %% ../nbs/src/core/models.ipynb 363 diff --git a/statsforecast/utils.py b/statsforecast/utils.py index 21de5854c..89a812ed5 100644 --- a/statsforecast/utils.py +++ b/statsforecast/utils.py @@ -282,6 +282,12 @@ def _naive( fcst["fitted"] = fitted_vals return fcst + +def _ensure_float(x: np.ndarray) -> np.ndarray: + if x.dtype not in (np.float32, np.float64): + x = x.astype(np.float32) + return x + # %% ../nbs/src/utils.ipynb 19 # Functions used for calculating prediction intervals def _quantiles(level):