Skip to content

Commit

Permalink
Add schema version check and improve time grains logic
Browse files Browse the repository at this point in the history
Signed-off-by: Wei-Chun, Chang <[email protected]>
  • Loading branch information
wcchang1115 committed Aug 30, 2023
1 parent 624a448 commit fa0709b
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 5 deletions.
59 changes: 54 additions & 5 deletions piperider_cli/dbtutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,9 +351,54 @@ def is_chosen(key, metric):
return metrics


def is_dbt_schema_version_16(manifest: Dict):
# dbt_schema_version: 'https://schemas.getdbt.com/dbt/manifest/v10.json'
schema_version = manifest['metadata'].get('dbt_schema_version').split('/')[-1]
version = schema_version.split('.')[0]
return int(version[1:]) >= 10


def get_support_time_grains(grain: str):
all_time_grains = ['day', 'week', 'month', 'quarter', 'year']
available_time_grains = all_time_grains[all_time_grains.index(grain):]
support_time_grains = ['day', 'month', 'year']

return [x for x in support_time_grains if x in available_time_grains]


def find_derived_time_grains(manifest: Dict, metric: Dict):
nodes = metric.get('depends_on').get('nodes', [])
depends_on = nodes[0]
if depends_on.startswith('semantic_model.'):
semantic_model = manifest.get('semantic_models').get(depends_on)
measure = None
for obj in semantic_model.get('measures'):
if obj.get('name') == metric.get('type_params').get('measure').get('name'):
measure = obj
break

agg_time_dimension = measure.get('agg_time_dimension') if measure.get(
'agg_time_dimension') else semantic_model.get('defaults').get('agg_time_dimension')

time_grains = None
if agg_time_dimension is not None:
# find dimension definition - time
for obj in semantic_model.get('dimensions'):
if obj.get('name') == agg_time_dimension:
grain = obj.get('type_params').get('time_granularity')
time_grains = get_support_time_grains(grain)
break

return time_grains


def get_dbt_state_metrics_16(dbt_state_dir: str, dbt_tag: str, dbt_resources: Optional[dict] = None):
manifest = _get_state_manifest(dbt_state_dir)

if not is_dbt_schema_version_16(manifest):
console.print('Metric is not supported for dbt version < 0.16')
return []

def is_chosen(key, metric):
statistics = Statistics()
if dbt_resources:
Expand Down Expand Up @@ -417,15 +462,14 @@ def _create_metric(name, filter=None, alias=None):
agg_time_dimension = measure.get('agg_time_dimension') if measure.get(
'agg_time_dimension') else semantic_model.get('defaults').get('agg_time_dimension')
timestamp = None
time_grain = None
time_grains = None
if agg_time_dimension is not None:
# find dimension definition - time
for obj in semantic_model.get('dimensions'):
if obj.get('name') == agg_time_dimension:
timestamp = obj.get('expr')
grain = obj.get('type_params').get('time_granularity')
time_grain = ['day', 'week', 'month', 'quarter', 'year']
time_grain = time_grain[time_grain.index(grain):]
time_grains = get_support_time_grains(grain)
break

if metric.get('filter') is not None:
Expand All @@ -442,7 +486,7 @@ def _create_metric(name, filter=None, alias=None):
model = SemanticModel(metric.get('name'), table, schema, database, expression, timestamp,
filters=metric_filter)

m = Metric(metric.get('name'), model=model, calculation_method=calculation_method, time_grains=time_grain,
m = Metric(metric.get('name'), model=model, calculation_method=calculation_method, time_grains=time_grains,
label=metric.get('label'), description=metric.get('description'),
ref_id=metric.get('unique_id'))

Expand All @@ -468,17 +512,22 @@ def _create_metric(name, filter=None, alias=None):
return m
elif metric.get('type') == 'derived':
ref_metrics = []
time_grains = ['day', 'month', 'year']
for ref_metric in metric.get('type_params').get('metrics'):
m2 = _create_metric(
ref_metric.get('name'),
filter=ref_metric.get('filter'),
alias=ref_metric.get('alias'))
ref_metrics.append(m2)

derived_time_grains = find_derived_time_grains(manifest, metric_map[ref_metric.get('name')])
if len(time_grains) < len(derived_time_grains):
time_grains = derived_time_grains

m = Metric(metric.get('name'),
calculation_method='derived',
expression=metric.get('type_params').get('expr'),
time_grains=['day', 'month', 'year'],
time_grains=time_grains,
label=metric.get('label'), description=metric.get('description'), ref_metrics=ref_metrics,
ref_id=metric.get('unique_id'))
return m
Expand Down
16 changes: 16 additions & 0 deletions tests/test_dbt_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,3 +265,19 @@ def test_load_dbt_resources(self, get_dbt_manifest):
resources = dbtutil.load_dbt_resources(target_path)
self.assertIn('models', resources)
self.assertIn('metrics', resources)

def test_get_support_time_grain(self):
time_grains = dbtutil.get_support_time_grains('day')
self.assertListEqual(time_grains, ['day', 'month', 'year'])

time_grains = dbtutil.get_support_time_grains('week')
self.assertListEqual(time_grains, ['month', 'year'])

time_grains = dbtutil.get_support_time_grains('month')
self.assertListEqual(time_grains, ['month', 'year'])

time_grains = dbtutil.get_support_time_grains('quarter')
self.assertListEqual(time_grains, ['year'])

time_grains = dbtutil.get_support_time_grains('year')
self.assertListEqual(time_grains, ['year'])

0 comments on commit fa0709b

Please sign in to comment.