Skip to content

Commit

Permalink
[Feature] sc-32422 Add option to select dbt profile and target
Browse files Browse the repository at this point in the history
Signed-off-by: Kent Huang <[email protected]>
  • Loading branch information
kentwelcome committed Oct 17, 2023
1 parent a1a9823 commit ca149ac
Show file tree
Hide file tree
Showing 7 changed files with 93 additions and 40 deletions.
21 changes: 17 additions & 4 deletions piperider_cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,15 @@ def dbt_select_option_builder():
dbt_related_options = [
click.option('--dbt-project-dir', type=click.Path(exists=True),
help='The path to the dbt project directory.'),
click.option('--dbt-target', type=click.STRING, default=None,
help='Specify which dbt target to load for the given dbt profile.'),
click.option('--dbt-profiles-dir', type=click.Path(exists=True), default=None,
help='Directory to search for dbt profiles.yml.'),
click.option('--dbt-profile', type=click.STRING, default=None,
help='Specify which dbt profile to load. Overrides setting in dbt_project.yml.'),
click.option('--no-auto-search', type=click.BOOL, default=False, is_flag=True,
help='Disable auto detection of dbt projects.'),

]

feature_flags = [
Expand Down Expand Up @@ -165,13 +170,18 @@ def init(**kwargs):
no_auto_search = kwargs.get('no_auto_search')
dbt_project_path = DbtUtil.get_dbt_project_path(dbt_project_dir, no_auto_search)
dbt_profiles_dir = kwargs.get('dbt_profiles_dir')
dbt_profile = kwargs.get('dbt_profile')
dbt_target = kwargs.get('dbt_target')
if dbt_project_path:
FileSystem.set_working_directory(dbt_project_path)

# TODO show the process and message to users
console.print(f'Initialize piperider to path {FileSystem.PIPERIDER_WORKSPACE_PATH}')

config = Initializer.exec(dbt_project_path=dbt_project_path, dbt_profiles_dir=dbt_profiles_dir)
config = Initializer.exec(dbt_project_path=dbt_project_path, dbt_profiles_dir=dbt_profiles_dir,
dbt_profile=dbt_profile,
dbt_target=dbt_target,
reload=False)
if kwargs.get('debug'):
for ds in config.dataSources:
console.rule('Configuration')
Expand All @@ -180,7 +190,7 @@ def init(**kwargs):
sys.exit(1)

# Show the content of config.yml
Initializer.show_config()
Initializer.show_config_file()


@cli.command(short_help='Check the configuraion and connection.', cls=TrackCommand)
Expand All @@ -196,10 +206,13 @@ def diagnose(**kwargs):
no_auto_search = kwargs.get('no_auto_search')
dbt_project_path = DbtUtil.get_dbt_project_path(dbt_project_dir, no_auto_search)
dbt_profiles_dir = kwargs.get('dbt_profiles_dir')
dbt_profile = kwargs.get('dbt_profile')
dbt_target = kwargs.get('dbt_target')
if dbt_project_path:
FileSystem.set_working_directory(dbt_project_path)
# Only run initializer when dbt project path is provided
Initializer.exec(dbt_project_path=dbt_project_path, dbt_profiles_dir=dbt_profiles_dir, interactive=False)
Initializer.exec(dbt_project_path=dbt_project_path, dbt_profiles_dir=dbt_profiles_dir, interactive=False,
dbt_profile=dbt_profile, dbt_target=dbt_target)
elif is_piperider_workspace_exist() is False:
raise DbtProjectNotFoundError()

Expand All @@ -208,7 +221,7 @@ def diagnose(**kwargs):
console.print(f'[bold dark_orange]PipeRider Version:[/bold dark_orange] {__version__}')

from piperider_cli.validator import Validator
if not Validator.diagnose():
if not Validator.diagnose(dbt_profile=dbt_profile, dbt_target=dbt_target):
sys.exit(1)


Expand Down
7 changes: 6 additions & 1 deletion piperider_cli/cli_utils/run_cmd.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,13 +57,16 @@ def run(**kwargs):
no_auto_search = kwargs.get('no_auto_search')
dbt_project_path = DbtUtil.get_dbt_project_path(dbt_project_dir, no_auto_search, recursive=False)
dbt_profiles_dir = kwargs.get('dbt_profiles_dir')
dbt_target = kwargs.get('dbt_target')
dbt_profile = kwargs.get('dbt_profile')
if dbt_project_path:
working_dir = os.path.dirname(dbt_project_path) if dbt_project_path.endswith('.yml') else dbt_project_path
FileSystem.set_working_directory(working_dir)
if dbt_profiles_dir:
FileSystem.set_dbt_profiles_dir(dbt_profiles_dir)
# Only run initializer when dbt project path is provided
Initializer.exec(dbt_project_path=dbt_project_path, dbt_profiles_dir=dbt_profiles_dir, interactive=False)
Initializer.exec(dbt_project_path=dbt_project_path, dbt_profiles_dir=dbt_profiles_dir, interactive=False,
dbt_profile=dbt_profile, dbt_target=dbt_target)
elif is_piperider_workspace_exist() is False:
raise DbtProjectNotFoundError()

Expand All @@ -85,6 +88,8 @@ def run(**kwargs):
skip_report=skip_report,
dbt_target_path=dbt_target_path,
dbt_resources=dbt_resources,
dbt_profile=dbt_profile,
dbt_target=dbt_target,
dbt_select=select,
dbt_state=state,
report_dir=kwargs.get('report_dir'),
Expand Down
1 change: 1 addition & 0 deletions piperider_cli/cloud/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(self):
except BaseException:
self.available = False
self.me = None
self.config: dict = {}

def update_config(self, options: dict):
self.service.update_config(options)
Expand Down
52 changes: 33 additions & 19 deletions piperider_cli/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,10 +377,12 @@ def update_config(key: str, update_values: Union[dict, str]):
_yml.dump(config, f)

@classmethod
def from_dbt_project(cls, dbt_project_path, dbt_profiles_dir=None):
def from_dbt_project(cls, dbt_project_path, dbt_profiles_dir=None, dbt_profile: str = None, dbt_target: str = None):
"""
build configuration from the existing dbt project
:param dbt_target:
:param dbt_profile:
:param dbt_project_path:
:param dbt_profiles_dir:
:return:
Expand All @@ -404,23 +406,23 @@ def from_dbt_project(cls, dbt_project_path, dbt_profiles_dir=None):
if not os.path.exists(os.path.expanduser(dbt_profile_path)):
raise DbtProfileNotFoundError(dbt_profile_path)

dbt_profile = DbtUtil.load_dbt_profile(os.path.expanduser(dbt_profile_path))
profile = DbtUtil.load_dbt_profile(os.path.expanduser(dbt_profile_path))

console = Console()
profile_name = dbt_project.get('profile', '')
profile_content = dbt_profile.get(profile_name, None)
profile_name = dbt_project.get('profile', '') if dbt_profile is None else dbt_profile
profile_content = profile.get(profile_name, None)

if profile_content is None:
console.print("[bold red]Error:[/bold red] "
f"Could not find profile named '{profile_name}'")
f"Could not find profile named \"{profile_name}\" in 'dbt_project.yml'.")
sys.exit(1)
target_name = profile_content.get('target', 'default')
if target_name not in list(dbt_profile.get(profile_name, {}).get('outputs', {}).keys()):
target_name = profile_content.get('target', 'default') if dbt_target is None else dbt_target
if target_name not in list(profile.get(profile_name, {}).get('outputs', {}).keys()):
console.print("[bold red]Error:[/bold red] "
f"The profile '{profile_name}' does not have a target named '{target_name}'.\n"
"Please check the dbt profile format.")
sys.exit(1)
credential = DbtUtil.load_credential_from_dbt_profile(dbt_profile, profile_name, target_name)
credential = DbtUtil.load_credential_from_dbt_profile(profile, profile_name, target_name)
type_name = credential.get('type')
dbt = {
'projectDir': os.path.relpath(os.path.dirname(dbt_project_path), FileSystem.WORKING_DIRECTORY),
Expand All @@ -429,6 +431,12 @@ def from_dbt_project(cls, dbt_project_path, dbt_profiles_dir=None):
if dbt_profiles_dir:
dbt['profilesDir'] = dbt_profiles_dir

if dbt_profile:
dbt['profile'] = dbt_profile

if dbt_target:
dbt['target'] = dbt_target

if type_name not in DATASOURCE_PROVIDERS:
console.print(f"[[bold yellow]WARNING[/bold yellow]] Unsupported data source type '{type_name}' is found")
# raise PipeRiderInvalidDataSourceError(type_name, dbt_profile_path)
Expand All @@ -445,15 +453,18 @@ def from_dbt_project(cls, dbt_project_path, dbt_profiles_dir=None):
return cls(dataSources=[datasource])

@classmethod
def instance(cls, piperider_config_path=None):
def instance(cls, piperider_config_path=None, dbt_profile: str = None, dbt_target: str = None, reload: bool = True):
piperider_working_directory = cls.search_piperider_project_path()
if piperider_working_directory:
FileSystem.set_working_directory(piperider_working_directory)
piperider_config_path = piperider_config_path or FileSystem.PIPERIDER_CONFIG_PATH
global configuration_instance
if configuration_instance is not None:
return configuration_instance
configuration_instance = cls._load(piperider_config_path)
if configuration_instance:
dbt = configuration_instance.dbt
if reload is False or dbt is None or (
dbt.get('profile') == dbt_profile and dbt.get('target') == dbt_target):
return configuration_instance
configuration_instance = cls._load(piperider_config_path, dbt_profile=dbt_profile, dbt_target=dbt_target)
return configuration_instance

@classmethod
Expand All @@ -470,7 +481,7 @@ def search_piperider_project_path(cls) -> str:
None)

@classmethod
def _load(cls, piperider_config_path=None):
def _load(cls, piperider_config_path=None, dbt_profile: str = None, dbt_target: str = None):
"""
load from the existing configuration
Expand Down Expand Up @@ -509,8 +520,8 @@ def _load(cls, piperider_config_path=None):
if '~' in profile_path:
profile_path = os.path.expanduser(profile_path)
profile = DbtUtil.load_dbt_profile(profile_path)
profile_name = dbt.get('profile')
target_name = dbt.get('target')
profile_name = dbt_profile if dbt_profile else dbt.get('profile')
target_name = dbt_target if dbt_target else dbt.get('target')
credential.update(DbtUtil.load_credential_from_dbt_profile(profile, profile_name, target_name))
# TODO: extract duplicate code from func 'from_dbt_project'
if credential.get('pass') and credential.get('password') is None:
Expand All @@ -533,7 +544,7 @@ def _load(cls, piperider_config_path=None):
if dbt:
project_dir = config.get('dbt').get('projectDir')
project = DbtUtil.load_dbt_project(project_dir)
profile_name = project.get('profile')
profile_name = dbt_profile if dbt_profile else dbt.get('profile', project.get('profile'))

# Precedence reference
# https://docs.getdbt.com/docs/get-started/connection-profiles#advanced-customizing-a-profile-directory
Expand All @@ -555,13 +566,16 @@ def _load(cls, piperider_config_path=None):
datasource_class = DATASOURCE_PROVIDERS[credential.get('type')]
data_source = datasource_class(
name=target,
dbt=dict(**dbt, profile=profile_name, target=target),
dbt=dbt.update(profile=profile_name, target=target),
credential=credential
)
data_sources.append(data_source)
# dbt behavior: dbt uses 'default' as target name if no target given in profiles.yml
dbt['target'] = profile.get(profile_name).get('target', 'default')
if dbt['target'] not in target_names:
dbt['target'] = dbt_target if dbt_target else dbt.get('target',
profile.get(profile_name).get('target',
'default'))

if dbt.get('target') not in target_names:
console = Console()
console.print("[bold red]Error:[/bold red] "
f"The profile '{profile_name}' does not have a target named '{dbt['target']}'.\n"
Expand Down
28 changes: 18 additions & 10 deletions piperider_cli/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,16 @@ def _ask_user_input_datasource(config: Configuration = None):
return config


def _inherit_datasource_from_dbt_project(dbt_project_path, dbt_profiles_dir=None, interactive=True):
def _inherit_datasource_from_dbt_project(dbt_project_path, dbt_profiles_dir=None, dbt_profile: str = None,
dbt_target: str = None, interactive=True):
config = safe_load_yaml(FileSystem.PIPERIDER_CONFIG_PATH)
if config and config.get('dataSources'):
if interactive is True:
console.print('[[bold yellow]Warning[/bold yellow]] Found existing configuration. Skip initialization.')
return config

dbt_config = Configuration.from_dbt_project(dbt_project_path, dbt_profiles_dir)
dbt_config = Configuration.from_dbt_project(dbt_project_path, dbt_profiles_dir, dbt_profile=dbt_profile,
dbt_target=dbt_target)
_generate_piperider_workspace()
dbt_config.dump(FileSystem.PIPERIDER_CONFIG_PATH)

Expand All @@ -88,7 +90,8 @@ def _inherit_datasource_from_dbt_project(dbt_project_path, dbt_profiles_dir=None
return dbt_config


def _generate_configuration(dbt_project_path=None, dbt_profiles_dir=None, interactive=True):
def _generate_configuration(dbt_project_path=None, dbt_profiles_dir=None, dbt_profile: str = None,
dbt_target: str = None, interactive=True, ):
"""
:param dbt_project_path:
:return: Configuration object
Expand All @@ -105,35 +108,40 @@ def _generate_configuration(dbt_project_path=None, dbt_profiles_dir=None, intera
# TODO: mark as deprecated in the future
console.rule('Deprecated', style='bold red')
console.print(
'Non-dbt project is deprecated and will be removed in the future. If you have a strong need for non-dbt project, please contact us by "piperider feedback".\n')
'Non-dbt project is deprecated and will be removed in the future. If you have a strong need for non-dbt '
'project, please contact us by "piperider feedback".\n')
return _ask_user_input_datasource(config=config)

if config is not None:
if interactive:
console.print('[[bold yellow]Warning[/bold yellow]] Found existing configuration. Skip initialization.')
return config

return _inherit_datasource_from_dbt_project(dbt_project_path, dbt_profiles_dir, interactive)
return _inherit_datasource_from_dbt_project(dbt_project_path, dbt_profiles_dir, interactive=interactive,
dbt_profile=dbt_profile,
dbt_target=dbt_target)


class Initializer():
class Initializer:
@staticmethod
def exec(working_dir=None, dbt_project_path=None, dbt_profiles_dir=None, interactive=True):
def exec(working_dir=None, dbt_project_path=None, dbt_profiles_dir=None, dbt_profile: str = None,
dbt_target: str = None, interactive=True, reload: bool = True):
if working_dir is None:
working_dir = FileSystem.PIPERIDER_WORKSPACE_PATH

if _is_piperider_workspace_exist(working_dir) and interactive is True:
console.print('[bold green]Piperider workspace already exist[/bold green] ')

# get Configuration object from dbt or user created configuration
_generate_configuration(dbt_project_path, dbt_profiles_dir, interactive)
configuration = Configuration.instance()
_generate_configuration(dbt_project_path, dbt_profiles_dir, interactive=interactive, dbt_profile=dbt_profile,
dbt_target=dbt_target)
configuration = Configuration.instance(dbt_profile=dbt_profile, dbt_target=dbt_target, reload=reload)
configuration.activate_report_directory()

return configuration

@staticmethod
def show_config():
def show_config_file():

# show config.yml
with open(FileSystem.PIPERIDER_CONFIG_PATH, 'r') as f:
Expand Down
11 changes: 9 additions & 2 deletions piperider_cli/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -621,13 +621,14 @@ def _exec(command_line: str):
class Runner:
@staticmethod
def exec(datasource=None, table=None, output=None, skip_report=False, dbt_target_path: str = None,
dbt_resources: Optional[dict] = None, dbt_select: tuple = None, dbt_state: str = None,
dbt_resources: Optional[dict] = None, dbt_profile: str = None, dbt_target: str = None,
dbt_select: tuple = None, dbt_state: str = None,
report_dir: str = None, skip_datasource_connection: bool = False, event_payload=RunEventPayload()):
console = Console()

raise_exception_when_directory_not_writable(output)

configuration = Configuration.instance()
configuration = Configuration.instance(dbt_profile=dbt_profile, dbt_target=dbt_target)
filesystem = configuration.activate_report_directory(report_dir=report_dir)
ds = configuration.get_datasource(datasource)
if ds is None:
Expand All @@ -643,6 +644,12 @@ def exec(datasource=None, table=None, output=None, skip_report=False, dbt_target
if skip_datasource_connection:
event_payload.skip_datasource = True

console.rule('DBT')
console.print('Profile: ', style='bold', end='')
console.print(f'{configuration.dbt.get("profile")}', style='cyan')
console.print('Target: ', style='bold', end='')
console.print(f'{configuration.dbt.get("target")}', style='cyan')

# Validating
console.rule('Validating')
event_payload.step = 'validate'
Expand Down
13 changes: 9 additions & 4 deletions piperider_cli/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,18 +25,23 @@ def check_function(self, configurator: Configuration) -> (bool, str):


class CheckingHandler(object):
def __init__(self):
def __init__(self, dbt_profile=None, dbt_target=None):
self.configurator = None
self.checker_chain = []
self.console = Console()
self.dbt = {
'profile': dbt_profile,
'target': dbt_target
}

def set_checker(self, name: str, checker: AbstractChecker):
self.checker_chain.append({'name': name, 'cls': checker()})

def execute(self):
if not self.configurator:
try:
self.configurator = Configuration.instance()
self.configurator = Configuration.instance(dbt_profile=self.dbt.get('profile'),
dbt_target=self.dbt.get('target'))
self.configurator.activate_report_directory()
except Exception:
pass
Expand Down Expand Up @@ -177,8 +182,8 @@ def check_function(self, configurator: Configuration) -> (bool, str):

class Validator():
@staticmethod
def diagnose():
handler = CheckingHandler()
def diagnose(dbt_profile: str = None, dbt_target: str = None):
handler = CheckingHandler(dbt_profile=dbt_profile, dbt_target=dbt_target)
handler.set_checker('config files', CheckConfiguration)
handler.set_checker('format of data sources', CheckDataSources)
handler.set_checker('connections', CheckConnections)
Expand Down

0 comments on commit ca149ac

Please sign in to comment.