Skip to content

Commit

Permalink
[AIP-44] Introduce Pydantic model for LogTemplate (apache#36004)
Browse files Browse the repository at this point in the history
  • Loading branch information
mhenc authored Dec 1, 2023
1 parent 8f2cf41 commit c26aa12
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 5 deletions.
1 change: 1 addition & 0 deletions airflow/api_internal/endpoints/rpc_api_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def _initialize_map() -> dict[str, Callable]:
DagRun.get_previous_dagrun,
DagRun.get_previous_scheduled_dagrun,
DagRun.fetch_task_instance,
DagRun._get_log_template,
SerializedDagModel.get_serialized_dag,
TaskInstance._check_and_change_state_before_execution,
TaskInstance.get_task_instance,
Expand Down
18 changes: 14 additions & 4 deletions airflow/models/dagrun.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
from airflow.models.operator import Operator
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
from airflow.typing_compat import Literal
from airflow.utils.types import ArgNotSet

Expand Down Expand Up @@ -1460,14 +1461,23 @@ def schedule_tis(
return count

@provide_session
def get_log_template(self, *, session: Session = NEW_SESSION) -> LogTemplate:
if self.log_template_id is None: # DagRun created before LogTemplate introduction.
def get_log_template(self, *, session: Session = NEW_SESSION) -> LogTemplate | LogTemplatePydantic:
return DagRun._get_log_template(log_template_id=self.log_template_id, session=session)

@staticmethod
@internal_api_call
@provide_session
def _get_log_template(
log_template_id: int | None, session: Session = NEW_SESSION
) -> LogTemplate | LogTemplatePydantic:
template: LogTemplate | None
if log_template_id is None: # DagRun created before LogTemplate introduction.
template = session.scalar(select(LogTemplate).order_by(LogTemplate.id).limit(1))
else:
template = session.get(LogTemplate, self.log_template_id)
template = session.get(LogTemplate, log_template_id)
if template is None:
raise AirflowException(
f"No log_template entry found for ID {self.log_template_id!r}. "
f"No log_template entry found for ID {log_template_id!r}. "
f"Please make sure you set up the metadatabase correctly."
)
return template
Expand Down
1 change: 1 addition & 0 deletions airflow/serialization/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,6 @@ class DagAttributeTypes(str, Enum):
DAG_RUN = "dag_run"
DAG_MODEL = "dag_model"
DATA_SET = "data_set"
LOG_TEMPLATE = "log_template"
CONNECTION = "connection"
ARG_NOT_SET = "arg_not_set"
1 change: 1 addition & 0 deletions airflow/serialization/pydantic/dag_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class DagRunPydantic(BaseModelPydantic):
updated_at: Optional[datetime]
dag: Optional[PydanticDag]
consumed_dataset_events: List[DatasetEventPydantic] # noqa
log_template_id: Optional[int]

model_config = ConfigDict(from_attributes=True, arbitrary_types_allowed=True)

Expand Down
30 changes: 30 additions & 0 deletions airflow/serialization/pydantic/tasklog.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from datetime import datetime

from pydantic import BaseModel as BaseModelPydantic, ConfigDict


class LogTemplatePydantic(BaseModelPydantic):
"""Serializable version of the LogTemplate ORM SqlAlchemyModel used by internal API."""

id: int
filename: str
elasticsearch_id: str
created_at: datetime

model_config = ConfigDict(from_attributes=True)
7 changes: 6 additions & 1 deletion airflow/serialization/serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
from airflow.models.mappedoperator import MappedOperator
from airflow.models.param import Param, ParamsDict
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.tasklog import LogTemplate
from airflow.models.xcom_arg import XComArg, deserialize_xcom_arg, serialize_xcom_arg
from airflow.providers_manager import ProvidersManager
from airflow.serialization.enums import DagAttributeTypes as DAT, Encoding
Expand All @@ -57,6 +58,7 @@
from airflow.serialization.pydantic.dataset import DatasetPydantic
from airflow.serialization.pydantic.job import JobPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
from airflow.settings import _ENABLE_AIP_44, DAGS_FOLDER, json
from airflow.utils.code_utils import get_python_source
from airflow.utils.docs import get_docs_url
Expand Down Expand Up @@ -514,7 +516,8 @@ def _pydantic_model_dump(model_cls: type[BaseModel], var: Any) -> dict[str, Any]
return cls._encode(_pydantic_model_dump(DatasetPydantic, var), type_=DAT.DATA_SET)
elif isinstance(var, DagModel):
return cls._encode(_pydantic_model_dump(DagModelPydantic, var), type_=DAT.DAG_MODEL)

elif isinstance(var, LogTemplate):
return cls._encode(_pydantic_model_dump(LogTemplatePydantic, var), type_=DAT.LOG_TEMPLATE)
else:
return cls.default_serialization(strict, var)
elif isinstance(var, ArgNotSet):
Expand Down Expand Up @@ -596,6 +599,8 @@ def deserialize(cls, encoded_var: Any, use_pydantic_models=False) -> Any:
return DagModelPydantic.parse_obj(var)
elif type_ == DAT.DATA_SET:
return DatasetPydantic.parse_obj(var)
elif type_ == DAT.LOG_TEMPLATE:
return LogTemplatePydantic.parse_obj(var)
elif type_ == DAT.ARG_NOT_SET:
return NOTSET
else:
Expand Down
8 changes: 8 additions & 0 deletions tests/serialization/test_serialized_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from airflow.models.dagrun import DagRun
from airflow.models.param import Param
from airflow.models.taskinstance import SimpleTaskInstance, TaskInstance
from airflow.models.tasklog import LogTemplate
from airflow.models.xcom_arg import XComArg
from airflow.operators.empty import EmptyOperator
from airflow.operators.python import PythonOperator
Expand All @@ -41,6 +42,7 @@
from airflow.serialization.pydantic.dag_run import DagRunPydantic
from airflow.serialization.pydantic.job import JobPydantic
from airflow.serialization.pydantic.taskinstance import TaskInstancePydantic
from airflow.serialization.pydantic.tasklog import LogTemplatePydantic
from airflow.settings import _ENABLE_AIP_44
from airflow.utils.operator_resources import Resources
from airflow.utils.state import DagRunState, State
Expand Down Expand Up @@ -278,6 +280,12 @@ def test_backcompat_deserialize_connection(conn_uri):
DAT.DAG_MODEL,
lambda a, b: a.fileloc == b.fileloc and a.schedule_interval == b.schedule_interval,
),
(
LogTemplate(id=1, filename="test_file", elasticsearch_id="test_id", created_at=datetime.now()),
LogTemplatePydantic,
DAT.LOG_TEMPLATE,
lambda a, b: a.id == b.id and a.filename == b.filename and equal_time(a.created_at, b.created_at),
),
],
)
def test_serialize_deserialize_pydantic(input, pydantic_class, encoded_type, cmp_func):
Expand Down

0 comments on commit c26aa12

Please sign in to comment.