Skip to content

Commit

Permalink
feat: Add option to poll logs while running (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
zongsizhang authored Sep 5, 2024
1 parent 30805e9 commit 638aa71
Show file tree
Hide file tree
Showing 25 changed files with 386 additions and 135 deletions.
Binary file added .DS_Store
Binary file not shown.
File renamed without changes.
2 changes: 0 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,9 @@ repos:
- id: ruff
stages: [ pre-commit ]
args: [ --fix ]
files: ^airflow_providers_wherobots/
# Run the formatter.
- id: ruff-format
stages: [ pre-commit ]
files: ^airflow_providers_wherobots/

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.7.0
Expand Down
26 changes: 19 additions & 7 deletions airflow_providers_wherobots/hooks/rest_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,16 @@
import requests
from airflow.hooks.base import BaseHook
from airflow.models import Connection
from requests import PreparedRequest
from requests import PreparedRequest, Response
from requests.adapters import HTTPAdapter, Retry
from requests.auth import AuthBase

from airflow_providers_wherobots.hooks.base import DEFAULT_CONN_ID
from airflow_providers_wherobots.wherobots.models import Run, CreateRunPayload
from airflow_providers_wherobots.wherobots.models import (
Run,
CreateRunPayload,
LogsResponse,
)


class WherobotsAuth(AuthBase):
Expand Down Expand Up @@ -63,24 +67,32 @@ def _api_call(
method: str,
endpoint: str,
payload: Optional[dict[str, Any]] = None,
) -> Any:
params: Optional[dict[str, Any]] = None,
) -> Response:
auth = WherobotsAuth(self.conn.password)
url = "https://" + self.conn.host.rstrip("/") + endpoint
resp = self.session.request(url=url, method=method, json=payload, auth=auth)
resp = self.session.request(
url=url, method=method, json=payload, auth=auth, params=params
)
resp.raise_for_status()
return resp.json()
return resp

def get_run(self, run_id: str) -> Run:
resp_json = self._api_call("GET", f"/runs/{run_id}")
resp_json = self._api_call("GET", f"/runs/{run_id}").json()
return Run.model_validate(resp_json)

def create_run(self, payload: CreateRunPayload) -> Run:
resp_json = self._api_call(
"POST",
"/runs",
payload=payload.model_dump(mode="json"),
)
).json()
return Run.model_validate(resp_json)

def cancel_run(self, run_id: str) -> None:
self._api_call("POST", f"/runs/{run_id}/cancel")

def get_run_logs(self, run_id: str, start: int, size: int = 500) -> LogsResponse:
params = {"cursor": start, "size": size}
resp_json = self._api_call("GET", f"/runs/{run_id}/logs", params=params).json()
return LogsResponse.model_validate(resp_json)
86 changes: 78 additions & 8 deletions airflow_providers_wherobots/operators/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
JavaRunPayload,
CreateRunPayload,
RUN_NAME_ALPHABET,
RunStatus,
Run,
)


Expand All @@ -36,8 +38,9 @@ def __init__(
runtime: Optional[Runtime] = Runtime.SEDONA,
python: Optional[PythonRunPayload] = None,
java: Optional[JavaRunPayload] = None,
polling_interval: int = 30,
polling_interval: int = 20,
wherobots_conn_id: str = DEFAULT_CONN_ID,
poll_logs: bool = False,
xcom_push: bool = True,
**kwargs,
):
Expand All @@ -49,10 +52,12 @@ def __init__(
python=python,
java=java,
)
self.polling_interval = polling_interval
self._polling_interval = polling_interval
self.wherobots_conn_id = wherobots_conn_id
self.xcom_push = xcom_push
self.run_id: Optional[str] = None
self.poll_logs = poll_logs
self._logs_available = False

@property
def default_run_name(self) -> str:
Expand All @@ -67,6 +72,67 @@ def default_run_name(self) -> str:
+ ".{{ ts_nodash }}"
)

def poll_and_display_logs(
self, hook: WherobotsRestAPIHook, run: Run, start: int
) -> int:
"""
Poll the logs and display them
"""
log_resp = hook.get_run_logs(run.ext_id, start)
if not self._logs_available:
if log_resp.items:
self._logs_available = True
self.log.info("=== Logs for Run %s Start:", run.ext_id)
else:
self.log.info(
"Run %s status: %s, logs will start to stream once available",
run.ext_id,
run.status.value,
)
if not log_resp.items:
return log_resp.current_page
last_item = log_resp.items[-1]
first_item = log_resp.items[0]
if first_item.timestamp == start:
# We don't repeatedly print a log item
log_resp.items = log_resp.items[1:]
for log_item in log_resp.items:
self.log.info(f"Log: {log_item.raw}")
return log_resp.next_page or last_item.timestamp

def _log_run_status(self, run: Run):
self.log.info(f"Run {run.ext_id} status: {run.status}")

