diff --git a/airflow_providers_wherobots/hooks/rest_api.py b/airflow_providers_wherobots/hooks/rest_api.py index b92b5c5..30ab6e4 100644 --- a/airflow_providers_wherobots/hooks/rest_api.py +++ b/airflow_providers_wherobots/hooks/rest_api.py @@ -15,7 +15,6 @@ from airflow_providers_wherobots.hooks.base import DEFAULT_CONN_ID from airflow_providers_wherobots.wherobots.models import ( Run, - CreateRunPayload, LogsResponse, ) @@ -81,11 +80,11 @@ def get_run(self, run_id: str) -> Run: resp_json = self._api_call("GET", f"/runs/{run_id}").json() return Run.model_validate(resp_json) - def create_run(self, payload: CreateRunPayload) -> Run: + def create_run(self, payload: dict[str, Any]) -> Run: resp_json = self._api_call( "POST", "/runs", - payload=payload.model_dump(mode="json"), + payload=payload, ).json() return Run.model_validate(resp_json) diff --git a/airflow_providers_wherobots/operators/run.py b/airflow_providers_wherobots/operators/run.py index 4fb2bc7..f817696 100644 --- a/airflow_providers_wherobots/operators/run.py +++ b/airflow_providers_wherobots/operators/run.py @@ -8,14 +8,10 @@ from airflow.models import BaseOperator from strenum import StrEnum -from wherobots.db import Runtime from airflow_providers_wherobots.hooks.base import DEFAULT_CONN_ID from airflow_providers_wherobots.hooks.rest_api import WherobotsRestAPIHook from airflow_providers_wherobots.wherobots.models import ( - PythonRunPayload, - JavaRunPayload, - CreateRunPayload, RUN_NAME_ALPHABET, RunStatus, Run, @@ -36,9 +32,10 @@ class WherobotsRunOperator(BaseOperator): def __init__( self, name: Optional[str] = None, - runtime: Optional[Runtime] = Runtime.SEDONA, - python: Optional[PythonRunPayload] = None, - java: Optional[JavaRunPayload] = None, + runtime: str = "TINY", + run_python: Optional[dict[str, Any]] = None, + run_jar: Optional[dict[str, Any]] = None, + environment: Optional[dict[str, Any]] = None, polling_interval: int = 20, wherobots_conn_id: str = DEFAULT_CONN_ID, poll_logs: bool = False, @@ -47,12 +44,16 @@ def __init__( ): super().__init__(**kwargs) # If the user specifies the name, we will use it and rely on the server to validate the name - self.run_payload = CreateRunPayload( - runtime=runtime, - name=name or self.default_run_name, - python=python, - java=java, - ) + self.run_payload: dict[str, Any] = { + "runtime": runtime, + "name": name or self.default_run_name, + } + if run_python: + self.run_payload["runPython"] = run_python + if run_jar: + self.run_payload["runJar"] = run_jar + if environment: + self.run_payload["environment"] = environment self._polling_interval = polling_interval self.wherobots_conn_id = wherobots_conn_id self.xcom_push = xcom_push diff --git a/airflow_providers_wherobots/wherobots/models.py b/airflow_providers_wherobots/wherobots/models.py index d16df0a..239ad7f 100644 --- a/airflow_providers_wherobots/wherobots/models.py +++ b/airflow_providers_wherobots/wherobots/models.py @@ -5,11 +5,10 @@ import string from datetime import datetime from enum import auto -from typing import Optional, Sequence, List +from typing import Optional, List -from pydantic import BaseModel, Field, ConfigDict, computed_field +from pydantic import BaseModel, Field, ConfigDict from strenum import StrEnum -from wherobots.db import Runtime RUN_NAME_ALPHABET = string.ascii_letters + string.digits + "-_." @@ -39,81 +38,6 @@ class Run(WherobotsModel): end_time: Optional[datetime] = Field(default=None, alias="completeTime") -class PythonRunPayload(BaseModel): - """ - Model for the payload of Run with type == "python" - """ - - # For airflow to render the template fields - template_fields: Sequence[str] = Field( - ("uri", "args", "entrypoint"), exclude=True, init=False - ) - - uri: str - args: list[str] = [] - entrypoint: Optional[str] = None - - @classmethod - def create(cls, uri: str, args: list[str], entrypoint: Optional[str] = None): - return cls(uri=uri, args=args, entrypoint=entrypoint) - - -class JavaRunPayload(BaseModel): - """ - Model for the payload of Run with type == "python" - """ - - # For airflow to render the template fields - template_fields: Sequence[str] = Field(("uri", "args", "main_class"), exclude=True) - - uri: str - args: list[str] = [] - main_class: Optional[str] = Field(None, alias="mainClass") - - @classmethod - def create(cls, uri: str, args: list[str], main_class: Optional[str] = None): - return cls(uri=uri, args=args, mainClass=main_class) - - -class RunType(StrEnum): - python = auto() - java = auto() - - -class CreateRunPayload(BaseModel): - # For airflow to render the template fields - template_fields: Sequence[str] = Field(("name", "python", "java"), exclude=True) - - runtime: Runtime - name: Optional[str] = None - python: Optional[PythonRunPayload] = None - java: Optional[JavaRunPayload] = None - timeout_seconds: int = Field(3600, alias="timeoutSeconds") - - @computed_field - def type(self) -> RunType: - run_type = RunType.python if self.python else RunType.java - assert isinstance(run_type, RunType) - return run_type - - @classmethod - def create( - cls, - runtime: Runtime, - name: str, - python: Optional[PythonRunPayload] = None, - java: Optional[JavaRunPayload] = None, - timeout_seconds: int = 3600, - ): - return cls( - runtime=runtime, - name=name, - python=python, - java=java, - timeoutSeconds=timeout_seconds, - ) - - class LogItem(BaseModel): timestamp: int raw: str diff --git a/pyproject.toml b/pyproject.toml index c4e8134..3125684 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "airflow-providers-wherobots" -version = "0.1.9" +version = "0.1.10" description = "Airflow extension for communicating with Wherobots Cloud" authors = ["zongsi.zhang "] readme = "README.md" diff --git a/tests/integration_tests/operator/test_run.py b/tests/integration_tests/operator/test_run.py index 6a53ccf..5200e2b 100644 --- a/tests/integration_tests/operator/test_run.py +++ b/tests/integration_tests/operator/test_run.py @@ -10,9 +10,6 @@ from airflow.models import Connection from airflow_providers_wherobots.operators.run import WherobotsRunOperator -from airflow_providers_wherobots.wherobots.models import ( - PythonRunPayload, -) from tests.unit_tests.operator.test_run import execute_dag DEFAULT_START = pendulum.datetime(2021, 9, 13, tz="UTC") @@ -28,9 +25,9 @@ def test_staging_run_success(staging_conn: Connection, dag: DAG) -> None: wherobots_conn_id=staging_conn.conn_id, task_id="test_run_smoke", name="airflow_operator_test_run_{{ ts_nodash }}", - python=PythonRunPayload( - uri="s3://wbts-wbc-rcv7vl73oy/hao9o6y8ci/data/customer-z4asgjn7clrcbz/very_simple_job.py" - ), + run_python={ + "uri": "s3://wbts-wbc-rcv7vl73oy/hao9o6y8ci/data/customer-z4asgjn7clrcbz/very_simple_job.py" + }, dag=dag, ) execute_dag(dag, task_id=operator.task_id) diff --git a/tests/unit_tests/hooks/test_rest_api.py b/tests/unit_tests/hooks/test_rest_api.py index 4d4639d..41ea530 100644 --- a/tests/unit_tests/hooks/test_rest_api.py +++ b/tests/unit_tests/hooks/test_rest_api.py @@ -18,8 +18,6 @@ ) from airflow_providers_wherobots.wherobots.models import ( Run, - CreateRunPayload, - PythonRunPayload, LogsResponse, ) from tests.unit_tests import helpers @@ -108,22 +106,20 @@ def test_create_run(self, test_default_conn) -> None: """ test_run: Run = helpers.run_factory.build() url = f"https://{test_default_conn.host}/runs" - create_payload = CreateRunPayload.create( - name=test_run.name, - runtime=Runtime.SEDONA, - python=PythonRunPayload( - uri="s3://bucket/test.py", - args=["arg1", "arg2"], - entrypoint="src.main", - ), - ) + create_payload = { + "name": test_run.name, + "runtime": Runtime.SEDONA.value, + "python": { + "uri": "s3://bucket/test.py", + "args": ["arg1", "arg2"], + "entrypoint": "src.main", + }, + } responses.add( responses.POST, url, json=test_run.model_dump(mode="json"), - match=[ - matchers.json_params_matcher(create_payload.model_dump(mode="json")) - ], + match=[matchers.json_params_matcher(create_payload)], status=HTTPStatus.OK, ) with WherobotsRestAPIHook() as hook: diff --git a/tests/unit_tests/operator/test_run.py b/tests/unit_tests/operator/test_run.py index 27449b8..1c4af8f 100644 --- a/tests/unit_tests/operator/test_run.py +++ b/tests/unit_tests/operator/test_run.py @@ -18,9 +18,7 @@ from airflow_providers_wherobots.operators.run import WherobotsRunOperator from airflow_providers_wherobots.wherobots.models import ( - PythonRunPayload, RunStatus, - CreateRunPayload, LogsResponse, Run, LogItem, @@ -74,22 +72,22 @@ def test_render_template(self, mocker: MockerFixture, dag: DAG): operator = WherobotsRunOperator( task_id="test_render_template_python", name="test_run_{{ ds }}", - python=PythonRunPayload( - uri="s3://bucket/test-{{ ds }}.py", - args=["{{ ds }}"], - entrypoint="src.main_{{ ds }}", - ), + run_python={ + "uri": "s3://bucket/test-{{ ds }}.py", + "args": ["{{ ds }}"], + }, dag=dag, ) execute_dag(dag, task_id=operator.task_id) assert create_run.call_count == 1 rendered_payload = create_run.call_args.args[0] - assert isinstance(rendered_payload, CreateRunPayload) + assert isinstance(rendered_payload, dict) expected_ds = data_interval_start.format("YYYY-MM-DD") - assert rendered_payload.name == f"test_run_{expected_ds}" - assert rendered_payload.python.uri == f"s3://bucket/test-{expected_ds}.py" - assert rendered_payload.python.args == [expected_ds] - assert rendered_payload.python.entrypoint == f"src.main_{expected_ds}" + assert rendered_payload["name"] == f"test_run_{expected_ds}" + assert ( + rendered_payload["runPython"]["uri"] == f"s3://bucket/test-{expected_ds}.py" + ) + assert rendered_payload["runPython"]["args"] == [expected_ds] @pytest.mark.usefixtures("clean_airflow_db") def test_default_name(self, mocker: MockerFixture, dag: DAG): @@ -100,13 +98,13 @@ def test_default_name(self, mocker: MockerFixture, dag: DAG): ) operator = WherobotsRunOperator( task_id="test_default_name", - python=PythonRunPayload(uri=""), + run_python={"uri": ""}, dag=dag, ) execute_dag(dag, task_id=operator.task_id) rendered_payload = create_run.call_args.args[0] - assert isinstance(rendered_payload, CreateRunPayload) - assert rendered_payload.name == operator.default_run_name.replace( + assert isinstance(rendered_payload, dict) + assert rendered_payload["name"] == operator.default_run_name.replace( "{{ ts_nodash }}", data_interval_start.strftime("%Y%m%dT%H%M%S") ) @@ -160,7 +158,7 @@ def test_execute_handle_states( ) operator = WherobotsRunOperator( task_id=f"test_execute_{uuid.uuid4()}", - python=PythonRunPayload(uri=""), + run_python={"uri": ""}, dag=dag, polling_interval=0, poll_logs=poll_logs, @@ -186,11 +184,11 @@ def test_on_kill( operator = WherobotsRunOperator( task_id="test_render_template_python", name="test_run_{{ ds }}", - python=PythonRunPayload( - uri="s3://bucket/test-{{ ds }}.py", - args=["{{ ds }}"], - entrypoint="src.main_{{ ds }}", - ), + run_python={ + "uri": "s3://bucket/test-{{ ds }}.py", + "args": ["{{ ds }}"], + "entrypoint": "src.main_{{ ds }}", + }, dag=dag, ) operator.on_kill() @@ -209,7 +207,7 @@ def test_poll_and_display_logs(self, mocker: MockerFixture): ) operator = WherobotsRunOperator( task_id="test_poll_and_display_logs", - python=PythonRunPayload(uri=""), + run_python={"uri": ""}, dag=DAG("test_poll_and_display_logs"), ) assert operator.poll_and_display_logs(hook, test_run, 0) == 2 diff --git a/tests/unit_tests/wherobots/test_models.py b/tests/unit_tests/wherobots/test_models.py deleted file mode 100644 index 4ce6c6f..0000000 --- a/tests/unit_tests/wherobots/test_models.py +++ /dev/null @@ -1,31 +0,0 @@ -""" -Test module wherobots.models -""" - -from wherobots.db import Runtime - -from airflow_providers_wherobots.wherobots.models import ( - CreateRunPayload, - PythonRunPayload, - JavaRunPayload, -) - - -def test_create_run_payload(): - payload = CreateRunPayload.create( - runtime=Runtime.SEDONA, - name="test", - python=PythonRunPayload.create( - uri="s3://path/to/python", - args=["--arg1", "value1"], - entrypoint="main", - ), - java=JavaRunPayload.create( - uri="s3://path/to/java", - args=["--arg1", "value1"], - main_class="main", - ), - timeout_seconds=1200, - ) - assert payload.runtime == Runtime.SEDONA - assert payload.timeout_seconds == 1200