From fa0709b9a164e66a2af74a1829ac0dc916a0205d Mon Sep 17 00:00:00 2001 From: "Wei-Chun, Chang" Date: Wed, 30 Aug 2023 12:36:40 +0800 Subject: [PATCH] Add schema version check and improve time grains logic Signed-off-by: Wei-Chun, Chang --- piperider_cli/dbtutil.py | 59 ++++++++++++++++++++++++++++++++++++---- tests/test_dbt_util.py | 16 +++++++++++ 2 files changed, 70 insertions(+), 5 deletions(-) diff --git a/piperider_cli/dbtutil.py b/piperider_cli/dbtutil.py index 27e58bb67..397b524e5 100644 --- a/piperider_cli/dbtutil.py +++ b/piperider_cli/dbtutil.py @@ -351,9 +351,54 @@ def is_chosen(key, metric): return metrics +def is_dbt_schema_version_16(manifest: Dict): + # dbt_schema_version: 'https://schemas.getdbt.com/dbt/manifest/v10.json' + schema_version = manifest['metadata'].get('dbt_schema_version').split('/')[-1] + version = schema_version.split('.')[0] + return int(version[1:]) >= 10 + + +def get_support_time_grains(grain: str): + all_time_grains = ['day', 'week', 'month', 'quarter', 'year'] + available_time_grains = all_time_grains[all_time_grains.index(grain):] + support_time_grains = ['day', 'month', 'year'] + + return [x for x in support_time_grains if x in available_time_grains] + + +def find_derived_time_grains(manifest: Dict, metric: Dict): + nodes = metric.get('depends_on').get('nodes', []) + depends_on = nodes[0] + if depends_on.startswith('semantic_model.'): + semantic_model = manifest.get('semantic_models').get(depends_on) + measure = None + for obj in semantic_model.get('measures'): + if obj.get('name') == metric.get('type_params').get('measure').get('name'): + measure = obj + break + + agg_time_dimension = measure.get('agg_time_dimension') if measure.get( + 'agg_time_dimension') else semantic_model.get('defaults').get('agg_time_dimension') + + time_grains = None + if agg_time_dimension is not None: + # find dimension definition - time + for obj in semantic_model.get('dimensions'): + if obj.get('name') == agg_time_dimension: + grain = obj.get('type_params').get('time_granularity') + time_grains = get_support_time_grains(grain) + break + + return time_grains + + def get_dbt_state_metrics_16(dbt_state_dir: str, dbt_tag: str, dbt_resources: Optional[dict] = None): manifest = _get_state_manifest(dbt_state_dir) + if not is_dbt_schema_version_16(manifest): + console.print('Metric is not supported for dbt version < 0.16') + return [] + def is_chosen(key, metric): statistics = Statistics() if dbt_resources: @@ -417,15 +462,14 @@ def _create_metric(name, filter=None, alias=None): agg_time_dimension = measure.get('agg_time_dimension') if measure.get( 'agg_time_dimension') else semantic_model.get('defaults').get('agg_time_dimension') timestamp = None - time_grain = None + time_grains = None if agg_time_dimension is not None: # find dimension definition - time for obj in semantic_model.get('dimensions'): if obj.get('name') == agg_time_dimension: timestamp = obj.get('expr') grain = obj.get('type_params').get('time_granularity') - time_grain = ['day', 'week', 'month', 'quarter', 'year'] - time_grain = time_grain[time_grain.index(grain):] + time_grains = get_support_time_grains(grain) break if metric.get('filter') is not None: @@ -442,7 +486,7 @@ def _create_metric(name, filter=None, alias=None): model = SemanticModel(metric.get('name'), table, schema, database, expression, timestamp, filters=metric_filter) - m = Metric(metric.get('name'), model=model, calculation_method=calculation_method, time_grains=time_grain, + m = Metric(metric.get('name'), model=model, calculation_method=calculation_method, time_grains=time_grains, label=metric.get('label'), description=metric.get('description'), ref_id=metric.get('unique_id')) @@ -468,6 +512,7 @@ def _create_metric(name, filter=None, alias=None): return m elif metric.get('type') == 'derived': ref_metrics = [] + time_grains = ['day', 'month', 'year'] for ref_metric in metric.get('type_params').get('metrics'): m2 = _create_metric( ref_metric.get('name'), @@ -475,10 +520,14 @@ def _create_metric(name, filter=None, alias=None): alias=ref_metric.get('alias')) ref_metrics.append(m2) + derived_time_grains = find_derived_time_grains(manifest, metric_map[ref_metric.get('name')]) + if len(time_grains) < len(derived_time_grains): + time_grains = derived_time_grains + m = Metric(metric.get('name'), calculation_method='derived', expression=metric.get('type_params').get('expr'), - time_grains=['day', 'month', 'year'], + time_grains=time_grains, label=metric.get('label'), description=metric.get('description'), ref_metrics=ref_metrics, ref_id=metric.get('unique_id')) return m diff --git a/tests/test_dbt_util.py b/tests/test_dbt_util.py index 5a5931ffe..8855a0834 100644 --- a/tests/test_dbt_util.py +++ b/tests/test_dbt_util.py @@ -265,3 +265,19 @@ def test_load_dbt_resources(self, get_dbt_manifest): resources = dbtutil.load_dbt_resources(target_path) self.assertIn('models', resources) self.assertIn('metrics', resources) + + def test_get_support_time_grain(self): + time_grains = dbtutil.get_support_time_grains('day') + self.assertListEqual(time_grains, ['day', 'month', 'year']) + + time_grains = dbtutil.get_support_time_grains('week') + self.assertListEqual(time_grains, ['month', 'year']) + + time_grains = dbtutil.get_support_time_grains('month') + self.assertListEqual(time_grains, ['month', 'year']) + + time_grains = dbtutil.get_support_time_grains('quarter') + self.assertListEqual(time_grains, ['year']) + + time_grains = dbtutil.get_support_time_grains('year') + self.assertListEqual(time_grains, ['year'])