Skip to content

Commit

Permalink
Merge branch 'main' into edit-variable
Browse files Browse the repository at this point in the history
  • Loading branch information
shubhamraj-git authored Dec 27, 2024
2 parents d1710a6 + 61412b3 commit da83640
Show file tree
Hide file tree
Showing 21 changed files with 779 additions and 188 deletions.
36 changes: 7 additions & 29 deletions airflow/sensors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,9 @@
from airflow.models.taskreschedule import TaskReschedule
from airflow.ti_deps.deps.ready_to_reschedule import ReadyToRescheduleDep
from airflow.utils import timezone
from airflow.utils.session import NEW_SESSION, create_session, provide_session
from airflow.utils.session import create_session

if TYPE_CHECKING:
from sqlalchemy.orm.session import Session

from airflow.typing_compat import Self
from airflow.utils.context import Context

Expand Down Expand Up @@ -84,30 +82,6 @@ def __bool__(self) -> bool:
return self.is_done


@provide_session
def _orig_start_date(
dag_id: str, task_id: str, run_id: str, map_index: int, try_number: int, session: Session = NEW_SESSION
):
"""
Get the original start_date for a rescheduled task.
:meta private:
"""
return session.scalar(
select(TaskReschedule)
.where(
TaskReschedule.dag_id == dag_id,
TaskReschedule.task_id == task_id,
TaskReschedule.run_id == run_id,
TaskReschedule.map_index == map_index,
TaskReschedule.try_number == try_number,
)
.order_by(TaskReschedule.id.asc())
.with_only_columns(TaskReschedule.start_date)
.limit(1)
)


