Skip to content

Commit

Permalink
Make run.py configurable with config file
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexejPenner committed Oct 31, 2024
1 parent 63575b7 commit a7b30fc
Showing 1 changed file with 36 additions and 18 deletions.
54 changes: 36 additions & 18 deletions llm-complete-guide/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,12 @@
default=False,
help="Generate chunks for Hugging Face dataset",
)
@click.option(
"--config",
"config",
default=None,
help="Generate chunks for Hugging Face dataset",
)
def main(
rag: bool = False,
deploy: bool = False,
Expand All @@ -159,6 +165,7 @@ def main(
argilla: bool = False,
reranked: bool = False,
chunks: bool = False,
config: str = None,
):
"""Main entry point for the pipeline execution.
Expand All @@ -170,11 +177,11 @@ def main(
model (str): The model to use for the completion. Default is OPENAI_MODEL.
no_cache (bool): If `True`, cache will be disabled.
synthetic (bool): If `True`, the synthetic data pipeline will be run.
local (bool): If `True`, the local LLM via Ollama will be used.
embeddings (bool): If `True`, the embeddings will be fine-tuned.
argilla (bool): If `True`, the Argilla annotations will be used.
chunks (bool): If `True`, the chunks pipeline will be run.
reranked (bool): If `True`, rerankers will be used
config (str: Path to config
"""
pipeline_args = {"enable_cache": not no_cache}
embeddings_finetune_args = {
Expand All @@ -196,13 +203,21 @@ def main(
md = Markdown(response)
console.print(md)

print(f"Running Pipeline with pipeline args: {pipeline_args}")
if rag:
config_path = None
if config:
config_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"configs",
"rag_local_dev.yaml",
)

if rag:
if not config_path:
config_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"configs",
"rag_local_dev.yaml",
)
llm_basic_rag.with_options(config_path=config_path, **pipeline_args)()
if deploy:
rag_deployment.with_options(
Expand All @@ -211,28 +226,31 @@ def main(
if deploy:
rag_deployment.with_options(**pipeline_args)()
if evaluation:
config_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"configs",
"rag_eval.yaml",
)
if not config_path:
config_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"configs",
"rag_eval.yaml",
)
pipeline_args["enable_cache"] = False
llm_eval.with_options(config_path=config_path)()
if synthetic:
config_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"configs",
"synthetic.yaml",
)
if not config_path:
config_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"configs",
"synthetic.yaml",
)
generate_synthetic_data.with_options(
config_path=config_path, **pipeline_args
)()
if embeddings:
config_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"configs",
"embeddings.yaml",
)
if not config_path:
config_path = os.path.join(
os.path.dirname(os.path.realpath(__file__)),
"configs",
"embeddings.yaml",
)
finetune_embeddings.with_options(
config_path=config_path, **embeddings_finetune_args
)()
Expand Down

0 comments on commit a7b30fc

Please sign in to comment.