Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify asset decorator implementation #44344

Merged
merged 4 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 52 additions & 19 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from __future__ import annotations

import logging
import operator
import os
import urllib.parse
import warnings
from collections.abc import Iterable, Iterator
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -40,6 +40,7 @@
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
from urllib.parse import SplitResult

from sqlalchemy.orm.session import Session
Expand Down Expand Up @@ -221,11 +222,24 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
class Asset(os.PathLike, BaseAsset):
"""A representation of data asset dependencies between workflows."""

name: str
uri: str
group: str
extra: dict[str, Any]
watchers: list[BaseTrigger]
name: str = attrs.field(
validator=[_validate_asset_name],
)
uri: str = attrs.field(
validator=[_validate_non_empty_identifier],
converter=_sanitize_uri,
)
group: str = attrs.field(
default=attrs.Factory(operator.attrgetter("asset_type"), takes_self=True),
validator=[_validate_identifier],
)
extra: dict[str, Any] = attrs.field(
factory=dict,
converter=_set_extra_default,
)
watchers: list[BaseTrigger] = attrs.field(
factory=list,
)

asset_type: ClassVar[str] = "asset"
__version__: ClassVar[int] = 1
Expand All @@ -236,9 +250,9 @@ def __init__(
name: str,
uri: str,
*,
group: str = "",
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
watchers: list[BaseTrigger] = ...,
) -> None:
"""Canonical; both name and uri are provided."""

Expand All @@ -247,9 +261,9 @@ def __init__(
self,
name: str,
*,
group: str = "",
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
watchers: list[BaseTrigger] = ...,
) -> None:
"""It's possible to only provide the name, either by keyword or as the only positional argument."""

Expand All @@ -258,9 +272,9 @@ def __init__(
self,
*,
uri: str,
group: str = "",
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
watchers: list[BaseTrigger] = ...,
) -> None:
"""It's possible to only provide the URI as a keyword argument."""

Expand All @@ -269,7 +283,7 @@ def __init__(
name: str | None = None,
uri: str | None = None,
*,
group: str = "",
group: str | None = None,
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
Expand All @@ -279,16 +293,35 @@ def __init__(
name = uri
elif uri is None:
uri = name
fields = attrs.fields_dict(Asset)
self.name = _validate_asset_name(self, fields["name"], name)
self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri))
self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type
self.extra = _set_extra_default(extra)
self.watchers = watchers or []

if TYPE_CHECKING:
assert name is not None
assert uri is not None

# attrs default (and factory) does not kick in if any value is given to
# the argument. We need to exclude defaults from the custom ___init___.
kwargs: dict[str, Any] = {}
if group is not None:
kwargs["group"] = group
if extra is not None:
kwargs["extra"] = extra
if watchers is not None:
kwargs["watchers"] = watchers

self.__attrs_init__(name=name, uri=uri, **kwargs)

def __fspath__(self) -> str:
return self.uri

def __eq__(self, other: Any) -> bool:
# The Asset class can be subclassed, and we don't want fields added by a
# subclass to break equality. This explicitly filters out only fields
# defined by the Asset class for comparison.
if not isinstance(other, Asset):
return NotImplemented
f = attrs.filters.include(*attrs.fields_dict(Asset))
return attrs.asdict(self, filter=f) == attrs.asdict(other, filter=f)

@property
def normalized_uri(self) -> str | None:
"""
Expand Down
44 changes: 15 additions & 29 deletions task_sdk/src/airflow/sdk/definitions/asset/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,19 @@
from __future__ import annotations

import inspect
from collections.abc import Iterator, Mapping
from typing import (
TYPE_CHECKING,
Any,
Callable,
)
from typing import TYPE_CHECKING, Any, Callable

import attrs

from airflow.models.asset import _fetch_active_assets_by_name
from airflow.models.dag import DAG, ScheduleArg
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetRef
from airflow.utils.session import create_session

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping

from airflow.io.path import ObjectStoragePath
from airflow.models.dag import ScheduleArg
from airflow.triggers.base import BaseTrigger


Expand All @@ -58,14 +54,17 @@ def _iter_kwargs(
yield key, value

def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
active_assets: dict[str, Asset] = {}
from airflow.models.asset import _fetch_active_assets_by_name

asset_names = [asset_ref.name for asset_ref in self.inlets if isinstance(asset_ref, AssetRef)]
if "self" in inspect.signature(self.python_callable).parameters:
asset_names.append(self._definition_name)

if asset_names:
with create_session() as session:
active_assets = _fetch_active_assets_by_name(asset_names, session)
else:
active_assets = {}
return dict(self._iter_kwargs(context, active_assets))


Expand All @@ -81,50 +80,37 @@ class AssetDefinition(Asset):
schedule: ScheduleArg

def __attrs_post_init__(self) -> None:
parameters = inspect.signature(self.function).parameters
from airflow.models.dag import DAG

with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True):
_AssetMainOperator(
task_id="__main__",
inlets=[
AssetRef(name=inlet_asset_name)
for inlet_asset_name in parameters
for inlet_asset_name in inspect.signature(self.function).parameters
if inlet_asset_name not in ("self", "context")
],
outlets=[self.to_asset()],
outlets=[self],
python_callable=self.function,
definition_name=self.name,
uri=self.uri,
)

def to_asset(self) -> Asset:
return Asset(
name=self.name,
uri=self.uri,
group=self.group,
extra=self.extra,
)

def serialize(self):
return {
"uri": self.uri,
"name": self.name,
"group": self.group,
"extra": self.extra,
}


@attrs.define(kw_only=True)
class asset:
"""Create an asset by decorating a materialization function."""

schedule: ScheduleArg
uri: str | ObjectStoragePath | None = None
group: str = ""
group: str = Asset.asset_type
extra: dict[str, Any] = attrs.field(factory=dict)
watchers: list[BaseTrigger] = attrs.field(factory=list)

def __call__(self, f: Callable) -> AssetDefinition:
if self.schedule is not None:
raise NotImplementedError("asset scheduling not implemented yet")

if (name := f.__name__) != f.__qualname__:
raise ValueError("nested function not supported")

Expand Down
2 changes: 1 addition & 1 deletion task_sdk/tests/defintions/test_asset_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test__attrs_post_init__(
AssetRef(name="inlet_asset_1"),
AssetRef(name="inlet_asset_2"),
],
outlets=[asset_definition.to_asset()],
outlets=[asset_definition],
python_callable=ANY,
definition_name="example_asset_func",
uri="s3://bucket/object",
Expand Down