Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ele 1696 upgrade pydantic #1143

Merged
merged 10 commits into from
Sep 18, 2023
4 changes: 2 additions & 2 deletions elementary/clients/dbt/slim_dbt_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from dbt.parser.manifest import ManifestLoader
from dbt.tracking import disable_tracking
from dbt.version import __version__ as dbt_version_string
from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator

from elementary.clients.dbt.base_dbt_runner import BaseDbtRunner
from elementary.utils.log import get_logger
Expand Down Expand Up @@ -65,7 +65,7 @@ class ConfigArgs(BaseModel):
threads: Optional[int] = 1
vars: Optional[Union[str, Dict[str, Any]]] = DEFAULT_VARS

@validator("vars", pre=True)
@field_validator("vars", mode="before")
def validate_vars(cls, vars):
if not vars:
return DEFAULT_VARS
Expand Down
4 changes: 2 additions & 2 deletions elementary/clients/slack/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional

from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator

from elementary.utils.log import get_logger

Expand All @@ -17,7 +17,7 @@ class SlackMessageSchema(BaseModel):
attachments: Optional[list] = None
blocks: Optional[list] = None

@validator("attachments", pre=True)
@field_validator("attachments", mode="before")
def validate_attachments(cls, attachments):
if (
isinstance(attachments, list)
Expand Down
6 changes: 5 additions & 1 deletion elementary/monitor/api/filters/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
from typing import List, Optional

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


class FilterSchema(BaseModel):
model_config = ConfigDict(protected_namespaces=())

name: str
display_name: str
model_unique_ids: List[Optional[str]] = []
Expand All @@ -14,6 +16,8 @@ def add_model_unique_id(self, model_unique_id: Optional[str]):


class FiltersSchema(BaseModel):
model_config = ConfigDict(protected_namespaces=())

test_results: List[FilterSchema] = list()
test_runs: List[FilterSchema] = list()
model_runs: List[FilterSchema] = list()
4 changes: 2 additions & 2 deletions elementary/monitor/api/groups/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@


class GroupItemSchema(BaseModel):
node_id: Optional[str]
resource_type: Optional[str]
node_id: Optional[str] = None
resource_type: Optional[str] = None


DbtGroupSchema = Dict[str, dict]
Expand Down
9 changes: 4 additions & 5 deletions elementary/monitor/api/lineage/schema.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from typing import List, Optional, Tuple
from typing import List, Literal, Optional, Tuple

import networkx as nx
from pydantic import BaseModel, validator
from pydantic.typing import Literal
from pydantic import BaseModel, field_validator

NodeUniqueIdType = str
NodeType = Literal["model", "source", "exposure"]
Expand All @@ -17,11 +16,11 @@ class LineageSchema(BaseModel):
nodes: Optional[List[LineageNodeSchema]] = None
edges: Optional[List[Tuple[NodeUniqueIdType, NodeUniqueIdType]]] = None

@validator("nodes", pre=True, always=True)
@field_validator("nodes", mode="before")
def set_nodes(cls, nodes):
return nodes or []

@validator("edges", pre=True, always=True)
@field_validator("edges", mode="before")
def set_edges(cls, edges):
return edges or []

Expand Down
3 changes: 1 addition & 2 deletions elementary/monitor/api/models/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import os
import statistics
from collections import defaultdict
Expand Down Expand Up @@ -203,7 +202,7 @@ def _normalize_dbt_artifact_dict(
SourceSchema: NormalizedSourceSchema,
}
artifact_name = artifact.name
normalized_artifact = json.loads(artifact.json())
normalized_artifact = artifact.model_dump()
normalized_artifact["model_name"] = artifact_name
normalized_artifact["normalized_full_path"] = self._normalize_artifact_path(
artifact
Expand Down
24 changes: 13 additions & 11 deletions elementary/monitor/api/models/schema.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
import posixpath
from typing import Dict, List, Optional
from typing import Dict, List, Literal, Optional

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator

from elementary.monitor.api.totals_schema import TotalsSchema
from elementary.monitor.fetchers.models.schema import (
Expand All @@ -15,6 +15,8 @@


class NormalizedArtifactSchema(ExtendedBaseModel):
model_config = ConfigDict(protected_namespaces=())

owners: Optional[List[str]] = []
tags: Optional[List[str]] = []
# Should be changed to artifact_name.
Expand All @@ -23,32 +25,32 @@ class NormalizedArtifactSchema(ExtendedBaseModel):
normalized_full_path: str
fqn: str

@validator("tags", pre=True)
@field_validator("tags", mode="before")
def load_tags(cls, tags):
return cls._load_var_to_list(tags)

@validator("owners", pre=True)
@field_validator("owners", mode="before")
def load_owners(cls, owners):
return cls._load_var_to_list(owners)

@validator("normalized_full_path", pre=True)
@field_validator("normalized_full_path", mode="before")
def format_normalized_full_path_sep(cls, normalized_full_path: str) -> str:
return posixpath.sep.join(normalized_full_path.split(os.path.sep))


# NormalizedArtifactSchema must be first in the inheritance order
class NormalizedModelSchema(NormalizedArtifactSchema, ModelSchema):
artifact_type: str = Field("model", const=True)
artifact_type: Literal["model"] = "model"


# NormalizedArtifactSchema must be first in the inheritance order
class NormalizedSourceSchema(NormalizedArtifactSchema, SourceSchema):
artifact_type: str = Field("source", const=True)
artifact_type: Literal["source"] = "source"


# NormalizedArtifactSchema must be first in the inheritance order
class NormalizedExposureSchema(NormalizedArtifactSchema, ExposureSchema):
artifact_type: str = Field("exposure", const=True)
artifact_type: Literal["exposure"] = "exposure"


class ModelCoverageSchema(BaseModel):
Expand All @@ -60,11 +62,11 @@ class ModelRunSchema(BaseModel):
id: str
time_utc: str
status: str
full_refresh: Optional[bool]
materialization: Optional[str]
full_refresh: Optional[bool] = None
materialization: Optional[str] = None
execution_time: float

@validator("time_utc", pre=True)
@field_validator("time_utc", mode="before")
def format_time_utc(cls, time_utc):
return convert_partial_iso_format_to_full_iso_format(time_utc)

Expand Down
18 changes: 9 additions & 9 deletions elementary/monitor/api/report/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,10 +71,10 @@ def get_report_data(
test_results.totals, test_runs.totals, models, sources, models_runs.runs
)

serializable_groups = groups.dict()
serializable_groups = groups.model_dump()
serializable_models = self._serialize_models(models, sources, exposures)
serializable_model_runs = self._serialize_models_runs(models_runs.runs)
serializable_model_runs_totals = models_runs.dict(include={"totals"})[
serializable_model_runs_totals = models_runs.model_dump(include={"totals"})[
"totals"
]
serializable_models_coverages = self._serialize_coverages(coverages)
Expand All @@ -86,9 +86,9 @@ def get_report_data(
)
serializable_test_runs = self._serialize_test_runs(test_runs.runs)
serializable_test_runs_totals = self._serialize_totals(test_runs.totals)
serializable_invocation = test_results.invocation.dict()
serializable_filters = filters.dict()
serializable_lineage = lineage.dict()
serializable_invocation = test_results.invocation.model_dump()
serializable_filters = filters.model_dump()
serializable_lineage = lineage.model_dump()

models_latest_invocation = invocations_api.get_models_latest_invocation()
invocations = invocations_api.get_models_latest_invocations_data()
Expand Down Expand Up @@ -143,15 +143,15 @@ def _serialize_coverages(
return {model_id: dict(coverage) for model_id, coverage in coverages.items()}

def _serialize_models_runs(self, models_runs: List[ModelRunsSchema]) -> List[dict]:
return [model_runs.dict(by_alias=True) for model_runs in models_runs]
return [model_runs.model_dump(by_alias=True) for model_runs in models_runs]

def _serialize_test_results(
self, test_results: Dict[Optional[str], List[TestResultSchema]]
) -> Dict[Optional[str], List[dict]]:
serializable_test_results = defaultdict(list)
for model_unique_id, test_result in test_results.items():
serializable_test_results[model_unique_id].extend(
[result.dict() for result in test_result]
[result.model_dump() for result in test_result]
)
return serializable_test_results

Expand All @@ -161,7 +161,7 @@ def _serialize_test_runs(
serializable_test_runs = defaultdict(list)
for model_unique_id, test_run in test_runs.items():
serializable_test_runs[model_unique_id].extend(
[run.dict() for run in test_run]
[run.model_dump() for run in test_run]
)
return serializable_test_runs

Expand All @@ -170,5 +170,5 @@ def _serialize_totals(
) -> Dict[Optional[str], dict]:
serialized_totals = dict()
for model_unique_id, total in totals.items():
serialized_totals[model_unique_id] = total.dict()
serialized_totals[model_unique_id] = total.model_dump()
return serialized_totals
4 changes: 3 additions & 1 deletion elementary/monitor/api/report/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict


class ReportDataEnvSchema(BaseModel):
Expand All @@ -10,6 +10,8 @@ class ReportDataEnvSchema(BaseModel):


class ReportDataSchema(BaseModel):
model_config = ConfigDict(protected_namespaces=())

creation_time: Optional[str] = None
days_back: Optional[int] = None
models: dict = dict()
Expand Down
2 changes: 2 additions & 0 deletions elementary/monitor/api/test_management/test_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@


class TestManagementAPI(APIClient):
__test__ = False

def __init__(
self,
dbt_runner: BaseDbtRunner,
Expand Down
16 changes: 6 additions & 10 deletions elementary/monitor/api/tests/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Dict, List, Optional, Union

from pydantic import BaseModel, Field, validator
from pydantic import BaseModel, ConfigDict, Field, field_validator

from elementary.monitor.api.totals_schema import TotalsSchema
from elementary.monitor.fetchers.invocations.schema import DbtInvocationSchema
Expand All @@ -12,9 +12,6 @@ class ElementaryTestResultSchema(BaseModel):
metrics: Optional[Union[list, dict]] = None
result_description: Optional[str] = None

class Config:
smart_union = True


class DbtTestResultSchema(BaseModel):
display_name: Optional[str] = None
Expand All @@ -24,12 +21,12 @@ class DbtTestResultSchema(BaseModel):


class InvocationSchema(BaseModel):
affected_rows: Optional[int]
affected_rows: Optional[int] = None
time_utc: str
id: str
status: str

@validator("time_utc", pre=True)
@field_validator("time_utc", mode="before")
def format_time_utc(cls, time_utc):
return convert_partial_iso_format_to_full_iso_format(time_utc)

Expand All @@ -42,6 +39,8 @@ class InvocationsSchema(BaseModel):


class TestMetadataSchema(BaseModel):
model_config = ConfigDict(protected_namespaces=())

test_unique_id: str
elementary_unique_id: str
database_name: Optional[str] = None
Expand Down Expand Up @@ -69,9 +68,6 @@ class TestResultSchema(BaseModel):
metadata: TestMetadataSchema
test_results: Union[DbtTestResultSchema, ElementaryTestResultSchema]

class Config:
smart_union = True


class TestResultsWithTotalsSchema(BaseModel):
results: Dict[Optional[str], List[TestResultSchema]] = dict()
Expand All @@ -81,7 +77,7 @@ class TestResultsWithTotalsSchema(BaseModel):

class TestRunSchema(BaseModel):
metadata: TestMetadataSchema
test_runs: Optional[InvocationsSchema]
test_runs: Optional[InvocationsSchema] = None


class TestRunsWithTotalsSchema(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion elementary/monitor/data_monitoring/data_monitoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def __init__(
tracking.register_group(
"warehouse",
self.warehouse_info.id,
self.warehouse_info.dict(),
self.warehouse_info.model_dump(),
)
tracking.set_env("target_name", latest_invocation.get("target_name"))
tracking.set_env("dbt_orchestrator", latest_invocation.get("orchestrator"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def get_report_data(
)
self.success = False

report_data_dict = report_data.dict()
report_data_dict = report_data.model_dump()
return report_data_dict

def _add_report_tracking(
Expand Down
4 changes: 2 additions & 2 deletions elementary/monitor/data_monitoring/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import Dict, List, Optional

from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator

from elementary.monitor.alerts.model import ModelAlert
from elementary.monitor.alerts.source_freshness import SourceFreshnessAlert
Expand Down Expand Up @@ -62,7 +62,7 @@ class SelectorFilterSchema(BaseModel):
resource_types: Optional[List[ResourceType]] = None
node_names: Optional[List[str]] = None

@validator("invocation_time", pre=True)
@field_validator("invocation_time", mode="before")
def format_invocation_time(cls, invocation_time):
if invocation_time:
try:
Expand Down
6 changes: 3 additions & 3 deletions elementary/monitor/fetchers/invocations/schema.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Optional

from pydantic import BaseModel, validator
from pydantic import BaseModel, field_validator

from elementary.utils.json_utils import try_load_json
from elementary.utils.time import convert_partial_iso_format_to_full_iso_format
Expand All @@ -17,11 +17,11 @@ class DbtInvocationSchema(BaseModel):
job_id: Optional[str] = None
orchestrator: Optional[str] = None

@validator("detected_at", pre=True)
@field_validator("detected_at", mode="before")
def format_detected_at(cls, detected_at):
return convert_partial_iso_format_to_full_iso_format(detected_at)

@validator("selected", pre=True)
@field_validator("selected", mode="before")
def format_selected(cls, selected):
selected_list = try_load_json(selected) or []
return " ".join(selected_list)
Loading
Loading