Skip to content

Commit

Permalink
Bring over supervisor.py changes to check valid answer
Browse files Browse the repository at this point in the history
  • Loading branch information
CLeopard99 committed Oct 31, 2024
1 parent 24e0555 commit ce6e9e7
Showing 1 changed file with 21 additions and 12 deletions.
33 changes: 21 additions & 12 deletions backend/src/supervisors/supervisor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,13 @@ async def solve_all(intent_json) -> None:

for question in questions:
try:
(agent_name, answer, status) = await solve_task(question, get_scratchpad())
(agent_name, answer) = await solve_task(question, get_scratchpad())
update_scratchpad(agent_name, question, answer)
if status == "error":
raise Exception(answer)
except Exception as error:
update_scratchpad(error=error)


async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str, str]:
async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str]:
if attempt == 5:
raise Exception(unsolvable_response)

Expand All @@ -38,15 +36,26 @@ async def solve_task(task, scratchpad, attempt=0) -> Tuple[str, str, str]:
logger.info(f"Agent selected: {agent.name}")
logger.info(f"Task is {task}")
answer = await agent.invoke(task)
parsed_json = json.loads(answer)
status = parsed_json.get('status', 'success')
ignore_validation = parsed_json.get('ignore_validation', '')
answer_content = parsed_json.get('content', '')
if(ignore_validation == 'true') or await is_valid_answer(answer_content, task):
return (agent.name, answer_content, status)
if await is_valid_answer(agent, answer, task):
return (agent.name, answer)

return await solve_task(task, scratchpad, attempt + 1)


async def is_valid_answer(answer, task) -> bool:
is_valid = (await get_validator_agent().invoke(f"Task: {task} Answer: {answer}")).lower() == "true"
async def is_valid_answer(agent, answer, task) -> bool:
is_valid_result = await get_validator_agent().invoke(f"Task: {task} Answer: {answer}")
is_valid_result_json = json.loads(is_valid_result)
is_valid = is_valid_result_json["is_valid"]
if not is_valid:
logger.warning(f"Answer: {answer} for query: '{
task['query']}' is not valid")

if not is_valid and agent.name == "DatastoreAgent":
if answer == 'No database query':
update_scratchpad(agent_name=agent.name, result=is_valid, error=f'The task "{
task["query"]}" failed to generate cypher query, next time DO NOT use agent_name {agent.name} again')

Check failure on line 56 in backend/src/supervisors/supervisor.py

View workflow job for this annotation

GitHub Actions / Linting Backend

Ruff (E501)

backend/src/supervisors/supervisor.py:56:121: E501 Line too long (131 > 120)
else:
update_scratchpad(agent_name=agent.name, result=is_valid, error=f'The task "{
task["query"]}" generated cypher query but resulted in an invalid answer, next time try to use agent_name {agent.name} again')

Check failure on line 59 in backend/src/supervisors/supervisor.py

View workflow job for this annotation

GitHub Actions / Linting Backend

Ruff (E501)

backend/src/supervisors/supervisor.py:59:121: E501 Line too long (156 > 120)

return is_valid

0 comments on commit ce6e9e7

Please sign in to comment.