Skip to content

Commit

Permalink
finalisation of rag_testing_tool.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Morgan Diverrez committed Sep 4, 2024
1 parent eaea66a commit af75604
Showing 1 changed file with 41 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@
from gen_ai_orchestrator.services.langchain.rag_chain import create_rag_chain
from langfuse import Langfuse
from langsmith import Client
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential, )


def test_rag(args):
Expand Down Expand Up @@ -90,6 +94,41 @@ def _construct_chain():
'chat_history': lambda x: x['chat_history'] if 'chat_history' in x else [],
} | create_rag_chain(RagQuery(**rag_query))

@retry(wait=wait_random_exponential(min=10, max=60), stop=stop_after_attempt(5))
def run_dataset(run_name_dataset):

if args['<dataset_provider>'].lower() == 'langsmith':
client = Client()
client.run_on_dataset(

dataset_name=args['<dataset_name>'],
llm_or_chain_factory=_construct_chain,
project_name=run_name_dataset,
project_metadata={
'index_session_id': index_session_id,
'k': k,
},
concurrency_level=concurrency_level,
)
elif args['<dataset_provider>'].lower() == 'langfuse':
client = Langfuse()
dataset = client.get_dataset(args['<dataset_name>'])

for item in dataset.items:
callback_handlers = []
handler = item.get_langchain_handler(
run_name=run_name_dataset,
run_metadata={
'index_session_id': index_session_id,
'k': k,
},
)
callback_handlers.append(handler)
_construct_chain().invoke(
item.input, config={'callbacks': callback_handlers}
)
client.flush()

search_params = rag_query['document_search_params']
index_session_id = search_params['filter'][0]['term'][
'metadata.index_session_id.keyword'
Expand All @@ -102,36 +141,8 @@ def _construct_chain():
# one at a time
if args['<delay>']:
concurrency_level = 1
if args['<dataset_provider>'].lower() == 'langsmith':
client = Client()
client.run_on_dataset(
dataset_name=args['<dataset_name>'],
llm_or_chain_factory=_construct_chain,
project_name=args['<test_name>'] + '-' + str(uuid4())[:8],
project_metadata={
'index_session_id': index_session_id,
'k': k,
},
concurrency_level=concurrency_level,
)
elif args['<dataset_provider>'].lower() == 'langfuse':
client = Langfuse()
dataset = client.get_dataset(args['<dataset_name>'])
run_name_dataset = args['<test_name>'] + '-' + str(uuid4())[:8]
for item in dataset.items:
callback_handlers = []
handler = item.get_langchain_handler(
run_name=run_name_dataset,
run_metadata={
'index_session_id': index_session_id,
'k': k,
},
)
callback_handlers.append(handler)
_construct_chain().invoke(
item.input, config={'callbacks': callback_handlers}
)
client.flush()
run_name_dataset = args['<test_name>'] + '-' + str(uuid4())[:8]
run_dataset(run_name_dataset)

duration = datetime.now() - start_time
hours, remainder = divmod(duration.seconds, 3600)
Expand All @@ -155,13 +166,6 @@ def _construct_chain():
load_dotenv()
if cli_args['<dataset_provider>'].lower() == 'langsmith':
# Check env (LangSmith)
langchain_endpoint = os.getenv('LANGCHAIN_ENDPOINT')
if not langchain_endpoint:
logging.error(
'Cannot proceed: LANGCHAIN_ENDPOINT env variable is not defined (define it in a .env file)'
)
sys.exit(1)

langchain_apikey = os.getenv('LANGCHAIN_API_KEY')
if not langchain_apikey:
logging.error(
Expand Down

0 comments on commit af75604

Please sign in to comment.