From aceceac67bde69697baf3fd192eb5cb4afbd8f17 Mon Sep 17 00:00:00 2001 From: Morgan Diverrez Date: Wed, 4 Sep 2024 12:39:19 +0200 Subject: [PATCH] finalisation of rag_testing_tool.py --- .../rag_testing_tool.py | 78 ++++++++++--------- 1 file changed, 41 insertions(+), 37 deletions(-) diff --git a/gen-ai/orchestrator-server/src/main/python/tock-llm-indexing-tools/rag_testing_tool.py b/gen-ai/orchestrator-server/src/main/python/tock-llm-indexing-tools/rag_testing_tool.py index 206734f220..0bc8da1607 100644 --- a/gen-ai/orchestrator-server/src/main/python/tock-llm-indexing-tools/rag_testing_tool.py +++ b/gen-ai/orchestrator-server/src/main/python/tock-llm-indexing-tools/rag_testing_tool.py @@ -57,6 +57,10 @@ from gen_ai_orchestrator.services.langchain.rag_chain import create_rag_chain from langfuse import Langfuse from langsmith import Client +from tenacity import ( + retry, + stop_after_attempt, + wait_random_exponential, ) def test_rag(args): @@ -90,6 +94,41 @@ def _construct_chain(): 'chat_history': lambda x: x['chat_history'] if 'chat_history' in x else [], } | create_rag_chain(RagQuery(**rag_query)) + @retry(wait=wait_random_exponential(min=10, max=60), stop=stop_after_attempt(5)) + def run_dataset(run_name_dataset): + + if args[''].lower() == 'langsmith': + client = Client() + client.run_on_dataset( + + dataset_name=args[''], + llm_or_chain_factory=_construct_chain, + project_name=run_name_dataset, + project_metadata={ + 'index_session_id': index_session_id, + 'k': k, + }, + concurrency_level=concurrency_level, + ) + elif args[''].lower() == 'langfuse': + client = Langfuse() + dataset = client.get_dataset(args['']) + + for item in dataset.items: + callback_handlers = [] + handler = item.get_langchain_handler( + run_name=run_name_dataset, + run_metadata={ + 'index_session_id': index_session_id, + 'k': k, + }, + ) + callback_handlers.append(handler) + _construct_chain().invoke( + item.input, config={'callbacks': callback_handlers} + ) + client.flush() + search_params = rag_query['document_search_params'] index_session_id = search_params['filter'][0]['term'][ 'metadata.index_session_id.keyword' @@ -102,36 +141,8 @@ def _construct_chain(): # one at a time if args['']: concurrency_level = 1 - if args[''].lower() == 'langsmith': - client = Client() - client.run_on_dataset( - dataset_name=args[''], - llm_or_chain_factory=_construct_chain, - project_name=args[''] + '-' + str(uuid4())[:8], - project_metadata={ - 'index_session_id': index_session_id, - 'k': k, - }, - concurrency_level=concurrency_level, - ) - elif args[''].lower() == 'langfuse': - client = Langfuse() - dataset = client.get_dataset(args['']) - run_name_dataset = args[''] + '-' + str(uuid4())[:8] - for item in dataset.items: - callback_handlers = [] - handler = item.get_langchain_handler( - run_name=run_name_dataset, - run_metadata={ - 'index_session_id': index_session_id, - 'k': k, - }, - ) - callback_handlers.append(handler) - _construct_chain().invoke( - item.input, config={'callbacks': callback_handlers} - ) - client.flush() + run_name_dataset = args[''] + '-' + str(uuid4())[:8] + run_dataset(run_name_dataset) duration = datetime.now() - start_time hours, remainder = divmod(duration.seconds, 3600) @@ -155,13 +166,6 @@ def _construct_chain(): load_dotenv() if cli_args[''].lower() == 'langsmith': # Check env (LangSmith) - langchain_endpoint = os.getenv('LANGCHAIN_ENDPOINT') - if not langchain_endpoint: - logging.error( - 'Cannot proceed: LANGCHAIN_ENDPOINT env variable is not defined (define it in a .env file)' - ) - sys.exit(1) - langchain_apikey = os.getenv('LANGCHAIN_API_KEY') if not langchain_apikey: logging.error(