Skip to content

Commit

Permalink
Merge pull request #875 from InfuseAI/feature/sc-32064/support-the-ba…
Browse files Browse the repository at this point in the history
…sic-dbt-1-6-metric-restriction

[Feature] Add restriction to dbt 1.6 metric support
  • Loading branch information
popcornylu authored Sep 5, 2023
2 parents 6fa0e6b + a6b1e35 commit 3fb8ee4
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 22 deletions.
85 changes: 65 additions & 20 deletions piperider_cli/dbtutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Dict, Optional, Union

import inquirer
from jinja2 import UndefinedError
from rich.console import Console
from rich.table import Table
from ruamel import yaml
Expand Down Expand Up @@ -353,9 +354,16 @@ def is_chosen(key, metric):

def is_dbt_schema_version_16(manifest: Dict):
# dbt_schema_version: 'https://schemas.getdbt.com/dbt/manifest/v10.json'
schema_version = manifest['metadata'].get('dbt_schema_version').split('/')[-1]
version = schema_version.split('.')[0]
return int(version[1:]) >= 10
schema_version_url = manifest['metadata'].get('dbt_schema_version')

def parse_version_from_url(url: str):
import re
match = re.match(r"(.*?)\/v(\d+)\.json", url)
if match:
return int(match.group(2))

version = parse_version_from_url(schema_version_url)
return version >= 10


def load_metric_jinja_string_template(value: str):
Expand All @@ -379,6 +387,23 @@ def get_support_time_grains(grain: str):
return [x for x in support_time_grains if x in available_time_grains]


def get_metric_filter(metric_name, raw_filter):
try:
sql_filter = load_metric_jinja_string_template(raw_filter.get('where_sql_template')).render()
return {
'field': sql_filter.split(' ')[0],
'operator': sql_filter.split(' ')[1],
'value': sql_filter.split(' ')[2]
}
except UndefinedError as e:
func_name = e.message.split(' ')[0]
console.print(
f"[[bold yellow]Skip[/bold yellow]] Metric '{metric_name}'. "
f"Jinja function {func_name} of filter is not supported.")

return None


def find_derived_time_grains(manifest: Dict, metric: Dict):
nodes = metric.get('depends_on').get('nodes', [])
depends_on = nodes[0]
Expand Down Expand Up @@ -432,23 +457,29 @@ def is_chosen(key, metric):
for key, metric in manifest.get('metrics').items():
metric_map[metric.get('name')] = metric

def _create_metric(name, filter=None, alias=None):
def _create_metric(name, filter=None, alias=None, root_name=None):
root_name = name if root_name is None else root_name
statistics = Statistics()
metric = metric_map.get(name)

if metric.get('type') == 'simple':
primary_entity = None
metric_filter = []
if metric.get('filter') is not None:
sql_filter = load_metric_jinja_string_template(metric.get('filter').get('where_sql_template')).render()
metric_filter.append({'field': sql_filter.split(' ')[0],
'operator': sql_filter.split(' ')[1],
'value': sql_filter.split(' ')[2]})
f = get_metric_filter(root_name, metric.get('filter'))
if f is not None:
metric_filter.append(f)
else:
statistics.add_field_one('nosupport')
return None

if filter is not None:
sql_filter = load_metric_jinja_string_template(filter.get('where_sql_template')).render()
metric_filter.append({'field': sql_filter.split(' ')[0],
'operator': sql_filter.split(' ')[1],
'value': sql_filter.split(' ')[2]})
f = get_metric_filter(root_name, filter)
if f is not None:
metric_filter.append(f)
else:
statistics.add_field_one('nosupport')
return None

