Skip to content

Commit

Permalink
FS-59 Add pytest bdd (#13)
Browse files Browse the repository at this point in the history
Co-authored-by: Emma Pearce <[email protected]>
Co-authored-by: Maxwell Nyamunda <[email protected]>
  • Loading branch information
3 people authored Oct 31, 2024
1 parent 1c3321d commit 66a6ae4
Show file tree
Hide file tree
Showing 9 changed files with 186 additions and 5 deletions.
4 changes: 3 additions & 1 deletion .github/workflows/test-backend.yml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
name: Test Backend
on:
on:
workflow_dispatch:
pull_request:
push:
branches:
Expand Down Expand Up @@ -33,4 +34,5 @@ jobs:
emoji: true
verbose: true
job-summary: true
custom-arguments: '--ignore=backend/tests/BDD'
report-title: 'Backend Test Report'
5 changes: 4 additions & 1 deletion backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@ cffi==1.16.0
cryptography==42.0.7
isodate==0.6.1
pycparser==2.22
openai==1.35.3
openai==1.52.1
beautifulsoup4==4.12.3
aiohttp==3.9.5
googlesearch-python==1.2.4
matplotlib==3.9.1
pytest-bdd==7.3.0
langchain==0.3.4
langchain-openai==0.2.3
pillow==10.4.0
pypdf==4.3.1
hiredis==3.0.0
Expand Down
2 changes: 1 addition & 1 deletion backend/src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
default_frontend_url = "http://localhost:8650"
default_neo4j_uri = "bolt://localhost:7687"
default_files_directory = "files"
default_redis_host = "redis"
default_redis_host = "localhost"
default_redis_cache_duration = 3600


Expand Down
Empty file added backend/tests/BDD/__init__.py
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
@database_agent @ESG
Scenario Outline: When a user asks InferESG for information about their transaction history
Given a prompt to InferESG
When I get the response
Then the response to this '<prompt>' should match the '<expected_response>'
Examples:
|prompt |expected_response |
|Check the database and tell me the average ESG score (Environmental) for the WhiteRock ETF fund |The average ESG score (Environmental) for the WhiteRock ETF fund is approximately 69.67|
|Check the database and tell me the fund with the highest ESG social score |Dynamic Industries with a score of 91|
|Check the database and tell me the fund with the lowest Governance ESG score |Dynamic Industries, which has a score of 60|
# |Check the database and tell me the fund with the lowest ESG score |Dynamic Industries with a score of 50|
# |Check the database and tell me the largest fund |The largest fund is the Global Energy Fund, which has a size of 1,500|
# |Check the database and tell me which funds contain Shell |Funds containing Shell are European Growth Fund, Global Energy Fund, Silverman Global ETF and WhiteRock ETF|


@web_agent
Scenario Outline: When a user asks InferESG generic questions
Given a prompt to InferESG
When I get the response
Then the response to this '<prompt>' should match the '<expected_response>'
Examples:
|prompt |expected_response |
|What is the capital of France? |Paris |
|What is the capital of Zimbabwe? |Harare |
|What is the capital of Spain? |Madrid |
|What is the capital of China? |Beijing |
|What is the capital of United Kingdom? |London |
|What is the capital of Sweden? |Stockholm |

@confidence
Scenario Outline: Check Response's confidence
Given a prompt to InferESG
When I get the response
Then the response to this '<prompt>' should give a confident answer
Examples:
|prompt |
|What is the capital of France? |

Empty file.
85 changes: 85 additions & 0 deletions backend/tests/BDD/step_defs/test_prompts.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
from pytest_bdd import given, when, then, parsers, scenarios
import pytest
import logging
from tests.BDD.test_utilities import (
send_prompt,
app_healthcheck,
correctness_evaluator,
healthy_response,
check_response_confidence,
)
from decimal import Decimal
import decimal

logger = logging.getLogger(__name__)

scenarios("../features/Correctness/Accuracy_Factual_Correctness.feature")


@pytest.fixture
def context():
return {}


@given(parsers.parse("a prompt to InferESG"))
def prepare_prompt(context):
healthcheck_response = app_healthcheck()
assert healthcheck_response.status_code == 200
assert healthcheck_response.json() == healthy_response
context["health_check_passed"] = True


@when(parsers.parse("I get the response"))
def get_response(context):
assert context.get("health_check_passed", False)


@then(parsers.parse("the response to this '{prompt}' should match the '{expected_response}'"))
def check_response_includes_expected_response(context, prompt, expected_response):
response = send_prompt(prompt)
actual_response = response.json()

try:
expected_value = Decimal(str(expected_response).strip())
actual_value = Decimal(str(actual_response).strip())

tolerance = Decimal("0.01")
is_equal = abs(expected_value - actual_value) <= tolerance

if not is_equal:
pytest.fail(f"\nNumeric values don't match!\n" f"Expected: {expected_value}\n" f"Actual: {actual_value}")

except (ValueError, decimal.InvalidOperation):
expected_str = str(expected_response).strip()
actual_str = str(actual_response).strip()

logger.info(f"Expected : {expected_str} \nActual: {actual_str}")

if actual_str.find(expected_str) == -1:
result = correctness_evaluator.evaluate_strings(
input=prompt,
prediction=expected_str,
reference=actual_str,
)

if result["value"] == "N":
logger.error(
f"\nTest failed!\n"
f"Expected: {expected_str}\n"
f"Actual: {actual_str}\n"
f"Reasoning: {result.get('reasoning', 'No reasoning provided')}"
)

assert result["value"] == "Y", (
f"\nTest failed!\n"
f"Expected: {expected_str}\n"
f"Actual: {actual_str}\n"
f"Reasoning: {result.get('reasoning', 'No reasoning provided')}"
)


@then(parsers.parse("the response to this '{prompt}' should give a confident answer"))
def check_bot_response_confidence(prompt):
response = send_prompt(prompt)
result = check_response_confidence(prompt, response.json())
assert result["score"] == 1, "The bot response is not confident enough. \nReasoning: " + result["reasoning"]
50 changes: 50 additions & 0 deletions backend/tests/BDD/test_utilities.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from src.api import app
from src.utils import Config
from fastapi.testclient import TestClient
from langchain.evaluation import EvaluatorType, StringEvaluator, load_evaluator
from langchain_openai.chat_models import ChatOpenAI

START_ENDPOINT_URL = "/chat?utterance={utterance}"
CONVERSATION_ENDPOINT_URL = "/chat?utterance={utterance}"
HEALTHCHECK_ENDPOINT_URL = "/health"
health_prefix = "InferESG healthcheck: "
healthy_response = health_prefix + "backend is healthy. Neo4J is healthy."

client = TestClient(app)
config = Config()


def app_healthcheck():
healthcheck_response = client.get(HEALTHCHECK_ENDPOINT_URL)
return healthcheck_response


def send_prompt(prompt: str):
start_response = client.get(START_ENDPOINT_URL.format(utterance=prompt))
return start_response

# Evaluators
# Evaluation LLM
llm = ChatOpenAI(api_key=config.openai_key, model="gpt-4o-mini", temperature=0, max_retries=2) # type: ignore

correctness_evaluator: StringEvaluator = load_evaluator( # type: ignore
EvaluatorType.LABELED_CRITERIA, criteria="correctness", llm=llm
)

confidence_criterion = {
"confidence": "Does the bot seem confident that it replied to the question and gave the correct answer?"
}

confidence_evaluator: StringEvaluator = load_evaluator( # type: ignore
EvaluatorType.CRITERIA, criteria=confidence_criterion, llm=llm
)


def check_response_confidence(prompt: str, bot_response: str) -> dict[str, str]:
"""
Uses an LLM to check the confidence of the bot's response.\n
Returns a dictionary with the binary score (pass = 1, fail = 0) and reasoning (text format)."""
return confidence_evaluator.evaluate_strings(
input=prompt,
prediction=bot_response,
)
7 changes: 5 additions & 2 deletions compose.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ services:
container_name: redis
restart: always
ports:
- '6379:6379'
- "6379:6379"
expose:
- "6379:6379" # needed for pytest-bdd tests
networks:
- network
healthcheck:
test: ['CMD-SHELL', 'redis-cli ping | grep PONG']
test: ["CMD-SHELL", "redis-cli ping | grep PONG"]
interval: 60s
timeout: 10s
retries: 5
Expand All @@ -70,6 +72,7 @@ services:
NEO4J_URI: bolt://neo4j-db:7687
NEO4J_USERNAME: ${NEO4J_USERNAME}
NEO4J_PASSWORD: ${NEO4J_PASSWORD}
REDIS_HOST: redis
MISTRAL_KEY: ${MISTRAL_KEY}
OPENAI_KEY: ${OPENAI_KEY}
FRONTEND_URL: ${FRONTEND_URL}
Expand Down

0 comments on commit 66a6ae4

Please sign in to comment.