Skip to content

Commit

Permalink
Replace state/county/location search parameters with location_id as s…
Browse files Browse the repository at this point in the history
…earch parameter
  • Loading branch information
maxachis committed Dec 13, 2024
1 parent 7c6cdbe commit f420040
Show file tree
Hide file tree
Showing 11 changed files with 129 additions and 233 deletions.
8 changes: 2 additions & 6 deletions database_client/database_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,10 +524,8 @@ def get_typeahead_agencies(self, search_term: str) -> dict:
@cursor_manager()
def search_with_location_and_record_type(
self,
state: str,
location_id: int,
record_categories: Optional[list[RecordCategories]] = None,
county: Optional[str] = None,
locality: Optional[str] = None,
) -> List[dict]:
"""
Searches for data sources in the database.
Expand All @@ -540,10 +538,8 @@ def search_with_location_and_record_type(
"""
optional_kwargs = {}
query = DynamicQueryConstructor.create_search_query(
state=state,
location_id=location_id,
record_categories=record_categories,
county=county,
locality=locality,
)
self.cursor.execute(query)
return self.cursor.fetchall()
Expand Down
26 changes: 6 additions & 20 deletions database_client/dynamic_query_constructor.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,15 +223,13 @@ def generate_new_typeahead_agencies_query(search_term: str):

@staticmethod
def create_search_query(
state: str,
location_id: int,
record_categories: Optional[list[RecordCategories]] = None,
county: Optional[str] = None,
locality: Optional[str] = None,
) -> sql.Composed:

base_query = sql.SQL(
"""
SELECT
SELECT DISTINCT
data_sources.id,
data_sources.name AS data_source_name,
data_sources.description,
Expand All @@ -255,14 +253,16 @@ def create_search_query(
locations_expanded on agencies.location_id = locations_expanded.id
INNER JOIN
record_types on record_types.id = data_sources.record_type_id
LEFT JOIN
DEPENDENT_LOCATIONS DL ON DL.DEPENDENT_LOCATION_ID = LOCATIONS_EXPANDED.ID
"""
)

join_conditions = []
where_subclauses = [
sql.SQL(
"LOWER(locations_expanded.state_name) = LOWER({state_name})"
).format(state_name=sql.Literal(state)),
"(locations_expanded.id = {location_id} OR DL.PARENT_LOCATION_ID = {location_id}) "
).format(location_id=sql.Literal(location_id)),
sql.SQL("data_sources.approval_status = 'approved'"),
sql.SQL("data_sources.url_status NOT IN ('broken', 'none found')"),
]
Expand All @@ -286,20 +286,6 @@ def create_search_query(
)
)

if county is not None:
where_subclauses.append(
sql.SQL(
"LOWER(locations_expanded.county_name) = LOWER({county_name})"
).format(county_name=sql.Literal(county))
)

if locality is not None:
where_subclauses.append(
sql.SQL(
"LOWER(locations_expanded.locality_name) = LOWER({locality})"
).format(locality=sql.Literal(locality))
)