nodes = metric.get('depends_on').get('nodes', [])
depends_on = nodes[0]
Expand Down Expand Up @@ -509,14 +540,15 @@ def _create_metric(name, filter=None, alias=None):
f['field'] = f['field'].replace(f'{primary_entity}__', '')
else:
console.print(
f"[[bold yellow]Skip[/bold yellow]] Metric '{metric.get('name')}'. "
f"[[bold yellow]Skip[/bold yellow]] "
f"Metric '{root_name if root_name else metric.get('name')}'. "
f"Dimension of foreign entities is not supported.")
statistics.add_field_one('nosupport')
return None
if m.calculation_method == 'median':
if m.calculation_method in ['sum_boolean', 'median', 'percentile']:
console.print(
f"[[bold yellow]Skip[/bold yellow]] Metric '{metric.get('name')}'. "
f"Aggregation type 'median' is not supported.")
f"Aggregation type '{m.calculation_method}' is not supported.")
statistics.add_field_one('nosupport')
return None

Expand All @@ -535,7 +567,11 @@ def _create_metric(name, filter=None, alias=None):
m2 = _create_metric(
ref_metric.get('name'),
filter=ref_metric.get('filter'),
alias=ref_metric.get('alias'))
alias=ref_metric.get('alias'),
root_name=root_name
)
if m2 is None:
return None
ref_metrics.append(m2)

derived_time_grains = find_derived_time_grains(manifest, metric_map[ref_metric.get('name')])
Expand All @@ -559,7 +595,11 @@ def _create_metric(name, filter=None, alias=None):
m2 = _create_metric(
numerator.get('name'),
filter=numerator.get('filter'),
alias=numerator.get('alias'))
alias=numerator.get('alias'),
root_name=root_name
)
if m2 is None:
return None
ref_metrics.append(m2)
derived_time_grains = find_derived_time_grains(manifest, metric_map[numerator.get('name')])
if len(time_grains) < len(derived_time_grains):
Expand All @@ -569,15 +609,20 @@ def _create_metric(name, filter=None, alias=None):
m2 = _create_metric(
denominator.get('name'),
filter=denominator.get('filter'),
alias=denominator.get('alias'))
alias=denominator.get('alias'),
root_name=root_name
)
if m2 is None:
return None
ref_metrics.append(m2)
derived_time_grains = find_derived_time_grains(manifest, metric_map[denominator.get('name')])
if len(time_grains) < len(derived_time_grains):
time_grains = derived_time_grains

m = Metric(metric.get('name'),
calculation_method='derived',
expression=f"{numerator.get('name')} / {denominator.get('name')}",
calculation_method='ratio',
numerator=numerator.get('name'),
denominator=denominator.get('name'),
time_grains=time_grains,
label=metric.get('label'), description=metric.get('description'), ref_metrics=ref_metrics,
ref_id=metric.get('unique_id'))
Expand Down
11 changes: 9 additions & 2 deletions piperider_cli/metrics_engine/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ def __init__(
calculation_method=None,
time_grains=None,
expression: str = None,
numerator: str = None,
denominator: str = None,
label=None,
description=None,
ref_metrics=None,
Expand All @@ -63,6 +65,8 @@ def __init__(
self.calculation_method = calculation_method
self.time_grains = time_grains
self.expression = expression
self.numerator = numerator
self.denominator = denominator
self.label = label
self.description = description
self.ref_metrics: List[Metric] = ref_metrics or []
Expand Down Expand Up @@ -111,7 +115,7 @@ def _compose_query_name(grain: str, dimensions: List[str], label=False) -> str:
def _get_query_stmt(self, metric: Metric, grain: str, dimension: List[str], date_spine_model: CTE):
metric_column_name = metric.name

if metric.calculation_method == 'derived':
if metric.calculation_method == 'derived' or metric.calculation_method == 'ratio':
selectable = None

# Join all parent metrics
Expand All @@ -123,8 +127,11 @@ def _get_query_stmt(self, metric: Metric, grain: str, dimension: List[str], date
else:
selectable = join(selectable, cte, selectable.c.d == cte.c.d)

# a / b / c -> a / nullif(b, 0) / nullif(c, 0)
expression = metric.expression
if metric.calculation_method == 'ratio':
expression = f"{metric.numerator}/{metric.denominator}"

# a / b / c -> a / nullif(b, 0) / nullif(c, 0)
if '/' in expression:
expression_list = expression.split('/')
dividend = expression_list[0]
Expand Down

0 comments on commit 3fb8ee4

Please sign in to comment.