From c6074a05346192dd5b86375c78251be18347b126 Mon Sep 17 00:00:00 2001 From: Casey Schneider-Mizell Date: Thu, 17 Oct 2024 22:28:57 +0100 Subject: [PATCH] Inequality table integration (#252) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * first take before testing * Added materialization version endpoint check to query entry functions. * ruff/lint corrections. * remove unused dicts * Bump version: 5.30.2 → 5.31.0 * Remove factory syntax for subclients (#222) * refactor * remove pcg chunkedgraph (#224) * remove factory (#226) * remove factory (#225) * Remove JSONService and L2CacheClient factories (#227) * remove factory * remove factory * remove client mappings * remove more * Remove MaterializationClient factory (#229) * remove factory * possible smarter versioning * attempt tables merge * attempt views merge * refactor live_live_query * Redo docs to acomodate factory refactor (#247) * many doc refactors * fix formatting * add node to changelog * add a date * minor fixes * Small doc fixes (#248) * more small fixes * formatting * Bump version: 5.31.0 → 6.0.0 * initial table manager * tweak bool compares * validation check on inequalities * update docs * add simple properties to class * fix unintentional file changes with rebase * fix dangling doc * do what ruff tells me to do --------- Co-authored-by: Keith Wiley Co-authored-by: Ben Pedigo Co-authored-by: github-actions[bot] --- caveclient/tools/table_manager.py | 250 ++++++++++++++++++++++++++++-- docs/tutorials/materialization.md | 35 +++++ 2 files changed, 270 insertions(+), 15 deletions(-) diff --git a/caveclient/tools/table_manager.py b/caveclient/tools/table_manager.py index bba2a025..385d10a2 100644 --- a/caveclient/tools/table_manager.py +++ b/caveclient/tools/table_manager.py @@ -1,6 +1,7 @@ import logging import re import warnings +from itertools import chain import attrs from cachetools import TTLCache, cached, keys @@ -9,11 +10,16 @@ # json schema column types that can act as potential columns for looking at tables ALLOW_COLUMN_TYPES = ["integer", "boolean", "string", "float"] +NUMERIC_COLUMN_TYPES = ["integer", "float"] SPATIAL_POINT_TYPES = ["SpatialPoint"] # Helper functions for turning schema field names ot column names +class InvalidInequalityException(Exception): + pass + + def bound_pt_position(pt): return f"{pt}_position" @@ -76,6 +82,17 @@ def get_all_view_metadata(client): return views, view_schema +def _parse_inequality(x, ineq): + if isinstance(x, dict): + return ineq in x + else: + return False + + +def is_equal_like(x): + return not is_list_like(x) + + def is_list_like(x): if isinstance(x, str): return False @@ -85,6 +102,26 @@ def is_list_like(x): return False +def is_isin_like(x): + return is_list_like(x) and not isinstance(x, dict) + + +def is_lessthan(x): + return _parse_inequality(x, "<") + + +def is_greaterthan(x): + return _parse_inequality(x, ">") + + +def is_lessthan_equal(x): + return _parse_inequality(x, "<=") + + +def is_greaterthan_equal(x): + return _parse_inequality(x, ">=") + + def update_spatial_dict(spatial_dict): new_dict = {} for k in spatial_dict: @@ -144,7 +181,10 @@ def get_col_info( add_fields=["id"], omit_fields=[], schema_definition=None, + numeric_types=NUMERIC_COLUMN_TYPES, ): + nonnumeric_types = [t for t in allow_types if t not in numeric_types] + numeric_types = [t for t in allow_types if t in numeric_types] if schema_definition is None: schema = client.schema.schema_definition(schema_name) else: @@ -156,6 +196,7 @@ def get_col_info( add_cols = [] pt_names = [] unbnd_pt_names = [] + numeric_cols = [] for k, v in schema["definitions"][sn]["properties"].items(): if v.get("$ref", "") == sp_name: pt_names.append(k) @@ -166,9 +207,11 @@ def get_col_info( if k in omit_fields: continue # Field type is format if exists, type otherwise - if v.get("format", v.get("type")) in allow_types: + if v.get("format", v.get("type")) in nonnumeric_types: add_cols.append(k) - return pt_names, add_fields + add_cols, unbnd_pt_names + if v.get("format", v.get("type")) in numeric_types: + numeric_cols.append(k) + return pt_names, add_fields + add_cols, numeric_cols, unbnd_pt_names _table_cache = TTLCache(maxsize=128, ttl=86_400) @@ -260,18 +303,19 @@ def get_table_info( schema = meta["schema"] ref_pts = [] ref_cols = [] + ref_numeric = [] ref_unbd_pts = [] name_base = tn name_ref = None else: schema = table_metadata(ref_table, client).get("schema") - ref_pts, ref_cols, ref_unbd_pts = get_col_info( + ref_pts, ref_cols, ref_numeric, ref_unbd_pts = get_col_info( meta["schema"], client, allow_types=allow_types, omit_fields=["target_id"] ) name_base = ref_table name_ref = tn - base_pts, base_cols, base_unbd_pts = get_col_info( + base_pts, base_cols, base_numeric, base_unbd_pts = get_col_info( schema, client, allow_types=allow_types ) @@ -279,13 +323,24 @@ def get_table_info( name_base, base_pts, name_ref, ref_pts, suffixes ) all_vals, val_map, rename_map_val = combine_names( - name_base, base_cols, name_ref, ref_cols, suffixes + name_base, base_cols + base_numeric, name_ref, ref_cols + ref_numeric, suffixes ) + + numeric_vals = [] + for k in all_vals: + if val_map[k] == name_base: + if k in base_numeric: + numeric_vals.append(k) + elif val_map[k] == name_ref: + if k in ref_numeric: + numeric_vals.append(k) + all_unbd_pts, unbd_pt_map, rename_map_unbd_pt = combine_names( name_base, base_unbd_pts, name_ref, ref_unbd_pts, suffixes ) rename_map = {**rename_map_pt, **rename_map_val, **rename_map_unbd_pt} column_map = {"id": name_base, **pt_map, **val_map, **unbd_pt_map} + return ( all_pts, all_vals, @@ -294,6 +349,7 @@ def get_table_info( rename_map, [name_base, name_ref], meta.get("description"), + numeric_vals, ) @@ -318,7 +374,14 @@ def table_metadata(table_name, client, meta=None): def make_class_vals( - pts, val_cols, unbd_pts, table_map, rename_map, table_list, raw_points=False + pts, + val_cols, + unbd_pts, + table_map, + rename_map, + table_list, + numeric_vals, + raw_points=False, ): class_vals = { "_reference_table": attrs.field( @@ -343,6 +406,7 @@ def make_class_vals( metadata={ "table": table_map[val], "original_name": rename_map.get(val, val), + "is_numeric": val in numeric_vals, }, ) for pt in pts + unbd_pts: @@ -393,10 +457,10 @@ def __attrs_post_init__(self): tn: filter_empty( attrs.asdict( self, - filter=lambda a, v: not is_list_like(v) + filter=lambda a, v: is_equal_like(v) and v is not None - and a.metadata.get("is_bbox", False) == False # noqa E712 - and a.metadata.get("is_meta", False) == False # noqa E712 + and a.metadata.get("is_bbox", False) is False # noqa E712 + and a.metadata.get("is_meta", False) is False # noqa E712 and a.metadata.get("table") == tn, ) ) @@ -408,10 +472,10 @@ def __attrs_post_init__(self): tn: filter_empty( attrs.asdict( self, - filter=lambda a, v: is_list_like(v) + filter=lambda a, v: is_isin_like(v) and v is not None - and a.metadata.get("is_bbox", False) == False # noqa E712 - and a.metadata.get("is_meta", False) == False # noqa E712 + and a.metadata.get("is_bbox", False) is False # noqa E712 + and a.metadata.get("is_meta", False) is False # noqa E712 and a.metadata.get("table") == tn, ) ) @@ -419,13 +483,81 @@ def __attrs_post_init__(self): } filter_in_dict = rename_fields(filter_in_dict, self) + filter_lt_dict = { + tn: filter_empty( + attrs.asdict( + self, + filter=lambda a, v: is_lessthan(v) + and v is not None + and a.metadata.get("is_bbox", False) is False + and a.metadata.get("is_meta", False) is False + and a.metadata.get("is_numeric", False) is True + and a.metadata.get("table") == tn, + value_serializer=lambda _, __, v: v.get("<"), + ) + ) + for tn in tables + } + filter_lt_dict = rename_fields(filter_lt_dict, self) + + filter_gt_dict = { + tn: filter_empty( + attrs.asdict( + self, + filter=lambda a, v: is_greaterthan(v) + and v is not None + and a.metadata.get("is_bbox", False) is False + and a.metadata.get("is_meta", False) is False + and a.metadata.get("is_numeric", False) is True + and a.metadata.get("table") == tn, + value_serializer=lambda _, __, v: v.get(">"), + ) + ) + for tn in tables + } + filter_gt_dict = rename_fields(filter_gt_dict, self) + + filter_geq_dict = { + tn: filter_empty( + attrs.asdict( + self, + filter=lambda a, v: is_greaterthan_equal(v) + and v is not None + and a.metadata.get("is_bbox", False) is False + and a.metadata.get("is_meta", False) is False + and a.metadata.get("is_numeric", False) is True + and a.metadata.get("table") == tn, + value_serializer=lambda _, __, v: v.get(">="), + ) + ) + for tn in tables + } + filter_geq_dict = rename_fields(filter_geq_dict, self) + + filter_leq_dict = { + tn: filter_empty( + attrs.asdict( + self, + filter=lambda a, v: is_lessthan_equal(v) + and v is not None + and a.metadata.get("is_bbox", False) is False + and a.metadata.get("is_meta", False) is False + and a.metadata.get("is_numeric", False) is True + and a.metadata.get("table") == tn, + value_serializer=lambda _, __, v: v.get("<="), + ) + ) + for tn in tables + } + filter_leq_dict = rename_fields(filter_leq_dict, self) + spatial_dict = { tn: update_spatial_dict( attrs.asdict( self, filter=lambda a, v: a.metadata.get("is_bbox", False) and v is not None - and a.metadata.get("is_meta", False) == False # noqa E712 + and a.metadata.get("is_meta", False) is False # noqa E712 and a.metadata.get("table") == tn, ) ) @@ -433,6 +565,32 @@ def __attrs_post_init__(self): } spatial_dict = rename_fields(spatial_dict, self) + invalid_inequality_queries = { + tn: filter_empty( + attrs.asdict( + self, + filter=lambda a, v: ( + is_greaterthan(v) + or is_greaterthan_equal(v) + or is_lessthan(v) + or is_lessthan_equal(v) + ) + and a.metadata.get("table") == tn + and a.metadata.get("is_numeric", False) is False, + ) + ) + for tn in tables + } + + if sum([len(v) for k, v in invalid_inequality_queries.items()]) > 0: + bad_fields = [ + list(v.keys()) + for k, v in invalid_inequality_queries.items() + if len(v) > 0 + ] + msg = f"Cannot use inequality for non-numeric fields: {list(chain.from_iterable(bad_fields))}" + raise InvalidInequalityException(msg) + self.filter_kwargs_live = { "filter_equal_dict": replace_empty_with_none( filter_empty(filter_equal_dict) @@ -441,6 +599,18 @@ def __attrs_post_init__(self): "filter_spatial_dict": replace_empty_with_none( filter_empty(spatial_dict) ), + "filter_greater_dict": replace_empty_with_none( + filter_empty(filter_gt_dict) + ), + "filter_less_dict": replace_empty_with_none( + filter_empty(filter_lt_dict) + ), + "filter_greater_equal_dict": replace_empty_with_none( + filter_empty(filter_geq_dict) + ), + "filter_less_equal_dict": replace_empty_with_none( + filter_empty(filter_leq_dict) + ), } if len(tables) == 2: self.filter_kwargs_mat = self.filter_kwargs_live @@ -453,6 +623,10 @@ def __attrs_post_init__(self): "filter_equal_dict", "filter_in_dict", "filter_spatial_dict", + "filter_greater_dict", + "filter_less_dict", + "filter_greater_equal_dict", + "filter_less_equal_dict", ] if self.filter_kwargs_live[k] is not None } @@ -640,14 +814,34 @@ def make_query_filter(table_name, meta, client): rename_map, table_list, desc, + numeric_vals, ) = get_table_info(table_name, meta, client) + class_vals = make_class_vals( - pts, val_cols, all_unbd_pts, table_map, rename_map, table_list + pts, val_cols, all_unbd_pts, table_map, rename_map, table_list, numeric_vals ) QueryFilter = attrs.make_class( table_name, class_vals, bases=(make_kwargs_mixin(client),) ) QueryFilter.__doc__ = desc + setattr(QueryFilter, "query", QueryFilter().query) + setattr(QueryFilter, "live_query", QueryFilter().live_query) + + fields = [ + x.name + for x in attrs.fields(QueryFilter) + if not x.metadata.get("is_meta", False) + ] + setattr(QueryFilter, "fields", fields) + numeric_fields = [ + x.name for x in attrs.fields(QueryFilter) if x.metadata.get("is_numeric", False) + ] + setattr(QueryFilter, "numeric_fields", numeric_fields) + spatial_fields = [ + x.name for x in attrs.fields(QueryFilter) if x.metadata.get("is_bbox", False) + ] + setattr(QueryFilter, "spatial_fields", spatial_fields) + return QueryFilter @@ -662,9 +856,13 @@ def make_query_filter_view(view_name, meta, schema, client): desc, live_compatible, ) = get_view_info(view_name, meta, schema) + + numeric_vals = [k for k, v in schema.items() if v["type"] in NUMERIC_COLUMN_TYPES] + class_vals = make_class_vals( - pts, val_cols, all_unbd_pts, table_map, rename_map, table_list + pts, val_cols, all_unbd_pts, table_map, rename_map, table_list, numeric_vals ) + ViewQueryFilter = attrs.make_class( view_name, class_vals, @@ -673,6 +871,28 @@ def make_query_filter_view(view_name, meta, schema, client): ), ) ViewQueryFilter.__doc__ = desc + + setattr(ViewQueryFilter, "query", ViewQueryFilter().query) + + fields = [ + x.name + for x in attrs.fields(ViewQueryFilter) + if not x.metadata.get("is_meta", False) + ] + setattr(ViewQueryFilter, "fields", fields) + numeric_fields = [ + x.name + for x in attrs.fields(ViewQueryFilter) + if x.metadata.get("is_numeric", False) + ] + setattr(ViewQueryFilter, "numeric_fields", numeric_fields) + spatial_fields = [ + x.name + for x in attrs.fields(ViewQueryFilter) + if x.metadata.get("is_bbox", False) + ] + setattr(ViewQueryFilter, "spatial_fields", spatial_fields) + return ViewQueryFilter diff --git a/docs/tutorials/materialization.md b/docs/tutorials/materialization.md index 9d1784c7..17184f1c 100644 --- a/docs/tutorials/materialization.md +++ b/docs/tutorials/materialization.md @@ -342,6 +342,21 @@ nuc_df = client.materialize.tables.nucleus_detection_v0( ).query() ``` +If you are not using any filters, you can omit the parenthesis and use the `query` +or `live_query` function directly. The first example could be rewritten as: + +```python +nuc_df = client.materialize.tables.nucleus_detection_v0.query() +``` + +If you want to list all available fields, you can use the `.fields` attribute. +Similarly, you can get all numeric fields with the `.numeric_fields` attribute +and all spatial fields (allowing bounding box queries) with `.spatial_fields`. + +```python +nuc_df = client.materialize.tables.nucleus_detection_v0.spatial_fields +``` + If you need to specify the table programmatically, you can also use a dictionary-style approach to getting the table filtering function. For example, an equivalent version of the above line would be: @@ -370,6 +385,26 @@ nuc_df = client.materialize.tables.nucleus_detection_v0( ) ``` +Inequalities can also be used in filtering numeric columns. +Here, you can pass a dictionary instead of a list of values, with the keys being +inequality operators (">", ">=", "<", and "<=") and the values being the comparison. +For example, to query for all nuclei with a volume greater than 1000: + +```python +client.materialize.tables.nucleus_detection_v0( + volume={">": 1000} +).query() +``` + +You can also use multiple inequalities in the same dictionary to filter within a range. +For example, to query for all nuclei with a volume between 500 and 750: + +```python +client.materialize.tables.nucleus_detection_v0( + volume={">": 500, "<": 750} +).query() +``` + If you want to do a live query instead of a materialized query, the filtering remains identical but we use the `live_query` function instead. The one required argument for `live_query` is the timestamp.