Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(ingest/datahub): Add way to filter soft deleted entities #11738

Merged
merged 6 commits into from
Oct 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import Optional
from typing import Optional, Set

from pydantic import Field, root_validator

Expand Down Expand Up @@ -35,6 +35,19 @@ class DataHubSourceConfig(StatefulIngestionConfigBase):
),
)

include_soft_deleted_entities: bool = Field(
default=True,
description=(
"If enabled, include entities that have been soft deleted. "
"Otherwise, include all entities regardless of removal status. "
),
)

exclude_aspects: Set[str] = Field(
default_factory=set,
description="Set of aspect names to exclude from ingestion",
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume we use a set here instead of an AllowDenyPattern so that we can push down the filters to sql?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

exactly


database_query_batch_size: int = Field(
default=DEFAULT_DATABASE_BATCH_SIZE,
description="Number of records to fetch from the database at a time",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,17 @@ def __init__(
self.report = report
self.graph = graph

def get_aspects(self) -> Iterable[MetadataChangeProposalWrapper]:
def get_urns(self) -> Iterable[str]:
urns = self.graph.get_urns_by_filter(
status=RemovedStatusFilter.ALL,
status=RemovedStatusFilter.ALL
if self.config.include_soft_deleted_entities
else RemovedStatusFilter.NOT_SOFT_DELETED,
batch_size=self.config.database_query_batch_size,
)
return urns

def get_aspects(self) -> Iterable[MetadataChangeProposalWrapper]:
urns = self.get_urns()
tasks: List[futures.Future[Iterable[MetadataChangeProposalWrapper]]] = []
with futures.ThreadPoolExecutor(
max_workers=self.config.max_workers
Expand All @@ -43,6 +49,9 @@ def get_aspects(self) -> Iterable[MetadataChangeProposalWrapper]:
def _get_aspects_for_urn(self, urn: str) -> Iterable[MetadataChangeProposalWrapper]:
aspects: Dict[str, _Aspect] = self.graph.get_entity_semityped(urn) # type: ignore
for aspect in aspects.values():
if aspect.get_aspect_name().lower() in self.config.exclude_aspects:
continue

yield MetadataChangeProposalWrapper(
entityUrn=urn,
aspect=aspect,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import contextlib
import json
import logging
from datetime import datetime
from typing import Any, Generic, Iterable, List, Optional, Tuple, TypeVar
from typing import Any, Dict, Generic, Iterable, List, Optional, Tuple, TypeVar

from sqlalchemy import create_engine
from sqlalchemy.engine import Row
from typing_extensions import Protocol

from datahub.emitter.aspect import ASPECT_MAP
from datahub.emitter.mcp import MetadataChangeProposalWrapper
Expand All @@ -21,13 +20,7 @@
# Should work for at least mysql, mariadb, postgres
DATETIME_FORMAT = "%Y-%m-%d %H:%M:%S.%f"


class VersionOrderable(Protocol):
createdon: Any # Should restrict to only orderable types
version: int


ROW = TypeVar("ROW", bound=VersionOrderable)
ROW = TypeVar("ROW", bound=Dict[str, Any])


class VersionOrderer(Generic[ROW]):
Expand All @@ -54,22 +47,22 @@ def _process_row(self, row: ROW) -> Iterable[ROW]:
return

yield from self._attempt_queue_flush(row)
if row.version == 0:
if row["version"] == 0:
self._add_to_queue(row)
else:
yield row

def _add_to_queue(self, row: ROW) -> None:
if self.queue is None:
self.queue = (row.createdon, [row])
self.queue = (row["createdon"], [row])
else:
self.queue[1].append(row)

def _attempt_queue_flush(self, row: ROW) -> Iterable[ROW]:
if self.queue is None:
return

if row.createdon > self.queue[0]:
if row["createdon"] > self.queue[0]:
yield from self._flush_queue()

def _flush_queue(self) -> Iterable[ROW]:
Expand All @@ -92,6 +85,21 @@ def __init__(
**connection_config.options,
)

@property
def soft_deleted_urns_query(self) -> str:
return f"""
SELECT DISTINCT mav.urn
FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)} as mav
JOIN (
SELECT *,
JSON_EXTRACT(metadata, '$.removed') as removed
FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)}
WHERE aspect = "status" AND version = 0
) as sd ON sd.urn = mav.urn
WHERE sd.removed = true
ORDER BY mav.urn
"""

@property
def query(self) -> str:
# May repeat rows for the same date
Expand All @@ -101,66 +109,117 @@ def query(self) -> str:
# Relies on createdon order to reflect version order
# Ordering of entries with the same createdon is handled by VersionOrderer
return f"""
SELECT urn, aspect, metadata, systemmetadata, createdon, version
FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)}
WHERE createdon >= %(since_createdon)s
{"" if self.config.include_all_versions else "AND version = 0"}
ORDER BY createdon, urn, aspect, version
LIMIT %(limit)s
OFFSET %(offset)s
SELECT *
FROM (
SELECT
mav.urn,
mav.aspect,
mav.metadata,
mav.systemmetadata,
mav.createdon,
mav.version,
removed
FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)} as mav
LEFT JOIN (
SELECT
*,
JSON_EXTRACT(metadata, '$.removed') as removed
FROM {self.engine.dialect.identifier_preparer.quote(self.config.database_table_name)}
WHERE aspect = 'status'
AND version = 0
) as sd ON sd.urn = mav.urn
WHERE 1 = 1
{"" if self.config.include_all_versions else "AND mav.version = 0"}
{"" if not self.config.exclude_aspects else "AND mav.aspect NOT IN %(exclude_aspects)s"}
AND mav.createdon >= %(since_createdon)s
ORDER BY
createdon,
urn,
aspect,
version
) as t
WHERE 1=1
{"" if self.config.include_soft_deleted_entities else "AND (removed = false or removed is NULL)"}
ORDER BY
createdon,
urn,
aspect,
version
"""

def get_aspects(
self, from_createdon: datetime, stop_time: datetime
) -> Iterable[Tuple[MetadataChangeProposalWrapper, datetime]]:
orderer = VersionOrderer[Row](enabled=self.config.include_all_versions)
orderer = VersionOrderer[Dict[str, Any]](
enabled=self.config.include_all_versions
)
rows = self._get_rows(from_createdon=from_createdon, stop_time=stop_time)
for row in orderer(rows):
mcp = self._parse_row(row)
if mcp:
yield mcp, row.createdon
yield mcp, row["createdon"]

def _get_rows(self, from_createdon: datetime, stop_time: datetime) -> Iterable[Row]:
def _get_rows(
self, from_createdon: datetime, stop_time: datetime
) -> Iterable[Dict[str, Any]]:
with self.engine.connect() as conn:
ts = from_createdon
offset = 0
while ts.timestamp() <= stop_time.timestamp():
logger.debug(f"Polling database aspects from {ts}")
rows = conn.execute(
with contextlib.closing(conn.connection.cursor()) as cursor:
cursor.execute(
self.query,
since_createdon=ts.strftime(DATETIME_FORMAT),
limit=self.config.database_query_batch_size,
offset=offset,
{
"exclude_aspects": list(self.config.exclude_aspects),
"since_createdon": from_createdon.strftime(DATETIME_FORMAT),
},
)
if not rows.rowcount:
return

for i, row in enumerate(rows):
yield row
columns = [desc[0] for desc in cursor.description]
while True:
rows = cursor.fetchmany(self.config.database_query_batch_size)
if not rows:
return
for row in rows:
yield dict(zip(columns, row))

if ts == row.createdon:
offset += i + 1
else:
ts = row.createdon
offset = 0
def get_soft_deleted_rows(self) -> Iterable[Dict[str, Any]]:
"""
Fetches all soft-deleted entities from the database.

def _parse_row(self, row: Row) -> Optional[MetadataChangeProposalWrapper]:
Yields:
Row objects containing URNs of soft-deleted entities
"""
with self.engine.connect() as conn:
with contextlib.closing(conn.connection.cursor()) as cursor:
logger.debug("Polling soft-deleted urns from database")
cursor.execute(self.soft_deleted_urns_query)
columns = [desc[0] for desc in cursor.description]
while True:
rows = cursor.fetchmany(self.config.database_query_batch_size)
if not rows:
return
for row in rows:
yield dict(zip(columns, row))

def _parse_row(
self, row: Dict[str, Any]
) -> Optional[MetadataChangeProposalWrapper]:
try:
json_aspect = post_json_transform(json.loads(row.metadata))
json_metadata = post_json_transform(json.loads(row.systemmetadata or "{}"))
json_aspect = post_json_transform(json.loads(row["metadata"]))
json_metadata = post_json_transform(
json.loads(row["systemmetadata"] or "{}")
)
system_metadata = SystemMetadataClass.from_obj(json_metadata)
return MetadataChangeProposalWrapper(
entityUrn=row.urn,
aspect=ASPECT_MAP[row.aspect].from_obj(json_aspect),
entityUrn=row["urn"],
aspect=ASPECT_MAP[row["aspect"]].from_obj(json_aspect),
systemMetadata=system_metadata,
changeType=ChangeTypeClass.UPSERT,
)
except Exception as e:
logger.warning(
f"Failed to parse metadata for {row.urn}: {e}", exc_info=True
f'Failed to parse metadata for {row["urn"]}: {e}', exc_info=True
)
self.report.num_database_parse_errors += 1
self.report.database_parse_errors.setdefault(
str(e), LossyDict()
).setdefault(row.aspect, LossyList()).append(row.urn)
).setdefault(row["aspect"], LossyList()).append(row["urn"])
return None
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(
self.connection_config = connection_config
self.report = report
self.group_id = f"{KAFKA_GROUP_PREFIX}-{ctx.pipeline_name}"
self.ctx = ctx

def __enter__(self) -> "DataHubKafkaReader":
self.consumer = DeserializingConsumer(
Expand Down Expand Up @@ -95,6 +96,10 @@ def _poll_partition(
)
break

if mcl.aspectName and mcl.aspectName in self.config.exclude_aspects:
self.report.num_kafka_excluded_aspects += 1
continue

# TODO: Consider storing state in kafka instead, via consumer.commit()
yield mcl, PartitionOffset(partition=msg.partition(), offset=msg.offset())

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,18 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
self.report.stop_time = datetime.now(tz=timezone.utc)
logger.info(f"Ingesting DataHub metadata up until {self.report.stop_time}")
state = self.stateful_ingestion_handler.get_last_run_state()
database_reader: Optional[DataHubDatabaseReader] = None

if self.config.pull_from_datahub_api:
yield from self._get_api_workunits()

if self.config.database_connection is not None:
database_reader = DataHubDatabaseReader(
self.config, self.config.database_connection, self.report
)

yield from self._get_database_workunits(
from_createdon=state.database_createdon_datetime
from_createdon=state.database_createdon_datetime, reader=database_reader
)
self._commit_progress()
else:
Expand All @@ -77,23 +82,29 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]:
)

if self.config.kafka_connection is not None:
yield from self._get_kafka_workunits(from_offsets=state.kafka_offsets)
soft_deleted_urns = []
if not self.config.include_soft_deleted_entities:
if database_reader is None:
raise ValueError(
"Cannot exclude soft deleted entities without a database connection"
)
soft_deleted_urns = [
row["urn"] for row in database_reader.get_soft_deleted_rows()
]

yield from self._get_kafka_workunits(
from_offsets=state.kafka_offsets, soft_deleted_urns=soft_deleted_urns
)
self._commit_progress()
else:
logger.info(
"Skipping ingestion of timeseries aspects as no kafka_connection provided"
)

def _get_database_workunits(
self, from_createdon: datetime
self, from_createdon: datetime, reader: DataHubDatabaseReader
) -> Iterable[MetadataWorkUnit]:
if self.config.database_connection is None:
return

logger.info(f"Fetching database aspects starting from {from_createdon}")
reader = DataHubDatabaseReader(
self.config, self.config.database_connection, self.report
)
mcps = reader.get_aspects(from_createdon, self.report.stop_time)
for i, (mcp, createdon) in enumerate(mcps):

Expand All @@ -113,20 +124,29 @@ def _get_database_workunits(
self._commit_progress(i)

def _get_kafka_workunits(
self, from_offsets: Dict[int, int]
self, from_offsets: Dict[int, int], soft_deleted_urns: List[str] = []
) -> Iterable[MetadataWorkUnit]:
if self.config.kafka_connection is None:
return

logger.info("Fetching timeseries aspects from kafka")
with DataHubKafkaReader(
self.config, self.config.kafka_connection, self.report, self.ctx
self.config,
self.config.kafka_connection,
self.report,
self.ctx,
) as reader:
mcls = reader.get_mcls(
from_offsets=from_offsets, stop_time=self.report.stop_time
)
for i, (mcl, offset) in enumerate(mcls):
mcp = MetadataChangeProposalWrapper.try_from_mcl(mcl)
if mcp.entityUrn in soft_deleted_urns:
self.report.num_timeseries_soft_deleted_aspects_dropped += 1
logger.debug(
f"Dropping soft-deleted aspect of {mcp.aspectName} on {mcp.entityUrn}"
)
continue
if mcp.changeType == ChangeTypeClass.DELETE:
self.report.num_timeseries_deletions_dropped += 1
logger.debug(
Expand Down
Loading
Loading