query = sql.Composed(
[
base_query,
Expand Down
21 changes: 16 additions & 5 deletions middleware/dynamic_request_logic/post_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,22 @@ def call_database_client_method(self):
self.id_val = self.mp.db_client_method(
self.mp.db_client, column_value_mappings=self.entry
)
except sqlalchemy.exc.IntegrityError:
FlaskResponseManager.abort(
code=HTTPStatus.CONFLICT,
message=f"{self.mp.entry_name} already exists.",
)
except sqlalchemy.exc.IntegrityError as e:
if e.orig.sqlstate == "23505":
FlaskResponseManager.abort(
code=HTTPStatus.CONFLICT,
message=f"{self.mp.entry_name} already exists.",
)
elif e.orig.sqlstate == "23503":
FlaskResponseManager.abort(
code=HTTPStatus.BAD_REQUEST,
message=f"{self.mp.entry_name} not found.",
)
else:
FlaskResponseManager.abort(
code=HTTPStatus.INTERNAL_SERVER_ERROR,
message=f"Error creating {self.mp.entry_name}.",
)

def make_response(self) -> Response:
return created_id_response(
Expand Down
36 changes: 13 additions & 23 deletions middleware/primary_resource_logic/search_logic.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from middleware.enums import JurisdictionSimplified, Relations, OutputFormatEnum
from middleware.flask_response_manager import FlaskResponseManager
from middleware.schema_and_dto_logic.primary_resource_schemas.search_schemas import (
SearchRequests,
SearchRequestsDTO,
)
from middleware.common_response_formatting import message_response
from middleware.util import get_datetime_now, write_to_csv, find_root_directory
Expand Down Expand Up @@ -100,16 +100,14 @@ def format_as_csv(ld: list[dict]) -> BytesIO:
def search_wrapper(
db_client: DatabaseClient,
access_info: AccessInfoPrimary,
dto: SearchRequests,
dto: SearchRequestsDTO,
) -> Response:
create_search_record(access_info, db_client, dto)
explicit_record_categories = get_explicit_record_categories(dto.record_categories)
search_results = db_client.search_with_location_and_record_type(
location_id=dto.location_id,
record_categories=explicit_record_categories,
state=dto.state,
# Pass modified record categories, which breaks down ALL into individual categories
county=dto.county,
locality=dto.locality,
)
return send_search_results(
search_results=search_results,
Expand All @@ -118,13 +116,9 @@ def search_wrapper(


def create_search_record(access_info, db_client, dto):
location_id = try_getting_location_id_and_raise_error_if_not_found(
db_client=db_client,
dto=dto,
)
db_client.create_search_record(
user_id=access_info.get_user_id(),
location_id=location_id,
location_id=dto.location_id,
# Pass originally provided record categories
record_categories=dto.record_categories,
)
Expand Down Expand Up @@ -170,7 +164,7 @@ def get_explicit_record_categories(

def try_getting_location_id_and_raise_error_if_not_found(
db_client: DatabaseClient,
dto: SearchRequests,
dto: SearchRequestsDTO,
) -> int:
where_mappings = WhereMapping.from_dict(
{
Expand Down Expand Up @@ -228,7 +222,7 @@ def make_response(self) -> Response:


def get_link_id_and_raise_error_if_not_found(
db_client: DatabaseClient, access_info: AccessInfoPrimary, dto: SearchRequests
db_client: DatabaseClient, access_info: AccessInfoPrimary, dto: SearchRequestsDTO
):
location_id = try_getting_location_id_and_raise_error_if_not_found(
db_client=db_client,
Expand All @@ -244,18 +238,14 @@ def get_link_id_and_raise_error_if_not_found(
def get_location_link_and_raise_error_if_not_found(
db_client: DatabaseClient,
access_info: AccessInfoPrimary,
dto: SearchRequests,
dto: SearchRequestsDTO,
):
location_id = try_getting_location_id_and_raise_error_if_not_found(
db_client=db_client,
dto=dto,
)
link_id = get_user_followed_search_link(
db_client=db_client,
access_info=access_info,
location_id=location_id,
location_id=dto.location_id,
)
return LocationLink(link_id=link_id, location_id=location_id)
return LocationLink(link_id=link_id, location_id=dto.location_id)


class LocationLink(BaseModel):
Expand All @@ -266,7 +256,7 @@ class LocationLink(BaseModel):
def create_followed_search(
db_client: DatabaseClient,
access_info: AccessInfoPrimary,
dto: SearchRequests,
dto: SearchRequestsDTO,
) -> Response:
# Get location id. If not found, not a valid location. Raise error
location_link = get_location_link_and_raise_error_if_not_found(
Expand All @@ -279,7 +269,7 @@ def create_followed_search(

return post_entry(
middleware_parameters=MiddlewareParameters(
entry_name="followed search",
entry_name="Location for followed search",
relation=Relations.LINK_USER_FOLLOWED_LOCATION.value,
db_client_method=DatabaseClient.create_followed_search,
access_info=access_info,
Expand All @@ -296,7 +286,7 @@ def create_followed_search(
def delete_followed_search(
db_client: DatabaseClient,
access_info: AccessInfoPrimary,
dto: SearchRequests,
dto: SearchRequestsDTO,
) -> Response:
# Get location id. If not found, not a valid location. Raise error
location_link = get_location_link_and_raise_error_if_not_found(
Expand All @@ -310,7 +300,7 @@ def delete_followed_search(

return delete_entry(
middleware_parameters=MiddlewareParameters(
entry_name="Followed search",
entry_name="Location for followed search",
relation=Relations.LINK_USER_FOLLOWED_LOCATION.value,
db_client_method=DatabaseClient.delete_followed_search,
access_info=access_info,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from middleware.enums import OutputFormatEnum
from middleware.schema_and_dto_logic.schema_helpers import create_get_many_schema
from middleware.schema_and_dto_logic.util import get_json_metadata
from middleware.schema_and_dto_logic.util import get_json_metadata, get_query_metadata
from utilities.common import get_enums_from_string
from utilities.enums import RecordCategories, SourceMappingEnum, ParserLocation

Expand All @@ -18,13 +18,9 @@ def transform_record_categories(value: str) -> Optional[list[RecordCategories]]:


class SearchRequestSchema(Schema):
state = fields.Str(
required=True,
metadata={
"description": "The state of the search.",
"source": SourceMappingEnum.QUERY_ARGS,
"location": ParserLocation.QUERY.value,
},
location_id = fields.Int(
required=False,
metadata=get_query_metadata("The location ID of the search."),
)
record_categories = fields.Str(
required=False,
Expand All @@ -40,22 +36,6 @@ class SearchRequestSchema(Schema):
"location": ParserLocation.QUERY.value,
},
)
county = fields.Str(
required=False,
metadata={
"description": "The county of the search. If empty, all counties for the given state will be searched.",
"source": SourceMappingEnum.QUERY_ARGS,
"location": ParserLocation.QUERY.value,
},
)
locality = fields.Str(
required=False,
metadata={
"description": "The locality of the search. If empty, all localities for the given county will be searched.",
"source": SourceMappingEnum.QUERY_ARGS,
"location": ParserLocation.QUERY.value,
},
)
output_format = fields.Enum(
required=False,
enum=OutputFormatEnum,
Expand All @@ -68,12 +48,6 @@ class SearchRequestSchema(Schema):
},
)

@validates_schema
def validate_location_info(self, data, **kwargs):
if data.get("locality") and not data.get("county"):
raise ValidationError(
"If locality is provided, county must also be provided."
)


class SearchResultsInnerSchema(Schema):
Expand Down Expand Up @@ -210,9 +184,7 @@ class FollowSearchResponseSchema(Schema):
)


class SearchRequests(BaseModel):
state: str
class SearchRequestsDTO(BaseModel):

Check warning on line 187 in middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py

View workflow job for this annotation

GitHub Actions / flake8

[flake8] middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py#L187 <101>

Missing docstring in public class
Raw output
./middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py:187:1: D101 Missing docstring in public class
location_id: int
record_categories: Optional[list[RecordCategories]] = None
county: Optional[str] = None
locality: Optional[str] = None
output_format: Optional[OutputFormatEnum] = None
8 changes: 4 additions & 4 deletions resources/endpoint_schema_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@
from middleware.schema_and_dto_logic.primary_resource_schemas.search_schemas import (
SearchRequestSchema,
GetUserFollowedSearchesSchema,
SearchRequests,
SearchRequestsDTO,
SearchResponseSchema,
)
from middleware.schema_and_dto_logic.common_schemas_and_dtos import (
Expand Down Expand Up @@ -278,9 +278,9 @@ def schema_config_with_message_output(
)
SEARCH_FOLLOW_UPDATE = EndpointSchemaConfig(
input_schema=SearchRequestSchema(
exclude=["record_categories"],
exclude=["record_categories", "output_format"],
),
input_dto_class=SearchRequests,
input_dto_class=SearchRequestsDTO,
)


Expand Down Expand Up @@ -387,7 +387,7 @@ class SchemaConfigs(Enum):
SEARCH_LOCATION_AND_RECORD_TYPE_GET = EndpointSchemaConfig(
input_schema=SearchRequestSchema(),
primary_output_schema=SearchResponseSchema(),
input_dto_class=SearchRequests,
input_dto_class=SearchRequestsDTO,
)
SEARCH_FOLLOW_GET = EndpointSchemaConfig(
primary_output_schema=GetUserFollowedSearchesSchema(),
Expand Down
27 changes: 6 additions & 21 deletions tests/helper_scripts/helper_classes/RequestValidator.py
Original file line number Diff line number Diff line change
Expand Up @@ -268,18 +268,14 @@ def update_permissions(
def search(
self,
headers: dict,
state: str,
location_id: int,
record_categories: Optional[list[RecordCategories]] = None,
county: Optional[str] = None,
locality: Optional[str] = None,
format: Optional[OutputFormatEnum] = OutputFormatEnum.JSON,
):
endpoint_base = "/search/search-location-and-record-type"
query_params = self._get_search_query_params(
county=county,
locality=locality,
location_id=location_id,
record_categories=record_categories,
state=state,
)
query_params.update({} if format is None else {"output_format": format.value})
endpoint = add_query_params(
Expand All @@ -295,39 +291,28 @@ def search(
)

@staticmethod
def _get_search_query_params(county, locality, record_categories, state):
def _get_search_query_params(record_categories, location_id: int):
query_params = {
"state": state,
"location_id": location_id,
}
if record_categories is not None:
query_params["record_categories"] = ",".join(
[rc.value for rc in record_categories]
)
update_if_not_none(
dict_to_update=query_params,
secondary_dict={
"county": county,
"locality": locality,
},
)
return query_params

def follow_search(
self,
headers: dict,
state: str,
location_id: int,
record_categories: Optional[list[RecordCategories]] = None,
county: Optional[str] = None,
locality: Optional[str] = None,
expected_json_content: Optional[dict] = None,
expected_response_status: HTTPStatus = HTTPStatus.OK,
):
endpoint_base = "/api/search/follow"
query_params = self._get_search_query_params(
county=county,
locality=locality,
location_id=location_id,
record_categories=record_categories,
state=state,
)
endpoint = add_query_params(
url=endpoint_base,
Expand Down
Loading

0 comments on commit f420040

Please sign in to comment.