diff --git a/database_client/database_client.py b/database_client/database_client.py index 627890e2..131d4d52 100644 --- a/database_client/database_client.py +++ b/database_client/database_client.py @@ -934,12 +934,14 @@ def _select_single_entry_from_relation( columns: list[str], where_mappings: Optional[Union[list[WhereMapping], dict]] = [True], subquery_parameters: Optional[list[SubqueryParameters]] = [], + **kwargs, ) -> Any: results = self._select_from_relation( relation_name=relation_name, columns=columns, where_mappings=where_mappings, subquery_parameters=subquery_parameters, + **kwargs, ) if len(results) == 0: return None @@ -1151,9 +1153,6 @@ def get_linked_rows( build_metadata=False, subquery_parameters: Optional[list[SubqueryParameters]] = [], ): - LinkTable = SQL_ALCHEMY_TABLE_REFERENCE[link_table.value] - LinkedRelation = SQL_ALCHEMY_TABLE_REFERENCE[linked_relation.value] - # Get ids via linked table link_results = self._select_from_relation( relation_name=link_table.value, @@ -1214,6 +1213,7 @@ def _build_column_references( linked_relation_linking_column="id", columns_to_retrieve=["state_name", "county_name", "locality_name", "id"], build_metadata=True, + alias_mappings={"id": "location_id"}, ) DataRequestIssueInfo = namedtuple( @@ -1484,4 +1484,5 @@ def get_location_by_id(self, location_id: int): "id", ], where_mappings={"id": location_id}, + alias_mappings={"id": "location_id"}, ) diff --git a/database_client/dynamic_query_constructor.py b/database_client/dynamic_query_constructor.py index 06783692..2f3e35d5 100644 --- a/database_client/dynamic_query_constructor.py +++ b/database_client/dynamic_query_constructor.py @@ -5,7 +5,7 @@ from psycopg import sql from sqlalchemy import select -from sqlalchemy.orm import load_only +from sqlalchemy.orm import load_only, InstrumentedAttribute, aliased from sqlalchemy.schema import Column from database_client.constants import ( @@ -431,7 +431,9 @@ def create_selection_query( primary_relation_columns = columns if alias_mappings is not None: - DynamicQueryConstructor.apply_alias_mappings(columns, alias_mappings) + primary_relation_columns = DynamicQueryConstructor.apply_alias_mappings( + columns, alias_mappings + ) base_query = ( lambda: select(*primary_relation_columns) @@ -556,5 +558,17 @@ def get_distinct_source_urls_query(url: str) -> sql.Composed: return query @staticmethod - def apply_alias_mappings(columns, alias_mappings): - pass + def apply_alias_mappings( + columns: list[InstrumentedAttribute], alias_mappings: dict[str, str] + ): + aliased_columns = [] + + for column in columns: + # Alias column if it exists in the alias mappings + key = column.key + if key in alias_mappings: + aliased_columns.append(column.label(alias_mappings[key])) + else: + aliased_columns.append(column) + + return aliased_columns diff --git a/database_client/subquery_logic.py b/database_client/subquery_logic.py index f6bffcac..a5858e3d 100644 --- a/database_client/subquery_logic.py +++ b/database_client/subquery_logic.py @@ -46,10 +46,16 @@ class SubqueryParameterManager: @staticmethod def get_subquery_params( - relation: Relations, linking_column: str, columns: list[str] = None + relation: Relations, + linking_column: str, + columns: list[str] = None, + alias_mappings: Optional[dict[str, str]] = None, ) -> SubqueryParameters: return SubqueryParameters( - relation_name=relation.value, linking_column=linking_column, columns=columns + relation_name=relation.value, + linking_column=linking_column, + columns=columns, + alias_mappings=alias_mappings, ) agencies = partialmethod( @@ -98,4 +104,5 @@ def locations(): "locality_name", "display_name", ], + alias_mappings={"id": "location_id"}, ) diff --git a/middleware/dynamic_request_logic/get_related_resource_logic.py b/middleware/dynamic_request_logic/get_related_resource_logic.py index 5bfe1664..e6792cdc 100644 --- a/middleware/dynamic_request_logic/get_related_resource_logic.py +++ b/middleware/dynamic_request_logic/get_related_resource_logic.py @@ -31,6 +31,7 @@ class GetRelatedResourcesParameters: def get_related_resource( get_related_resources_parameters: GetRelatedResourcesParameters, permitted_columns: Optional[list] = None, + alias_mappings: Optional[dict] = None, ) -> Response: # Technically, it'd make more sense as "grrp", # but "gerp" rolls off the tongue better @@ -54,6 +55,7 @@ def get_related_resource( relation=gerp.related_relation, linking_column=gerp.linking_column, columns=permitted_columns, + alias_mappings=alias_mappings, ) ] where_mappings = [ diff --git a/middleware/primary_resource_logic/data_requests.py b/middleware/primary_resource_logic/data_requests.py index 06a2908d..2766d5ed 100644 --- a/middleware/primary_resource_logic/data_requests.py +++ b/middleware/primary_resource_logic/data_requests.py @@ -396,6 +396,7 @@ def get_data_request_related_locations( "locality_name", "type", ], + alias_mappings={"id": "location_id"}, ) diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/locations_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/locations_schemas.py index 70d27b0f..568bb3a4 100644 --- a/middleware/schema_and_dto_logic/primary_resource_schemas/locations_schemas.py +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/locations_schemas.py @@ -72,7 +72,7 @@ class LocationInfoResponseSchema(Schema): required=True, metadata=get_json_metadata(description="The display name for the location"), ) - id = LOCATION_ID_FIELD + location_id = LOCATION_ID_FIELD class LocationInfoSchema(Schema): @@ -80,7 +80,7 @@ class LocationInfoSchema(Schema): state_iso = STATE_ISO_FIELD county_fips = COUNTY_FIPS_FIELD locality_name = LOCALITY_NAME_FIELD - id = LOCATION_ID_FIELD + location_id = LOCATION_ID_FIELD @validates_schema def validate_location_fields(self, data, **kwargs): diff --git a/middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py b/middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py index 45997252..7d086a94 100644 --- a/middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py +++ b/middleware/schema_and_dto_logic/primary_resource_schemas/search_schemas.py @@ -175,7 +175,7 @@ class FollowSearchResponseSchema(Schema): "The locality of the search. If empty, all localities for the given county will be searched." ), ) - id = fields.Int( + location_id = fields.Int( required=True, metadata=get_json_metadata("The location ID of the search."), ) diff --git a/tests/integration/test_data_requests.py b/tests/integration/test_data_requests.py index 2dfbcfbc..b6852a59 100644 --- a/tests/integration/test_data_requests.py +++ b/tests/integration/test_data_requests.py @@ -642,7 +642,7 @@ def post_location_association( data = get_locations() assert data == [ { - "id": location_id, + "location_id": location_id, "state_name": "Pennsylvania", "state_iso": "PA", "county_name": "Allegheny", diff --git a/tests/integration/test_locations.py b/tests/integration/test_locations.py index 9ba435cc..87af7c0c 100644 --- a/tests/integration/test_locations.py +++ b/tests/integration/test_locations.py @@ -28,7 +28,7 @@ def locations_test_setup(test_data_creator_flask: TestDataCreatorFlask): "county_name": "Allegheny", "county_fips": "42003", "locality_name": locality_name, - "id": location_id, + "location_id": location_id, } return LocationsTestSetup(tdc=tdc, location_info=loc_info) @@ -39,7 +39,7 @@ def test_locations_get_by_id(locations_test_setup: LocationsTestSetup): # Get location, confirm information matches data = tdc.request_validator.get_location_by_id( - location_id=lts.location_info["id"], + location_id=lts.location_info["location_id"], headers=tdc.get_admin_tus().jwt_authorization_header, expected_json_content=lts.location_info, ) @@ -48,15 +48,16 @@ def test_locations_get_by_id(locations_test_setup: LocationsTestSetup): def test_locations_related_data_requests(locations_test_setup: LocationsTestSetup): lts = locations_test_setup tdc = lts.tdc + location_id = lts.location_info["location_id"] # Add two data requests to location - dr_1 = tdc.data_request(location_ids=[lts.location_info["id"]]).id - dr_2 = tdc.data_request(location_ids=[lts.location_info["id"]]).id + dr_1 = tdc.data_request(location_ids=[location_id]).id + dr_2 = tdc.data_request(location_ids=[location_id]).id # Get data requests tus = tdc.standard_user() data = tdc.request_validator.get_location_related_data_requests( - location_id=lts.location_info["id"], + location_id=location_id, headers=tus.api_authorization_header, ) @@ -65,8 +66,8 @@ def test_locations_related_data_requests(locations_test_setup: LocationsTestSetu # Confirm also works with jwt data = tdc.request_validator.get_location_related_data_requests( - location_id=lts.location_info["id"], + location_id=location_id, headers=tus.jwt_authorization_header, )["data"] assert data[0]["locations"] == data[1]["locations"] - assert data[0]["locations"][0]["id"] == lts.location_info["id"] + assert data[0]["locations"][0]["location_id"] == location_id diff --git a/tests/integration/test_search.py b/tests/integration/test_search.py index 07134308..b91d9ad5 100644 --- a/tests/integration/test_search.py +++ b/tests/integration/test_search.py @@ -262,7 +262,7 @@ def follow_extant_location( "state_name": TEST_STATE, "county_name": TEST_COUNTY, "locality_name": TEST_LOCALITY, - "id": sts.location_id, + "location_id": sts.location_id, } ], "message": "Followed searches found.",