Skip to content

Commit

Permalink
Merge remote-tracking branch 'Data-Sources-App-V2/dev' into mc_replac…
Browse files Browse the repository at this point in the history
…e_with_abort

# Conflicts:
#	resources/QuickSearch.py
  • Loading branch information
maxachis committed Jun 16, 2024
2 parents 6a7467a + 2bdd9dc commit 9c40067
Show file tree
Hide file tree
Showing 17 changed files with 351 additions and 110 deletions.
149 changes: 102 additions & 47 deletions middleware/quick_search_query.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from collections import namedtuple

import psycopg2
import spacy
import json
import datetime

from flask import make_response, Response
from sqlalchemy.dialects.postgresql import psycopg2

from middleware.webhook_logic import post_to_webhook
from utilities.common import convert_dates_to_strings, format_arrays
from typing import List, Dict, Any, Optional
from psycopg2.extensions import connection as PgConnection, cursor as PgCursor
from typing import List, Dict, Any
from psycopg2.extensions import cursor as PgCursor

QUICK_SEARCH_COLUMNS = [
"airtable_uid",
Expand Down Expand Up @@ -57,7 +59,11 @@
"""

INSERT_LOG_QUERY = "INSERT INTO quick_search_query_logs (search, location, results, result_count) VALUES ('{0}', '{1}', '{2}', '{3}')"
INSERT_LOG_QUERY = """
INSERT INTO quick_search_query_logs
(search, location, results, result_count)
VALUES (%s, %s, %s, %s)
"""


def unaltered_search_query(
Expand Down Expand Up @@ -90,11 +96,7 @@ def spacy_search_query(
:return: A list of dictionaries representing the search results.
"""
# Depluralize search term to increase match potential
nlp = spacy.load("en_core_web_sm")
search = search.strip()
doc = nlp(search)
lemmatized_tokens = [token.lemma_ for token in doc]
depluralized_search_term = " ".join(lemmatized_tokens)
depluralized_search_term = depluralize(search)
location = location.strip()

print(f"Query parameters: '%{depluralized_search_term}%', '%{location}%'")
Expand All @@ -107,71 +109,124 @@ def spacy_search_query(
return results


def depluralize(term: str):
"""
Depluralizes a given term using lemmatization.
:param term: The term to be depluralized.
:return: The depluralized term.
"""
nlp = spacy.load("en_core_web_sm")
term = term.strip()
doc = nlp(term)
lemmatized_tokens = [token.lemma_ for token in doc]
depluralized_search_term = " ".join(lemmatized_tokens)
return depluralized_search_term


SearchParameters = namedtuple("SearchParameters", ["search", "location"])


def quick_search_query(
search: str = "",
location: str = "",
conn: Optional[PgConnection] = None,
search_parameters: SearchParameters,
cursor: PgCursor = None,
) -> Dict[str, Any]:
"""
Performs a quick search using both unaltered and lemmatized search terms, returning the more fruitful result set.
:param search: The search term.
:param location: The location term.
:param conn: A psycopg2 connection to the database.
:param search_parameters:
:param cursor: A psycopg2 cursor to the database.
:return: A dictionary with the count of results and the data itself.
"""
data_sources = {"count": 0, "data": []}
if type(conn) == dict and "data" in conn:
return data_sources

search = "" if search == "all" else search.replace("'", "")
location = "" if location == "all" else location.replace("'", "")
processed_search_parameters = process_search_parameters(search_parameters)

if conn:
cursor = conn.cursor()
data_source_matches = get_data_source_matches(cursor, processed_search_parameters)
processed_data_source_matches = process_data_source_matches(data_source_matches)

unaltered_results = unaltered_search_query(cursor, search, location)
spacy_results = spacy_search_query(cursor, search, location)
data_sources = {
"count": len(processed_data_source_matches.converted),
"data": processed_data_source_matches.converted,
}

# Compare altered search term results with unaltered search term results, return the longer list
results = (
spacy_results
if len(spacy_results) > len(unaltered_results)
else unaltered_results
log_query(
cursor,
data_sources["count"],
processed_data_source_matches,
processed_search_parameters,
)

data_source_matches = [
dict(zip(QUICK_SEARCH_COLUMNS, result)) for result in results
]
return data_sources


def log_query(
cursor,
data_sources_count,
processed_data_source_matches,
processed_search_parameters,
):
query_results = json.dumps(processed_data_source_matches.ids).replace("'", "")
cursor.execute(
INSERT_LOG_QUERY,
(
processed_search_parameters.search,
processed_search_parameters.location,
query_results,
data_sources_count,
),
)


def process_search_parameters(raw_sp: SearchParameters) -> SearchParameters:
return SearchParameters(
search="" if raw_sp.search == "all" else raw_sp.search.replace("'", ""),
location="" if raw_sp.location == "all" else raw_sp.location.replace("'", ""),
)


DataSourceMatches = namedtuple("DataSourceMatches", ["converted", "ids"])


def process_data_source_matches(data_source_matches: List[dict]) -> DataSourceMatches:
data_source_matches_converted = []
data_source_matches_ids = []
for data_source_match in data_source_matches:
data_source_match = convert_dates_to_strings(data_source_match)
data_source_matches_converted.append(format_arrays(data_source_match))
# Add ids to list for logging
data_source_matches_ids.append(data_source_match["airtable_uid"])
return DataSourceMatches(data_source_matches_converted, data_source_matches_ids)

data_sources = {
"count": len(data_source_matches_converted),
"data": data_source_matches_converted,
}

query_results = json.dumps(data_sources["data"]).replace("'", "")

cursor.execute(
INSERT_LOG_QUERY.format(search, location, query_results, data_sources["count"]),
def get_data_source_matches(
cursor: PgCursor, sp: SearchParameters
) -> List[Dict[str, Any]]:
unaltered_results = unaltered_search_query(cursor, sp.search, sp.location)
spacy_results = spacy_search_query(cursor, sp.search, sp.location)
# Compare altered search term results with unaltered search term results, return the longer list
results = (
spacy_results
if len(spacy_results) > len(unaltered_results)
else unaltered_results
)
conn.commit()
cursor.close()

return data_sources
data_source_matches = [
dict(zip(QUICK_SEARCH_COLUMNS, result)) for result in results
]
return data_source_matches


def quick_search_query_wrapper(arg1, arg2, conn: PgConnection) -> Response:
def quick_search_query_wrapper(
arg1, arg2, cursor: psycopg2.extensions.cursor
) -> Response:
try:
data_sources = quick_search_query(search=arg1, location=arg2, conn=conn)
data_sources = quick_search_query(
SearchParameters(search=arg1, location=arg2), cursor=cursor
)

return make_response(data_sources, 200)

except Exception as e:
conn.rollback()
user_message = "There was an error during the search operation"
message = {
"content": user_message
Expand Down
1 change: 0 additions & 1 deletion middleware/security.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ def validate_role(role: str, endpoint: str, method: str):
raise InvalidRoleError("You do not have permission to access this endpoint")



def get_role(api_key, cursor):
cursor.execute(f"select id, api_key, role from users where api_key = '{api_key}'")
user_results = cursor.fetchall()
Expand Down
8 changes: 5 additions & 3 deletions resources/QuickSearch.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from flask_restx import abort
from flask import Response

from middleware.security import api_required
from middleware.quick_search_query import quick_search_query_wrapper
Expand All @@ -17,7 +17,7 @@ class QuickSearch(PsycopgResource):
# api_required decorator requires the request"s header to include an "Authorization" key with the value formatted as "Bearer [api_key]"
# A user can get an API key by signing up and logging in (see User.py)
@api_required
def get(self, search: str, location: str) -> Dict[str, Any]:
def get(self, search: str, location: str) -> Response:
"""
Performs a quick search using the provided search terms and location. It attempts to find relevant
data sources in the database. If no results are found initially, it re-initializes the database
Expand All @@ -30,4 +30,6 @@ def get(self, search: str, location: str) -> Dict[str, Any]:
Returns:
- A dictionary containing a message about the search results and the data found, if any.
"""
return quick_search_query_wrapper(search, location, self.psycopg2_connection)
return quick_search_query_wrapper(
search, location, self.psycopg2_connection.cursor()
)
6 changes: 5 additions & 1 deletion resources/RefreshSession.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,11 @@
from flask_restx import abort

from middleware.custom_exceptions import TokenNotFoundError
from middleware.login_queries import get_session_token_user_data, create_session_token, delete_session_token
from middleware.login_queries import (
get_session_token_user_data,
create_session_token,
delete_session_token,
)
from typing import Dict, Any

from resources.PsycopgResource import PsycopgResource, handle_exceptions
Expand Down
4 changes: 3 additions & 1 deletion resources/SearchTokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,9 @@ def get(self) -> Dict[str, Any]:

def perform_endpoint_logic(self, arg1, arg2, endpoint):
if endpoint == "quick-search":
return quick_search_query_wrapper(arg1, arg2, self.psycopg2_connection)
return quick_search_query_wrapper(
arg1, arg2, self.psycopg2_connection.cursor()
)
if endpoint == "data-sources":
return get_approved_data_sources_wrapper(self.psycopg2_connection)
if endpoint == "data-sources-by-id":
Expand Down
6 changes: 5 additions & 1 deletion tests/fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,10 @@ def connection_with_test_data(
dev_db_connection.rollback()
return dev_db_connection


ClientWithMockDB = namedtuple("ClientWithMockDB", ["client", "mock_db"])


@pytest.fixture
def client_with_mock_db(mocker) -> ClientWithMockDB:
"""
Expand All @@ -94,6 +97,7 @@ def client_with_mock_db(mocker) -> ClientWithMockDB:
with app.test_client() as client:
yield ClientWithMockDB(client, mock_db)


@pytest.fixture
def client_with_db(dev_db_connection: psycopg2.extensions.connection):
"""
Expand All @@ -103,4 +107,4 @@ def client_with_db(dev_db_connection: psycopg2.extensions.connection):
"""
app = create_app(dev_db_connection)
with app.test_client() as client:
yield client
yield client
18 changes: 11 additions & 7 deletions tests/helper_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def create_test_user(


QuickSearchQueryLogResult = namedtuple(
"QuickSearchQueryLogResult", ["result_count", "updated_at"]
"QuickSearchQueryLogResult", ["result_count", "updated_at", "results"]
)


Expand All @@ -151,15 +151,17 @@ def get_most_recent_quick_search_query_log(
"""
cursor.execute(
"""
SELECT RESULT_COUNT, CREATED_AT FROM QUICK_SEARCH_QUERY_LOGS WHERE
SELECT RESULT_COUNT, CREATED_AT, RESULTS FROM QUICK_SEARCH_QUERY_LOGS WHERE
search = %s AND location = %s ORDER BY CREATED_AT DESC LIMIT 1
""",
(search, location),
)
result = cursor.fetchone()
if result is None:
return result
return QuickSearchQueryLogResult(result_count=result[0], updated_at=result[1])
return QuickSearchQueryLogResult(
result_count=result[0], updated_at=result[1], results=result[2]
)


def has_expected_keys(result_keys: list, expected_keys: list) -> bool:
Expand Down Expand Up @@ -271,11 +273,10 @@ def create_api_key(client_with_db, user_info):
api_key = response.json.get("api_key")
return api_key


def create_api_key_db(cursor, user_id: str):
api_key = uuid.uuid4().hex
cursor.execute(
"UPDATE users SET api_key = %s WHERE id = %s", (api_key, user_id)
)
cursor.execute("UPDATE users SET api_key = %s WHERE id = %s", (api_key, user_id))
return api_key


Expand Down Expand Up @@ -327,5 +328,8 @@ def give_user_admin_role(
(user_info.email,),
)


def check_response_status(response, status_code):
assert response.status_code == status_code, f"Expected status code {status_code}, got {response.status_code}: {response.text}"
assert (
response.status_code == status_code
), f"Expected status code {status_code}, got {response.status_code}: {response.text}"
2 changes: 1 addition & 1 deletion tests/integration/test_agencies.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Integration tests for /agencies endpoint"""

import psycopg2
import pytest
from tests.fixtures import connection_with_test_data, dev_db_connection, client_with_db
Expand All @@ -20,4 +21,3 @@ def test_agencies_get(
)
assert response.status_code == 200
assert len(response.json["data"]) > 0

4 changes: 1 addition & 3 deletions tests/integration/test_search_tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,7 @@ def test_search_tokens_get(
)
check_response_status(response, 200)
data = response.json.get("data")
assert (
len(data) == 1
), "Quick Search endpoint response should return only one entry"
assert len(data) == 1, "Quick Search endpoint response should return only one entry"
entry = data[0]
assert entry["agency_name"] == "Agency A"
assert entry["airtable_uid"] == "SOURCE_UID_1"
11 changes: 8 additions & 3 deletions tests/middleware/test_data_source_queries.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,9 @@ def mock_data_source_by_id_query(monkeypatch):
return mock


def test_data_source_by_id_wrapper_data_found(mock_data_source_by_id_query, mock_make_response):
def test_data_source_by_id_wrapper_data_found(
mock_data_source_by_id_query, mock_make_response
):
mock_data_source_by_id_query.return_value = {"agency_name": "Agency A"}
mock_conn = MagicMock()
data_source_by_id_wrapper(arg="SOURCE_UID_1", conn=mock_conn)
Expand All @@ -181,11 +183,14 @@ def test_data_source_by_id_wrapper_data_found(mock_data_source_by_id_query, mock
)
mock_make_response.assert_called_with({"agency_name": "Agency A"}, 200)

def test_data_source_by_id_wrapper_data_not_found(mock_data_source_by_id_query, mock_make_response):

def test_data_source_by_id_wrapper_data_not_found(
mock_data_source_by_id_query, mock_make_response
):
mock_data_source_by_id_query.return_value = None
mock_conn = MagicMock()
data_source_by_id_wrapper(arg="SOURCE_UID_1", conn=mock_conn)
mock_data_source_by_id_query.assert_called_with(
data_source_id="SOURCE_UID_1", conn=mock_conn
)
mock_make_response.assert_called_with({"message": "Data source not found."}, 200)
mock_make_response.assert_called_with({"message": "Data source not found."}, 200)
Loading

0 comments on commit 9c40067

Please sign in to comment.