class BaseSensorOperator(BaseOperator, SkipMixin):
"""
Sensor operators are derived from this class and inherit these attributes.
Expand Down Expand Up @@ -246,8 +220,12 @@ def execute(self, context: Context) -> Any:
ti = context["ti"]
max_tries: int = ti.max_tries or 0
retries: int = self.retries or 0

# If reschedule, use the start date of the first try (first try can be either the very
# first execution of the task, or the first execution after the task was cleared.)
# first execution of the task, or the first execution after the task was cleared).
# If the first try's record was not saved due to the Exception occurred and the following
# transaction rollback, the next available attempt should be taken
# to prevent falling in the endless rescheduling
first_try_number = max_tries - retries + 1
with create_session() as session:
start_date = session.scalar(
Expand All @@ -257,7 +235,7 @@ def execute(self, context: Context) -> Any:
TaskReschedule.task_id == ti.task_id,
TaskReschedule.run_id == ti.run_id,
TaskReschedule.map_index == ti.map_index,
TaskReschedule.try_number == first_try_number,
TaskReschedule.try_number >= first_try_number,
)
.order_by(TaskReschedule.id.asc())
.with_only_columns(TaskReschedule.start_date)
Expand Down
7 changes: 7 additions & 0 deletions contributing-docs/testing/unit_tests.rst
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,13 @@ For selected test types (example - the tests will run for Providers/API/CLI code
breeze testing providers-tests --skip-db-tests --parallel-test-types "Providers[google] Providers[amazon]"
You can also enter interactive shell with ``--skip-db-tests`` flag and run the tests iteratively

.. code-block:: bash
breeze shell --skip-db-tests
> pytest tests/your_test.py
How to make your test not depend on DB
......................................
Expand Down
8 changes: 8 additions & 0 deletions providers/tests/microsoft/azure/fs/test_adls.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,21 @@

from __future__ import annotations

import os
from unittest import mock

import pytest

from airflow.models import Connection
from airflow.providers.microsoft.azure.fs.adls import get_fs

if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
# Handle collection of the test by non-db case
Connection = mock.MagicMock() # type: ignore[misc]


pytestmark = pytest.mark.db_test


@pytest.fixture
def mocked_blob_file_system():
Expand Down
5 changes: 5 additions & 0 deletions providers/tests/microsoft/azure/hooks/test_adx.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import os
from unittest import mock

import pytest
Expand All @@ -31,6 +32,10 @@

pytestmark = pytest.mark.db_test

if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
# Handle collection of the test by non-db case
Connection = mock.MagicMock() # type: ignore[misc]


class TestAzureDataExplorerHook:
@pytest.mark.parametrize(
Expand Down
7 changes: 6 additions & 1 deletion providers/tests/microsoft/azure/hooks/test_base_azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
# under the License.
from __future__ import annotations

from unittest.mock import Mock, patch
import os
from unittest.mock import MagicMock, Mock, patch

import pytest

Expand All @@ -25,6 +26,10 @@

pytestmark = pytest.mark.db_test

if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
# Handle collection of the test by non-db case
Connection = MagicMock() # type: ignore[misc]

MODULE = "airflow.providers.microsoft.azure.hooks.base_azure"


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,20 @@
# under the License.
from __future__ import annotations

import os
from unittest import mock

import pytest

from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.container_registry import AzureContainerRegistryHook

pytestmark = pytest.mark.db_test

if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
# Handle collection of the test by non-db case
Connection = mock.MagicMock() # type: ignore[misc]


class TestAzureContainerRegistryHook:
@pytest.mark.parametrize(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,20 @@
# under the License.
from __future__ import annotations

import os
from unittest import mock

import pytest

from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.container_volume import AzureContainerVolumeHook

pytestmark = pytest.mark.db_test

if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
# Handle collection of the test by non-db case
Connection = mock.MagicMock() # type: ignore[misc]


class TestAzureContainerVolumeHook:
@pytest.mark.parametrize(
Expand Down
7 changes: 7 additions & 0 deletions providers/tests/microsoft/azure/hooks/test_cosmos.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from __future__ import annotations

import logging
import os
import uuid
from unittest import mock
from unittest.mock import PropertyMock
Expand All @@ -30,8 +31,14 @@
from airflow.models import Connection
from airflow.providers.microsoft.azure.hooks.cosmos import AzureCosmosDBHook

pytestmark = pytest.mark.db_test

MODULE = "airflow.providers.microsoft.azure.hooks.cosmos"

if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
# Handle collection of the test by non-db case
Connection = mock.MagicMock() # type: ignore[misc]


class TestAzureCosmosDbHook:
# Set up an environment to test with
Expand Down
6 changes: 6 additions & 0 deletions providers/tests/microsoft/azure/hooks/test_data_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@
# TODO: FIXME: the tests here have tricky issues with typing and need a bit more thought to fix them
# mypy: disable-error-code="union-attr,call-overload"

pytestmark = pytest.mark.db_test

if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
# Handle collection of the test by non-db case
Connection = mock.MagicMock() # type: ignore[misc]


@pytest.fixture(autouse=True)
def setup_connections(create_mock_connections):
Expand Down
4 changes: 4 additions & 0 deletions providers/tests/microsoft/azure/hooks/test_wasb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
# under the License.
from __future__ import annotations

import os
import re
from unittest import mock

Expand All @@ -31,6 +32,9 @@

pytestmark = pytest.mark.db_test

if os.environ.get("_AIRFLOW_SKIP_DB_TESTS") == "true":
# Handle collection of the test by non-db case
Connection = mock.MagicMock() # type: ignore[misc]

# connection_string has a format
CONN_STRING = (
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,7 @@ testing = ["dev", "providers.tests", "task_sdk.tests", "tests_common", "tests"]

# Test compat imports banned imports to allow testing against older airflow versions
"tests_common/test_utils/compat.py" = ["TID251", "F401"]
"tests_common/pytest_plugin.py" = ["F811"]

[tool.ruff.lint.flake8-tidy-imports]
# Disallow all relative imports.
Expand Down
1 change: 1 addition & 0 deletions task_sdk/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ dependencies = [
"msgspec>=0.18.6",
"psutil>=6.1.0",
"structlog>=24.4.0",
"retryhttp>=1.2.0",
]
classifiers = [
"Framework :: Apache Airflow",
Expand Down
28 changes: 28 additions & 0 deletions task_sdk/src/airflow/sdk/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from __future__ import annotations

import logging
import os
import sys
import uuid
from http import HTTPStatus
Expand All @@ -26,6 +28,8 @@
import msgspec
import structlog
from pydantic import BaseModel
from retryhttp import retry, wait_retry_after
from tenacity import before_log, wait_random_exponential
from uuid6 import uuid7

from airflow.sdk import __version__
Expand Down Expand Up @@ -268,6 +272,15 @@ def noop_handler(request: httpx.Request) -> httpx.Response:
return httpx.Response(200, json={"text": "Hello, world!"})


# Config options for SDK how retries on HTTP requests should be handled
# Note: Given defaults make attempts after 1, 3, 7, 15, 31seconds, 1:03, 2:07, 3:37 and fails after 5:07min
# So far there is no other config facility in SDK we use ENV for the moment
# TODO: Consider these env variables while handling airflow confs in task sdk
API_RETRIES = int(os.getenv("AIRFLOW__WORKERS__API_RETRIES", 10))
API_RETRY_WAIT_MIN = float(os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MIN", 1.0))
API_RETRY_WAIT_MAX = float(os.getenv("AIRFLOW__WORKERS__API_RETRY_WAIT_MAX", 90.0))


class Client(httpx.Client):
def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, **kwargs: Any):
if (not base_url) ^ dry_run:
Expand All @@ -289,6 +302,21 @@ def __init__(self, *, base_url: str | None, dry_run: bool = False, token: str, *
**kwargs,
)

_default_wait = wait_random_exponential(min=API_RETRY_WAIT_MIN, max=API_RETRY_WAIT_MAX)

@retry(
reraise=True,
max_attempt_number=API_RETRIES,
wait_server_errors=_default_wait,
wait_network_errors=_default_wait,
wait_timeouts=_default_wait,
wait_rate_limited=wait_retry_after(fallback=_default_wait), # No infinite timeout on HTTP 429
before_sleep=before_log(log, logging.WARNING),
)
def request(self, *args, **kwargs):
"""Implement a convenience for httpx.Client.request with a retry layer."""
return super().request(*args, **kwargs)

# We "group" or "namespace" operations by what they operate on, rather than a flat namespace with all
# methods on one object prefixed with the object type (`.task_instances.update` rather than
# `task_instance_update` etc.)
Expand Down
51 changes: 46 additions & 5 deletions task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

import os
import sys
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from datetime import datetime, timezone
from io import FileIO
from typing import TYPE_CHECKING, Annotated, Any, Generic, TextIO, TypeVar
Expand Down Expand Up @@ -195,9 +195,12 @@ def xcom_pull(
if TYPE_CHECKING:
assert isinstance(msg, XComResult)

value = msg.value
if value is not None:
return value
if msg.value is not None:
from airflow.models.xcom import XCom

# TODO: Move XCom serialization & deserialization to Task SDK
# https://github.com/apache/airflow/issues/45231
return XCom.deserialize_value(msg) # type: ignore[arg-type]
return default

def xcom_push(self, key: str, value: Any):
Expand All @@ -207,6 +210,12 @@ def xcom_push(self, key: str, value: Any):
:param key: Key to store the value under.
:param value: Value to store. Only be JSON-serializable may be used otherwise.
"""
from airflow.models.xcom import XCom

