Skip to content

Commit

Permalink
chore: Add user agent in the rest API call headers (#18)
Browse files Browse the repository at this point in the history
* chore: Add user agent in the rest API call headers

* chore: Add user agent in the rest API call headers

* chore: ugprade dbapi driver version, add smoke test for sql operator
  • Loading branch information
zongsizhang authored Sep 19, 2024
1 parent c6cd9ff commit 0226da0
Show file tree
Hide file tree
Showing 6 changed files with 212 additions and 159 deletions.
28 changes: 27 additions & 1 deletion airflow_providers_wherobots/hooks/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,13 @@
Hook for WhereRobots API
"""

import platform
from functools import cached_property
from typing import Any, Optional

import requests
from importlib import metadata
from airflow.version import version as airflow_version
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from requests import PreparedRequest, Response
Expand Down Expand Up @@ -61,6 +64,20 @@ def __exit__(self, exc_type, exc_val, exc_tb):
def conn(self) -> Connection:
return self.get_connection(self.wherobots_conn_id)

@cached_property
def user_agent_header(self):
try:
package_version = metadata.version("airflow-providers-wherobots")
except metadata.PackageNotFoundError:
package_version = "unknown"
python_version = platform.python_version()
system = platform.system().lower()
header_value = (
f"airflow-providers-wherobots/{package_version} os/{system}"
f" python/{python_version} airflow/{airflow_version}"
)
return {"User-Agent": header_value}

def _api_call(
self,
method: str,
Expand All @@ -71,7 +88,12 @@ def _api_call(
auth = WherobotsAuth(self.conn.password)
url = "https://" + self.conn.host.rstrip("/") + endpoint
resp = self.session.request(
url=url, method=method, json=payload, auth=auth, params=params
url=url,
method=method,
json=payload,
auth=auth,
params=params,
headers=self.user_agent_header,
)
try:
resp.raise_for_status()
Expand All @@ -98,3 +120,7 @@ def get_run_logs(self, run_id: str, start: int, size: int = 500) -> LogsResponse
params = {"cursor": start, "size": size}
resp_json = self._api_call("GET", f"/runs/{run_id}/logs", params=params).json()
return LogsResponse.model_validate(resp_json)


if __name__ == "__main__":
metadata.version("airflow-providers-wherobots")
Loading

0 comments on commit 0226da0

Please sign in to comment.