diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index c1cd28bc..f68001ad 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -13,8 +13,8 @@ concurrency: env: NIXTLA_API_KEY: ${{ secrets.NIXTLA_DEV_API_KEY }} NIXTLA_BASE_URL: ${{ secrets.NIXTLA_DEV_BASE_URL }} - NIXTLA_API_KEY_CUSTOM: ${{ secrets.NIXTLA_DEV_API_KEY }} - NIXTLA_BASE_URL_CUSTOM: ${{ secrets.NIXTLA_DEV_BASE_URL }} + NIXTLA_API_KEY_CUSTOM: ${{ secrets.NIXTLA_API_KEY_CUSTOM }} + NIXTLA_BASE_URL_CUSTOM: ${{ secrets.NIXTLA_BASE_URL_CUSTOM }} API_KEY_FRED: ${{ secrets.API_KEY_FRED }} jobs: @@ -82,4 +82,4 @@ jobs: run: pip install uv && uv pip install --system ".[dev]" - name: Run tests - run: nbdev_test --timing --do_print --n_workers 0 --skip_file_re "computing_at_scale|distributed" \ No newline at end of file + run: nbdev_test --timing --do_print --n_workers 0 --skip_file_re "computing_at_scale|distributed" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 4c7499d9..8b9d930e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -17,10 +17,11 @@ repos: rev: v0.2.1 hooks: - id: ruff + files: 'nixtla' - repo: https://github.com/pre-commit/mirrors-mypy rev: v1.10.1 hooks: - id: mypy args: [--ignore-missing-imports] - exclude: "setup.py" + files: 'nixtla' diff --git a/action_files/models_performance/main.py b/action_files/models_performance/main.py index 22a631a5..0d20d995 100644 --- a/action_files/models_performance/main.py +++ b/action_files/models_performance/main.py @@ -184,7 +184,7 @@ def evaluate_benchmark_performace(self) -> Tuple[pd.DataFrame, pd.DataFrame]: h=self.h, n_windows=self.n_windows, step_size=self.h, - ).reset_index() + ) total_time = time() - init_time cv_model_df = cv_model_df.rename( columns={value: key for key, value in renamer.items()} diff --git a/nbs/src/nixtla_client.ipynb b/nbs/src/nixtla_client.ipynb index c6b0219e..d8d10431 100644 --- a/nbs/src/nixtla_client.ipynb +++ b/nbs/src/nixtla_client.ipynb @@ -964,6 +964,16 @@ " or 'Forecasting! :)' in validation.get('detail', '')\n", " )\n", "\n", + " def usage(self) -> dict[str, dict[str, int]]:\n", + " if self._is_azure:\n", + " raise NotImplementedError('usage is not implemented for Azure deployments')\n", + " with httpx.Client(**self._client_kwargs) as client:\n", + " resp = client.get('/usage')\n", + " body = resp.json()\n", + " if resp.status_code != 200:\n", + " raise ApiError(status_code=resp.status_code, body=body)\n", + " return body\n", + "\n", " def _distributed_forecast(\n", " self,\n", " df: DistributedDFType,\n", @@ -2222,6 +2232,22 @@ "nixtla_client.validate_api_key()" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# usage endpoint\n", + "client2 = NixtlaClient(\n", + " base_url=os.environ['NIXTLA_BASE_URL_CUSTOM'],\n", + " api_key=os.environ['NIXTLA_API_KEY_CUSTOM'],\n", + ")\n", + "usage = client2.usage()\n", + "assert sorted(usage.keys()) == ['minute', 'month']" + ] + }, { "cell_type": "code", "execution_count": null, diff --git a/nixtla/_modidx.py b/nixtla/_modidx.py index a12fd69f..1481f44a 100644 --- a/nixtla/_modidx.py +++ b/nixtla/_modidx.py @@ -64,6 +64,8 @@ 'nixtla/nixtla_client.py'), 'nixtla.nixtla_client.NixtlaClient.plot': ( 'src/nixtla_client.html#nixtlaclient.plot', 'nixtla/nixtla_client.py'), + 'nixtla.nixtla_client.NixtlaClient.usage': ( 'src/nixtla_client.html#nixtlaclient.usage', + 'nixtla/nixtla_client.py'), 'nixtla.nixtla_client.NixtlaClient.validate_api_key': ( 'src/nixtla_client.html#nixtlaclient.validate_api_key', 'nixtla/nixtla_client.py'), 'nixtla.nixtla_client._array_tails': ( 'src/nixtla_client.html#_array_tails', diff --git a/nixtla/nixtla_client.py b/nixtla/nixtla_client.py index cbc1ccaa..7d403e05 100644 --- a/nixtla/nixtla_client.py +++ b/nixtla/nixtla_client.py @@ -893,6 +893,16 @@ def validate_api_key(self, log: bool = True) -> bool: "message", "" ) == "success" or "Forecasting! :)" in validation.get("detail", "") + def usage(self) -> dict[str, dict[str, int]]: + if self._is_azure: + raise NotImplementedError("usage is not implemented for Azure deployments") + with httpx.Client(**self._client_kwargs) as client: + resp = client.get("/usage") + body = resp.json() + if resp.status_code != 200: + raise ApiError(status_code=resp.status_code, body=body) + return body + def _distributed_forecast( self, df: DistributedDFType,