def _wait_run_poll_logs(self, hook: WherobotsRestAPIHook, run: Run):
logs_cursor: int = 0
while run.status == RunStatus.PENDING:
sleep(self._polling_interval)
run = hook.get_run(run.ext_id)
self._log_run_status(run)
while run.status == RunStatus.RUNNING:
# Pull the run logs
logs_cursor = self.poll_and_display_logs(hook, run, logs_cursor)
sleep(self._polling_interval)
run = hook.get_run(run.ext_id)
# If logs_cursor is still not None after run is ended, there are still logs to pull, we will pull them all.
while True:
# Sleep 3 sec to avoid too frequent polling
sleep(3)
next_cursor = self.poll_and_display_logs(hook, run, logs_cursor)
if next_cursor == logs_cursor:
break
else:
logs_cursor = next_cursor
self.log.info("=== Logs for Run %s End", run.ext_id)
return run

def _wait_run_simple(self, hook: WherobotsRestAPIHook, run: Run) -> Run:
while run.status.is_active():
sleep(self._polling_interval)
run = hook.get_run(run.ext_id)
self._log_run_status(run)
return run

def execute(self, context) -> Any:
"""
Trigger the Wherobots Run and keep polling for status until the Run ends
Expand All @@ -76,15 +142,19 @@ def execute(self, context) -> Any:
run = rest_api_hook.create_run(self.run_payload)
if self.xcom_push and context:
context["ti"].xcom_push(key=XComKey.run_id, value=run.ext_id)

self.run_id = run.ext_id
self.log.info(f"Run {run.ext_id} created")
while run.status.is_active():
sleep(self.polling_interval)
run = rest_api_hook.get_run(run.ext_id)
self.log.info(f"Run {run.ext_id} status: {run.status}")
# wait for the run ends
if self.poll_logs:
run = self._wait_run_poll_logs(rest_api_hook, run)
else:
run = self._wait_run_simple(rest_api_hook, run)
# loop end, means run is in terminal state
self.log.info(f"Run {run.ext_id} is {run.status}")
self._log_run_status(run)
if run.status in (RunStatus.FAILED, RunStatus.CANCELLED):
raise RuntimeError(
f"Run {run.ext_id} failed or cancelled by another party"
)

def on_kill(self) -> None:
"""
Expand Down
13 changes: 12 additions & 1 deletion airflow_providers_wherobots/wherobots/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import string
from datetime import datetime
from enum import auto
from typing import Optional, Sequence
from typing import Optional, Sequence, List

from pydantic import BaseModel, Field, ConfigDict, computed_field
from strenum import StrEnum
Expand Down Expand Up @@ -112,3 +112,14 @@ def create(
java=java,
timeoutSeconds=timeout_seconds,
)


class LogItem(BaseModel):
timestamp: int
raw: str


class LogsResponse(BaseModel):
items: List[LogItem]
current_page: int
next_page: Optional[int] = None
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "airflow-providers-wherobots"
version = "0.1.7"
version = "0.1.8"
description = "Airflow extension for communicating with Wherobots Cloud"
authors = ["zongsi.zhang <[email protected]>"]
readme = "README.md"
Expand Down
4 changes: 0 additions & 4 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os

import pytest
from airflow import DAG
from airflow.models import Connection
from pytest_mock import MockerFixture

Expand All @@ -24,6 +23,3 @@ def test_default_conn(mocker: MockerFixture):
@pytest.fixture(scope="function")
def clean_airflow_db():
os.system("airflow db reset --yes")



File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import os

import pytest
from airflow import DAG
from airflow.models import Connection
from pytest_mock import MockerFixture

Expand All @@ -16,6 +15,13 @@ def staging_conn(mocker: MockerFixture):
staging_api_token = os.getenv("STAGING_API_TOKEN")
if not staging_api_token:
raise ValueError("STAGING_API_TOKEN is not set")
staging_conn = Connection(conn_id="wherobots_staging_conn", conn_type="http", host=staging_host, password=staging_api_token)
mocker.patch.dict("os.environ", AIRFLOW_CONN_WHEROBOTS_STAGING_CONN=staging_conn.get_uri())
staging_conn = Connection(
conn_id="wherobots_staging_conn",
conn_type="http",
host=staging_host,
password=staging_api_token,
)
mocker.patch.dict(
"os.environ", AIRFLOW_CONN_WHEROBOTS_STAGING_CONN=staging_conn.get_uri()
)
return staging_conn
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from airflow_providers_wherobots.wherobots.models import (
PythonRunPayload,
)
from tests.unit.operator.test_run import dag, execute_dag
from tests.unit_tests.operator.test_run import execute_dag

DEFAULT_START = pendulum.datetime(2021, 9, 13, tz="UTC")
DEFAULT_END = DEFAULT_START + datetime.timedelta(days=1)
Expand All @@ -33,4 +33,4 @@ def test_staging_run_success(staging_conn: Connection, dag: DAG) -> None:
),
dag=dag,
)
execute_dag(dag, task_id=operator.task_id)
execute_dag(dag, task_id=operator.task_id)
99 changes: 0 additions & 99 deletions tests/unit/operator/test_run.py

This file was deleted.

File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@
"""

from airflow_providers_wherobots.wherobots.models import Run
from tests.unit.helpers.model_factories import RunFactory
from tests.unit_tests.helpers.model_factories import RunFactory

run_factory = RunFactory.create_factory(Run)
File renamed without changes.
File renamed without changes.
Loading

0 comments on commit 638aa71

Please sign in to comment.