From 10aba66a152d77fcef911131d6f217c30209e83e Mon Sep 17 00:00:00 2001 From: Matthew Chatham <18629176+MatthewChatham@users.noreply.github.com> Date: Tue, 24 Jan 2023 18:21:48 -0800 Subject: [PATCH] add logic for max_studies > 100 --- pytrials/client.py | 48 +++++++++++++++++++++++++++++++++++--------- tests/test_client.py | 19 ++++++++++++++++-- 2 files changed, 56 insertions(+), 11 deletions(-) diff --git a/pytrials/client.py b/pytrials/client.py index 0ff21a2..2590357 100644 --- a/pytrials/client.py +++ b/pytrials/client.py @@ -41,8 +41,8 @@ def __api_info(self): return api_version, last_updated - def get_full_studies(self, search_expr, max_studies=50): - """Returns all content for a maximum of 100 study records. + def get_full_studies(self, search_expr, max_studies=None): + """Returns all content for a maximum of `max_studies` study records. Retrieves information from the full studies endpoint, which gets all study fields. This endpoint can only output JSON (Or not-supported XML) format and does not allow @@ -51,21 +51,51 @@ def get_full_studies(self, search_expr, max_studies=50): Args: search_expr (str): A string containing a search expression as specified by `their documentation `_. - max_studies (int): An integer indicating the maximum number of studies to return. - Defaults to 50. + max_studies (int; optional): An integer indicating the maximum number of studies to return. + Defaults to None, resulting in all studies being returned. Returns: - dict: Object containing the information queried with the search expression. + list: List of responses containing the information queried with the search expression. Raises: ValueError: The number of studies can only be between 1 and 100 """ - if max_studies > 100 or max_studies < 1: - raise ValueError("The number of studies can only be between 1 and 100") + if max_studies is not None and max_studies < 1: + raise ValueError("The number of studies must be at least 1") + + min_rnk = 1 + max_rnk = 100 if max_studies is None else min(100, max_studies) + req = "full_studies?expr={}&min_rnk={}&max_rnk={}&{}" + + full_studies = list() - req = f"full_studies?expr={search_expr}&max_rnk={max_studies}&{self._JSON}" + reqf = req.format(search_expr, min_rnk, max_rnk, self._JSON) + full_studies.append(json_handler(f"{self._BASE_URL}{self._QUERY}{reqf}")) - full_studies = json_handler(f"{self._BASE_URL}{self._QUERY}{req}") + if max_studies is None or max_studies > 100: + + n_studies_found = full_studies[0]["FullStudiesResponse"]["NStudiesFound"] + + min_rnk += 100 + max_rnk += 100 + + while_stop = ( + n_studies_found + if max_studies is None + else min(max_studies, n_studies_found) + ) + max_rnk = min(n_studies_found, max_rnk, while_stop) + + print(while_stop) + while min_rnk <= while_stop: + print(max_rnk) + reqf = req.format(search_expr, min_rnk, max_rnk, self._JSON) + full_studies.append( + json_handler(f"{self._BASE_URL}{self._QUERY}{reqf}") + ) + min_rnk += 100 + max_rnk += 100 + max_rnk = min(while_stop, max_rnk) return full_studies diff --git a/tests/test_client.py b/tests/test_client.py index 02700da..a525488 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -10,6 +10,10 @@ def test_full_studies(): fifty_studies = ct.get_full_studies(search_expr="Coronavirus+COVID", max_studies=50) + assert len(fifty_studies) == 1 + + fifty_studies = fifty_studies[0] + assert [*fifty_studies["FullStudiesResponse"].keys()] == [ "APIVrs", "DataVrs", @@ -30,6 +34,10 @@ def test_full_studies_max(): search_expr="Coronavirus+COVID", max_studies=100 ) + assert len(hundred_studies) == 1 + + hundred_studies = hundred_studies[0] + assert [*hundred_studies["FullStudiesResponse"].keys()] == [ "APIVrs", "DataVrs", @@ -51,8 +59,15 @@ def test_full_studies_below(): def test_full_studies_above(): - with raises(ValueError): - ct.get_full_studies(search_expr="Coronavirus+COVID", max_studies=150) + hundred_fifty_studies = ct.get_full_studies( + search_expr="Coronavirus+COVID", max_studies=150 + ) + + n = 0 + for r in hundred_fifty_studies: + n += len(r["FullStudiesResponse"]["FullStudies"]) + + assert n == 150 def test_study_fields_csv():