diff --git a/opendbt/dbt/__init__.py b/opendbt/dbt/__init__.py index 3dc9ca5..bd566a9 100644 --- a/opendbt/dbt/__init__.py +++ b/opendbt/dbt/__init__.py @@ -22,3 +22,7 @@ def patch_dbt(): f"Unsupported dbt version {dbt_version}, please make sure dbt version is supported/integrated by opendbt") # shared code patches + import opendbt.dbt.shared.cli.main + dbt.cli.main.sqlfluff = opendbt.dbt.shared.cli.main.sqlfluff + dbt.cli.main.sqlfluff_lint = opendbt.dbt.shared.cli.main.sqlfluff_lint + dbt.cli.main.sqlfluff_fix = opendbt.dbt.shared.cli.main.sqlfluff_fix diff --git a/opendbt/dbt/shared/cli/__init__.py b/opendbt/dbt/shared/cli/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/opendbt/dbt/shared/cli/main.py b/opendbt/dbt/shared/cli/main.py new file mode 100644 index 0000000..a5c1bef --- /dev/null +++ b/opendbt/dbt/shared/cli/main.py @@ -0,0 +1,103 @@ +import click +from dbt.cli import requires, params as p +from dbt.cli.main import global_flags, cli + +from opendbt.dbt.shared.task.sqlfluff import SqlFluffTasks + + +# dbt docs +@cli.group() +@click.pass_context +@global_flags +def sqlfluff(ctx, **kwargs): + """Generate or serve the documentation website for your project""" + + +# dbt docs generate +@sqlfluff.command("lint") +@click.pass_context +@global_flags +@p.defer +@p.deprecated_defer +@p.exclude +@p.favor_state +@p.deprecated_favor_state +@p.full_refresh +@p.indirect_selection +@p.profile +@p.profiles_dir +@p.project_dir +@p.resource_type +@p.select +@p.selector +@p.show +@p.state +@p.defer_state +@p.deprecated_state +@p.store_failures +@p.target +@p.target_path +@p.threads +@p.vars +@requires.postflight +@requires.preflight +@requires.profile +@requires.project +@requires.runtime_config +@requires.manifest(write=False) +def sqlfluff_lint(ctx, **kwargs): + """Generate the documentation website for your project""" + task = SqlFluffTasks( + ctx.obj["flags"], + ctx.obj["runtime_config"], + ctx.obj["manifest"], + ) + + results = task.lint() + success = task.interpret_results(results) + return results, success + + +# dbt docs generate +@sqlfluff.command("fix") +@click.pass_context +@global_flags +@p.defer +@p.deprecated_defer +@p.exclude +@p.favor_state +@p.deprecated_favor_state +@p.full_refresh +@p.indirect_selection +@p.profile +@p.profiles_dir +@p.project_dir +@p.resource_type +@p.select +@p.selector +@p.show +@p.state +@p.defer_state +@p.deprecated_state +@p.store_failures +@p.target +@p.target_path +@p.threads +@p.vars +@requires.postflight +@requires.preflight +@requires.profile +@requires.project +@requires.runtime_config +@requires.manifest(write=False) +def sqlfluff_fix(ctx, **kwargs): + """Generate the documentation website for your project""" + task = SqlFluffTasks( + ctx.obj["flags"], + ctx.obj["runtime_config"], + ctx.obj["manifest"], + ) + + results = task.fix() + success = task.interpret_results(results) + return results, success diff --git a/opendbt/dbt/shared/task/__init__.py b/opendbt/dbt/shared/task/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/opendbt/dbt/shared/task/sqlfluff.py b/opendbt/dbt/shared/task/sqlfluff.py new file mode 100644 index 0000000..6aa9598 --- /dev/null +++ b/opendbt/dbt/shared/task/sqlfluff.py @@ -0,0 +1,87 @@ +import os +from datetime import datetime +from pathlib import Path +from typing import Optional + +from dbt.config import RuntimeConfig +from dbt.contracts.results import ( + CatalogResults, + CatalogArtifact, RunExecutionResult, +) +from dbt.task.compile import CompileTask +from sqlfluff.cli import commands +from sqlfluff.core import Linter, FluffConfig +from sqlfluff.core.linter import LintingResult +from sqlfluff_templater_dbt import DbtTemplater + + +class SqlFluffTasks(CompileTask): + + def __init__(self, args, config, manifest): + super().__init__(args, config, manifest) + + self.sqlfluff_config = FluffConfig.from_path(path=self.config.project_root) + + templater_obj = self.sqlfluff_config._configs["core"]["templater_obj"] + if isinstance(templater_obj, DbtTemplater): + templater_obj: DbtTemplater + self.config: RuntimeConfig + templater_obj.project_root = self.config.project_root + templater_obj.working_dir = self.config.project_root + self.linter = Linter(self.sqlfluff_config) + + def get_result(self, elapsed_time: float, violations: list, num_violations: int): + run_result = RunExecutionResult( + results=[], + elapsed_time=elapsed_time, + generated_at=datetime.now(), + # args=dbt.utils.args_to_dict(self.args), + args={}, + ) + result = CatalogArtifact.from_results( + nodes={}, + sources={}, + generated_at=datetime.now(), + errors=violations if violations else None, + compile_results=run_result, + ) + if num_violations > 0: + setattr(result, 'exception', Exception(f"Linting {num_violations} errors found!")) + result.exception = Exception(f"Linting {num_violations} errors found!") + + return result + + def lint(self) -> CatalogArtifact: + os.chdir(self.config.project_root) + lint_result: LintingResult = self.linter.lint_paths(paths=(self.config.project_root,)) + result = self.get_result(lint_result.total_time, lint_result.get_violations(), lint_result.num_violations()) + if lint_result.num_violations() > 0: + print(f"Linting {lint_result.num_violations()} errors found!") + for error in lint_result.as_records(): + filepath = Path(error['filepath']) + violations: list = error['violations'] + if violations: + print(f"File: {filepath.relative_to(self.config.project_root)}") + for violation in violations: + print(f" {violation}") + # print(f"Code:{violation['code']} Line:{violation['start_line_no']}, LinePos:{violation['start_line_pos']} {violation['description']}") + return result + + def fix(self) -> CatalogArtifact: + os.chdir(self.config.project_root) + lnt, formatter = commands.get_linter_and_formatter(cfg=self.sqlfluff_config) + lint_result: LintingResult = lnt.lint_paths( + paths=(self.config.project_root,), + fix=True, + apply_fixes=True + ) + result = self.get_result(lint_result.total_time, [], 0) + return result + + @classmethod + def interpret_results(self, results: Optional[CatalogResults]) -> bool: + if results is None: + return False + if hasattr(results, "errors") and results.errors: + return False + return True diff --git a/setup.py b/setup.py index 21f1896..3b894e5 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ include_package_data=True, license="Apache License 2.0", test_suite='tests', - install_requires=["dbt-duckdb>=1.6"], + install_requires=["dbt-duckdb>=1.6", "sqlfluff", "sqlfluff-templater-dbt"], extras_require={ "airflow": ["apache-airflow"], "test": ["testcontainers>=3.7,<4.9"], diff --git a/tests/resources/dbttest/.sqlfluff b/tests/resources/dbttest/.sqlfluff new file mode 100644 index 0000000..b6b3e3d --- /dev/null +++ b/tests/resources/dbttest/.sqlfluff @@ -0,0 +1,53 @@ +[sqlfluff] +templater = dbt +dialect = duckdb +# This change (from jinja to dbt templater) will make linting slower +# because linting will first compile dbt code into data warehouse code. +runaway_limit = 1000 +max_line_length = 180 +indent_unit = space + +[sqlfluff:indentation] +tab_space_size = 4 + +[sqlfluff:layout:type:comma] +spacing_before = touch +line_position = trailing + +# For rule specific configuration, use dots between the names exactly +# as you would in .sqlfluff. In the background, SQLFluff will unpack the +# configuration paths accordingly. +[tool.sqlfluff.rules.capitalisation.keywords] +capitalisation_policy = "upper" + +# The default configuration for capitalisation rules is "consistent" +# which will auto-detect the setting from the rest of the file. This +# is less desirable in a new project and you may find this (slightly +# more strict) setting more useful. +# Typically we find users rely on syntax highlighting rather than +# capitalisation to distinguish between keywords and identifiers. +# Clearly, if your organisation has already settled on uppercase +# formatting for any of these syntax elements then set them to "upper". +# See https://stackoverflow.com/questions/608196/why-should-i-capitalize-my-sql-keywords-is-there-a-good-reason +[sqlfluff:rules:capitalisation.keywords] +capitalisation_policy = upper +[sqlfluff:rules:capitalisation.identifiers] +capitalisation_policy = upper +[sqlfluff:rules:capitalisation.functions] +extended_capitalisation_policy = upper +# [sqlfluff:rules:capitalisation.literals] +# capitalisation_policy = lower +[sqlfluff:rules:capitalisation.types] +extended_capitalisation_policy = upper + +[sqlfluff:rules:aliasing.table] +aliasing = explicit + +[sqlfluff:rules:aliasing.column] +aliasing = explicit + +[sqlfluff:rules:aliasing.expression] +allow_scalar = False + +[sqlfluff:rules:ambiguous.column_references] # Number in group by +group_by_and_order_by_style = implicit \ No newline at end of file diff --git a/tests/test_dbt_sqlfluff.py b/tests/test_dbt_sqlfluff.py new file mode 100644 index 0000000..ebea3ee --- /dev/null +++ b/tests/test_dbt_sqlfluff.py @@ -0,0 +1,18 @@ +from pathlib import Path +from unittest import TestCase + +from opendbt import OpenDbtProject + + +class TestDbtSqlFluff(TestCase): + RESOURCES_DIR = Path(__file__).parent.joinpath("resources") + DBTTEST_DIR = RESOURCES_DIR.joinpath("dbttest") + + def test_run_sqlfluff_lint(self): + dp = OpenDbtProject(project_dir=self.DBTTEST_DIR, profiles_dir=self.DBTTEST_DIR) + dp.run(command="sqlfluff", args=['fix']) + dp.run(command="sqlfluff", args=['lint']) + + def test_run_sqlfluff_fix(self): + dp = OpenDbtProject(project_dir=self.DBTTEST_DIR, profiles_dir=self.DBTTEST_DIR) + dp.run(command="sqlfluff", args=['fix'])