diff --git a/src/main/java/org/highmed/numportal/service/AqlService.java b/src/main/java/org/highmed/numportal/service/AqlService.java index 2b2f0512..fdbca8ac 100644 --- a/src/main/java/org/highmed/numportal/service/AqlService.java +++ b/src/main/java/org/highmed/numportal/service/AqlService.java @@ -250,19 +250,19 @@ public long getAqlSize(SlimAqlDto aql, String userId) { validateQuery(aql.getQuery()); - Set ehrIds; + int numberOfPatients; try { - ehrIds = - ehrBaseService.retrieveEligiblePatientIds(Aql.builder().query(aql.getQuery()).build()); + numberOfPatients = + ehrBaseService.retrieveNumberOfPatients(Aql.builder().query(aql.getQuery()).build()); } catch (AqlParseException e) { throw new BadRequestException(AqlParseException.class, e.getLocalizedMessage(), e.getMessage()); } - if (ehrIds.size() < privacyProperties.getMinHits()) { + if (numberOfPatients < privacyProperties.getMinHits()) { log.warn(TOO_FEW_MATCHES_RESULTS_WITHHELD_FOR_PRIVACY_REASONS); throw new PrivacyException(AqlService.class, TOO_FEW_MATCHES_RESULTS_WITHHELD_FOR_PRIVACY_REASONS); } - return ehrIds.size(); + return numberOfPatients; } public List getAqlCategories() { diff --git a/src/main/java/org/highmed/numportal/service/ehrbase/EhrBaseService.java b/src/main/java/org/highmed/numportal/service/ehrbase/EhrBaseService.java index 8e8f79c2..1db0a363 100644 --- a/src/main/java/org/highmed/numportal/service/ehrbase/EhrBaseService.java +++ b/src/main/java/org/highmed/numportal/service/ehrbase/EhrBaseService.java @@ -11,6 +11,7 @@ import org.apache.commons.lang3.StringUtils; import org.ehrbase.openehr.sdk.aql.dto.AqlQuery; import org.ehrbase.openehr.sdk.aql.dto.containment.ContainmentClassExpression; +import org.ehrbase.openehr.sdk.aql.dto.operand.CountDistinctAggregateFunction; import org.ehrbase.openehr.sdk.aql.dto.operand.IdentifiedPath; import org.ehrbase.openehr.sdk.aql.dto.path.AqlObjectPath; import org.ehrbase.openehr.sdk.aql.dto.select.SelectExpression; @@ -118,6 +119,51 @@ public Set retrieveEligiblePatientIds(String query) { } } + /** + * Retrieves the number of patients for the given aql + * + * @param aql The aql to retrieve patient ids for + * @return number of patients + * @throws WrongStatusCodeException in case if a malformed aql + */ + public int retrieveNumberOfPatients(Aql aql) { + return retrieveNumberOfPatients(aql.getQuery()); + } + + public int retrieveNumberOfPatients(String query) { + log.debug("EhrBase retrieve number of patients for query: {} ", query); + AqlQuery dto = AqlQueryParser.parse(query); + SelectExpression selectExpression = new SelectExpression(); + + var count = new CountDistinctAggregateFunction(); + selectExpression.setColumnExpression(count); + + IdentifiedPath ehrIdPath = new IdentifiedPath(); + ehrIdPath.setPath(AqlObjectPath.parse(AqlQueryConstants.EHR_ID_PATH)); + + ContainmentClassExpression containmentClassExpression = new ContainmentClassExpression(); + containmentClassExpression.setType(AqlQueryConstants.EHR_TYPE); + containmentClassExpression.setIdentifier(AqlQueryConstants.EHR_CONTAINMENT_IDENTIFIER); + ehrIdPath.setRoot(containmentClassExpression); + + count.setIdentifiedPath(ehrIdPath); + + dto.getSelect().setStatement(List.of(selectExpression)); + log.info("Generated query for retrieveNumberOfPatients {} ", AqlRenderer.render(dto)); + + try { + List> results = restClient.aqlEndpoint().execute(Query.buildNativeQuery(AqlRenderer.render(dto), Integer.class)); + return results.get(0).value1(); + } catch (WrongStatusCodeException e) { + log.error(INVALID_AQL_QUERY, e.getMessage(), e); + throw new WrongStatusCodeException("EhrBaseService.class", 93, 1); + } catch (ClientException e) { + log.error(ERROR_MESSAGE, e.getMessage(), e); + throw new SystemException(EhrBaseService.class, AN_ERROR_HAS_OCCURRED_CANNOT_EXECUTE_AQL, + String.format(AN_ERROR_HAS_OCCURRED_CANNOT_EXECUTE_AQL, e.getMessage())); + } + } + /** * Executes a raw aql query * diff --git a/src/test/java/org/highmed/numportal/service/AqlServiceTest.java b/src/test/java/org/highmed/numportal/service/AqlServiceTest.java index f02d609f..3f68df87 100644 --- a/src/test/java/org/highmed/numportal/service/AqlServiceTest.java +++ b/src/test/java/org/highmed/numportal/service/AqlServiceTest.java @@ -171,7 +171,7 @@ public void getAqlSizeTest() { SlimAqlDto aqlDto = SlimAqlDto.builder() .query("select * from dummy_table") .build(); - Mockito.when(ehrBaseService.retrieveEligiblePatientIds(Mockito.any(Aql.class))).thenReturn(new HashSet<>(Arrays.asList("id1", "id2", "id3", "id4"))); + Mockito.when(ehrBaseService.retrieveNumberOfPatients(Mockito.any(Aql.class))).thenReturn(4); aqlService.getAqlSize(aqlDto, "4"); } @@ -180,7 +180,7 @@ public void shouldHandlePrivacyExceptionWhenGetAqlSize() { SlimAqlDto aqlDto = SlimAqlDto.builder() .query("select * from dummy_table") .build(); - Mockito.when(ehrBaseService.retrieveEligiblePatientIds(Mockito.any(Aql.class))).thenReturn(new HashSet<>(Arrays.asList("id1"))); + Mockito.when(ehrBaseService.retrieveNumberOfPatients(Mockito.any(Aql.class))).thenReturn(1); aqlService.getAqlSize(aqlDto, "4"); }