Skip to content

Commit

Permalink
Simplify asset decorator implementation (#44344)
Browse files Browse the repository at this point in the history
  • Loading branch information
uranusjr authored Nov 26, 2024
1 parent a832c41 commit f8a61cb
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 49 deletions.
71 changes: 52 additions & 19 deletions task_sdk/src/airflow/sdk/definitions/asset/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@
from __future__ import annotations

import logging
import operator
import os
import urllib.parse
import warnings
from collections.abc import Iterable, Iterator
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -40,6 +40,7 @@
from airflow.utils.session import NEW_SESSION, provide_session

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator
from urllib.parse import SplitResult

from sqlalchemy.orm.session import Session
Expand Down Expand Up @@ -221,11 +222,24 @@ def iter_dag_dependencies(self, *, source: str, target: str) -> Iterator[DagDepe
class Asset(os.PathLike, BaseAsset):
"""A representation of data asset dependencies between workflows."""

name: str
uri: str
group: str
extra: dict[str, Any]
watchers: list[BaseTrigger]
name: str = attrs.field(
validator=[_validate_asset_name],
)
uri: str = attrs.field(
validator=[_validate_non_empty_identifier],
converter=_sanitize_uri,
)
group: str = attrs.field(
default=attrs.Factory(operator.attrgetter("asset_type"), takes_self=True),
validator=[_validate_identifier],
)
extra: dict[str, Any] = attrs.field(
factory=dict,
converter=_set_extra_default,
)
watchers: list[BaseTrigger] = attrs.field(
factory=list,
)

asset_type: ClassVar[str] = "asset"
__version__: ClassVar[int] = 1
Expand All @@ -236,9 +250,9 @@ def __init__(
name: str,
uri: str,
*,
group: str = "",
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
watchers: list[BaseTrigger] = ...,
) -> None:
"""Canonical; both name and uri are provided."""

Expand All @@ -247,9 +261,9 @@ def __init__(
self,
name: str,
*,
group: str = "",
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
watchers: list[BaseTrigger] = ...,
) -> None:
"""It's possible to only provide the name, either by keyword or as the only positional argument."""

Expand All @@ -258,9 +272,9 @@ def __init__(
self,
*,
uri: str,
group: str = "",
group: str = ...,
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
watchers: list[BaseTrigger] = ...,
) -> None:
"""It's possible to only provide the URI as a keyword argument."""

Expand All @@ -269,7 +283,7 @@ def __init__(
name: str | None = None,
uri: str | None = None,
*,
group: str = "",
group: str | None = None,
extra: dict | None = None,
watchers: list[BaseTrigger] | None = None,
) -> None:
Expand All @@ -279,16 +293,35 @@ def __init__(
name = uri
elif uri is None:
uri = name
fields = attrs.fields_dict(Asset)
self.name = _validate_asset_name(self, fields["name"], name)
self.uri = _sanitize_uri(_validate_non_empty_identifier(self, fields["uri"], uri))
self.group = _validate_identifier(self, fields["group"], group) if group else self.asset_type
self.extra = _set_extra_default(extra)
self.watchers = watchers or []

if TYPE_CHECKING:
assert name is not None
assert uri is not None

# attrs default (and factory) does not kick in if any value is given to
# the argument. We need to exclude defaults from the custom ___init___.
kwargs: dict[str, Any] = {}
if group is not None:
kwargs["group"] = group
if extra is not None:
kwargs["extra"] = extra
if watchers is not None:
kwargs["watchers"] = watchers

self.__attrs_init__(name=name, uri=uri, **kwargs)

def __fspath__(self) -> str:
return self.uri

def __eq__(self, other: Any) -> bool:
# The Asset class can be subclassed, and we don't want fields added by a
# subclass to break equality. This explicitly filters out only fields
# defined by the Asset class for comparison.
if not isinstance(other, Asset):
return NotImplemented
f = attrs.filters.include(*attrs.fields_dict(Asset))
return attrs.asdict(self, filter=f) == attrs.asdict(other, filter=f)

@property
def normalized_uri(self) -> str | None:
"""
Expand Down
44 changes: 15 additions & 29 deletions task_sdk/src/airflow/sdk/definitions/asset/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,19 @@
from __future__ import annotations

import inspect
from collections.abc import Iterator, Mapping
from typing import (
TYPE_CHECKING,
Any,
Callable,
)
from typing import TYPE_CHECKING, Any, Callable

import attrs

from airflow.models.asset import _fetch_active_assets_by_name
from airflow.models.dag import DAG, ScheduleArg
from airflow.providers.standard.operators.python import PythonOperator
from airflow.sdk.definitions.asset import Asset, AssetRef
from airflow.utils.session import create_session

if TYPE_CHECKING:
from collections.abc import Iterator, Mapping

from airflow.io.path import ObjectStoragePath
from airflow.models.dag import ScheduleArg
from airflow.triggers.base import BaseTrigger


Expand All @@ -58,14 +54,17 @@ def _iter_kwargs(
yield key, value

def determine_kwargs(self, context: Mapping[str, Any]) -> Mapping[str, Any]:
active_assets: dict[str, Asset] = {}
from airflow.models.asset import _fetch_active_assets_by_name

asset_names = [asset_ref.name for asset_ref in self.inlets if isinstance(asset_ref, AssetRef)]
if "self" in inspect.signature(self.python_callable).parameters:
asset_names.append(self._definition_name)

if asset_names:
with create_session() as session:
active_assets = _fetch_active_assets_by_name(asset_names, session)
else:
active_assets = {}
return dict(self._iter_kwargs(context, active_assets))


Expand All @@ -81,50 +80,37 @@ class AssetDefinition(Asset):
schedule: ScheduleArg

def __attrs_post_init__(self) -> None:
parameters = inspect.signature(self.function).parameters
from airflow.models.dag import DAG

with DAG(dag_id=self.name, schedule=self.schedule, auto_register=True):
_AssetMainOperator(
task_id="__main__",
inlets=[
AssetRef(name=inlet_asset_name)
for inlet_asset_name in parameters
for inlet_asset_name in inspect.signature(self.function).parameters
if inlet_asset_name not in ("self", "context")
],
outlets=[self.to_asset()],
outlets=[self],
python_callable=self.function,
definition_name=self.name,
uri=self.uri,
)

def to_asset(self) -> Asset:
return Asset(
name=self.name,
uri=self.uri,
group=self.group,
extra=self.extra,
)

def serialize(self):
return {
"uri": self.uri,
"name": self.name,
"group": self.group,
"extra": self.extra,
}


@attrs.define(kw_only=True)
class asset:
"""Create an asset by decorating a materialization function."""

schedule: ScheduleArg
uri: str | ObjectStoragePath | None = None
group: str = ""
group: str = Asset.asset_type
extra: dict[str, Any] = attrs.field(factory=dict)
watchers: list[BaseTrigger] = attrs.field(factory=list)

def __call__(self, f: Callable) -> AssetDefinition:
if self.schedule is not None:
raise NotImplementedError("asset scheduling not implemented yet")

if (name := f.__name__) != f.__qualname__:
raise ValueError("nested function not supported")

Expand Down
2 changes: 1 addition & 1 deletion task_sdk/tests/defintions/test_asset_decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ def test__attrs_post_init__(
AssetRef(name="inlet_asset_1"),
AssetRef(name="inlet_asset_2"),
],
outlets=[asset_definition.to_asset()],
outlets=[asset_definition],
python_callable=ANY,
definition_name="example_asset_func",
uri="s3://bucket/object",
Expand Down

0 comments on commit f8a61cb

Please sign in to comment.