Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fs 80/use https for webagent search #21

Merged
merged 15 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions backend/src/agents/web_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
summarise_pdf_content,
find_info,
create_search_term,
answer_user_ques
answer_user_question
)
from .validator_agent import ValidatorAgent
import aiohttp
Expand All @@ -29,7 +29,7 @@
async def web_general_search_core(search_query, llm, model) -> str:
try:
# Step 1: Generate the search term from the user's query
answer_to_user = await answer_user_ques(search_query, llm, model)
answer_to_user = await answer_user_question(search_query, llm, model)
answer_result = json.loads(answer_to_user)
if answer_result["status"] == "error":
response = {
Expand All @@ -38,8 +38,8 @@ async def web_general_search_core(search_query, llm, model) -> str:
}
return json.dumps(response, indent=4)
logger.info(f'Answer found successfully {answer_result}')
valid_answer = json.loads(answer_result["response"]).get("is_valid", "")
if valid_answer:
should_perform_web_search = json.loads(answer_result["response"]).get("should_perform_web_search", "")
if not should_perform_web_search:
final_answer = json.loads(answer_result["response"]).get("answer", "")
if not final_answer:
return "No answer found."
Expand Down Expand Up @@ -239,6 +239,8 @@ async def perform_search(search_query: str, num_results: int) -> Dict[str, Any]:

async def perform_scrape(url: str) -> str:
try:
if not str(url).startswith("https"):
return ""
scrape_result_json = await scrape_content(url)
scrape_result = json.loads(scrape_result_json)
if scrape_result["status"] == "error":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,18 @@ User's question is:
Once you generate an answer:
- **Check** if the answer completely and accurately addresses the user's question.
- **Determine** if the answer is valid, based on the content provided.
- **Provide** a reason for the validity or invalidity of the answer.
- **Indicate** if a more general web search is required then indicate that.

Reply only in JSON format with the following structure:

```json
{
"answer": "The answer to the user's question, based on the content provided",
"is_valid": true or false,
"validation_reason": "A sentence explaining whether the answer is valid or not, and why"
"should_perform_web_search": true or false,
"perform_web_search_reason": "A sentence explaining whether the answer is valid or not, and why"
}



### **Explanation:**

1. **Answer**: The LLM generates an answer based on the user’s question and the provided content.
2. **Validity Check**: The LLM checks if its generated answer is complete and correct. This could be based on factual accuracy, coverage of the query, or relevance to the user's question.
3. **Validation Reason**: The LLM explains why the answer is valid or invalid.

### **Example of Usage:**

#### **User’s Question:**
Expand All @@ -40,15 +34,14 @@ Reply only in JSON format with the following structure:
```json
{
"answer": "Tesla's revenue since its creation is: 2008: $15 million, 2009: $30 million, ..., 2023: $81 billion.",
"is_valid": true,
"validation_reason": "The answer includes Tesla's revenue for every year since its creation, based on the data provided."
"should_perform_web_search": false,
"perform_web_search_reason": "The answer includes Tesla's revenue for every year since its creation, based on the data provided. No further search is required."
}

{
"answer": "Tesla's revenue for 2010 to 2023 is available, but data for the earlier years is missing.",
"is_valid": false,
"validation_reason": "The answer is incomplete because data for Tesla's early years is missing."
"should_perform_web_search": true,
"perform_web_search_reason": "The answer is incomplete because data for Tesla's early years is missing. A more general web search is required."
}


Important: If the question is related to real time data, the LLM should provide is_valid is false.
Important: If the question is related to real time data, the LLM should provide should_perform_web_search is true.
10 changes: 4 additions & 6 deletions backend/src/utils/web_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,12 @@

async def search_urls(search_query, num_results=10) -> str:
logger.info(f"Searching the web for: {search_query}")
urls = []
try:
for url in search(search_query, num_results=num_results):
urls.append(url)
https_urls = [str(url) for url in search(search_query, num_results=num_results) if str(url).startswith("https")]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it possible to test this by mocking out search library?

return json.dumps(
{
"status": "success",
"urls": urls,
"urls": https_urls,
"error": None,
}
)
Expand Down Expand Up @@ -83,9 +81,9 @@ async def create_search_term(search_query, llm, model) -> str:
}
)

