diff --git a/metadata-ingestion/docs/dev_guides/add_stateful_ingestion_to_source.md b/metadata-ingestion/docs/dev_guides/add_stateful_ingestion_to_source.md index 0e3ead9a7adf8..6a1204fb0f2b3 100644 --- a/metadata-ingestion/docs/dev_guides/add_stateful_ingestion_to_source.md +++ b/metadata-ingestion/docs/dev_guides/add_stateful_ingestion_to_source.md @@ -252,7 +252,7 @@ Example code: def get_workunits(self) -> Iterable[MetadataWorkUnit]: # Skip a redundant run if self.redundant_run_skip_handler.should_skip_this_run( - cur_start_time_millis=datetime_to_ts_millis(self.config.start_time) + cur_start_time_millis=self.config.start_time ): return @@ -260,7 +260,7 @@ Example code: # # Update checkpoint state for this run. self.redundant_run_skip_handler.update_state( - start_time_millis=datetime_to_ts_millis(self.config.start_time), - end_time_millis=datetime_to_ts_millis(self.config.end_time), + start_time_millis=self.config.start_time, + end_time_millis=self.config.end_time, ) ``` \ No newline at end of file diff --git a/metadata-ingestion/src/datahub/configuration/time_window_config.py b/metadata-ingestion/src/datahub/configuration/time_window_config.py index 1bf992952759b..15de7470e4d82 100644 --- a/metadata-ingestion/src/datahub/configuration/time_window_config.py +++ b/metadata-ingestion/src/datahub/configuration/time_window_config.py @@ -65,11 +65,15 @@ def default_start_time( assert delta < timedelta( 0 ), "Relative start time should start with minus sign (-) e.g. '-2 days'." - assert abs(delta) > get_bucket_duration_delta( + assert abs(delta) >= get_bucket_duration_delta( values["bucket_duration"] ), "Relative start time should be in terms of configured bucket duration. e.g '-2 days' or '-2 hours'." - return values["end_time"] + delta + return get_time_bucket( + values["end_time"] + delta, values["bucket_duration"] + ) except humanfriendly.InvalidTimespan: + # We do not floor start_time to the bucket start time if absolute start time is specified. + # If user has specified absolute start time in recipe, it's most likely that he means it. return parse_absolute_time(v) return v diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py index 1446812c29216..7690723837165 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery.py @@ -74,7 +74,8 @@ ) from datahub.ingestion.source.state.profiling_state_handler import ProfilingHandler from datahub.ingestion.source.state.redundant_run_skip_handler import ( - RedundantRunSkipHandler, + RedundantLineageRunSkipHandler, + RedundantUsageRunSkipHandler, ) from datahub.ingestion.source.state.stale_entity_removal_handler import ( StaleEntityRemovalHandler, @@ -82,6 +83,11 @@ from datahub.ingestion.source.state.stateful_ingestion_base import ( StatefulIngestionSourceBase, ) +from datahub.ingestion.source_report.ingestion_stage import ( + LINEAGE_EXTRACTION, + METADATA_EXTRACTION, + PROFILING, +) from datahub.metadata.com.linkedin.pegasus2avro.common import ( Status, SubTypes, @@ -122,7 +128,6 @@ from datahub.utilities.perf_timer import PerfTimer from datahub.utilities.registries.domain_registry import DomainRegistry from datahub.utilities.sqlglot_lineage import SchemaResolver, sqlglot_lineage -from datahub.utilities.time import datetime_to_ts_millis logger: logging.Logger = logging.getLogger(__name__) @@ -228,10 +233,36 @@ def __init__(self, ctx: PipelineContext, config: BigQueryV2Config): set_dataset_urn_to_lower(self.config.convert_urns_to_lowercase) + self.redundant_lineage_run_skip_handler: Optional[ + RedundantLineageRunSkipHandler + ] = None + if self.config.enable_stateful_lineage_ingestion: + self.redundant_lineage_run_skip_handler = RedundantLineageRunSkipHandler( + source=self, + config=self.config, + pipeline_name=self.ctx.pipeline_name, + run_id=self.ctx.run_id, + ) + # For database, schema, tables, views, etc - self.lineage_extractor = BigqueryLineageExtractor(config, self.report) + self.lineage_extractor = BigqueryLineageExtractor( + config, self.report, self.redundant_lineage_run_skip_handler + ) + + redundant_usage_run_skip_handler: Optional[RedundantUsageRunSkipHandler] = None + if self.config.enable_stateful_usage_ingestion: + redundant_usage_run_skip_handler = RedundantUsageRunSkipHandler( + source=self, + config=self.config, + pipeline_name=self.ctx.pipeline_name, + run_id=self.ctx.run_id, + ) + self.usage_extractor = BigQueryUsageExtractor( - config, self.report, dataset_urn_builder=self.gen_dataset_urn_from_ref + config, + self.report, + dataset_urn_builder=self.gen_dataset_urn_from_ref, + redundant_run_skip_handler=redundant_usage_run_skip_handler, ) self.domain_registry: Optional[DomainRegistry] = None @@ -240,15 +271,8 @@ def __init__(self, ctx: PipelineContext, config: BigQueryV2Config): cached_domains=[k for k in self.config.domain], graph=self.ctx.graph ) - self.redundant_run_skip_handler = RedundantRunSkipHandler( - source=self, - config=self.config, - pipeline_name=self.ctx.pipeline_name, - run_id=self.ctx.run_id, - ) - self.profiling_state_handler: Optional[ProfilingHandler] = None - if self.config.store_last_profiling_timestamps: + if self.config.enable_stateful_profiling: self.profiling_state_handler = ProfilingHandler( source=self, config=self.config, @@ -271,7 +295,7 @@ def __init__(self, ctx: PipelineContext, config: BigQueryV2Config): self.sql_parser_schema_resolver = SchemaResolver( platform=self.platform, env=self.config.env ) - + self.add_config_to_report() atexit.register(cleanup, config) @classmethod @@ -502,68 +526,50 @@ def get_workunit_processors(self) -> List[Optional[MetadataWorkUnitProcessor]]: def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: conn: bigquery.Client = get_bigquery_client(self.config) - self.add_config_to_report() projects = self._get_projects(conn) if not projects: return for project_id in projects: - self.report.set_ingestion_stage(project_id.id, "Metadata Extraction") + self.report.set_ingestion_stage(project_id.id, METADATA_EXTRACTION) logger.info(f"Processing project: {project_id.id}") yield from self._process_project(conn, project_id) - if self._should_ingest_usage(): + if self.config.include_usage_statistics: yield from self.usage_extractor.get_usage_workunits( [p.id for p in projects], self.table_refs ) if self._should_ingest_lineage(): for project in projects: - self.report.set_ingestion_stage(project.id, "Lineage Extraction") + self.report.set_ingestion_stage(project.id, LINEAGE_EXTRACTION) yield from self.generate_lineage(project.id) - def _should_ingest_usage(self) -> bool: - if not self.config.include_usage_statistics: - return False - - if self.config.store_last_usage_extraction_timestamp: - if self.redundant_run_skip_handler.should_skip_this_run( - cur_start_time_millis=datetime_to_ts_millis(self.config.start_time) - ): - self.report.report_warning( - "usage-extraction", - f"Skip this run as there was a run later than the current start time: {self.config.start_time}", - ) - return False - else: + if self.redundant_lineage_run_skip_handler: # Update the checkpoint state for this run. - self.redundant_run_skip_handler.update_state( - start_time_millis=datetime_to_ts_millis(self.config.start_time), - end_time_millis=datetime_to_ts_millis(self.config.end_time), + self.redundant_lineage_run_skip_handler.update_state( + self.config.start_time, self.config.end_time ) - return True def _should_ingest_lineage(self) -> bool: if not self.config.include_table_lineage: return False - if self.config.store_last_lineage_extraction_timestamp: - if self.redundant_run_skip_handler.should_skip_this_run( - cur_start_time_millis=datetime_to_ts_millis(self.config.start_time) - ): - # Skip this run - self.report.report_warning( - "lineage-extraction", - f"Skip this run as there was a run later than the current start time: {self.config.start_time}", - ) - return False - else: - # Update the checkpoint state for this run. - self.redundant_run_skip_handler.update_state( - start_time_millis=datetime_to_ts_millis(self.config.start_time), - end_time_millis=datetime_to_ts_millis(self.config.end_time), - ) + if ( + self.redundant_lineage_run_skip_handler + and self.redundant_lineage_run_skip_handler.should_skip_this_run( + cur_start_time=self.config.start_time, + cur_end_time=self.config.end_time, + ) + ): + # Skip this run + self.report.report_warning( + "lineage-extraction", + "Skip this run as there was already a run for current ingestion window.", + ) + return False + return True def _get_projects(self, conn: bigquery.Client) -> List[BigqueryProject]: @@ -664,7 +670,7 @@ def _process_project( if self.config.is_profiling_enabled(): logger.info(f"Starting profiling project {project_id}") - self.report.set_ingestion_stage(project_id, "Profiling") + self.report.set_ingestion_stage(project_id, PROFILING) yield from self.profiler.get_workunits( project_id=project_id, tables=db_tables, @@ -1328,3 +1334,13 @@ def add_config_to_report(self): self.report.use_exported_bigquery_audit_metadata = ( self.config.use_exported_bigquery_audit_metadata ) + self.report.stateful_lineage_ingestion_enabled = ( + self.config.enable_stateful_lineage_ingestion + ) + self.report.stateful_usage_ingestion_enabled = ( + self.config.enable_stateful_usage_ingestion + ) + self.report.window_start_time, self.report.window_end_time = ( + self.config.start_time, + self.config.end_time, + ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py index b57e691411f75..8c46d8f675259 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/bigquery_report.py @@ -2,21 +2,22 @@ import dataclasses import logging from dataclasses import dataclass, field -from datetime import datetime, timezone +from datetime import datetime from typing import Counter, Dict, List, Optional import pydantic from datahub.ingestion.source.sql.sql_generic_profiler import ProfilingSqlReport +from datahub.ingestion.source_report.ingestion_stage import IngestionStageReport +from datahub.ingestion.source_report.time_window import BaseTimeWindowReport from datahub.utilities.lossy_collections import LossyDict, LossyList -from datahub.utilities.perf_timer import PerfTimer from datahub.utilities.stats_collections import TopKDict, int_top_k_dict logger: logging.Logger = logging.getLogger(__name__) @dataclass -class BigQueryV2Report(ProfilingSqlReport): +class BigQueryV2Report(ProfilingSqlReport, IngestionStageReport, BaseTimeWindowReport): num_total_lineage_entries: TopKDict[str, int] = field(default_factory=TopKDict) num_skipped_lineage_entries_missing_data: TopKDict[str, int] = field( default_factory=int_top_k_dict @@ -52,7 +53,6 @@ class BigQueryV2Report(ProfilingSqlReport): use_date_sharded_audit_log_tables: Optional[bool] = None log_page_size: Optional[pydantic.PositiveInt] = None use_exported_bigquery_audit_metadata: Optional[bool] = None - end_time: Optional[datetime] = None log_entry_start_time: Optional[str] = None log_entry_end_time: Optional[str] = None audit_start_time: Optional[str] = None @@ -88,23 +88,14 @@ class BigQueryV2Report(ProfilingSqlReport): default_factory=collections.Counter ) usage_state_size: Optional[str] = None - ingestion_stage: Optional[str] = None - ingestion_stage_durations: TopKDict[str, float] = field(default_factory=TopKDict) - _timer: Optional[PerfTimer] = field( - default=None, init=False, repr=False, compare=False - ) + lineage_start_time: Optional[datetime] = None + lineage_end_time: Optional[datetime] = None + stateful_lineage_ingestion_enabled: bool = False - def set_ingestion_stage(self, project: str, stage: str) -> None: - if self._timer: - elapsed = round(self._timer.elapsed_seconds(), 2) - logger.info( - f"Time spent in stage <{self.ingestion_stage}>: {elapsed} seconds" - ) - if self.ingestion_stage: - self.ingestion_stage_durations[self.ingestion_stage] = elapsed - else: - self._timer = PerfTimer() + usage_start_time: Optional[datetime] = None + usage_end_time: Optional[datetime] = None + stateful_usage_ingestion_enabled: bool = False - self.ingestion_stage = f"{project}: {stage} at {datetime.now(timezone.utc)}" - self._timer.start() + def set_ingestion_stage(self, project_id: str, stage: str) -> None: + self.report_ingestion_stage_start(f"{project_id}: {stage}") diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/lineage.py index 255a673026252..842e3d2144600 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/lineage.py @@ -4,7 +4,18 @@ import textwrap from dataclasses import dataclass from datetime import datetime, timezone -from typing import Any, Callable, Dict, FrozenSet, Iterable, List, Optional, Set, Union +from typing import ( + Any, + Callable, + Dict, + FrozenSet, + Iterable, + List, + Optional, + Set, + Tuple, + Union, +) import humanfriendly from google.cloud.bigquery import Client as BigQueryClient @@ -29,6 +40,9 @@ _make_gcp_logging_client, get_bigquery_client, ) +from datahub.ingestion.source.state.redundant_run_skip_handler import ( + RedundantLineageRunSkipHandler, +) from datahub.metadata.schema_classes import ( AuditStampClass, DatasetLineageTypeClass, @@ -133,7 +147,6 @@ def _follow_column_lineage( def make_lineage_edges_from_parsing_result( sql_lineage: SqlParsingResult, audit_stamp: datetime, lineage_type: str ) -> List[LineageEdge]: - # Note: This ignores the out_tables section of the sql parsing result. audit_stamp = datetime.now(timezone.utc) @@ -215,10 +228,29 @@ class BigqueryLineageExtractor: timestamp < "{end_time}" """.strip() - def __init__(self, config: BigQueryV2Config, report: BigQueryV2Report): + def __init__( + self, + config: BigQueryV2Config, + report: BigQueryV2Report, + redundant_run_skip_handler: Optional[RedundantLineageRunSkipHandler] = None, + ): self.config = config self.report = report + self.redundant_run_skip_handler = redundant_run_skip_handler + self.start_time, self.end_time = ( + self.report.lineage_start_time, + self.report.lineage_end_time, + ) = self.get_time_window() + + def get_time_window(self) -> Tuple[datetime, datetime]: + if self.redundant_run_skip_handler: + return self.redundant_run_skip_handler.suggest_run_time_window( + self.config.start_time, self.config.end_time + ) + else: + return self.config.start_time, self.config.end_time + def error(self, log: logging.Logger, key: str, reason: str) -> None: self.report.report_warning(key, reason) log.error(f"{key} => {reason}") @@ -406,7 +438,7 @@ def _get_bigquery_log_entries( ) -> Iterable[AuditLogEntry]: self.report.num_total_log_entries[client.project] = 0 # Add a buffer to start and end time to account for delays in logging events. - start_time = (self.config.start_time - self.config.max_query_duration).strftime( + start_time = (self.start_time - self.config.max_query_duration).strftime( BQ_DATETIME_FORMAT ) self.report.log_entry_start_time = start_time @@ -462,12 +494,12 @@ def _get_exported_bigquery_audit_metadata( self.report.bigquery_audit_metadata_datasets_missing = True return - corrected_start_time = self.config.start_time - self.config.max_query_duration + corrected_start_time = self.start_time - self.config.max_query_duration start_time = corrected_start_time.strftime(BQ_DATETIME_FORMAT) start_date = corrected_start_time.strftime(BQ_DATE_SHARD_FORMAT) self.report.audit_start_time = start_time - corrected_end_time = self.config.end_time + self.config.max_query_duration + corrected_end_time = self.end_time + self.config.max_query_duration end_time = corrected_end_time.strftime(BQ_DATETIME_FORMAT) end_date = corrected_end_time.strftime(BQ_DATE_SHARD_FORMAT) self.report.audit_end_time = end_time @@ -663,6 +695,7 @@ def _compute_bigquery_lineage( "lineage", f"{project_id}: {e}", ) + self.report_status(f"{project_id}-lineage", False) lineage_metadata = {} self.report.lineage_mem_size[project_id] = humanfriendly.format_size( @@ -832,3 +865,7 @@ def test_capability(self, project_id: str) -> None: ) for entry in self._get_bigquery_log_entries(gcp_logging_client, limit=1): logger.debug(f"Connection test got one audit metadata entry {entry}") + + def report_status(self, step: str, status: bool) -> None: + if self.redundant_run_skip_handler: + self.redundant_run_skip_handler.report_current_run_status(step, status) diff --git a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/usage.py b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/usage.py index 1081dd8eec1ec..fe7ab8c49c79a 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/bigquery_v2/usage.py @@ -25,7 +25,10 @@ from google.cloud.logging_v2.client import Client as GCPLoggingClient from ratelimiter import RateLimiter -from datahub.configuration.time_window_config import get_time_bucket +from datahub.configuration.time_window_config import ( + BaseTimeWindowConfig, + get_time_bucket, +) from datahub.emitter.mce_builder import make_user_urn from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.closeable import Closeable @@ -50,10 +53,18 @@ _make_gcp_logging_client, get_bigquery_client, ) +from datahub.ingestion.source.state.redundant_run_skip_handler import ( + RedundantUsageRunSkipHandler, +) from datahub.ingestion.source.usage.usage_common import ( TOTAL_BUDGET_FOR_QUERY_LIST, make_usage_workunit, ) +from datahub.ingestion.source_report.ingestion_stage import ( + USAGE_EXTRACTION_INGESTION, + USAGE_EXTRACTION_OPERATIONAL_STATS, + USAGE_EXTRACTION_USAGE_AGGREGATION, +) from datahub.metadata.schema_classes import OperationClass, OperationTypeClass from datahub.utilities.bigquery_sql_parser import BigQuerySQLParser from datahub.utilities.file_backed_collections import ConnectionWrapper, FileBackedDict @@ -377,6 +388,7 @@ def __init__( config: BigQueryV2Config, report: BigQueryV2Report, dataset_urn_builder: Callable[[BigQueryTableRef], str], + redundant_run_skip_handler: Optional[RedundantUsageRunSkipHandler] = None, ): self.config: BigQueryV2Config = config self.report: BigQueryV2Report = report @@ -384,6 +396,20 @@ def __init__( # Replace hash of query with uuid if there are hash conflicts self.uuid_to_query: Dict[str, str] = {} + self.redundant_run_skip_handler = redundant_run_skip_handler + self.start_time, self.end_time = ( + self.report.usage_start_time, + self.report.usage_end_time, + ) = self.get_time_window() + + def get_time_window(self) -> Tuple[datetime, datetime]: + if self.redundant_run_skip_handler: + return self.redundant_run_skip_handler.suggest_run_time_window( + self.config.start_time, self.config.end_time + ) + else: + return self.config.start_time, self.config.end_time + def _is_table_allowed(self, table_ref: Optional[BigQueryTableRef]) -> bool: return ( table_ref is not None @@ -391,12 +417,39 @@ def _is_table_allowed(self, table_ref: Optional[BigQueryTableRef]) -> bool: and self.config.table_pattern.allowed(table_ref.table_identifier.table) ) + def _should_ingest_usage(self) -> bool: + if ( + self.redundant_run_skip_handler + and self.redundant_run_skip_handler.should_skip_this_run( + cur_start_time=self.config.start_time, + cur_end_time=self.config.end_time, + ) + ): + # Skip this run + self.report.report_warning( + "usage-extraction", + "Skip this run as there was already a run for current ingestion window.", + ) + return False + + return True + def get_usage_workunits( self, projects: Iterable[str], table_refs: Collection[str] ) -> Iterable[MetadataWorkUnit]: + if not self._should_ingest_usage(): + return events = self._get_usage_events(projects) yield from self._get_workunits_internal(events, table_refs) + if self.redundant_run_skip_handler: + # Update the checkpoint state for this run. + self.redundant_run_skip_handler.update_state( + self.config.start_time, + self.config.end_time, + self.config.bucket_duration, + ) + def _get_workunits_internal( self, events: Iterable[AuditEvent], table_refs: Collection[str] ) -> Iterable[MetadataWorkUnit]: @@ -413,7 +466,11 @@ def _get_workunits_internal( yield from auto_empty_dataset_usage_statistics( self._generate_usage_workunits(usage_state), - config=self.config, + config=BaseTimeWindowConfig( + start_time=self.start_time, + end_time=self.end_time, + bucket_duration=self.config.bucket_duration, + ), dataset_urns={ self.dataset_urn_builder(BigQueryTableRef.from_string_name(ref)) for ref in table_refs @@ -423,6 +480,7 @@ def _get_workunits_internal( except Exception as e: logger.error("Error processing usage", exc_info=True) self.report.report_warning("usage-ingestion", str(e)) + self.report_status("usage-ingestion", False) def generate_read_events_from_query( self, query_event_on_view: QueryEvent @@ -496,7 +554,7 @@ def _ingest_events( def _generate_operational_workunits( self, usage_state: BigQueryUsageState, table_refs: Collection[str] ) -> Iterable[MetadataWorkUnit]: - self.report.set_ingestion_stage("*", "Usage Extraction Operational Stats") + self.report.set_ingestion_stage("*", USAGE_EXTRACTION_OPERATIONAL_STATS) for audit_event in usage_state.standalone_events(): try: operational_wu = self._create_operation_workunit( @@ -515,7 +573,7 @@ def _generate_operational_workunits( def _generate_usage_workunits( self, usage_state: BigQueryUsageState ) -> Iterable[MetadataWorkUnit]: - self.report.set_ingestion_stage("*", "Usage Extraction Usage Aggregation") + self.report.set_ingestion_stage("*", USAGE_EXTRACTION_USAGE_AGGREGATION) top_n = ( self.config.usage.top_n_queries if self.config.usage.include_top_n_queries @@ -560,7 +618,7 @@ def _get_usage_events(self, projects: Iterable[str]) -> Iterable[AuditEvent]: with PerfTimer() as timer: try: self.report.set_ingestion_stage( - project_id, "Usage Extraction Ingestion" + project_id, USAGE_EXTRACTION_INGESTION ) yield from self._get_parsed_bigquery_log_events(project_id) except Exception as e: @@ -570,6 +628,7 @@ def _get_usage_events(self, projects: Iterable[str]) -> Iterable[AuditEvent]: ) self.report.usage_failed_extraction.append(project_id) self.report.report_warning(f"usage-extraction-{project_id}", str(e)) + self.report_status(f"usage-extraction-{project_id}", False) self.report.usage_extraction_sec[project_id] = round( timer.elapsed_seconds(), 2 @@ -583,7 +642,7 @@ def _store_usage_event( ) -> bool: """Stores a usage event in `usage_state` and returns if an event was successfully processed.""" if event.read_event and ( - self.config.start_time <= event.read_event.timestamp < self.config.end_time + self.start_time <= event.read_event.timestamp < self.end_time ): resource = event.read_event.resource if str(resource) not in table_refs: @@ -623,14 +682,15 @@ def _get_exported_bigquery_audit_metadata( limit: Optional[int] = None, ) -> Iterable[BigQueryAuditMetadata]: if self.config.bigquery_audit_metadata_datasets is None: + self.report.bigquery_audit_metadata_datasets_missing = True return - corrected_start_time = self.config.start_time - self.config.max_query_duration + corrected_start_time = self.start_time - self.config.max_query_duration start_time = corrected_start_time.strftime(BQ_DATETIME_FORMAT) start_date = corrected_start_time.strftime(BQ_DATE_SHARD_FORMAT) self.report.audit_start_time = start_time - corrected_end_time = self.config.end_time + self.config.max_query_duration + corrected_end_time = self.end_time + self.config.max_query_duration end_time = corrected_end_time.strftime(BQ_DATETIME_FORMAT) end_date = corrected_end_time.strftime(BQ_DATE_SHARD_FORMAT) self.report.audit_end_time = end_time @@ -664,7 +724,6 @@ def _get_exported_bigquery_audit_metadata( def _get_bigquery_log_entries_via_gcp_logging( self, client: GCPLoggingClient, limit: Optional[int] = None ) -> Iterable[AuditLogEntry]: - filter = self._generate_filter(BQ_AUDIT_V2) logger.debug(filter) @@ -707,11 +766,11 @@ def _generate_filter(self, audit_templates: Dict[str, str]) -> str: # handle the case where the read happens within our time range but the query # completion event is delayed and happens after the configured end time. - start_time = (self.config.start_time - self.config.max_query_duration).strftime( + start_time = (self.start_time - self.config.max_query_duration).strftime( BQ_DATETIME_FORMAT ) self.report.log_entry_start_time = start_time - end_time = (self.config.end_time + self.config.max_query_duration).strftime( + end_time = (self.end_time + self.config.max_query_duration).strftime( BQ_DATETIME_FORMAT ) self.report.log_entry_end_time = end_time @@ -1046,3 +1105,7 @@ def test_capability(self, project_id: str) -> None: for entry in self._get_parsed_bigquery_log_events(project_id, limit=1): logger.debug(f"Connection test got one {entry}") return + + def report_status(self, step: str, status: bool) -> None: + if self.redundant_run_skip_handler: + self.redundant_run_skip_handler.report_current_run_status(step, status) diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py index 268de5832559a..c8623798f6937 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/lineage.py @@ -2,6 +2,7 @@ import traceback from collections import defaultdict from dataclasses import dataclass, field +from datetime import datetime from enum import Enum from typing import Dict, List, Optional, Set, Tuple, Union from urllib.parse import urlparse @@ -24,6 +25,9 @@ RedshiftView, ) from datahub.ingestion.source.redshift.report import RedshiftReport +from datahub.ingestion.source.state.redundant_run_skip_handler import ( + RedundantLineageRunSkipHandler, +) from datahub.metadata.com.linkedin.pegasus2avro.dataset import UpstreamLineage from datahub.metadata.schema_classes import ( DatasetLineageTypeClass, @@ -79,11 +83,27 @@ def __init__( self, config: RedshiftConfig, report: RedshiftReport, + redundant_run_skip_handler: Optional[RedundantLineageRunSkipHandler] = None, ): self.config = config self.report = report self._lineage_map: Dict[str, LineageItem] = defaultdict() + self.redundant_run_skip_handler = redundant_run_skip_handler + self.start_time, self.end_time = ( + self.report.lineage_start_time, + self.report.lineage_end_time, + ) = self.get_time_window() + + def get_time_window(self) -> Tuple[datetime, datetime]: + if self.redundant_run_skip_handler: + self.report.stateful_lineage_ingestion_enabled = True + return self.redundant_run_skip_handler.suggest_run_time_window( + self.config.start_time, self.config.end_time + ) + else: + return self.config.start_time, self.config.end_time + def warn(self, log: logging.Logger, key: str, reason: str) -> None: self.report.report_warning(key, reason) log.warning(f"{key} => {reason}") @@ -263,6 +283,7 @@ def _populate_lineage_map( f"extract-{lineage_type.name}", f"Error was {e}, {traceback.format_exc()}", ) + self.report_status(f"extract-{lineage_type.name}", False) def _get_target_lineage( self, @@ -352,24 +373,24 @@ def populate_lineage( # Populate table level lineage by parsing table creating sqls query = RedshiftQuery.list_insert_create_queries_sql( db_name=database, - start_time=self.config.start_time, - end_time=self.config.end_time, + start_time=self.start_time, + end_time=self.end_time, ) populate_calls.append((query, LineageCollectorType.QUERY_SQL_PARSER)) elif self.config.table_lineage_mode == LineageMode.MIXED: # Populate table level lineage by parsing table creating sqls query = RedshiftQuery.list_insert_create_queries_sql( db_name=database, - start_time=self.config.start_time, - end_time=self.config.end_time, + start_time=self.start_time, + end_time=self.end_time, ) populate_calls.append((query, LineageCollectorType.QUERY_SQL_PARSER)) # Populate table level lineage by getting upstream tables from stl_scan redshift table query = RedshiftQuery.stl_scan_based_lineage_query( db_name=database, - start_time=self.config.start_time, - end_time=self.config.end_time, + start_time=self.start_time, + end_time=self.end_time, ) populate_calls.append((query, LineageCollectorType.QUERY_SCAN)) @@ -385,16 +406,16 @@ def populate_lineage( if self.config.include_copy_lineage: query = RedshiftQuery.list_copy_commands_sql( db_name=database, - start_time=self.config.start_time, - end_time=self.config.end_time, + start_time=self.start_time, + end_time=self.end_time, ) populate_calls.append((query, LineageCollectorType.COPY)) if self.config.include_unload_lineage: query = RedshiftQuery.list_unload_commands_sql( db_name=database, - start_time=self.config.start_time, - end_time=self.config.end_time, + start_time=self.start_time, + end_time=self.end_time, ) populate_calls.append((query, LineageCollectorType.UNLOAD)) @@ -469,3 +490,7 @@ def get_lineage( return None return UpstreamLineage(upstreams=upstream_lineage), {} + + def report_status(self, step: str, status: bool) -> None: + if self.redundant_run_skip_handler: + self.redundant_run_skip_handler.report_current_run_status(step, status) diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py index 29f0808a6ca7d..e8a8ff976afa6 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/redshift.py @@ -63,7 +63,8 @@ ) from datahub.ingestion.source.state.profiling_state_handler import ProfilingHandler from datahub.ingestion.source.state.redundant_run_skip_handler import ( - RedundantRunSkipHandler, + RedundantLineageRunSkipHandler, + RedundantUsageRunSkipHandler, ) from datahub.ingestion.source.state.stale_entity_removal_handler import ( StaleEntityRemovalHandler, @@ -71,6 +72,11 @@ from datahub.ingestion.source.state.stateful_ingestion_base import ( StatefulIngestionSourceBase, ) +from datahub.ingestion.source_report.ingestion_stage import ( + LINEAGE_EXTRACTION, + METADATA_EXTRACTION, + PROFILING, +) from datahub.metadata.com.linkedin.pegasus2avro.common import SubTypes, TimeStamp from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( DatasetProperties, @@ -95,7 +101,6 @@ from datahub.utilities.mapping import Constants from datahub.utilities.perf_timer import PerfTimer from datahub.utilities.registries.domain_registry import DomainRegistry -from datahub.utilities.time import datetime_to_ts_millis logger: logging.Logger = logging.getLogger(__name__) @@ -297,15 +302,19 @@ def __init__(self, config: RedshiftConfig, ctx: PipelineContext): cached_domains=list(self.config.domain.keys()), graph=self.ctx.graph ) - self.redundant_run_skip_handler = RedundantRunSkipHandler( - source=self, - config=self.config, - pipeline_name=self.ctx.pipeline_name, - run_id=self.ctx.run_id, - ) + self.redundant_lineage_run_skip_handler: Optional[ + RedundantLineageRunSkipHandler + ] = None + if self.config.enable_stateful_lineage_ingestion: + self.redundant_lineage_run_skip_handler = RedundantLineageRunSkipHandler( + source=self, + config=self.config, + pipeline_name=self.ctx.pipeline_name, + run_id=self.ctx.run_id, + ) self.profiling_state_handler: Optional[ProfilingHandler] = None - if self.config.store_last_profiling_timestamps: + if self.config.enable_stateful_profiling: self.profiling_state_handler = ProfilingHandler( source=self, config=self.config, @@ -317,6 +326,8 @@ def __init__(self, config: RedshiftConfig, ctx: PipelineContext): self.db_views: Dict[str, Dict[str, List[RedshiftView]]] = {} self.db_schemas: Dict[str, Dict[str, RedshiftSchema]] = {} + self.add_config_to_report() + @classmethod def create(cls, config_dict, ctx): config = RedshiftConfig.parse_obj(config_dict) @@ -367,7 +378,7 @@ def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit connection = RedshiftSource.get_redshift_connection(self.config) database = get_db_name(self.config) logger.info(f"Processing db {self.config.database} with name {database}") - # self.add_config_to_report() + self.report.report_ingestion_stage_start(METADATA_EXTRACTION) self.db_tables[database] = defaultdict() self.db_views[database] = defaultdict() self.db_schemas.setdefault(database, {}) @@ -388,17 +399,8 @@ def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit all_tables = self.get_all_tables() - if ( - self.config.store_last_lineage_extraction_timestamp - or self.config.store_last_usage_extraction_timestamp - ): - # Update the checkpoint state for this run. - self.redundant_run_skip_handler.update_state( - start_time_millis=datetime_to_ts_millis(self.config.start_time), - end_time_millis=datetime_to_ts_millis(self.config.end_time), - ) - if self.config.include_table_lineage or self.config.include_copy_lineage: + self.report.report_ingestion_stage_start(LINEAGE_EXTRACTION) yield from self.extract_lineage( connection=connection, all_tables=all_tables, database=database ) @@ -409,6 +411,7 @@ def get_workunits_internal(self) -> Iterable[Union[MetadataWorkUnit, SqlWorkUnit ) if self.config.is_profiling_enabled(): + self.report.report_ingestion_stage_start(PROFILING) profiler = RedshiftProfiler( config=self.config, report=self.report, @@ -841,26 +844,26 @@ def extract_usage( database: str, all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]], ) -> Iterable[MetadataWorkUnit]: - if ( - self.config.store_last_usage_extraction_timestamp - and self.redundant_run_skip_handler.should_skip_this_run( - cur_start_time_millis=datetime_to_ts_millis(self.config.start_time) - ) - ): - # Skip this run - self.report.report_warning( - "usage-extraction", - f"Skip this run as there was a run later than the current start time: {self.config.start_time}", - ) - return - with PerfTimer() as timer: - yield from RedshiftUsageExtractor( + redundant_usage_run_skip_handler: Optional[ + RedundantUsageRunSkipHandler + ] = None + if self.config.enable_stateful_usage_ingestion: + redundant_usage_run_skip_handler = RedundantUsageRunSkipHandler( + source=self, + config=self.config, + pipeline_name=self.ctx.pipeline_name, + run_id=self.ctx.run_id, + ) + usage_extractor = RedshiftUsageExtractor( config=self.config, connection=connection, report=self.report, dataset_urn_builder=self.gen_dataset_urn, - ).get_usage_workunits(all_tables=all_tables) + redundant_run_skip_handler=redundant_usage_run_skip_handler, + ) + + yield from usage_extractor.get_usage_workunits(all_tables=all_tables) self.report.usage_extraction_sec[database] = round( timer.elapsed_seconds(), 2 @@ -872,22 +875,13 @@ def extract_lineage( database: str, all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]], ) -> Iterable[MetadataWorkUnit]: - if ( - self.config.store_last_lineage_extraction_timestamp - and self.redundant_run_skip_handler.should_skip_this_run( - cur_start_time_millis=datetime_to_ts_millis(self.config.start_time) - ) - ): - # Skip this run - self.report.report_warning( - "lineage-extraction", - f"Skip this run as there was a run later than the current start time: {self.config.start_time}", - ) + if not self._should_ingest_lineage(): return self.lineage_extractor = RedshiftLineageExtractor( config=self.config, report=self.report, + redundant_run_skip_handler=self.redundant_lineage_run_skip_handler, ) with PerfTimer() as timer: @@ -900,6 +894,29 @@ def extract_lineage( ) yield from self.generate_lineage(database) + if self.redundant_lineage_run_skip_handler: + # Update the checkpoint state for this run. + self.redundant_lineage_run_skip_handler.update_state( + self.config.start_time, self.config.end_time + ) + + def _should_ingest_lineage(self) -> bool: + if ( + self.redundant_lineage_run_skip_handler + and self.redundant_lineage_run_skip_handler.should_skip_this_run( + cur_start_time=self.config.start_time, + cur_end_time=self.config.end_time, + ) + ): + # Skip this run + self.report.report_warning( + "lineage-extraction", + "Skip this run as there was already a run for current ingestion window.", + ) + return False + + return True + def generate_lineage(self, database: str) -> Iterable[MetadataWorkUnit]: assert self.lineage_extractor @@ -940,3 +957,15 @@ def generate_lineage(self, database: str) -> Iterable[MetadataWorkUnit]: yield from gen_lineage( dataset_urn, lineage_info, self.config.incremental_lineage ) + + def add_config_to_report(self): + self.report.stateful_lineage_ingestion_enabled = ( + self.config.enable_stateful_lineage_ingestion + ) + self.report.stateful_usage_ingestion_enabled = ( + self.config.enable_stateful_usage_ingestion + ) + self.report.window_start_time, self.report.window_end_time = ( + self.config.start_time, + self.config.end_time, + ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py index 319a731a14cef..b845580f35939 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/report.py @@ -1,13 +1,16 @@ from dataclasses import dataclass, field +from datetime import datetime from typing import Dict, Optional from datahub.ingestion.source.sql.sql_generic_profiler import ProfilingSqlReport +from datahub.ingestion.source_report.ingestion_stage import IngestionStageReport +from datahub.ingestion.source_report.time_window import BaseTimeWindowReport from datahub.utilities.lossy_collections import LossyDict from datahub.utilities.stats_collections import TopKDict @dataclass -class RedshiftReport(ProfilingSqlReport): +class RedshiftReport(ProfilingSqlReport, IngestionStageReport, BaseTimeWindowReport): num_usage_workunits_emitted: Optional[int] = None num_operational_stats_workunits_emitted: Optional[int] = None upstream_lineage: LossyDict = field(default_factory=LossyDict) @@ -32,5 +35,13 @@ class RedshiftReport(ProfilingSqlReport): num_lineage_dropped_query_parser: int = 0 num_lineage_dropped_not_support_copy_path: int = 0 + lineage_start_time: Optional[datetime] = None + lineage_end_time: Optional[datetime] = None + stateful_lineage_ingestion_enabled: bool = False + + usage_start_time: Optional[datetime] = None + usage_end_time: Optional[datetime] = None + stateful_usage_ingestion_enabled: bool = False + def report_dropped(self, key: str) -> None: self.filtered.append(key) diff --git a/metadata-ingestion/src/datahub/ingestion/source/redshift/usage.py b/metadata-ingestion/src/datahub/ingestion/source/redshift/usage.py index 653b41d690e48..953f0edd7c2bb 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/redshift/usage.py +++ b/metadata-ingestion/src/datahub/ingestion/source/redshift/usage.py @@ -2,7 +2,7 @@ import logging import time from datetime import datetime -from typing import Callable, Dict, Iterable, List, Optional, Union +from typing import Callable, Dict, Iterable, List, Optional, Tuple, Union import pydantic.error_wrappers import redshift_connector @@ -10,7 +10,10 @@ from pydantic.main import BaseModel import datahub.emitter.mce_builder as builder -from datahub.configuration.time_window_config import get_time_bucket +from datahub.configuration.time_window_config import ( + BaseTimeWindowConfig, + get_time_bucket, +) from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.source_helpers import auto_empty_dataset_usage_statistics from datahub.ingestion.api.workunit import MetadataWorkUnit @@ -20,7 +23,14 @@ RedshiftView, ) from datahub.ingestion.source.redshift.report import RedshiftReport +from datahub.ingestion.source.state.redundant_run_skip_handler import ( + RedundantUsageRunSkipHandler, +) from datahub.ingestion.source.usage.usage_common import GenericAggregatedDataset +from datahub.ingestion.source_report.ingestion_stage import ( + USAGE_EXTRACTION_OPERATIONAL_STATS, + USAGE_EXTRACTION_USAGE_AGGREGATION, +) from datahub.metadata.schema_classes import OperationClass, OperationTypeClass from datahub.utilities.perf_timer import PerfTimer @@ -170,18 +180,56 @@ def __init__( connection: redshift_connector.Connection, report: RedshiftReport, dataset_urn_builder: Callable[[str], str], + redundant_run_skip_handler: Optional[RedundantUsageRunSkipHandler] = None, ): self.config = config self.report = report self.connection = connection self.dataset_urn_builder = dataset_urn_builder + self.redundant_run_skip_handler = redundant_run_skip_handler + self.start_time, self.end_time = ( + self.report.usage_start_time, + self.report.usage_end_time, + ) = self.get_time_window() + + def get_time_window(self) -> Tuple[datetime, datetime]: + if self.redundant_run_skip_handler: + return self.redundant_run_skip_handler.suggest_run_time_window( + self.config.start_time, self.config.end_time + ) + else: + return self.config.start_time, self.config.end_time + + def _should_ingest_usage(self): + if ( + self.redundant_run_skip_handler + and self.redundant_run_skip_handler.should_skip_this_run( + cur_start_time=self.config.start_time, + cur_end_time=self.config.end_time, + ) + ): + # Skip this run + self.report.report_warning( + "usage-extraction", + "Skip this run as there was already a run for current ingestion window.", + ) + return False + + return True + def get_usage_workunits( self, all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]] ) -> Iterable[MetadataWorkUnit]: + if not self._should_ingest_usage(): + return yield from auto_empty_dataset_usage_statistics( self._get_workunits_internal(all_tables), - config=self.config, + config=BaseTimeWindowConfig( + start_time=self.start_time, + end_time=self.end_time, + bucket_duration=self.config.bucket_duration, + ), dataset_urns={ self.dataset_urn_builder(f"{database}.{schema}.{table.name}") for database in all_tables @@ -190,6 +238,14 @@ def get_usage_workunits( }, ) + if self.redundant_run_skip_handler: + # Update the checkpoint state for this run. + self.redundant_run_skip_handler.update_state( + self.config.start_time, + self.config.end_time, + self.config.bucket_duration, + ) + def _get_workunits_internal( self, all_tables: Dict[str, Dict[str, List[Union[RedshiftView, RedshiftTable]]]] ) -> Iterable[MetadataWorkUnit]: @@ -198,6 +254,7 @@ def _get_workunits_internal( self.report.num_operational_stats_skipped = 0 if self.config.include_operational_stats: + self.report.report_ingestion_stage_start(USAGE_EXTRACTION_OPERATIONAL_STATS) with PerfTimer() as timer: # Generate operation aspect workunits yield from self._gen_operation_aspect_workunits( @@ -208,9 +265,10 @@ def _get_workunits_internal( ] = round(timer.elapsed_seconds(), 2) # Generate aggregate events + self.report.report_ingestion_stage_start(USAGE_EXTRACTION_USAGE_AGGREGATION) query: str = REDSHIFT_USAGE_QUERY_TEMPLATE.format( - start_time=self.config.start_time.strftime(REDSHIFT_DATETIME_FORMAT), - end_time=self.config.end_time.strftime(REDSHIFT_DATETIME_FORMAT), + start_time=self.start_time.strftime(REDSHIFT_DATETIME_FORMAT), + end_time=self.end_time.strftime(REDSHIFT_DATETIME_FORMAT), database=self.config.database, ) access_events_iterable: Iterable[ @@ -236,8 +294,8 @@ def _gen_operation_aspect_workunits( ) -> Iterable[MetadataWorkUnit]: # Generate access events query: str = REDSHIFT_OPERATION_ASPECT_QUERY_TEMPLATE.format( - start_time=self.config.start_time.strftime(REDSHIFT_DATETIME_FORMAT), - end_time=self.config.end_time.strftime(REDSHIFT_DATETIME_FORMAT), + start_time=self.start_time.strftime(REDSHIFT_DATETIME_FORMAT), + end_time=self.end_time.strftime(REDSHIFT_DATETIME_FORMAT), ) access_events_iterable: Iterable[ RedshiftAccessEvent @@ -392,3 +450,7 @@ def _make_usage_stat(self, agg: AggregatedDataset) -> MetadataWorkUnit: self.config.format_sql_queries, self.config.include_top_n_queries, ) + + def report_status(self, step: str, status: bool) -> None: + if self.redundant_run_skip_handler: + self.redundant_run_skip_handler.report_current_run_status(step, status) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py index a7d946e99d806..af99faf6e6396 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_config.py @@ -14,6 +14,7 @@ ClassificationSourceConfigMixin, ) from datahub.ingestion.source.state.stateful_ingestion_base import ( + StatefulLineageConfigMixin, StatefulProfilingConfigMixin, StatefulUsageConfigMixin, ) @@ -72,6 +73,7 @@ def source_database(self) -> DatabaseId: class SnowflakeV2Config( SnowflakeConfig, SnowflakeUsageConfig, + StatefulLineageConfigMixin, StatefulUsageConfigMixin, StatefulProfilingConfigMixin, ClassificationSourceConfigMixin, diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py index c338c427aefbf..cee3a2926520f 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_lineage_v2.py @@ -2,6 +2,7 @@ import logging from collections import defaultdict from dataclasses import dataclass +from datetime import datetime from typing import ( Callable, Collection, @@ -35,6 +36,9 @@ SnowflakePermissionError, SnowflakeQueryMixin, ) +from datahub.ingestion.source.state.redundant_run_skip_handler import ( + RedundantLineageRunSkipHandler, +) from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( FineGrainedLineage, FineGrainedLineageDownstreamType, @@ -48,10 +52,15 @@ SqlParsingResult, sqlglot_lineage, ) +from datahub.utilities.time import ts_millis_to_datetime from datahub.utilities.urns.dataset_urn import DatasetUrn logger: logging.Logger = logging.getLogger(__name__) +EXTERNAL_LINEAGE = "external_lineage" +TABLE_LINEAGE = "table_lineage" +VIEW_LINEAGE = "view_lineage" + @dataclass(frozen=True) class SnowflakeColumnId: @@ -81,6 +90,7 @@ def __init__( config: SnowflakeV2Config, report: SnowflakeV2Report, dataset_urn_builder: Callable[[str], str], + redundant_run_skip_handler: Optional[RedundantLineageRunSkipHandler], ) -> None: self._external_lineage_map: Dict[str, Set[str]] = defaultdict(set) self.config = config @@ -89,6 +99,28 @@ def __init__( self.dataset_urn_builder = dataset_urn_builder self.connection: Optional[SnowflakeConnection] = None + self.redundant_run_skip_handler = redundant_run_skip_handler + self.start_time, self.end_time = ( + self.report.lineage_start_time, + self.report.lineage_end_time, + ) = self.get_time_window() + + def get_time_window(self) -> Tuple[datetime, datetime]: + if self.redundant_run_skip_handler: + return self.redundant_run_skip_handler.suggest_run_time_window( + self.config.start_time + if not self.config.ignore_start_time_lineage + else ts_millis_to_datetime(0), + self.config.end_time, + ) + else: + return ( + self.config.start_time + if not self.config.ignore_start_time_lineage + else ts_millis_to_datetime(0), + self.config.end_time, + ) + def get_workunits( self, discovered_tables: List[str], @@ -96,6 +128,9 @@ def get_workunits( schema_resolver: SchemaResolver, view_definitions: MutableMapping[str, str], ) -> Iterable[MetadataWorkUnit]: + if not self._should_ingest_lineage(): + return + self.connection = self.create_connection() if self.connection is None: return @@ -117,6 +152,15 @@ def get_workunits( if self._external_lineage_map: # Some external lineage is yet to be emitted yield from self.get_table_external_upstream_workunits() + if self.redundant_run_skip_handler: + # Update the checkpoint state for this run. + self.redundant_run_skip_handler.update_state( + self.config.start_time + if not self.config.ignore_start_time_lineage + else ts_millis_to_datetime(0), + self.config.end_time, + ) + def get_table_external_upstream_workunits(self) -> Iterable[MetadataWorkUnit]: for ( dataset_name, @@ -140,12 +184,14 @@ def get_table_upstream_workunits( else: with PerfTimer() as timer: results = self._fetch_upstream_lineages_for_tables() - self.report.table_lineage_query_secs = timer.elapsed_seconds() - if not results: - return + if not results: + return - yield from self._gen_workunits_from_query_result(discovered_tables, results) + yield from self._gen_workunits_from_query_result( + discovered_tables, results + ) + self.report.table_lineage_query_secs = timer.elapsed_seconds() logger.info( f"Upstream lineage detected for {self.report.num_tables_with_upstreams} tables.", ) @@ -212,12 +258,14 @@ def get_view_upstream_workunits( with PerfTimer() as timer: results = self._fetch_upstream_lineages_for_views() - self.report.view_upstream_lineage_query_secs = timer.elapsed_seconds() - if results: - yield from self._gen_workunits_from_query_result( - set(discovered_views) - views_processed, results, upstream_for_view=True - ) + if results: + yield from self._gen_workunits_from_query_result( + set(discovered_views) - views_processed, + results, + upstream_for_view=True, + ) + self.report.view_upstream_lineage_query_secs = timer.elapsed_seconds() logger.info( f"Upstream lineage detected for {self.report.num_views_with_upstreams} views.", ) @@ -377,6 +425,7 @@ def _populate_external_lineage_from_show_query(self, discovered_tables): "external_lineage", f"Populating external table lineage from Snowflake failed due to error {e}.", ) + self.report_status(EXTERNAL_LINEAGE, False) # Handles the case where a table is populated from an external stage/s3 location via copy. # Eg: copy into category_english from @external_s3_stage; @@ -386,10 +435,8 @@ def _populate_external_lineage_from_copy_history( self, discovered_tables: List[str] ) -> None: query: str = SnowflakeQuery.copy_lineage_history( - start_time_millis=int(self.config.start_time.timestamp() * 1000) - if not self.config.ignore_start_time_lineage - else 0, - end_time_millis=int(self.config.end_time.timestamp() * 1000), + start_time_millis=int(self.start_time.timestamp() * 1000), + end_time_millis=int(self.end_time.timestamp() * 1000), downstreams_deny_pattern=self.config.temporary_tables_pattern, ) @@ -406,6 +453,7 @@ def _populate_external_lineage_from_copy_history( "external_lineage", f"Populating table external lineage from Snowflake failed due to error {e}.", ) + self.report_status(EXTERNAL_LINEAGE, False) def _process_external_lineage_result_row(self, db_row, discovered_tables): # key is the down-stream table name @@ -429,10 +477,8 @@ def _process_external_lineage_result_row(self, db_row, discovered_tables): def _fetch_upstream_lineages_for_tables(self): query: str = SnowflakeQuery.table_to_table_lineage_history_v2( - start_time_millis=int(self.config.start_time.timestamp() * 1000) - if not self.config.ignore_start_time_lineage - else 0, - end_time_millis=int(self.config.end_time.timestamp() * 1000), + start_time_millis=int(self.start_time.timestamp() * 1000), + end_time_millis=int(self.end_time.timestamp() * 1000), upstreams_deny_pattern=self.config.temporary_tables_pattern, include_view_lineage=self.config.include_view_lineage, include_column_lineage=self.config.include_column_lineage, @@ -450,6 +496,7 @@ def _fetch_upstream_lineages_for_tables(self): "table-upstream-lineage", f"Extracting lineage from Snowflake failed due to error {e}.", ) + self.report_status(TABLE_LINEAGE, False) def map_query_result_upstreams(self, upstream_tables): if not upstream_tables: @@ -535,6 +582,7 @@ def _fetch_upstream_lineages_for_views(self): "view-upstream-lineage", f"Extracting the upstream view lineage from Snowflake failed due to error {e}.", ) + self.report_status(VIEW_LINEAGE, False) def build_finegrained_lineage( self, @@ -596,3 +644,25 @@ def get_external_upstreams(self, external_lineage: Set[str]) -> List[UpstreamCla ) external_upstreams.append(external_upstream_table) return external_upstreams + + def _should_ingest_lineage(self) -> bool: + if ( + self.redundant_run_skip_handler + and self.redundant_run_skip_handler.should_skip_this_run( + cur_start_time=self.config.start_time + if not self.config.ignore_start_time_lineage + else ts_millis_to_datetime(0), + cur_end_time=self.config.end_time, + ) + ): + # Skip this run + self.report.report_warning( + "lineage-extraction", + "Skip this run as there was already a run for current ingestion window.", + ) + return False + return True + + def report_status(self, step: str, status: bool) -> None: + if self.redundant_run_skip_handler: + self.redundant_run_skip_handler.report_current_run_status(step, status) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py index 8003de8286288..f67b359dedb11 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_report.py @@ -1,16 +1,71 @@ from dataclasses import dataclass, field -from typing import Dict, MutableSet, Optional +from datetime import datetime +from typing import Dict, List, MutableSet, Optional from datahub.ingestion.glossary.classification_mixin import ClassificationReportMixin from datahub.ingestion.source.snowflake.constants import SnowflakeEdition from datahub.ingestion.source.sql.sql_generic_profiler import ProfilingSqlReport -from datahub.ingestion.source_report.sql.snowflake import SnowflakeReport -from datahub.ingestion.source_report.usage.snowflake_usage import SnowflakeUsageReport +from datahub.ingestion.source.state.stateful_ingestion_base import ( + StatefulIngestionReport, +) +from datahub.ingestion.source_report.ingestion_stage import IngestionStageReport +from datahub.ingestion.source_report.time_window import BaseTimeWindowReport + + +@dataclass +class SnowflakeUsageReport: + min_access_history_time: Optional[datetime] = None + max_access_history_time: Optional[datetime] = None + access_history_range_query_secs: float = -1 + access_history_query_secs: float = -1 + + rows_processed: int = 0 + rows_missing_query_text: int = 0 + rows_zero_base_objects_accessed: int = 0 + rows_zero_direct_objects_accessed: int = 0 + rows_missing_email: int = 0 + rows_parsing_error: int = 0 + + usage_start_time: Optional[datetime] = None + usage_end_time: Optional[datetime] = None + stateful_usage_ingestion_enabled: bool = False + + +@dataclass +class SnowflakeReport(ProfilingSqlReport, BaseTimeWindowReport): + num_table_to_table_edges_scanned: int = 0 + num_table_to_view_edges_scanned: int = 0 + num_view_to_table_edges_scanned: int = 0 + num_external_table_edges_scanned: int = 0 + ignore_start_time_lineage: Optional[bool] = None + upstream_lineage_in_report: Optional[bool] = None + upstream_lineage: Dict[str, List[str]] = field(default_factory=dict) + + lineage_start_time: Optional[datetime] = None + lineage_end_time: Optional[datetime] = None + stateful_lineage_ingestion_enabled: bool = False + + cleaned_account_id: str = "" + run_ingestion: bool = False + + # https://community.snowflake.com/s/topic/0TO0Z000000Unu5WAC/releases + saas_version: Optional[str] = None + default_warehouse: Optional[str] = None + default_db: Optional[str] = None + default_schema: Optional[str] = None + role: str = "" + + profile_if_updated_since: Optional[datetime] = None + profile_candidates: Dict[str, List[str]] = field(default_factory=dict) @dataclass class SnowflakeV2Report( - SnowflakeReport, SnowflakeUsageReport, ProfilingSqlReport, ClassificationReportMixin + SnowflakeReport, + SnowflakeUsageReport, + StatefulIngestionReport, + ClassificationReportMixin, + IngestionStageReport, ): account_locator: Optional[str] = None region: Optional[str] = None @@ -94,3 +149,6 @@ def _is_tag_scanned(self, tag_name: str) -> bool: def report_tag_processed(self, tag_name: str) -> None: self._processed_tags.add(tag_name) + + def set_ingestion_stage(self, database: str, stage: str) -> None: + self.report_ingestion_stage_start(f"{database}: {stage}") diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py index f8dfa612952d8..f79be7174dbd9 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_usage_v2.py @@ -2,11 +2,12 @@ import logging import time from datetime import datetime, timezone -from typing import Any, Callable, Dict, Iterable, List, Optional +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple import pydantic from snowflake.connector import SnowflakeConnection +from datahub.configuration.time_window_config import BaseTimeWindowConfig from datahub.emitter.mce_builder import make_user_urn from datahub.emitter.mcp import MetadataChangeProposalWrapper from datahub.ingestion.api.source_helpers import auto_empty_dataset_usage_statistics @@ -21,7 +22,14 @@ SnowflakePermissionError, SnowflakeQueryMixin, ) +from datahub.ingestion.source.state.redundant_run_skip_handler import ( + RedundantUsageRunSkipHandler, +) from datahub.ingestion.source.usage.usage_common import TOTAL_BUDGET_FOR_QUERY_LIST +from datahub.ingestion.source_report.ingestion_stage import ( + USAGE_EXTRACTION_OPERATIONAL_STATS, + USAGE_EXTRACTION_USAGE_AGGREGATION, +) from datahub.metadata.com.linkedin.pegasus2avro.dataset import ( DatasetFieldUsageCounts, DatasetUsageStatistics, @@ -107,6 +115,7 @@ def __init__( config: SnowflakeV2Config, report: SnowflakeV2Report, dataset_urn_builder: Callable[[str], str], + redundant_run_skip_handler: Optional[RedundantUsageRunSkipHandler], ) -> None: self.config: SnowflakeV2Config = config self.report: SnowflakeV2Report = report @@ -114,9 +123,28 @@ def __init__( self.logger = logger self.connection: Optional[SnowflakeConnection] = None + self.redundant_run_skip_handler = redundant_run_skip_handler + self.start_time, self.end_time = ( + self.report.usage_start_time, + self.report.usage_end_time, + ) = self.get_time_window() + + def get_time_window(self) -> Tuple[datetime, datetime]: + if self.redundant_run_skip_handler: + return self.redundant_run_skip_handler.suggest_run_time_window( + self.config.start_time, self.config.end_time + ) + else: + return self.config.start_time, self.config.end_time + def get_usage_workunits( self, discovered_datasets: List[str] ) -> Iterable[MetadataWorkUnit]: + if not self._should_ingest_usage(): + return + + self.report.set_ingestion_stage("*", USAGE_EXTRACTION_USAGE_AGGREGATION) + self.connection = self.create_connection() if self.connection is None: return @@ -144,13 +172,19 @@ def get_usage_workunits( if self.config.include_usage_stats: yield from auto_empty_dataset_usage_statistics( self._get_workunits_internal(discovered_datasets), - config=self.config, + config=BaseTimeWindowConfig( + start_time=self.start_time, + end_time=self.end_time, + bucket_duration=self.config.bucket_duration, + ), dataset_urns={ self.dataset_urn_builder(dataset_identifier) for dataset_identifier in discovered_datasets }, ) + self.report.set_ingestion_stage("*", USAGE_EXTRACTION_OPERATIONAL_STATS) + if self.config.include_operational_stats: # Generate the operation workunits. access_events = self._get_snowflake_history() @@ -159,6 +193,14 @@ def get_usage_workunits( event, discovered_datasets ) + if self.redundant_run_skip_handler: + # Update the checkpoint state for this run. + self.redundant_run_skip_handler.update_state( + self.config.start_time, + self.config.end_time, + self.config.bucket_duration, + ) + def _get_workunits_internal( self, discovered_datasets: List[str] ) -> Iterable[MetadataWorkUnit]: @@ -167,10 +209,8 @@ def _get_workunits_internal( try: results = self.query( SnowflakeQuery.usage_per_object_per_time_bucket_for_time_window( - start_time_millis=int( - self.config.start_time.timestamp() * 1000 - ), - end_time_millis=int(self.config.end_time.timestamp() * 1000), + start_time_millis=int(self.start_time.timestamp() * 1000), + end_time_millis=int(self.end_time.timestamp() * 1000), time_bucket_size=self.config.bucket_duration, use_base_objects=self.config.apply_view_usage_to_tables, top_n_queries=self.config.top_n_queries, @@ -179,11 +219,13 @@ def _get_workunits_internal( ) except Exception as e: logger.debug(e, exc_info=e) - self.report_warning( + self.warn_if_stateful_else_error( "usage-statistics", f"Populating table usage statistics from Snowflake failed due to error {e}.", ) + self.report_status(USAGE_EXTRACTION_USAGE_AGGREGATION, False) return + self.report.usage_aggregation_query_secs = timer.elapsed_seconds() for row in results: @@ -300,10 +342,11 @@ def _get_snowflake_history(self) -> Iterable[SnowflakeJoinedAccessEvent]: results = self.query(query) except Exception as e: logger.debug(e, exc_info=e) - self.report_warning( + self.warn_if_stateful_else_error( "operation", f"Populating table operation history from Snowflake failed due to error {e}.", ) + self.report_status(USAGE_EXTRACTION_OPERATIONAL_STATS, False) return self.report.access_history_query_secs = round(timer.elapsed_seconds(), 2) @@ -311,8 +354,8 @@ def _get_snowflake_history(self) -> Iterable[SnowflakeJoinedAccessEvent]: yield from self._process_snowflake_history_row(row) def _make_operations_query(self) -> str: - start_time = int(self.config.start_time.timestamp() * 1000) - end_time = int(self.config.end_time.timestamp() * 1000) + start_time = int(self.start_time.timestamp() * 1000) + end_time = int(self.end_time.timestamp() * 1000) return SnowflakeQuery.operational_data_for_time_window(start_time, end_time) def _check_usage_date_ranges(self) -> Any: @@ -331,6 +374,7 @@ def _check_usage_date_ranges(self) -> Any: "usage", f"Extracting the date range for usage data from Snowflake failed due to error {e}.", ) + self.report_status("date-range-check", False) else: for db_row in results: if ( @@ -493,3 +537,24 @@ def _is_object_valid(self, obj: Dict[str, Any]) -> bool: ): return False return True + + def _should_ingest_usage(self) -> bool: + if ( + self.redundant_run_skip_handler + and self.redundant_run_skip_handler.should_skip_this_run( + cur_start_time=self.config.start_time, + cur_end_time=self.config.end_time, + ) + ): + # Skip this run + self.report.report_warning( + "usage-extraction", + "Skip this run as there was already a run for current ingestion window.", + ) + return False + + return True + + def report_status(self, step: str, status: bool) -> None: + if self.redundant_run_skip_handler: + self.redundant_run_skip_handler.report_current_run_status(step, status) diff --git a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py index 2cb4b37fdd696..90b751c875add 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py +++ b/metadata-ingestion/src/datahub/ingestion/source/snowflake/snowflake_v2.py @@ -90,7 +90,8 @@ ) from datahub.ingestion.source.state.profiling_state_handler import ProfilingHandler from datahub.ingestion.source.state.redundant_run_skip_handler import ( - RedundantRunSkipHandler, + RedundantLineageRunSkipHandler, + RedundantUsageRunSkipHandler, ) from datahub.ingestion.source.state.stale_entity_removal_handler import ( StaleEntityRemovalHandler, @@ -98,6 +99,11 @@ from datahub.ingestion.source.state.stateful_ingestion_base import ( StatefulIngestionSourceBase, ) +from datahub.ingestion.source_report.ingestion_stage import ( + LINEAGE_EXTRACTION, + METADATA_EXTRACTION, + PROFILING, +) from datahub.metadata.com.linkedin.pegasus2avro.common import ( GlobalTags, Status, @@ -130,7 +136,6 @@ from datahub.utilities.perf_timer import PerfTimer from datahub.utilities.registries.domain_registry import DomainRegistry from datahub.utilities.sqlglot_lineage import SchemaResolver -from datahub.utilities.time import datetime_to_ts_millis logger: logging.Logger = logging.getLogger(__name__) @@ -222,13 +227,6 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): self.snowsight_base_url: Optional[str] = None self.connection: Optional[SnowflakeConnection] = None - self.redundant_run_skip_handler = RedundantRunSkipHandler( - source=self, - config=self.config, - pipeline_name=self.ctx.pipeline_name, - run_id=self.ctx.run_id, - ) - self.domain_registry: Optional[DomainRegistry] = None if self.config.domain: self.domain_registry = DomainRegistry( @@ -238,14 +236,42 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): # For database, schema, tables, views, etc self.data_dictionary = SnowflakeDataDictionary() - if config.include_table_lineage: + self.lineage_extractor: Optional[SnowflakeLineageExtractor] = None + if self.config.include_table_lineage: + redundant_lineage_run_skip_handler: Optional[ + RedundantLineageRunSkipHandler + ] = None + if self.config.enable_stateful_lineage_ingestion: + redundant_lineage_run_skip_handler = RedundantLineageRunSkipHandler( + source=self, + config=self.config, + pipeline_name=self.ctx.pipeline_name, + run_id=self.ctx.run_id, + ) self.lineage_extractor = SnowflakeLineageExtractor( - config, self.report, dataset_urn_builder=self.gen_dataset_urn + config, + self.report, + dataset_urn_builder=self.gen_dataset_urn, + redundant_run_skip_handler=redundant_lineage_run_skip_handler, ) - if config.include_usage_stats or config.include_operational_stats: + self.usage_extractor: Optional[SnowflakeUsageExtractor] = None + if self.config.include_usage_stats or self.config.include_operational_stats: + redundant_usage_run_skip_handler: Optional[ + RedundantUsageRunSkipHandler + ] = None + if self.config.enable_stateful_usage_ingestion: + redundant_usage_run_skip_handler = RedundantUsageRunSkipHandler( + source=self, + config=self.config, + pipeline_name=self.ctx.pipeline_name, + run_id=self.ctx.run_id, + ) self.usage_extractor = SnowflakeUsageExtractor( - config, self.report, dataset_urn_builder=self.gen_dataset_urn + config, + self.report, + dataset_urn_builder=self.gen_dataset_urn, + redundant_run_skip_handler=redundant_usage_run_skip_handler, ) self.tag_extractor = SnowflakeTagExtractor( @@ -253,7 +279,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): ) self.profiling_state_handler: Optional[ProfilingHandler] = None - if self.config.store_last_profiling_timestamps: + if self.config.enable_stateful_profiling: self.profiling_state_handler = ProfilingHandler( source=self, config=self.config, @@ -281,6 +307,7 @@ def __init__(self, ctx: PipelineContext, config: SnowflakeV2Config): env=self.config.env, ) self.view_definitions: FileBackedDict[str] = FileBackedDict() + self.add_config_to_report() @classmethod def create(cls, config_dict: dict, ctx: PipelineContext) -> "Source": @@ -481,7 +508,6 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: if self.connection is None: return - self.add_config_to_report() self.inspect_session_metadata() if self.config.include_external_url: @@ -506,6 +532,7 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: for snowflake_db in databases: try: + self.report.set_ingestion_stage(snowflake_db.name, METADATA_EXTRACTION) yield from self._process_database(snowflake_db) except SnowflakePermissionError as e: @@ -555,7 +582,8 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: discovered_datasets = discovered_tables + discovered_views - if self.config.include_table_lineage: + if self.config.include_table_lineage and self.lineage_extractor: + self.report.set_ingestion_stage("*", LINEAGE_EXTRACTION) yield from self.lineage_extractor.get_workunits( discovered_tables=discovered_tables, discovered_views=discovered_views, @@ -563,27 +591,9 @@ def get_workunits_internal(self) -> Iterable[MetadataWorkUnit]: view_definitions=self.view_definitions, ) - if self.config.include_usage_stats or self.config.include_operational_stats: - if ( - self.config.store_last_usage_extraction_timestamp - and self.redundant_run_skip_handler.should_skip_this_run( - cur_start_time_millis=datetime_to_ts_millis(self.config.start_time) - ) - ): - # Skip this run - self.report.report_warning( - "usage-extraction", - f"Skip this run as there was a run later than the current start time: {self.config.start_time}", - ) - return - - if self.config.store_last_usage_extraction_timestamp: - # Update the checkpoint state for this run. - self.redundant_run_skip_handler.update_state( - start_time_millis=datetime_to_ts_millis(self.config.start_time), - end_time_millis=datetime_to_ts_millis(self.config.end_time), - ) - + if ( + self.config.include_usage_stats or self.config.include_operational_stats + ) and self.usage_extractor: yield from self.usage_extractor.get_usage_workunits(discovered_datasets) def report_warehouse_failure(self): @@ -690,6 +700,7 @@ def _process_database( yield from self._process_schema(snowflake_schema, db_name) if self.config.is_profiling_enabled() and self.db_tables: + self.report.set_ingestion_stage(snowflake_db.name, PROFILING) yield from self.profiler.get_workunits(snowflake_db, self.db_tables) def fetch_schemas_for_database( @@ -1420,16 +1431,20 @@ def add_config_to_report(self): self.report.cleaned_account_id = self.config.get_account() self.report.ignore_start_time_lineage = self.config.ignore_start_time_lineage self.report.upstream_lineage_in_report = self.config.upstream_lineage_in_report - if not self.report.ignore_start_time_lineage: - self.report.lineage_start_time = self.config.start_time - self.report.lineage_end_time = self.config.end_time self.report.include_technical_schema = self.config.include_technical_schema self.report.include_usage_stats = self.config.include_usage_stats self.report.include_operational_stats = self.config.include_operational_stats self.report.include_column_lineage = self.config.include_column_lineage - if self.report.include_usage_stats or self.config.include_operational_stats: - self.report.window_start_time = self.config.start_time - self.report.window_end_time = self.config.end_time + self.report.stateful_lineage_ingestion_enabled = ( + self.config.enable_stateful_lineage_ingestion + ) + self.report.stateful_usage_ingestion_enabled = ( + self.config.enable_stateful_usage_ingestion + ) + self.report.window_start_time, self.report.window_end_time = ( + self.config.start_time, + self.config.end_time, + ) def inspect_session_metadata(self) -> None: try: @@ -1611,7 +1626,7 @@ def close(self) -> None: StatefulIngestionSourceBase.close(self) self.view_definitions.close() self.sql_parser_schema_resolver.close() - if hasattr(self, "lineage_extractor"): + if self.lineage_extractor: self.lineage_extractor.close() - if hasattr(self, "usage_extractor"): + if self.usage_extractor: self.usage_extractor.close() diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py b/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py index 459dbe0ce0af7..a2e078f233f1d 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/redundant_run_skip_handler.py @@ -1,8 +1,11 @@ import logging -from typing import Optional, cast +from abc import ABCMeta, abstractmethod +from datetime import datetime +from typing import Dict, Optional, Tuple, cast import pydantic +from datahub.configuration.time_window_config import BucketDuration, get_time_bucket from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import JobId from datahub.ingestion.source.state.checkpoint import Checkpoint from datahub.ingestion.source.state.stateful_ingestion_base import ( @@ -10,26 +13,24 @@ StatefulIngestionConfigBase, StatefulIngestionSourceBase, ) -from datahub.ingestion.source.state.usage_common_state import BaseUsageCheckpointState +from datahub.ingestion.source.state.usage_common_state import ( + BaseTimeWindowCheckpointState, +) from datahub.ingestion.source.state.use_case_handler import ( StatefulIngestionUsecaseHandlerBase, ) -from datahub.utilities.time import get_datetime_from_ts_millis_in_utc +from datahub.utilities.time import ( + TimeWindow, + datetime_to_ts_millis, + ts_millis_to_datetime, +) logger: logging.Logger = logging.getLogger(__name__) -class StatefulRedundantRunSkipConfig(StatefulIngestionConfig): - """ - Base specialized config of Stateful Ingestion to skip redundant runs. - """ - - # Defines the alias 'force_rerun' for ignore_old_state field. - ignore_old_state = pydantic.Field(False, alias="force_rerun") - - class RedundantRunSkipHandler( - StatefulIngestionUsecaseHandlerBase[BaseUsageCheckpointState] + StatefulIngestionUsecaseHandlerBase[BaseTimeWindowCheckpointState], + metaclass=ABCMeta, ): """ The stateful ingestion helper class that handles skipping redundant runs. @@ -41,38 +42,28 @@ class RedundantRunSkipHandler( def __init__( self, source: StatefulIngestionSourceBase, - config: StatefulIngestionConfigBase[StatefulRedundantRunSkipConfig], + config: StatefulIngestionConfigBase[StatefulIngestionConfig], pipeline_name: Optional[str], run_id: str, ): self.source = source self.state_provider = source.state_provider self.stateful_ingestion_config: Optional[ - StatefulRedundantRunSkipConfig + StatefulIngestionConfig ] = config.stateful_ingestion self.pipeline_name = pipeline_name self.run_id = run_id - self.checkpointing_enabled: bool = ( - self.state_provider.is_stateful_ingestion_configured() - ) self._job_id = self._init_job_id() self.state_provider.register_stateful_ingestion_usecase_handler(self) - def _ignore_old_state(self) -> bool: - if ( - self.stateful_ingestion_config is not None - and self.stateful_ingestion_config.ignore_old_state - ): - return True - return False + # step -> step status + self.status: Dict[str, bool] = {} def _ignore_new_state(self) -> bool: - if ( + return ( self.stateful_ingestion_config is not None and self.stateful_ingestion_config.ignore_new_state - ): - return True - return False + ) def _init_job_id(self) -> JobId: platform: Optional[str] = None @@ -80,22 +71,26 @@ def _init_job_id(self) -> JobId: if hasattr(source_class, "get_platform_name"): platform = source_class.get_platform_name() # type: ignore - # Handle backward-compatibility for existing sources. - if platform == "Snowflake": - return JobId("snowflake_usage_ingestion") - # Default name for everything else - job_name_suffix = "skip_redundant_run" - return JobId(f"{platform}_{job_name_suffix}" if platform else job_name_suffix) + job_name_suffix = self.get_job_name_suffix() + return JobId( + f"{platform}_skip_redundant_run{job_name_suffix}" + if platform + else f"skip_redundant_run{job_name_suffix}" + ) + + @abstractmethod + def get_job_name_suffix(self): + raise NotImplementedError("Sub-classes must override this method.") @property def job_id(self) -> JobId: return self._job_id def is_checkpointing_enabled(self) -> bool: - return self.checkpointing_enabled + return self.state_provider.is_stateful_ingestion_configured() - def create_checkpoint(self) -> Optional[Checkpoint[BaseUsageCheckpointState]]: + def create_checkpoint(self) -> Optional[Checkpoint[BaseTimeWindowCheckpointState]]: if not self.is_checkpointing_enabled() or self._ignore_new_state(): return None @@ -104,46 +99,150 @@ def create_checkpoint(self) -> Optional[Checkpoint[BaseUsageCheckpointState]]: job_name=self.job_id, pipeline_name=self.pipeline_name, run_id=self.run_id, - state=BaseUsageCheckpointState( + state=BaseTimeWindowCheckpointState( begin_timestamp_millis=self.INVALID_TIMESTAMP_VALUE, end_timestamp_millis=self.INVALID_TIMESTAMP_VALUE, ), ) - def update_state( + def report_current_run_status(self, step: str, status: bool) -> None: + """ + A helper to track status of all steps of current run. + This will be used to decide overall status of the run. + Checkpoint state will not be updated/committed for current run if there are any failures. + """ + self.status[step] = status + + def is_current_run_successful(self) -> bool: + return all(self.status.values()) + + def get_current_checkpoint( self, - start_time_millis: pydantic.PositiveInt, - end_time_millis: pydantic.PositiveInt, - ) -> None: - if not self.is_checkpointing_enabled() or self._ignore_new_state(): - return + ) -> Optional[Checkpoint]: + if ( + not self.is_checkpointing_enabled() + or self._ignore_new_state() + or not self.is_current_run_successful() + ): + return None cur_checkpoint = self.state_provider.get_current_checkpoint(self.job_id) assert cur_checkpoint is not None - cur_state = cast(BaseUsageCheckpointState, cur_checkpoint.state) - cur_state.begin_timestamp_millis = start_time_millis - cur_state.end_timestamp_millis = end_time_millis - - def should_skip_this_run(self, cur_start_time_millis: int) -> bool: - if not self.is_checkpointing_enabled() or self._ignore_old_state(): - return False - # Determine from the last check point state - last_successful_pipeline_run_end_time_millis: Optional[int] = None + return cur_checkpoint + + def should_skip_this_run( + self, cur_start_time: datetime, cur_end_time: datetime + ) -> bool: + skip: bool = False + last_checkpoint = self.state_provider.get_last_checkpoint( - self.job_id, BaseUsageCheckpointState + self.job_id, BaseTimeWindowCheckpointState ) - if last_checkpoint and last_checkpoint.state: - state = cast(BaseUsageCheckpointState, last_checkpoint.state) - last_successful_pipeline_run_end_time_millis = state.end_timestamp_millis - if ( - last_successful_pipeline_run_end_time_millis is not None - and cur_start_time_millis <= last_successful_pipeline_run_end_time_millis + if last_checkpoint: + last_run_time_window = TimeWindow( + ts_millis_to_datetime(last_checkpoint.state.begin_timestamp_millis), + ts_millis_to_datetime(last_checkpoint.state.end_timestamp_millis), + ) + + logger.debug( + f"{self.job_id} : Last run start, end times:" + f"({last_run_time_window})" + ) + + # If current run's time window is subset of last run's time window, then skip. + # Else there is at least some part in current time window that was not covered in past run's time window + if last_run_time_window.contains(TimeWindow(cur_start_time, cur_end_time)): + skip = True + + return skip + + def suggest_run_time_window( + self, + cur_start_time: datetime, + cur_end_time: datetime, + allow_reduce: int = True, + allow_expand: int = False, + ) -> Tuple[datetime, datetime]: + # If required in future, allow_reduce, allow_expand can be accepted as user input + # as part of stateful ingestion configuration. It is likely that they may cause + # more confusion than help to most users hence not added to start with. + last_checkpoint = self.state_provider.get_last_checkpoint( + self.job_id, BaseTimeWindowCheckpointState + ) + if (last_checkpoint is None) or self.should_skip_this_run( + cur_start_time, cur_end_time ): - warn_msg = ( - f"Skippig this run, since the last run's bucket duration end: " - f"{get_datetime_from_ts_millis_in_utc(last_successful_pipeline_run_end_time_millis)}" - f" is later than the current start_time: {get_datetime_from_ts_millis_in_utc(cur_start_time_millis)}" + return cur_start_time, cur_end_time + + suggested_start_time, suggested_end_time = cur_start_time, cur_end_time + + last_run = last_checkpoint.state.to_time_interval() + self.log(f"Last run start, end times:{last_run}") + cur_run = TimeWindow(cur_start_time, cur_end_time) + + if cur_run.starts_after(last_run): + # scenario of time gap between past successful run window and current run window - maybe due to failed past run + # Should we keep some configurable limits here to decide how much increase in time window is fine ? + if allow_expand: + suggested_start_time = last_run.end_time + self.log( + f"Expanding time window. Updating start time to {suggested_start_time}." + ) + else: + self.log( + f"Observed gap in last run end time({last_run.end_time}) and current run start time({cur_start_time})." + ) + elif allow_reduce and cur_run.left_intersects(last_run): + # scenario of scheduled ingestions with default start, end times + suggested_start_time = last_run.end_time + self.log( + f"Reducing time window. Updating start time to {suggested_start_time}." + ) + elif allow_reduce and cur_run.right_intersects(last_run): + # a manual backdated run + suggested_end_time = last_run.start_time + self.log( + f"Reducing time window. Updating end time to {suggested_end_time}." ) - logger.warning(warn_msg) - return True - return False + + # make sure to consider complete time bucket for usage + if last_checkpoint.state.bucket_duration: + suggested_start_time = get_time_bucket( + suggested_start_time, last_checkpoint.state.bucket_duration + ) + + self.log( + "Adjusted start, end times: " + f"({suggested_start_time}, {suggested_end_time})" + ) + return (suggested_start_time, suggested_end_time) + + def log(self, msg: str) -> None: + logger.info(f"{self.job_id} : {msg}") + + +class RedundantLineageRunSkipHandler(RedundantRunSkipHandler): + def get_job_name_suffix(self): + return "_lineage" + + def update_state(self, start_time: datetime, end_time: datetime) -> None: + cur_checkpoint = self.get_current_checkpoint() + if cur_checkpoint: + cur_state = cast(BaseTimeWindowCheckpointState, cur_checkpoint.state) + cur_state.begin_timestamp_millis = datetime_to_ts_millis(start_time) + cur_state.end_timestamp_millis = datetime_to_ts_millis(end_time) + + +class RedundantUsageRunSkipHandler(RedundantRunSkipHandler): + def get_job_name_suffix(self): + return "_usage" + + def update_state( + self, start_time: datetime, end_time: datetime, bucket_duration: BucketDuration + ) -> None: + cur_checkpoint = self.get_current_checkpoint() + if cur_checkpoint: + cur_state = cast(BaseTimeWindowCheckpointState, cur_checkpoint.state) + cur_state.begin_timestamp_millis = datetime_to_ts_millis(start_time) + cur_state.end_timestamp_millis = datetime_to_ts_millis(end_time) + cur_state.bucket_duration = bucket_duration diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py b/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py index 9dd6d27d56ea9..be97e9380f1f5 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/stateful_ingestion_base.py @@ -14,6 +14,7 @@ LineageConfig, ) from datahub.configuration.time_window_config import BaseTimeWindowConfig +from datahub.configuration.validate_field_rename import pydantic_renamed_field from datahub.ingestion.api.common import PipelineContext from datahub.ingestion.api.ingestion_job_checkpointing_provider_base import ( IngestionCheckpointingProviderBase, @@ -100,57 +101,75 @@ class StatefulIngestionConfigBase(GenericModel, Generic[CustomConfig]): class StatefulLineageConfigMixin(LineageConfig): - store_last_lineage_extraction_timestamp: bool = Field( - default=False, - description="Enable checking last lineage extraction date in store.", + enable_stateful_lineage_ingestion: bool = Field( + default=True, + description="Enable stateful lineage ingestion." + " This will store lineage window timestamps after successful lineage ingestion. " + "and will not run lineage ingestion for same timestamps in subsequent run. ", + ) + + _store_last_lineage_extraction_timestamp = pydantic_renamed_field( + "store_last_lineage_extraction_timestamp", "enable_stateful_lineage_ingestion" ) @root_validator(pre=False) def lineage_stateful_option_validator(cls, values: Dict) -> Dict: sti = values.get("stateful_ingestion") if not sti or not sti.enabled: - if values.get("store_last_lineage_extraction_timestamp"): + if values.get("enable_stateful_lineage_ingestion"): logger.warning( - "Stateful ingestion is disabled, disabling store_last_lineage_extraction_timestamp config option as well" + "Stateful ingestion is disabled, disabling enable_stateful_lineage_ingestion config option as well" ) - values["store_last_lineage_extraction_timestamp"] = False + values["enable_stateful_lineage_ingestion"] = False return values class StatefulProfilingConfigMixin(ConfigModel): - store_last_profiling_timestamps: bool = Field( - default=False, - description="Enable storing last profile timestamp in store.", + enable_stateful_profiling: bool = Field( + default=True, + description="Enable stateful profiling." + " This will store profiling timestamps per dataset after successful profiling. " + "and will not run profiling again in subsequent run if table has not been updated. ", + ) + + _store_last_profiling_timestamps = pydantic_renamed_field( + "store_last_profiling_timestamps", "enable_stateful_profiling" ) @root_validator(pre=False) def profiling_stateful_option_validator(cls, values: Dict) -> Dict: sti = values.get("stateful_ingestion") if not sti or not sti.enabled: - if values.get("store_last_profiling_timestamps"): + if values.get("enable_stateful_profiling"): logger.warning( - "Stateful ingestion is disabled, disabling store_last_profiling_timestamps config option as well" + "Stateful ingestion is disabled, disabling enable_stateful_profiling config option as well" ) - values["store_last_profiling_timestamps"] = False + values["enable_stateful_profiling"] = False return values class StatefulUsageConfigMixin(BaseTimeWindowConfig): - store_last_usage_extraction_timestamp: bool = Field( + enable_stateful_usage_ingestion: bool = Field( default=True, - description="Enable checking last usage timestamp in store.", + description="Enable stateful lineage ingestion." + " This will store usage window timestamps after successful usage ingestion. " + "and will not run usage ingestion for same timestamps in subsequent run. ", + ) + + _store_last_usage_extraction_timestamp = pydantic_renamed_field( + "store_last_usage_extraction_timestamp", "enable_stateful_usage_ingestion" ) @root_validator(pre=False) def last_usage_extraction_stateful_option_validator(cls, values: Dict) -> Dict: sti = values.get("stateful_ingestion") if not sti or not sti.enabled: - if values.get("store_last_usage_extraction_timestamp"): + if values.get("enable_stateful_usage_ingestion"): logger.warning( - "Stateful ingestion is disabled, disabling store_last_usage_extraction_timestamp config option as well" + "Stateful ingestion is disabled, disabling enable_stateful_usage_ingestion config option as well" ) - values["store_last_usage_extraction_timestamp"] = False + values["enable_stateful_usage_ingestion"] = False return values diff --git a/metadata-ingestion/src/datahub/ingestion/source/state/usage_common_state.py b/metadata-ingestion/src/datahub/ingestion/source/state/usage_common_state.py index 5ecd9946d3602..b8d44796e4b69 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/state/usage_common_state.py +++ b/metadata-ingestion/src/datahub/ingestion/source/state/usage_common_state.py @@ -1,14 +1,27 @@ +from typing import Optional + import pydantic +from datahub.configuration.time_window_config import BucketDuration from datahub.ingestion.source.state.checkpoint import CheckpointStateBase +from datahub.utilities.time import TimeWindow, ts_millis_to_datetime -class BaseUsageCheckpointState(CheckpointStateBase): +class BaseTimeWindowCheckpointState(CheckpointStateBase): """ - Base class for representing the checkpoint state for all usage based sources. + Base class for representing the checkpoint state for all time window based ingestion stages. Stores the last successful run's begin and end timestamps. Subclasses can define additional state as appropriate. """ - begin_timestamp_millis: pydantic.PositiveInt - end_timestamp_millis: pydantic.PositiveInt + begin_timestamp_millis: pydantic.NonNegativeInt + end_timestamp_millis: pydantic.NonNegativeInt + + # Required for time bucket based aggregations - e.g. Usage + bucket_duration: Optional[BucketDuration] = None + + def to_time_interval(self) -> TimeWindow: + return TimeWindow( + ts_millis_to_datetime(self.begin_timestamp_millis), + ts_millis_to_datetime(self.end_timestamp_millis), + ) diff --git a/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py b/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py index 8d4ac37f49213..92f8223f34d14 100644 --- a/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py +++ b/metadata-ingestion/src/datahub/ingestion/source/usage/usage_common.py @@ -213,19 +213,6 @@ def ensure_top_n_queries_is_not_too_big(cls, v: int) -> int: ) return v - @pydantic.validator("start_time") - def ensure_start_time_aligns_with_bucket_start_time( - cls, v: datetime, values: dict - ) -> datetime: - if get_time_bucket(v, values["bucket_duration"]) != v: - new_start_time = get_time_bucket(v, values["bucket_duration"]) - logger.warning( - f"`start_time` will be changed to {new_start_time}, although the input `start_time` is {v}." - "This is necessary to record correct usage for the configured bucket duration." - ) - return new_start_time - return v - class UsageAggregator(Generic[ResourceType]): # TODO: Move over other connectors to use this class diff --git a/metadata-ingestion/src/datahub/ingestion/source_report/ingestion_stage.py b/metadata-ingestion/src/datahub/ingestion/source_report/ingestion_stage.py new file mode 100644 index 0000000000000..e7da7eb6e701a --- /dev/null +++ b/metadata-ingestion/src/datahub/ingestion/source_report/ingestion_stage.py @@ -0,0 +1,41 @@ +import logging +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Optional + +from datahub.utilities.perf_timer import PerfTimer +from datahub.utilities.stats_collections import TopKDict + +logger: logging.Logger = logging.getLogger(__name__) + + +METADATA_EXTRACTION = "Metadata Extraction" +LINEAGE_EXTRACTION = "Lineage Extraction" +USAGE_EXTRACTION_INGESTION = "Usage Extraction Ingestion" +USAGE_EXTRACTION_OPERATIONAL_STATS = "Usage Extraction Operational Stats" +USAGE_EXTRACTION_USAGE_AGGREGATION = "Usage Extraction Usage Aggregation" +PROFILING = "Profiling" + + +@dataclass +class IngestionStageReport: + ingestion_stage: Optional[str] = None + ingestion_stage_durations: TopKDict[str, float] = field(default_factory=TopKDict) + + _timer: Optional[PerfTimer] = field( + default=None, init=False, repr=False, compare=False + ) + + def report_ingestion_stage_start(self, stage: str) -> None: + if self._timer: + elapsed = round(self._timer.elapsed_seconds(), 2) + logger.info( + f"Time spent in stage <{self.ingestion_stage}>: {elapsed} seconds" + ) + if self.ingestion_stage: + self.ingestion_stage_durations[self.ingestion_stage] = elapsed + else: + self._timer = PerfTimer() + + self.ingestion_stage = f"{stage} at {datetime.now(timezone.utc)}" + self._timer.start() diff --git a/metadata-ingestion/src/datahub/ingestion/source_report/sql/__init__.py b/metadata-ingestion/src/datahub/ingestion/source_report/sql/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/metadata-ingestion/src/datahub/ingestion/source_report/sql/snowflake.py b/metadata-ingestion/src/datahub/ingestion/source_report/sql/snowflake.py deleted file mode 100644 index 8ad583686f061..0000000000000 --- a/metadata-ingestion/src/datahub/ingestion/source_report/sql/snowflake.py +++ /dev/null @@ -1,37 +0,0 @@ -from dataclasses import dataclass, field -from datetime import datetime -from typing import Dict, List, Optional - -from datahub.ingestion.source.sql.sql_generic_profiler import ProfilingSqlReport -from datahub.ingestion.source_report.time_window import BaseTimeWindowReport - - -@dataclass -class BaseSnowflakeReport(BaseTimeWindowReport): - pass - - -@dataclass -class SnowflakeReport(BaseSnowflakeReport, ProfilingSqlReport): - num_table_to_table_edges_scanned: int = 0 - num_table_to_view_edges_scanned: int = 0 - num_view_to_table_edges_scanned: int = 0 - num_external_table_edges_scanned: int = 0 - ignore_start_time_lineage: Optional[bool] = None - upstream_lineage_in_report: Optional[bool] = None - upstream_lineage: Dict[str, List[str]] = field(default_factory=dict) - lineage_start_time: Optional[datetime] = None - lineage_end_time: Optional[datetime] = None - - cleaned_account_id: str = "" - run_ingestion: bool = False - - # https://community.snowflake.com/s/topic/0TO0Z000000Unu5WAC/releases - saas_version: Optional[str] = None - default_warehouse: Optional[str] = None - default_db: Optional[str] = None - default_schema: Optional[str] = None - role: str = "" - - profile_if_updated_since: Optional[datetime] = None - profile_candidates: Dict[str, List[str]] = field(default_factory=dict) diff --git a/metadata-ingestion/src/datahub/ingestion/source_report/usage/__init__.py b/metadata-ingestion/src/datahub/ingestion/source_report/usage/__init__.py deleted file mode 100644 index e69de29bb2d1d..0000000000000 diff --git a/metadata-ingestion/src/datahub/ingestion/source_report/usage/snowflake_usage.py b/metadata-ingestion/src/datahub/ingestion/source_report/usage/snowflake_usage.py deleted file mode 100644 index 5f7962fc36710..0000000000000 --- a/metadata-ingestion/src/datahub/ingestion/source_report/usage/snowflake_usage.py +++ /dev/null @@ -1,23 +0,0 @@ -from dataclasses import dataclass -from datetime import datetime -from typing import Optional - -from datahub.ingestion.source.state.stateful_ingestion_base import ( - StatefulIngestionReport, -) -from datahub.ingestion.source_report.sql.snowflake import BaseSnowflakeReport - - -@dataclass -class SnowflakeUsageReport(BaseSnowflakeReport, StatefulIngestionReport): - min_access_history_time: Optional[datetime] = None - max_access_history_time: Optional[datetime] = None - access_history_range_query_secs: float = -1 - access_history_query_secs: float = -1 - - rows_processed: int = 0 - rows_missing_query_text: int = 0 - rows_zero_base_objects_accessed: int = 0 - rows_zero_direct_objects_accessed: int = 0 - rows_missing_email: int = 0 - rows_parsing_error: int = 0 diff --git a/metadata-ingestion/src/datahub/utilities/time.py b/metadata-ingestion/src/datahub/utilities/time.py index d9e643b6bccc2..0df7afb19935f 100644 --- a/metadata-ingestion/src/datahub/utilities/time.py +++ b/metadata-ingestion/src/datahub/utilities/time.py @@ -1,4 +1,5 @@ import time +from dataclasses import dataclass from datetime import datetime, timezone @@ -6,9 +7,37 @@ def get_current_time_in_seconds() -> int: return int(time.time()) -def get_datetime_from_ts_millis_in_utc(ts_millis: int) -> datetime: +def ts_millis_to_datetime(ts_millis: int) -> datetime: + """Converts input timestamp in milliseconds to a datetime object with UTC timezone""" return datetime.fromtimestamp(ts_millis / 1000, tz=timezone.utc) def datetime_to_ts_millis(dt: datetime) -> int: + """Converts a datetime object to timestamp in milliseconds""" return int(round(dt.timestamp() * 1000)) + + +@dataclass +class TimeWindow: + start_time: datetime + end_time: datetime + + def contains(self, other: "TimeWindow") -> bool: + """Whether current window contains other window completely""" + return self.start_time <= other.start_time <= other.end_time <= self.end_time + + def left_intersects(self, other: "TimeWindow") -> bool: + """Whether only left part of current window overlaps other window.""" + return other.start_time <= self.start_time < other.end_time < self.end_time + + def right_intersects(self, other: "TimeWindow") -> bool: + """Whether only right part of current window overlaps other window.""" + return self.start_time < other.start_time < self.end_time <= other.end_time + + def starts_after(self, other: "TimeWindow") -> bool: + """Whether current window starts after other window ends""" + return other.start_time <= other.end_time < self.start_time + + def ends_after(self, other: "TimeWindow") -> bool: + """Whether current window ends after other window ends.""" + return self.end_time > other.end_time diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake.py index 6135b0b3b3274..dec50aefd19f0 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake.py @@ -124,7 +124,7 @@ def test_snowflake_basic(pytestconfig, tmp_path, mock_time, mock_datahub_graph): validate_upstreams_against_patterns=False, include_operational_stats=True, email_as_user_identifier=True, - start_time=datetime(2022, 6, 6, 7, 17, 0, 0).replace( + start_time=datetime(2022, 6, 6, 0, 0, 0, 0).replace( tzinfo=timezone.utc ), end_time=datetime(2022, 6, 7, 7, 17, 0, 0).replace( @@ -214,7 +214,7 @@ def test_snowflake_private_link(pytestconfig, tmp_path, mock_time, mock_datahub_ include_view_lineage=False, include_usage_stats=False, include_operational_stats=False, - start_time=datetime(2022, 6, 6, 7, 17, 0, 0).replace( + start_time=datetime(2022, 6, 6, 0, 0, 0, 0).replace( tzinfo=timezone.utc ), end_time=datetime(2022, 6, 7, 7, 17, 0, 0).replace( diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_failures.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_failures.py index 4963e71ae4d96..bba53c1e97a47 100644 --- a/metadata-ingestion/tests/integration/snowflake/test_snowflake_failures.py +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_failures.py @@ -55,7 +55,7 @@ def snowflake_pipeline_config(tmp_path): schema_pattern=AllowDenyPattern(allow=["test_db.test_schema"]), include_view_lineage=False, include_usage_stats=False, - start_time=datetime(2022, 6, 6, 7, 17, 0, 0).replace( + start_time=datetime(2022, 6, 6, 0, 0, 0, 0).replace( tzinfo=timezone.utc ), end_time=datetime(2022, 6, 7, 7, 17, 0, 0).replace(tzinfo=timezone.utc), diff --git a/metadata-ingestion/tests/integration/snowflake/test_snowflake_stateful.py b/metadata-ingestion/tests/integration/snowflake/test_snowflake_stateful.py new file mode 100644 index 0000000000000..f72bd5b72d2cd --- /dev/null +++ b/metadata-ingestion/tests/integration/snowflake/test_snowflake_stateful.py @@ -0,0 +1,119 @@ +from unittest import mock + +from freezegun import freeze_time + +from datahub.configuration.common import AllowDenyPattern, DynamicTypedConfig +from datahub.ingestion.run.pipeline import Pipeline +from datahub.ingestion.run.pipeline_config import PipelineConfig, SourceConfig +from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config +from datahub.ingestion.source.state.stale_entity_removal_handler import ( + StatefulStaleMetadataRemovalConfig, +) +from tests.integration.snowflake.common import FROZEN_TIME, default_query_results +from tests.test_helpers.state_helpers import ( + get_current_checkpoint_from_pipeline, + validate_all_providers_have_committed_successfully, +) + +GMS_PORT = 8080 +GMS_SERVER = f"http://localhost:{GMS_PORT}" + + +def stateful_pipeline_config(include_tables: bool) -> PipelineConfig: + return PipelineConfig( + pipeline_name="test_snowflake", + source=SourceConfig( + type="snowflake", + config=SnowflakeV2Config( + account_id="ABC12345.ap-south-1.aws", + username="TST_USR", + password="TST_PWD", + match_fully_qualified_names=True, + schema_pattern=AllowDenyPattern(allow=["test_db.test_schema"]), + include_tables=include_tables, + stateful_ingestion=StatefulStaleMetadataRemovalConfig.parse_obj( + { + "enabled": True, + "remove_stale_metadata": True, + "fail_safe_threshold": 100.0, + "state_provider": { + "type": "datahub", + "config": {"datahub_api": {"server": GMS_SERVER}}, + }, + } + ), + ), + ), + sink=DynamicTypedConfig(type="blackhole"), + ) + + +@freeze_time(FROZEN_TIME) +def test_tableau_stateful(mock_datahub_graph): + with mock.patch( + "datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider.DataHubGraph", + mock_datahub_graph, + ) as mock_checkpoint, mock.patch("snowflake.connector.connect") as mock_connect: + sf_connection = mock.MagicMock() + sf_cursor = mock.MagicMock() + mock_connect.return_value = sf_connection + sf_connection.cursor.return_value = sf_cursor + + sf_cursor.execute.side_effect = default_query_results + mock_checkpoint.return_value = mock_datahub_graph + pipeline_run1 = Pipeline(config=stateful_pipeline_config(True)) + pipeline_run1.run() + pipeline_run1.raise_from_status() + checkpoint1 = get_current_checkpoint_from_pipeline(pipeline_run1) + + assert checkpoint1 + assert checkpoint1.state + + with mock.patch( + "datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider.DataHubGraph", + mock_datahub_graph, + ) as mock_checkpoint, mock.patch("snowflake.connector.connect") as mock_connect: + sf_connection = mock.MagicMock() + sf_cursor = mock.MagicMock() + mock_connect.return_value = sf_connection + sf_connection.cursor.return_value = sf_cursor + + sf_cursor.execute.side_effect = default_query_results + + mock_checkpoint.return_value = mock_datahub_graph + pipeline_run2 = Pipeline(config=stateful_pipeline_config(False)) + pipeline_run2.run() + pipeline_run2.raise_from_status() + checkpoint2 = get_current_checkpoint_from_pipeline(pipeline_run2) + + assert checkpoint2 + assert checkpoint2.state + + # Validate that all providers have committed successfully. + validate_all_providers_have_committed_successfully( + pipeline=pipeline_run1, expected_providers=1 + ) + validate_all_providers_have_committed_successfully( + pipeline=pipeline_run2, expected_providers=1 + ) + + # Perform all assertions on the states. The deleted table should not be + # part of the second state + state1 = checkpoint1.state + state2 = checkpoint2.state + + difference_dataset_urns = list( + state1.get_urns_not_in(type="dataset", other_checkpoint_state=state2) + ) + assert sorted(difference_dataset_urns) == [ + "urn:li:dataset:(urn:li:dataPlatform:snowflake,test_db.test_schema.table_1,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:snowflake,test_db.test_schema.table_10,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:snowflake,test_db.test_schema.table_2,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:snowflake,test_db.test_schema.table_3,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:snowflake,test_db.test_schema.table_4,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:snowflake,test_db.test_schema.table_5,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:snowflake,test_db.test_schema.table_6,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:snowflake,test_db.test_schema.table_7,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:snowflake,test_db.test_schema.table_8,PROD)", + "urn:li:dataset:(urn:li:dataPlatform:snowflake,test_db.test_schema.table_9,PROD)", + ] diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/provider/test_datahub_ingestion_checkpointing_provider.py b/metadata-ingestion/tests/unit/stateful_ingestion/provider/test_datahub_ingestion_checkpointing_provider.py index 65a348026e852..600985266043b 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/provider/test_datahub_ingestion_checkpointing_provider.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/provider/test_datahub_ingestion_checkpointing_provider.py @@ -15,7 +15,9 @@ from datahub.ingestion.source.state.sql_common_state import ( BaseSQLAlchemyCheckpointState, ) -from datahub.ingestion.source.state.usage_common_state import BaseUsageCheckpointState +from datahub.ingestion.source.state.usage_common_state import ( + BaseTimeWindowCheckpointState, +) from datahub.ingestion.source.state_provider.datahub_ingestion_checkpointing_provider import ( DatahubIngestionCheckpointingProvider, ) @@ -113,8 +115,8 @@ def test_provider(self): run_id=self.run_id, state=job1_state_obj, ) - # Job2 - Checkpoint with a BaseUsageCheckpointState state - job2_state_obj = BaseUsageCheckpointState( + # Job2 - Checkpoint with a BaseTimeWindowCheckpointState state + job2_state_obj = BaseTimeWindowCheckpointState( begin_timestamp_millis=10, end_timestamp_millis=100 ) job2_checkpoint = Checkpoint( diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py index 51e2b0795819a..532ab69d1c6b1 100644 --- a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py +++ b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_checkpoint.py @@ -9,7 +9,9 @@ from datahub.ingestion.source.state.sql_common_state import ( BaseSQLAlchemyCheckpointState, ) -from datahub.ingestion.source.state.usage_common_state import BaseUsageCheckpointState +from datahub.ingestion.source.state.usage_common_state import ( + BaseTimeWindowCheckpointState, +) from datahub.metadata.schema_classes import ( DatahubIngestionCheckpointClass, IngestionCheckpointStateClass, @@ -67,8 +69,8 @@ def _make_sql_alchemy_checkpoint_state() -> BaseSQLAlchemyCheckpointState: return base_sql_alchemy_checkpoint_state_obj -def _make_usage_checkpoint_state() -> BaseUsageCheckpointState: - base_usage_checkpoint_state_obj = BaseUsageCheckpointState( +def _make_usage_checkpoint_state() -> BaseTimeWindowCheckpointState: + base_usage_checkpoint_state_obj = BaseTimeWindowCheckpointState( version="2.0", begin_timestamp_millis=1, end_timestamp_millis=100 ) return base_usage_checkpoint_state_obj @@ -77,8 +79,8 @@ def _make_usage_checkpoint_state() -> BaseUsageCheckpointState: _checkpoint_aspect_test_cases: Dict[str, CheckpointStateBase] = { # An instance of BaseSQLAlchemyCheckpointState. "BaseSQLAlchemyCheckpointState": _make_sql_alchemy_checkpoint_state(), - # An instance of BaseUsageCheckpointState. - "BaseUsageCheckpointState": _make_usage_checkpoint_state(), + # An instance of BaseTimeWindowCheckpointState. + "BaseTimeWindowCheckpointState": _make_usage_checkpoint_state(), } @@ -141,7 +143,7 @@ def test_supported_encodings(): """ Tests utf-8 and base85-bz2-json encodings """ - test_state = BaseUsageCheckpointState( + test_state = BaseTimeWindowCheckpointState( version="1.0", begin_timestamp_millis=1, end_timestamp_millis=100 ) diff --git a/metadata-ingestion/tests/unit/stateful_ingestion/state/test_redundant_run_skip_handler.py b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_redundant_run_skip_handler.py new file mode 100644 index 0000000000000..0400bd6a72aa5 --- /dev/null +++ b/metadata-ingestion/tests/unit/stateful_ingestion/state/test_redundant_run_skip_handler.py @@ -0,0 +1,273 @@ +from datetime import datetime, timezone +from unittest import mock + +import pytest + +from datahub.configuration.time_window_config import BucketDuration, get_time_bucket +from datahub.ingestion.api.common import PipelineContext +from datahub.ingestion.graph.client import DataHubGraph +from datahub.ingestion.source.snowflake.snowflake_config import SnowflakeV2Config +from datahub.ingestion.source.snowflake.snowflake_v2 import SnowflakeV2Source +from datahub.ingestion.source.state.stale_entity_removal_handler import ( + StatefulStaleMetadataRemovalConfig, +) +from datahub.ingestion.source.state.stateful_ingestion_base import ( + DynamicTypedStateProviderConfig, +) +from datahub.ingestion.source.state.usage_common_state import ( + BaseTimeWindowCheckpointState, +) +from datahub.utilities.time import datetime_to_ts_millis + +GMS_PORT = 8080 +GMS_SERVER = f"http://localhost:{GMS_PORT}" + + +@pytest.fixture +def stateful_source(mock_datahub_graph: DataHubGraph) -> SnowflakeV2Source: + pipeline_name = "test_redundant_run_lineage" + run_id = "test_redundant_run" + ctx = PipelineContext( + pipeline_name=pipeline_name, + run_id=run_id, + graph=mock_datahub_graph, + ) + config = SnowflakeV2Config( + account_id="ABC12345.ap-south-1", + username="TST_USR", + password="TST_PWD", + stateful_ingestion=StatefulStaleMetadataRemovalConfig( + enabled=True, + state_provider=DynamicTypedStateProviderConfig( + type="datahub", config={"datahub_api": {"server": GMS_SERVER}} + ), + ), + ) + source = SnowflakeV2Source(ctx=ctx, config=config) + return source + + +def test_redundant_run_job_ids(stateful_source: SnowflakeV2Source) -> None: + assert stateful_source.lineage_extractor is not None + assert stateful_source.lineage_extractor.redundant_run_skip_handler is not None + assert ( + stateful_source.lineage_extractor.redundant_run_skip_handler.job_id + == "Snowflake_skip_redundant_run_lineage" + ) + + assert stateful_source.usage_extractor is not None + assert stateful_source.usage_extractor.redundant_run_skip_handler is not None + assert ( + stateful_source.usage_extractor.redundant_run_skip_handler.job_id + == "Snowflake_skip_redundant_run_usage" + ) + + +# last run +last_run_start_time = datetime(2023, 7, 2, tzinfo=timezone.utc) +last_run_end_time = datetime(2023, 7, 3, 12, tzinfo=timezone.utc) + + +@pytest.mark.parametrize( + "start_time,end_time,should_skip,suggested_start_time,suggested_end_time", + [ + # Case = current run time window is same as of last run time window + [ + datetime(2023, 7, 2, tzinfo=timezone.utc), + datetime(2023, 7, 3, 12, tzinfo=timezone.utc), + True, + None, + None, + ], + # Case = current run time window is starts at same time as of last run time window but ends later + [ + datetime(2023, 7, 2, tzinfo=timezone.utc), + datetime(2023, 7, 3, 18, tzinfo=timezone.utc), + False, + datetime(2023, 7, 3, 12, tzinfo=timezone.utc), + datetime(2023, 7, 3, 18, tzinfo=timezone.utc), + ], + # Case = current run time window is subset of last run time window + [ + datetime(2023, 7, 2, tzinfo=timezone.utc), + datetime(2023, 7, 3, tzinfo=timezone.utc), + True, + None, + None, + ], + # Case = current run time window is after last run time window but has some overlap with last run + # Scenario for next day's run for scheduled daily ingestions + [ + datetime(2023, 7, 3, tzinfo=timezone.utc), + datetime(2023, 7, 4, 12, tzinfo=timezone.utc), + False, + datetime(2023, 7, 3, 12, tzinfo=timezone.utc), + datetime(2023, 7, 4, 12, tzinfo=timezone.utc), + ], + # Case = current run time window is after last run time window and has no overlap with last run + [ + datetime(2023, 7, 5, tzinfo=timezone.utc), + datetime(2023, 7, 7, 12, tzinfo=timezone.utc), + False, + datetime(2023, 7, 5, tzinfo=timezone.utc), + datetime(2023, 7, 7, 12, tzinfo=timezone.utc), + ], + # Case = current run time window is before last run time window but has some overlap with last run + # Scenario for manual run for past dates + [ + datetime(2023, 6, 30, tzinfo=timezone.utc), + datetime(2023, 7, 2, 12, tzinfo=timezone.utc), + False, + datetime(2023, 6, 30, tzinfo=timezone.utc), + datetime(2023, 7, 2, tzinfo=timezone.utc), + ], + # Case = current run time window starts before last run time window and ends exactly on last run end time + # Scenario for manual run for past dates + [ + datetime(2023, 6, 30, tzinfo=timezone.utc), + datetime(2023, 7, 3, 12, tzinfo=timezone.utc), + False, + datetime(2023, 6, 30, tzinfo=timezone.utc), + datetime(2023, 7, 2, tzinfo=timezone.utc), + ], + # Case = current run time window is before last run time window and has no overlap with last run + # Scenario for manual run for past dates + [ + datetime(2023, 6, 20, tzinfo=timezone.utc), + datetime(2023, 6, 30, tzinfo=timezone.utc), + False, + datetime(2023, 6, 20, tzinfo=timezone.utc), + datetime(2023, 6, 30, tzinfo=timezone.utc), + ], + # Case = current run time window subsumes last run time window and extends on both sides + # Scenario for manual run + [ + datetime(2023, 6, 20, tzinfo=timezone.utc), + datetime(2023, 7, 20, tzinfo=timezone.utc), + False, + datetime(2023, 6, 20, tzinfo=timezone.utc), + datetime(2023, 7, 20, tzinfo=timezone.utc), + ], + ], +) +def test_redundant_run_skip_handler( + stateful_source: SnowflakeV2Source, + start_time: datetime, + end_time: datetime, + should_skip: bool, + suggested_start_time: datetime, + suggested_end_time: datetime, +) -> None: + # mock_datahub_graph + + # mocked_source = mock.MagicMock() + # mocked_config = mock.MagicMock() + + with mock.patch( + "datahub.ingestion.source.state.stateful_ingestion_base.StateProviderWrapper.get_last_checkpoint" + ) as mocked_fn: + set_mock_last_run_time_window( + mocked_fn, + last_run_start_time, + last_run_end_time, + ) + + # Redundant Lineage Skip Handler + assert stateful_source.lineage_extractor is not None + assert stateful_source.lineage_extractor.redundant_run_skip_handler is not None + assert ( + stateful_source.lineage_extractor.redundant_run_skip_handler.should_skip_this_run( + start_time, end_time + ) + == should_skip + ) + + if not should_skip: + suggested_time_window = stateful_source.lineage_extractor.redundant_run_skip_handler.suggest_run_time_window( + start_time, end_time + ) + assert suggested_time_window == (suggested_start_time, suggested_end_time) + + set_mock_last_run_time_window_usage( + mocked_fn, last_run_start_time, last_run_end_time + ) + # Redundant Usage Skip Handler + assert stateful_source.usage_extractor is not None + assert stateful_source.usage_extractor.redundant_run_skip_handler is not None + assert ( + stateful_source.usage_extractor.redundant_run_skip_handler.should_skip_this_run( + start_time, end_time + ) + == should_skip + ) + + if not should_skip: + suggested_time_window = stateful_source.usage_extractor.redundant_run_skip_handler.suggest_run_time_window( + start_time, end_time + ) + assert suggested_time_window == ( + get_time_bucket(suggested_start_time, BucketDuration.DAY), + suggested_end_time, + ) + + +def set_mock_last_run_time_window(mocked_fn, start_time, end_time): + mock_checkpoint = mock.MagicMock() + mock_checkpoint.state = BaseTimeWindowCheckpointState( + begin_timestamp_millis=datetime_to_ts_millis(start_time), + end_timestamp_millis=datetime_to_ts_millis(end_time), + ) + mocked_fn.return_value = mock_checkpoint + + +def set_mock_last_run_time_window_usage(mocked_fn, start_time, end_time): + mock_checkpoint = mock.MagicMock() + mock_checkpoint.state = BaseTimeWindowCheckpointState( + begin_timestamp_millis=datetime_to_ts_millis(start_time), + end_timestamp_millis=datetime_to_ts_millis(end_time), + bucket_duration=BucketDuration.DAY, + ) + mocked_fn.return_value = mock_checkpoint + + +def test_successful_run_creates_checkpoint(stateful_source: SnowflakeV2Source) -> None: + assert stateful_source.lineage_extractor is not None + assert stateful_source.lineage_extractor.redundant_run_skip_handler is not None + with mock.patch( + "datahub.ingestion.source.state.stateful_ingestion_base.StateProviderWrapper.create_checkpoint" + ) as mocked_create_checkpoint_fn, mock.patch( + "datahub.ingestion.source.state.stateful_ingestion_base.StateProviderWrapper.get_last_checkpoint" + ) as mocked_fn: + set_mock_last_run_time_window( + mocked_fn, + last_run_start_time, + last_run_end_time, + ) + stateful_source.lineage_extractor.redundant_run_skip_handler.update_state( + datetime.now(tz=timezone.utc), datetime.now(tz=timezone.utc) + ) + mocked_create_checkpoint_fn.assert_called_once() + + +def test_failed_run_does_not_create_checkpoint( + stateful_source: SnowflakeV2Source, +) -> None: + assert stateful_source.lineage_extractor is not None + assert stateful_source.lineage_extractor.redundant_run_skip_handler is not None + stateful_source.lineage_extractor.redundant_run_skip_handler.report_current_run_status( + "some_step", False + ) + with mock.patch( + "datahub.ingestion.source.state.stateful_ingestion_base.StateProviderWrapper.create_checkpoint" + ) as mocked_create_checkpoint_fn, mock.patch( + "datahub.ingestion.source.state.stateful_ingestion_base.StateProviderWrapper.get_last_checkpoint" + ) as mocked_fn: + set_mock_last_run_time_window( + mocked_fn, + last_run_start_time, + last_run_end_time, + ) + stateful_source.lineage_extractor.redundant_run_skip_handler.update_state( + datetime.now(tz=timezone.utc), datetime.now(tz=timezone.utc) + ) + mocked_create_checkpoint_fn.assert_not_called() diff --git a/metadata-ingestion/tests/unit/test_base_usage_config.py b/metadata-ingestion/tests/unit/test_base_usage_config.py deleted file mode 100644 index 008dcf25e38e4..0000000000000 --- a/metadata-ingestion/tests/unit/test_base_usage_config.py +++ /dev/null @@ -1,34 +0,0 @@ -from datetime import datetime, timezone - -from freezegun import freeze_time - -from datahub.ingestion.source.usage.usage_common import BaseUsageConfig - -FROZEN_TIME = "2023-08-03 09:00:00" -FROZEN_TIME2 = "2023-08-03 09:10:00" - - -@freeze_time(FROZEN_TIME) -def test_relative_start_time_aligns_with_bucket_start_time(): - config = BaseUsageConfig.parse_obj( - {"start_time": "-2 days", "end_time": "2023-07-07T09:00:00Z"} - ) - assert config.start_time == datetime(2023, 7, 5, 0, tzinfo=timezone.utc) - assert config.end_time == datetime(2023, 7, 7, 9, tzinfo=timezone.utc) - - config = BaseUsageConfig.parse_obj( - {"start_time": "-2 days", "end_time": "2023-07-07T09:00:00Z"} - ) - assert config.start_time == datetime(2023, 7, 5, 0, tzinfo=timezone.utc) - assert config.end_time == datetime(2023, 7, 7, 9, tzinfo=timezone.utc) - - -@freeze_time(FROZEN_TIME) -def test_absolute_start_time_aligns_with_bucket_start_time(): - config = BaseUsageConfig.parse_obj({"start_time": "2023-07-01T00:00:00Z"}) - assert config.start_time == datetime(2023, 7, 1, 0, tzinfo=timezone.utc) - assert config.end_time == datetime(2023, 8, 3, 9, tzinfo=timezone.utc) - - config = BaseUsageConfig.parse_obj({"start_time": "2023-07-01T09:00:00Z"}) - assert config.start_time == datetime(2023, 7, 1, 0, tzinfo=timezone.utc) - assert config.end_time == datetime(2023, 8, 3, 9, tzinfo=timezone.utc) diff --git a/metadata-ingestion/tests/unit/test_time_window_config.py b/metadata-ingestion/tests/unit/test_time_window_config.py index 127dc179c21e7..847bda2511a0c 100644 --- a/metadata-ingestion/tests/unit/test_time_window_config.py +++ b/metadata-ingestion/tests/unit/test_time_window_config.py @@ -26,23 +26,23 @@ def test_default_start_end_time_hour_bucket_duration(): @freeze_time(FROZEN_TIME) def test_relative_start_time(): config = BaseTimeWindowConfig.parse_obj({"start_time": "-2 days"}) - assert config.start_time == datetime(2023, 8, 1, 9, tzinfo=timezone.utc) + assert config.start_time == datetime(2023, 8, 1, 0, tzinfo=timezone.utc) assert config.end_time == datetime(2023, 8, 3, 9, tzinfo=timezone.utc) config = BaseTimeWindowConfig.parse_obj({"start_time": "-2d"}) - assert config.start_time == datetime(2023, 8, 1, 9, tzinfo=timezone.utc) + assert config.start_time == datetime(2023, 8, 1, 0, tzinfo=timezone.utc) assert config.end_time == datetime(2023, 8, 3, 9, tzinfo=timezone.utc) config = BaseTimeWindowConfig.parse_obj( {"start_time": "-2 days", "end_time": "2023-07-07T09:00:00Z"} ) - assert config.start_time == datetime(2023, 7, 5, 9, tzinfo=timezone.utc) + assert config.start_time == datetime(2023, 7, 5, 0, tzinfo=timezone.utc) assert config.end_time == datetime(2023, 7, 7, 9, tzinfo=timezone.utc) config = BaseTimeWindowConfig.parse_obj( {"start_time": "-2 days", "end_time": "2023-07-07T09:00:00Z"} ) - assert config.start_time == datetime(2023, 7, 5, 9, tzinfo=timezone.utc) + assert config.start_time == datetime(2023, 7, 5, 0, tzinfo=timezone.utc) assert config.end_time == datetime(2023, 7, 7, 9, tzinfo=timezone.utc)