diff --git a/chord_metadata_service/discovery/fields.py b/chord_metadata_service/discovery/fields.py index 5f36eb84a..f21017139 100644 --- a/chord_metadata_service/discovery/fields.py +++ b/chord_metadata_service/discovery/fields.py @@ -232,6 +232,11 @@ async def get_categorical_stats(field_props: DiscoveryFieldProps, low_counts_cen model, field_name = get_model_and_field(field_props["mapping"]) + # Collect stats for the field, censoring low cell counts along the way + # - We cannot append 0-counts for derived labels, since that indicates there is a non-0 count for this label in the + # database - i.e., if the label is pulled from the values in the database, someone could otherwise learn + # 1 <= this field <= threshold given it being present at all. + # - stats_for_field(...) handles this! stats: Mapping[str, int] = await stats_for_field(model, field_name, low_counts_censored, add_missing=True) # Enforce values order from config and apply policies @@ -242,35 +247,21 @@ async def get_categorical_stats(field_props: DiscoveryFieldProps, low_counts_cen # dataset (enum is null in the public JSON). # - Here, apply lexical sort, and exclude the "missing" value which will # be appended at the end if it is set. - # - Note that in this situation, we explictly MUST remove rounded-down 0-counts - # (below the threshold) below, otherwise we LEAK that there is 1 <= x <= threshold - # matching entries in the DB. + # - Note that in this situation, we explictly MUST HAVE remove rounded-down 0-counts (below the threshold) below, + # otherwise we LEAK that there is 1 <= x <= threshold matching entries in the DB. However, since + # stats_for_field(...) has already handled not adding these keys, these labels don't make it into this list. if derived_labels: labels = sorted( [k for k in stats.keys() if k != "missing"], key=lambda x: x.lower() ) - bins: list[BinWithValue] = [] - - for category in labels: - # Censor small counts by rounding them to 0 - v: int = thresholded_count(stats.get(category, 0), low_counts_censored) - - if v == 0 and derived_labels: - # We cannot append 0-counts for derived labels, since that indicates - # there is a non-0 count for this label in the database - i.e., if the label is pulled - # from the values in the database, someone could otherwise learn 1 <= this field <= threshold - # given it being present at all. - continue - - # Otherwise (pre-made labels, so we aren't leaking anything), keep the 0-count. - - bins.append({"label": category, "value": v}) - - bins.append({"label": "missing", "value": stats["missing"]}) - - return bins + # Create bin structures for each label, and add an extra `missing` bin for items missing a value for this field. + return [ + # Don't need to re-censor counts - we've already censored them in stats_for_field(...): + *({"label": category, "value": stats.get(category, 0)} for category in labels), + {"label": "missing", "value": stats["missing"]}, + ] async def get_date_stats(field_props: DiscoveryFieldProps, low_counts_censored: bool = True) -> list[BinWithValue]: diff --git a/chord_metadata_service/discovery/tests/test_fields.py b/chord_metadata_service/discovery/tests/test_fields.py index c8d268c8b..2bd552acb 100644 --- a/chord_metadata_service/discovery/tests/test_fields.py +++ b/chord_metadata_service/discovery/tests/test_fields.py @@ -2,12 +2,16 @@ from django.test import TransactionTestCase, override_settings from rest_framework.test import APITestCase +from chord_metadata_service.patients import models as pa_m +from chord_metadata_service.phenopackets.tests import constants as ph_c + from .constants import CONFIG_PUBLIC_TEST from ..fields import ( get_model_and_field, get_field_options, + get_categorical_stats, get_date_stats, - get_month_date_range + get_month_date_range, ) @@ -51,6 +55,35 @@ async def test_get_field_options_not_impl(self): await get_field_options({**self.field_some_prop, "datatype": "made_up"}, low_counts_censored=False) +class TestGetCategoricalStats(TransactionTestCase): + + f_sex = { + "mapping": "individual/sex", + "datatype": "string", + "title": "Sex", + "description": "Sex", + "config": { + "enum": None, + }, + } + + def setUp(self): + self.individual_1 = pa_m.Individual.objects.create(**ph_c.VALID_INDIVIDUAL_1) + + async def test_categorical_stats_lcf(self): + import sys + res = await get_categorical_stats(self.f_sex, low_counts_censored=False) + print("AAAAA", file=sys.stderr) + self.assertListEqual(res, [{"label": "MALE", "value": 1}, {"label": "missing", "value": 0}]) + + @override_settings(CONFIG_PUBLIC=CONFIG_PUBLIC_TEST) + async def test_categorical_stats_lct(self): + import sys + res = await get_categorical_stats(self.f_sex, low_counts_censored=True) + print("BBBBB", file=sys.stderr) + self.assertListEqual(res, [{"label": "missing", "value": 0}]) + + class TestDateStatsExcept(APITestCase): @override_settings(CONFIG_PUBLIC=CONFIG_PUBLIC_TEST)