async def answer_user_ques(search_query, llm, model) -> str:
async def answer_user_question(search_query, llm, model) -> str:
try:
summariser_prompt = engine.load_prompt("answer-user-ques", question=search_query)
summariser_prompt = engine.load_prompt("answer-user-question", question=search_query)
response = await llm.chat(model, summariser_prompt, "", return_json=True)
return json.dumps(
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@ Examples:
|prompt |expected_response |
|Check the database and tell me the average ESG score (Environmental) for the WhiteRock ETF fund |The average ESG score (Environmental) for the WhiteRock ETF fund is approximately 69.67|
|Check the database and tell me the fund with the highest ESG social score |Dynamic Industries with a score of 91|
|Check the database and tell me the fund with the lowest Governance ESG score |Dynamic Industries, which has a score of 60|
# |Check the database and tell me the fund with the lowest ESG score |Dynamic Industries with a score of 50|
# |Check the database and tell me the largest fund |The largest fund is the Global Energy Fund, which has a size of 1,500|
|Check the database and tell me the fund with the lowest Governance ESG score |Dynamics Industries, Silvermans Global ETF, WhiteRocks ETF, which has a score of 60|
|Check the database and tell me the fund with the lowest ESG score |Dynamic Industries, Silverman Global ETF, WhiteRock ETF, with a score of 50|
|Check the database and tell me the largest fund |The largest fund is the Global Energy Fund, which has a size of 1,500|
# |Check the database and tell me which funds contain Shell |Funds containing Shell are European Growth Fund, Global Energy Fund, Silverman Global ETF and WhiteRock ETF|


Expand Down
72 changes: 38 additions & 34 deletions backend/tests/BDD/step_defs/test_prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,44 +39,48 @@ def check_response_includes_expected_response(context, prompt, expected_response
response = send_prompt(prompt)
actual_response = response.json()

try:
expected_value = Decimal(str(expected_response).strip())
actual_value = Decimal(str(actual_response).strip())

tolerance = Decimal("0.01")
is_equal = abs(expected_value - actual_value) <= tolerance

if not is_equal:
pytest.fail(f"\nNumeric values don't match!\n" f"Expected: {expected_value}\n" f"Actual: {actual_value}")

except (ValueError, decimal.InvalidOperation):
expected_str = str(expected_response).strip()
actual_str = str(actual_response).strip()

logger.info(f"Expected : {expected_str} \nActual: {actual_str}")

if actual_str.find(expected_str) == -1:
result = correctness_evaluator.evaluate_strings(
input=prompt,
prediction=expected_str,
reference=actual_str,
)

if result["value"] == "N":
logger.error(
f"\nTest failed!\n"
f"Expected: {expected_str}\n"
f"Actual: {actual_str}\n"
f"Reasoning: {result.get('reasoning', 'No reasoning provided')}"
)

assert result["value"] == "Y", (
# Allow `expected_response` to be a list of possible valid responses
possible_responses = [resp.strip() for resp in expected_response.split(",")]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good that this now support multiple possible answers.
I guess the implication of splitting on , and then doing a contains check for each substring though is that the resulting check is less strict. E.g. for a question like
Check the database and tell me the fund with the lowest Governance ESG score with an expected response of
Dynamics Industries, Silvermans Global ETF, WhiteRocks ETF, which has a score of 60.

This would pass for any answer that contains one of those fund names or which contains which has a score of 60.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mic-smith - I get your point, so what do you suggest we should do?

Copy link
Collaborator

@mic-smith mic-smith Nov 20, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess we could repeat the whole string in each option e.g. Dynamics Industries which has a score of 60, Silvermans Global ETF which has a score of 60, WhiteRocks ETF which has a score of 60 maybe with a different separator than , so it's more obvious? Though not sure how well that would work with the case we need to use an LLM because the response isn't a substring.

We're also commenting out these datasource based tests in #23 . So, I wonder if it's worth pulling these test changes into in a separate branch so it doesn't block this PR and then addressing once we have the new flow to upload data and a different data set available?


match_found = False
for expected_resp in possible_responses:
try:
expected_value = Decimal(expected_resp)
actual_value = Decimal(str(actual_response).strip())

tolerance = Decimal("0.01")
if abs(expected_value - actual_value) <= tolerance:
match_found = True
break # Exit loop if a match is found

except (ValueError, decimal.InvalidOperation):
if expected_resp in str(actual_response).strip():
match_found = True
break

if not match_found:
# Fallback to the correctness evaluator if none of the options matched
result = correctness_evaluator.evaluate_strings(
input=prompt,
prediction=expected_response,
reference=actual_response,
)

if result["value"] == "N":
logger.error(
f"\nTest failed!\n"
f"Expected: {expected_str}\n"
f"Actual: {actual_str}\n"
f"Expected one of: {possible_responses}\n"
f"Actual: {actual_response}\n"
f"Reasoning: {result.get('reasoning', 'No reasoning provided')}"
)

assert result["value"] == "Y", (
f"\nTest failed!\n"
f"Expected one of: {possible_responses}\n"
f"Actual: {actual_response}\n"
f"Reasoning: {result.get('reasoning', 'No reasoning provided')}"
)


@then(parsers.parse("the response to this '{prompt}' should give a confident answer"))
def check_bot_response_confidence(prompt):
Expand Down
45 changes: 39 additions & 6 deletions backend/tests/agents/web_agent_test.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import pytest
from unittest.mock import patch, AsyncMock
import json
from src.agents.web_agent import web_general_search_core
from src.agents.web_agent import web_general_search_core, perform_scrape
from src.utils.web_utils import search_urls

@pytest.mark.asyncio
@patch("src.agents.web_agent.answer_user_ques", new_callable=AsyncMock)
@patch("src.agents.web_agent.answer_user_question", new_callable=AsyncMock)
@patch("src.agents.web_agent.create_search_term", new_callable=AsyncMock)
@patch("src.agents.web_agent.perform_search", new_callable=AsyncMock)
@patch("src.agents.web_agent.perform_scrape", new_callable=AsyncMock)
Expand All @@ -16,13 +17,13 @@ async def test_web_general_search_core(
mock_perform_scrape,
mock_perform_search,
mock_create_search_term,
mock_answer_user_ques
mock_answer_user_question
):
llm = AsyncMock()
model = "mock_model"

# Mocking answer_user_ques to return a valid answer
mock_answer_user_ques.return_value = json.dumps({
# Mocking answer_user_question to return a valid answer
mock_answer_user_question.return_value = json.dumps({
"status": "success",
"response": json.dumps({"is_valid": True, "answer": "Example summary."})
})
Expand Down Expand Up @@ -70,7 +71,7 @@ async def test_web_general_search_core_invalid_summary(
):
llm = AsyncMock()
model = "mock_model"
mock_perform_search.return_value = {"status": "success", "urls": ["http://example.com"]}
mock_perform_search.return_value = {"status": "success", "urls": ["https://example.com"], "error": None}
mock_perform_scrape.return_value = "Example scraped content."
mock_perform_summarization.return_value = json.dumps({"summary": "Example invalid summary."})
mock_is_valid_answer.return_value = False
Expand All @@ -81,3 +82,35 @@ async def test_web_general_search_core_invalid_summary(
}
assert json.loads(result) == expected_response

@pytest.mark.asyncio
@patch("src.utils.web_utils.search")
async def test_https_urls(mock_search):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

mock_search.return_value = [
"https://example.com",
"http://nonsecure.com",
"https://another-secure-site.com"
]

result = await search_urls("query", num_results=5)
expected_result = {
"status": "success",
"urls": ["https://example.com", "https://another-secure-site.com"],
"error": None
}
assert json.loads(result) == expected_result

@pytest.mark.asyncio
@patch("src.agents.web_agent.scrape_content", new_callable=AsyncMock)
async def test_perform_scrape_http_url(mock_scrape_content):
mock_scrape_content.return_value = json.dumps({"status": "success", "content": "Scraped content."})

result = await perform_scrape("http://nonsecure.com")
assert result == ""

@pytest.mark.asyncio
@patch("src.agents.web_agent.scrape_content", new_callable=AsyncMock)
async def test_perform_scrape_https_url(mock_scrape_content):
mock_scrape_content.return_value = json.dumps({"status": "success", "content": "Scraped content."})

result = await perform_scrape("https://secure.com")
assert result == "Scraped content."
73 changes: 0 additions & 73 deletions backend/tests/websockets/user_confirmer_test.py

This file was deleted.

Loading