Skip to content

Commit

Permalink
seed database and get threads
Browse files Browse the repository at this point in the history
  • Loading branch information
azliu0 committed Apr 16, 2024
1 parent 082dffc commit 6c440cb
Show file tree
Hide file tree
Showing 6 changed files with 97 additions and 35 deletions.
37 changes: 3 additions & 34 deletions server/cli.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Flask CLI commands."""

import datetime
import json

from flask import Blueprint

Expand All @@ -11,6 +10,7 @@
increment_response_count,
thread_emails_to_openai_messages,
)
from server.fake_data import generate_test_documents
from server.models.document import Document
from server.models.email import Email
from server.models.response import Response
Expand All @@ -20,37 +20,6 @@

seed = Blueprint("seed", __name__)

FLASK_SEED_CORPUS = "server/nlp/corpus_flask_seed.json"


def _generate_test_documents():
"""Generate test documents."""
with open(FLASK_SEED_CORPUS) as f:
corpus = json.load(f)

documents = []
for doc in corpus:
document = Document(
question=doc["question"],
label=doc["question"],
source=doc["source"],
content=doc["content"],
)
db.session.add(document)
db.session.commit()
documents.append(document)

test_documents = [
{
"question": doc.question,
"source": doc.source,
"content": doc.content,
"sql_id": doc.id,
}
for doc in documents
]
return test_documents


def _embed_existing_documents(documents: list[Document]):
"""Embed existing documents."""
Expand All @@ -69,7 +38,7 @@ def _embed_existing_documents(documents: list[Document]):
@seed.cli.command()
def corpus():
"""Add test documents to the corpus."""
test_documents = _generate_test_documents()
test_documents = generate_test_documents()
embed_corpus(test_documents)


Expand All @@ -82,7 +51,7 @@ def email():
# responses, so the only way this command succeeds is if the corpus is
# already populated
print("No documents in the database. Generating test documents...")
test_documents = _generate_test_documents()
test_documents = generate_test_documents()
embed_corpus(test_documents)
else:
print("Embedding existing documents...")
Expand Down
37 changes: 37 additions & 0 deletions server/fake_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Fake data for seed cli and testing."""

import json

from server import db
from server.models.document import Document

FLASK_SEED_CORPUS = "server/nlp/corpus_flask_seed.json"


def generate_test_documents():
"""Generate test documents."""
with open(FLASK_SEED_CORPUS) as f:
corpus = json.load(f)

documents = []
for doc in corpus:
document = Document(
question=doc["question"],
label=doc["question"],
source=doc["source"],
content=doc["content"],
)
db.session.add(document)
db.session.commit()
documents.append(document)

test_documents = [
{
"question": doc.question,
"source": doc.source,
"content": doc.content,
"sql_id": doc.id,
}
for doc in documents
]
return test_documents
1 change: 0 additions & 1 deletion server/nlp/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,6 @@ def generate_context(
docs = {}

results = query_all(3, questions)
print("results", results)
message = "Here is some context to help you answer this email: \n"
for result in results:
confidence = 0
Expand Down
4 changes: 4 additions & 0 deletions server_tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from server import create_app, db
from server.config import LOCAL, VECTOR_DIMENSION
from server_tests.utils import seed_database


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -75,6 +76,9 @@ def app(db_url: str, redis_host: str):
}
)

with app.app_context():
seed_database(db)

yield app


Expand Down
19 changes: 19 additions & 0 deletions server_tests/test_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,3 +8,22 @@ def test_get_threads(app: APIFlask, client: FlaskClient):
"""Test fetching threads."""
response = client.get("/api/emails/get_threads")
assert_status(response, 200)

threads = response.json
assert threads is not None
assert len(threads) == 1

thread = threads[0]
assert thread["id"] == 1
assert not thread["resolved"] # resolved is false
assert len(thread["emailList"]) == 5

for email in thread["emailList"]:
# checking that .map() is intact
assert email["id"] is not None
assert email["body"] is not None
assert email["subject"] is not None
assert email["sender"] is not None
assert email["message_id"] is not None
assert email["is_reply"] is not None
assert email["thread_id"] is not None
34 changes: 34 additions & 0 deletions server_tests/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""Utils for testing."""

import datetime
import logging

from werkzeug.test import TestResponse

from server import ProperlyTypedSQLAlchemy
from server.fake_data import generate_test_documents
from server.models.email import Email
from server.models.thread import Thread


def assert_status(response: TestResponse, status: int):
"""Asserts a response's status code, logging the response if it fails."""
Expand All @@ -15,3 +21,31 @@ def assert_status(response: TestResponse, status: int):
f"Response body: {response.data.decode()}"
)
raise

def seed_database(db: ProperlyTypedSQLAlchemy):
"""Seeds the database with some fake data."""

# add some documents to the database
generate_test_documents()

# create fake thread with 5 emails
thread = Thread()
db.session.add(thread)
db.session.commit()

emails = []
for i in range(5):
# every other email is a reply sent from pigeon
is_reply = i % 2 == 1
test_email = Email(
date=datetime.datetime.now(datetime.timezone.utc),
sender="[email protected]",
subject="Test Subject",
body="Test Body",
message_id=f"test-message-id-{i}",
is_reply=is_reply,
thread_id=thread.id
)
emails.append(test_email)
db.session.add(test_email)
db.session.commit()

0 comments on commit 6c440cb

Please sign in to comment.