diff --git a/middleware/quick_search_query.py b/middleware/quick_search_query.py index dc09d385..e47a1add 100644 --- a/middleware/quick_search_query.py +++ b/middleware/quick_search_query.py @@ -1,14 +1,16 @@ +from collections import namedtuple + +import psycopg2 import spacy import json import datetime from flask import make_response, Response -from sqlalchemy.dialects.postgresql import psycopg2 from middleware.webhook_logic import post_to_webhook from utilities.common import convert_dates_to_strings, format_arrays -from typing import List, Dict, Any, Optional -from psycopg2.extensions import connection as PgConnection, cursor as PgCursor +from typing import List, Dict, Any +from psycopg2.extensions import cursor as PgCursor QUICK_SEARCH_COLUMNS = [ "airtable_uid", @@ -57,7 +59,11 @@ """ -INSERT_LOG_QUERY = "INSERT INTO quick_search_query_logs (search, location, results, result_count) VALUES ('{0}', '{1}', '{2}', '{3}')" +INSERT_LOG_QUERY = """ + INSERT INTO quick_search_query_logs + (search, location, results, result_count) + VALUES (%s, %s, %s, %s) + """ def unaltered_search_query( @@ -90,11 +96,7 @@ def spacy_search_query( :return: A list of dictionaries representing the search results. """ # Depluralize search term to increase match potential - nlp = spacy.load("en_core_web_sm") - search = search.strip() - doc = nlp(search) - lemmatized_tokens = [token.lemma_ for token in doc] - depluralized_search_term = " ".join(lemmatized_tokens) + depluralized_search_term = depluralize(search) location = location.strip() print(f"Query parameters: '%{depluralized_search_term}%', '%{location}%'") @@ -107,71 +109,124 @@ def spacy_search_query( return results +def depluralize(term: str): + """ + Depluralizes a given term using lemmatization. + + :param term: The term to be depluralized. + :return: The depluralized term. + """ + nlp = spacy.load("en_core_web_sm") + term = term.strip() + doc = nlp(term) + lemmatized_tokens = [token.lemma_ for token in doc] + depluralized_search_term = " ".join(lemmatized_tokens) + return depluralized_search_term + + +SearchParameters = namedtuple("SearchParameters", ["search", "location"]) + + def quick_search_query( - search: str = "", - location: str = "", - conn: Optional[PgConnection] = None, + search_parameters: SearchParameters, + cursor: PgCursor = None, ) -> Dict[str, Any]: """ Performs a quick search using both unaltered and lemmatized search terms, returning the more fruitful result set. - :param search: The search term. - :param location: The location term. - :param conn: A psycopg2 connection to the database. + :param search_parameters: + + :param cursor: A psycopg2 cursor to the database. :return: A dictionary with the count of results and the data itself. """ - data_sources = {"count": 0, "data": []} - if type(conn) == dict and "data" in conn: - return data_sources - search = "" if search == "all" else search.replace("'", "") - location = "" if location == "all" else location.replace("'", "") + processed_search_parameters = process_search_parameters(search_parameters) - if conn: - cursor = conn.cursor() + data_source_matches = get_data_source_matches(cursor, processed_search_parameters) + processed_data_source_matches = process_data_source_matches(data_source_matches) - unaltered_results = unaltered_search_query(cursor, search, location) - spacy_results = spacy_search_query(cursor, search, location) + data_sources = { + "count": len(processed_data_source_matches.converted), + "data": processed_data_source_matches.converted, + } - # Compare altered search term results with unaltered search term results, return the longer list - results = ( - spacy_results - if len(spacy_results) > len(unaltered_results) - else unaltered_results + log_query( + cursor, + data_sources["count"], + processed_data_source_matches, + processed_search_parameters, ) - data_source_matches = [ - dict(zip(QUICK_SEARCH_COLUMNS, result)) for result in results - ] + return data_sources + + +def log_query( + cursor, + data_sources_count, + processed_data_source_matches, + processed_search_parameters, +): + query_results = json.dumps(processed_data_source_matches.ids).replace("'", "") + cursor.execute( + INSERT_LOG_QUERY, + ( + processed_search_parameters.search, + processed_search_parameters.location, + query_results, + data_sources_count, + ), + ) + + +def process_search_parameters(raw_sp: SearchParameters) -> SearchParameters: + return SearchParameters( + search="" if raw_sp.search == "all" else raw_sp.search.replace("'", ""), + location="" if raw_sp.location == "all" else raw_sp.location.replace("'", ""), + ) + + +DataSourceMatches = namedtuple("DataSourceMatches", ["converted", "ids"]) + + +def process_data_source_matches(data_source_matches: List[dict]) -> DataSourceMatches: data_source_matches_converted = [] + data_source_matches_ids = [] for data_source_match in data_source_matches: data_source_match = convert_dates_to_strings(data_source_match) data_source_matches_converted.append(format_arrays(data_source_match)) + # Add ids to list for logging + data_source_matches_ids.append(data_source_match["airtable_uid"]) + return DataSourceMatches(data_source_matches_converted, data_source_matches_ids) - data_sources = { - "count": len(data_source_matches_converted), - "data": data_source_matches_converted, - } - - query_results = json.dumps(data_sources["data"]).replace("'", "") - cursor.execute( - INSERT_LOG_QUERY.format(search, location, query_results, data_sources["count"]), +def get_data_source_matches( + cursor: PgCursor, sp: SearchParameters +) -> List[Dict[str, Any]]: + unaltered_results = unaltered_search_query(cursor, sp.search, sp.location) + spacy_results = spacy_search_query(cursor, sp.search, sp.location) + # Compare altered search term results with unaltered search term results, return the longer list + results = ( + spacy_results + if len(spacy_results) > len(unaltered_results) + else unaltered_results ) - conn.commit() - cursor.close() - - return data_sources + data_source_matches = [ + dict(zip(QUICK_SEARCH_COLUMNS, result)) for result in results + ] + return data_source_matches -def quick_search_query_wrapper(arg1, arg2, conn: PgConnection) -> Response: +def quick_search_query_wrapper( + arg1, arg2, cursor: psycopg2.extensions.cursor +) -> Response: try: - data_sources = quick_search_query(search=arg1, location=arg2, conn=conn) + data_sources = quick_search_query( + SearchParameters(search=arg1, location=arg2), cursor=cursor + ) return make_response(data_sources, 200) except Exception as e: - conn.rollback() user_message = "There was an error during the search operation" message = { "content": user_message diff --git a/middleware/security.py b/middleware/security.py index 78ca2fac..a92939bc 100644 --- a/middleware/security.py +++ b/middleware/security.py @@ -77,7 +77,6 @@ def validate_role(role: str, endpoint: str, method: str): raise InvalidRoleError("You do not have permission to access this endpoint") - def get_role(api_key, cursor): cursor.execute(f"select id, api_key, role from users where api_key = '{api_key}'") user_results = cursor.fetchall() diff --git a/resources/QuickSearch.py b/resources/QuickSearch.py index 5b7c0f5a..972f214b 100644 --- a/resources/QuickSearch.py +++ b/resources/QuickSearch.py @@ -1,4 +1,4 @@ -from flask_restx import abort +from flask import Response from middleware.security import api_required from middleware.quick_search_query import quick_search_query_wrapper @@ -17,7 +17,7 @@ class QuickSearch(PsycopgResource): # api_required decorator requires the request"s header to include an "Authorization" key with the value formatted as "Bearer [api_key]" # A user can get an API key by signing up and logging in (see User.py) @api_required - def get(self, search: str, location: str) -> Dict[str, Any]: + def get(self, search: str, location: str) -> Response: """ Performs a quick search using the provided search terms and location. It attempts to find relevant data sources in the database. If no results are found initially, it re-initializes the database @@ -30,4 +30,6 @@ def get(self, search: str, location: str) -> Dict[str, Any]: Returns: - A dictionary containing a message about the search results and the data found, if any. """ - return quick_search_query_wrapper(search, location, self.psycopg2_connection) + return quick_search_query_wrapper( + search, location, self.psycopg2_connection.cursor() + ) diff --git a/resources/RefreshSession.py b/resources/RefreshSession.py index 95330e04..6c554038 100644 --- a/resources/RefreshSession.py +++ b/resources/RefreshSession.py @@ -3,7 +3,11 @@ from flask_restx import abort from middleware.custom_exceptions import TokenNotFoundError -from middleware.login_queries import get_session_token_user_data, create_session_token, delete_session_token +from middleware.login_queries import ( + get_session_token_user_data, + create_session_token, + delete_session_token, +) from typing import Dict, Any from resources.PsycopgResource import PsycopgResource, handle_exceptions diff --git a/resources/SearchTokens.py b/resources/SearchTokens.py index 43326ef5..738c9961 100644 --- a/resources/SearchTokens.py +++ b/resources/SearchTokens.py @@ -56,7 +56,9 @@ def get(self) -> Dict[str, Any]: def perform_endpoint_logic(self, arg1, arg2, endpoint): if endpoint == "quick-search": - return quick_search_query_wrapper(arg1, arg2, self.psycopg2_connection) + return quick_search_query_wrapper( + arg1, arg2, self.psycopg2_connection.cursor() + ) if endpoint == "data-sources": return get_approved_data_sources_wrapper(self.psycopg2_connection) if endpoint == "data-sources-by-id": diff --git a/tests/fixtures.py b/tests/fixtures.py index 77e2c157..5846eecb 100644 --- a/tests/fixtures.py +++ b/tests/fixtures.py @@ -81,7 +81,10 @@ def connection_with_test_data( dev_db_connection.rollback() return dev_db_connection + ClientWithMockDB = namedtuple("ClientWithMockDB", ["client", "mock_db"]) + + @pytest.fixture def client_with_mock_db(mocker) -> ClientWithMockDB: """ @@ -94,6 +97,7 @@ def client_with_mock_db(mocker) -> ClientWithMockDB: with app.test_client() as client: yield ClientWithMockDB(client, mock_db) + @pytest.fixture def client_with_db(dev_db_connection: psycopg2.extensions.connection): """ @@ -103,4 +107,4 @@ def client_with_db(dev_db_connection: psycopg2.extensions.connection): """ app = create_app(dev_db_connection) with app.test_client() as client: - yield client \ No newline at end of file + yield client diff --git a/tests/helper_functions.py b/tests/helper_functions.py index cc1cd281..9dfb28ae 100644 --- a/tests/helper_functions.py +++ b/tests/helper_functions.py @@ -133,7 +133,7 @@ def create_test_user( QuickSearchQueryLogResult = namedtuple( - "QuickSearchQueryLogResult", ["result_count", "updated_at"] + "QuickSearchQueryLogResult", ["result_count", "updated_at", "results"] ) @@ -151,7 +151,7 @@ def get_most_recent_quick_search_query_log( """ cursor.execute( """ - SELECT RESULT_COUNT, CREATED_AT FROM QUICK_SEARCH_QUERY_LOGS WHERE + SELECT RESULT_COUNT, CREATED_AT, RESULTS FROM QUICK_SEARCH_QUERY_LOGS WHERE search = %s AND location = %s ORDER BY CREATED_AT DESC LIMIT 1 """, (search, location), @@ -159,7 +159,9 @@ def get_most_recent_quick_search_query_log( result = cursor.fetchone() if result is None: return result - return QuickSearchQueryLogResult(result_count=result[0], updated_at=result[1]) + return QuickSearchQueryLogResult( + result_count=result[0], updated_at=result[1], results=result[2] + ) def has_expected_keys(result_keys: list, expected_keys: list) -> bool: @@ -271,11 +273,10 @@ def create_api_key(client_with_db, user_info): api_key = response.json.get("api_key") return api_key + def create_api_key_db(cursor, user_id: str): api_key = uuid.uuid4().hex - cursor.execute( - "UPDATE users SET api_key = %s WHERE id = %s", (api_key, user_id) - ) + cursor.execute("UPDATE users SET api_key = %s WHERE id = %s", (api_key, user_id)) return api_key @@ -327,5 +328,8 @@ def give_user_admin_role( (user_info.email,), ) + def check_response_status(response, status_code): - assert response.status_code == status_code, f"Expected status code {status_code}, got {response.status_code}: {response.text}" \ No newline at end of file + assert ( + response.status_code == status_code + ), f"Expected status code {status_code}, got {response.status_code}: {response.text}" diff --git a/tests/integration/test_agencies.py b/tests/integration/test_agencies.py index 01fb5797..09fbf3b7 100644 --- a/tests/integration/test_agencies.py +++ b/tests/integration/test_agencies.py @@ -1,4 +1,5 @@ """Integration tests for /agencies endpoint""" + import psycopg2 import pytest from tests.fixtures import connection_with_test_data, dev_db_connection, client_with_db @@ -20,4 +21,3 @@ def test_agencies_get( ) assert response.status_code == 200 assert len(response.json["data"]) > 0 - diff --git a/tests/integration/test_search_tokens.py b/tests/integration/test_search_tokens.py index e151b4ff..17952359 100644 --- a/tests/integration/test_search_tokens.py +++ b/tests/integration/test_search_tokens.py @@ -25,9 +25,7 @@ def test_search_tokens_get( ) check_response_status(response, 200) data = response.json.get("data") - assert ( - len(data) == 1 - ), "Quick Search endpoint response should return only one entry" + assert len(data) == 1, "Quick Search endpoint response should return only one entry" entry = data[0] assert entry["agency_name"] == "Agency A" assert entry["airtable_uid"] == "SOURCE_UID_1" diff --git a/tests/middleware/test_data_source_queries.py b/tests/middleware/test_data_source_queries.py index 201ca518..33f0ca50 100644 --- a/tests/middleware/test_data_source_queries.py +++ b/tests/middleware/test_data_source_queries.py @@ -172,7 +172,9 @@ def mock_data_source_by_id_query(monkeypatch): return mock -def test_data_source_by_id_wrapper_data_found(mock_data_source_by_id_query, mock_make_response): +def test_data_source_by_id_wrapper_data_found( + mock_data_source_by_id_query, mock_make_response +): mock_data_source_by_id_query.return_value = {"agency_name": "Agency A"} mock_conn = MagicMock() data_source_by_id_wrapper(arg="SOURCE_UID_1", conn=mock_conn) @@ -181,11 +183,14 @@ def test_data_source_by_id_wrapper_data_found(mock_data_source_by_id_query, mock ) mock_make_response.assert_called_with({"agency_name": "Agency A"}, 200) -def test_data_source_by_id_wrapper_data_not_found(mock_data_source_by_id_query, mock_make_response): + +def test_data_source_by_id_wrapper_data_not_found( + mock_data_source_by_id_query, mock_make_response +): mock_data_source_by_id_query.return_value = None mock_conn = MagicMock() data_source_by_id_wrapper(arg="SOURCE_UID_1", conn=mock_conn) mock_data_source_by_id_query.assert_called_with( data_source_id="SOURCE_UID_1", conn=mock_conn ) - mock_make_response.assert_called_with({"message": "Data source not found."}, 200) \ No newline at end of file + mock_make_response.assert_called_with({"message": "Data source not found."}, 200) diff --git a/tests/middleware/test_quick_search_query.py b/tests/middleware/test_quick_search_query.py index 71ead947..b46f1c0b 100644 --- a/tests/middleware/test_quick_search_query.py +++ b/tests/middleware/test_quick_search_query.py @@ -1,5 +1,7 @@ import json from unittest.mock import MagicMock +from datetime import datetime +from unittest.mock import patch import psycopg2 import pytest @@ -9,6 +11,9 @@ quick_search_query, QUICK_SEARCH_COLUMNS, quick_search_query_wrapper, + process_data_source_matches, + SearchParameters, + depluralize, ) from tests.helper_functions import ( has_expected_keys, @@ -48,14 +53,14 @@ def test_quick_search_query_logging( result = cursor.fetchone() test_datetime = result[0] - quick_search_query( - search="Source 1", location="City A", conn=connection_with_test_data - ) - - cursor = connection_with_test_data.cursor() - # Test that query inserted into log - result = get_most_recent_quick_search_query_log(cursor, "Source 1", "City A") + quick_search_query( + SearchParameters(search="Source 1", location="City A"), cursor=cursor + ) + # Test that query inserted into log + result = get_most_recent_quick_search_query_log(cursor, "Source 1", "City A") assert result.result_count == 1 + assert len(result.results) == 1 + assert result.results[0] == "SOURCE_UID_1" assert result.updated_at >= test_datetime @@ -68,19 +73,19 @@ def test_quick_search_query_results( :param connection_with_test_data: The connection to the test data database. :return: None """ - # TODO: Something about the quick_search_query might be mucking up the savepoints. Address once you fix quick_search's logic issues - results = quick_search_query( - search="Source 1", location="City A", conn=connection_with_test_data - ) - # Test that results include expected keys - assert has_expected_keys(results["data"][0].keys(), QUICK_SEARCH_COLUMNS) - assert len(results["data"]) == 1 - assert results["data"][0]["record_type"] == "Type A" - # "Source 3" was listed as pending and shouldn't show up - results = quick_search_query( - search="Source 3", location="City C", conn=connection_with_test_data - ) - assert len(results["data"]) == 0 + with connection_with_test_data.cursor() as cursor: + results = quick_search_query( + SearchParameters(search="Source 1", location="City A"), cursor=cursor + ) + # Test that results include expected keys + assert has_expected_keys(results["data"][0].keys(), QUICK_SEARCH_COLUMNS) + assert len(results["data"]) == 1 + assert results["data"][0]["record_type"] == "Type A" + # "Source 3" was listed as pending and shouldn't show up + results = quick_search_query( + SearchParameters(search="Source 3", location="City C"), cursor=cursor + ) + assert len(results["data"]) == 0 def test_quick_search_query_no_results( @@ -93,9 +98,11 @@ def test_quick_search_query_no_results( :return: None """ results = quick_search_query( - search="Nonexistent Source", - location="Nonexistent Location", - conn=connection_with_test_data, + SearchParameters( + search="Nonexistent Source", + location="Nonexistent Location", + ), + cursor=connection_with_test_data.cursor(), ) assert len(results["data"]) == 0 @@ -125,10 +132,10 @@ def test_quick_search_query_wrapper_happy_path( mock_quick_search_query, mock_make_response ): mock_quick_search_query.return_value = [{"record_type": "Type A"}] - mock_conn = MagicMock() - quick_search_query_wrapper(arg1="Source 1", arg2="City A", conn=mock_conn) + mock_cursor = MagicMock() + quick_search_query_wrapper(arg1="Source 1", arg2="City A", cursor=mock_cursor) mock_quick_search_query.assert_called_with( - search="Source 1", location="City A", conn=mock_conn + SearchParameters(search="Source 1", location="City A"), cursor=mock_cursor ) mock_make_response.assert_called_with([{"record_type": "Type A"}], 200) @@ -139,16 +146,83 @@ def test_quick_search_query_wrapper_exception( mock_quick_search_query.side_effect = Exception("Test Exception") arg1 = "Source 1" arg2 = "City A" - mock_conn = MagicMock() - quick_search_query_wrapper(arg1=arg1, arg2=arg2, conn=mock_conn) + mock_cursor = MagicMock() + quick_search_query_wrapper(arg1=arg1, arg2=arg2, cursor=mock_cursor) mock_quick_search_query.assert_called_with( - search=arg1, location=arg2, conn=mock_conn + SearchParameters(search=arg1, location=arg2), cursor=mock_cursor ) - mock_conn.rollback.assert_called_once() user_message = "There was an error during the search operation" mock_post_to_webhook.assert_called_with( - json.dumps({'content': 'There was an error during the search operation: Test Exception\nSearch term: Source 1\nLocation: City A'}) - ) - mock_make_response.assert_called_with( - {"count": 0, "message": user_message}, 500 + json.dumps( + { + "content": "There was an error during the search operation: Test Exception\nSearch term: Source 1\nLocation: City A" + } + ) ) + mock_make_response.assert_called_with({"count": 0, "message": user_message}, 500) + + +# Test cases +@pytest.fixture +def sample_data_source_matches(): + return [ + { + "airtable_uid": "id1", + "field1": "value1", + "field_datetime": datetime(2020, 1, 1), + "field_array": '["abc","def"]', + }, + { + "airtable_uid": "id2", + "field2": "value2", + "field_datetime": datetime(2021, 2, 2), + "field_array": '["hello, world"]', + }, + ] + + +def test_process_data_source_matches(sample_data_source_matches): + expected_converted = [ + { + "airtable_uid": "id1", + "field1": "value1", + "field_datetime": "2020-01-01", + "field_array": ["abc", "def"], + }, + { + "airtable_uid": "id2", + "field2": "value2", + "field_datetime": "2021-02-02", + "field_array": ["hello, world"], + }, + ] + expected_ids = ["id1", "id2"] + + result = process_data_source_matches(sample_data_source_matches) + + assert result.converted == expected_converted + assert result.ids == expected_ids + + +def test_depluralize_with_plural_words(): + term = "apples oranges boxes" + expected = "apple orange box" + assert depluralize(term) == expected + + +def test_depluralize_with_singular_words(): + term = "apple orange box" + expected = "apple orange box" + assert depluralize(term) == expected + + +def test_depluralize_with_mixed_words(): + term = "apples orange box" + expected = "apple orange box" + assert depluralize(term) == expected + + +def test_depluralize_with_empty_string(): + term = "" + expected = "" + assert depluralize(term) == expected diff --git a/tests/middleware/test_security.py b/tests/middleware/test_security.py index 16b0c86f..70d0069f 100644 --- a/tests/middleware/test_security.py +++ b/tests/middleware/test_security.py @@ -134,6 +134,7 @@ def test_admin_only_action_with_admin_role(dev_db_connection): result = validate_api_key(api_key, "datasources", "PUT") assert result is None + @pytest.fixture def app() -> Flask: app = Flask(__name__) @@ -180,16 +181,23 @@ def test_api_required_happy_path( def test_api_required_api_key_expired( app, client, mock_request_headers, mock_validate_api_key, dummy_route: Callable ): - mock_validate_api_key.side_effect = ExpiredAPIKeyError("The provided API key has expired") + mock_validate_api_key.side_effect = ExpiredAPIKeyError( + "The provided API key has expired" + ) with app.test_request_context(headers={"Authorization": "Bearer valid_api_key"}): response = dummy_route() - assert response == ({"message": "The provided API key has expired"}, HTTPStatus.UNAUTHORIZED.value) + assert response == ( + {"message": "The provided API key has expired"}, + HTTPStatus.UNAUTHORIZED.value, + ) def test_api_required_expired_api_key( app, client, mock_request_headers, mock_validate_api_key, dummy_route: Callable ): - mock_validate_api_key.side_effect = ExpiredAPIKeyError("The provided API key has expired") + mock_validate_api_key.side_effect = ExpiredAPIKeyError( + "The provided API key has expired" + ) with app.test_request_context(headers={"Authorization": "Bearer expired_api_key"}): response = dummy_route() assert response == ( diff --git a/tests/resources/__init__.py b/tests/resources/__init__.py index b8dfa95d..be374293 100644 --- a/tests/resources/__init__.py +++ b/tests/resources/__init__.py @@ -1,4 +1,5 @@ # The below line is required to bypass the api_required decorator, # and must be positioned prior to other imports in order to work. from unittest.mock import patch, MagicMock -patch("middleware.security.api_required", lambda x: x).start() \ No newline at end of file + +patch("middleware.security.api_required", lambda x: x).start() diff --git a/tests/resources/test_DataSources.py b/tests/resources/test_DataSources.py index bfdc1628..a14207e4 100644 --- a/tests/resources/test_DataSources.py +++ b/tests/resources/test_DataSources.py @@ -1,12 +1,12 @@ # The below line is required to bypass the api_required decorator, # and must be positioned prior to other imports in order to work. from unittest.mock import patch, MagicMock + patch("middleware.security.api_required", lambda x: x).start() from tests.fixtures import client_with_mock_db -def test_put_data_source_by_id( - client_with_mock_db, monkeypatch -): + +def test_put_data_source_by_id(client_with_mock_db, monkeypatch): monkeypatch.setattr("resources.DataSources.request", MagicMock()) # mock_request.get_json.return_value = {"name": "Updated Data Source"} diff --git a/tests/resources/test_RefreshSession.py b/tests/resources/test_RefreshSession.py index 1346fa78..30eb7a57 100644 --- a/tests/resources/test_RefreshSession.py +++ b/tests/resources/test_RefreshSession.py @@ -114,7 +114,9 @@ def test_post_refresh_session_unexpected_error( :param client_with_mock_db: :return: """ - mock_get_session_token_user_data.side_effect = Exception("An unexpected error occurred") + mock_get_session_token_user_data.side_effect = Exception( + "An unexpected error occurred" + ) response = client_with_mock_db.client.post( "/refresh-session", json={ diff --git a/tests/utilities/test_managed_cursor.py b/tests/utilities/test_managed_cursor.py new file mode 100644 index 00000000..3e8302c7 --- /dev/null +++ b/tests/utilities/test_managed_cursor.py @@ -0,0 +1,56 @@ +from utilities.managed_cursor import managed_cursor +import uuid +from tests.fixtures import dev_db_connection + +SQL_TEST = """ + INSERT INTO test_table (name) VALUES (%s) +""" + + +class TestException(Exception): + pass + + +def test_managed_cursor_rollback(dev_db_connection): + """ + When an exception occurs, + the managed_cursor will rollback any changes made + and close the cursor + """ + name = str(uuid.uuid4()) + try: + with managed_cursor(dev_db_connection) as cursor: + cursor.execute(SQL_TEST, (name,)) + raise TestException + except TestException: + pass + assert cursor.closed == 1, "Cursor should be closed after exiting context manager" + cursor = dev_db_connection.cursor() + cursor.execute("SELECT * FROM test_table WHERE name = %s", (name,)) + result = cursor.fetchall() + cursor.close() + assert ( + len(result) == 0, + "Any transactions should be rolled back when an " + "exception is raised in the context of the managed cursor", + ) + + +def test_managed_cursors_happy_path(dev_db_connection): + """ + When no exception occurs, + the changes will be committed and the cursor will be closed + """ + name = str(uuid.uuid4()) + with managed_cursor(dev_db_connection) as cursor: + cursor.execute(SQL_TEST, (name,)) + assert cursor.closed == 1, "Cursor should be closed after exiting context manager" + cursor = dev_db_connection.cursor() + cursor.execute("SELECT * FROM test_table WHERE name = %s", (name,)) + result = cursor.fetchall() + cursor.close() + assert ( + len(result) == 1, + "Any transactions should persist in the absence of an exception " + "raised in the context of the managed cursor", + ) diff --git a/utilities/managed_cursor.py b/utilities/managed_cursor.py new file mode 100644 index 00000000..d15736c7 --- /dev/null +++ b/utilities/managed_cursor.py @@ -0,0 +1,27 @@ +from contextlib import contextmanager +from typing import Iterator + +import psycopg2 + + +@contextmanager +def managed_cursor( + connection: psycopg2.extensions.connection, +) -> Iterator[psycopg2.extensions.cursor]: + """ + Manage a cursor for a given database connection. + + :param connection: The psycopg2 database connection. + :return: Iterator that yields the cursor + and automatically commits changes on successful completion, + or rolls back changes and closes the cursor on failure. + """ + cursor = connection.cursor() + try: + yield cursor + connection.commit() + except Exception as e: + connection.rollback() + raise e + finally: + cursor.close()