# TODO: Move XCom serialization & deserialization to Task SDK
# https://github.com/apache/airflow/issues/45231
value = XCom.serialize_value(value)

log = structlog.get_logger(logger_name="task")
SUPERVISOR_COMMS.send_request(
log=log,
Expand Down Expand Up @@ -381,7 +390,9 @@ def run(ti: RuntimeTaskInstance, log: Logger):
# - Update RTIF
# - Pre Execute
# etc
ti.task.execute(context) # type: ignore[attr-defined]
result = ti.task.execute(context) # type: ignore[attr-defined]
_push_xcom_if_needed(result, ti)

msg = TaskState(state=TerminalTIState.SUCCESS, end_date=datetime.now(tz=timezone.utc))
except TaskDeferred as defer:
classpath, trigger_kwargs = defer.trigger.serialize()
Expand Down Expand Up @@ -436,6 +447,36 @@ def run(ti: RuntimeTaskInstance, log: Logger):
SUPERVISOR_COMMS.send_request(msg=msg, log=log)


def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance):
"""Push XCom values when task has ``do_xcom_push`` set to ``True`` and the task returns a result."""
if ti.task.do_xcom_push:
xcom_value = result
else:
xcom_value = None

# If the task returns a result, push an XCom containing it.
if xcom_value is None:
return

# If the task has multiple outputs, push each output as a separate XCom.
if ti.task.multiple_outputs:
if not isinstance(xcom_value, Mapping):
raise TypeError(
f"Returned output was type {type(xcom_value)} expected dictionary for multiple_outputs"
)
for key in xcom_value.keys():
if not isinstance(key, str):
raise TypeError(
"Returned dictionary keys must be strings when using "
f"multiple_outputs, found {key} ({type(key)}) instead"
)
for k, v in result.items():
ti.xcom_push(k, v)

# TODO: Use constant for XCom return key & use serialize_value from Task SDK
ti.xcom_push("return_value", result)


def finalize(log: Logger): ...


Expand Down
Loading

0 comments on commit da83640

Please sign in to comment.