From 4e92bf6630c5f64f560b793979a15a6a6bea1e9f Mon Sep 17 00:00:00 2001 From: Matthew McDermott Date: Sun, 18 Aug 2024 22:36:00 -0400 Subject: [PATCH] started some over-complicated and not yet fully validated code for providing schema validators and more consistent handling of optional types. --- src/meds/__init__.py | 7 +- src/meds/data_schema.py | 89 +++++++++++++++++++++++ src/meds/polars_support.py | 135 +++++++++++++++++++++++++++++++++++ src/meds/schema.py | 7 -- src/meds/shared_constants.py | 16 +++++ src/meds/utils.py | 125 ++++++++++++++++++++++++++++++++ 6 files changed, 370 insertions(+), 9 deletions(-) create mode 100644 src/meds/data_schema.py create mode 100644 src/meds/polars_support.py create mode 100644 src/meds/shared_constants.py create mode 100644 src/meds/utils.py diff --git a/src/meds/__init__.py b/src/meds/__init__.py index 9cde9d7..206f014 100644 --- a/src/meds/__init__.py +++ b/src/meds/__init__.py @@ -1,15 +1,18 @@ from meds._version import __version__ # noqa +from .shared_constants import ( + subject_id_field, time_field, code_field, numeric_value_field, subject_id_dtype, time_dtype, code_dtype, + numeric_value_dtype, birth_code, death_code +) + from .schema import ( CodeMetadata, DatasetMetadata, Label, - birth_code, code_field, code_metadata_schema, data_schema, dataset_metadata_schema, - death_code, held_out_split, label_schema, subject_id_dtype, diff --git a/src/meds/data_schema.py b/src/meds/data_schema.py new file mode 100644 index 0000000..375c1cc --- /dev/null +++ b/src/meds/data_schema.py @@ -0,0 +1,89 @@ +"""The data schema for the MEDS format. + +Please see the README for more information, including expected file organization on disk, more details on what +this schema should capture, etc. + +The data schema. + +MEDS data also must satisfy two important properties: + +1. Data about a single subject cannot be split across parquet files. + If a subject is in a dataset it must be in one and only one parquet file. +2. Data about a single subject must be contiguous within a particular parquet file and sorted by time. + +Both of these restrictions allow the stream rolling processing +(https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.rolling.html) +which vastly simplifies many data analysis pipelines. + +No python type is provided because Python tools for processing MEDS data will often provide their own types. +See https://github.com/EthanSteinberg/meds_reader/blob/0.0.6/src/meds_reader/__init__.pyi#L55 for example. +""" + +import pyarrow as pa +from .shared_codes import ( + subject_id_field, time_field, code_field, subject_id_dtype, time_dtype, code_dtype, numeric_value_field, + numeric_value_dtype +) + +from .utils import CustomizableSchemaFntr + +MANDATORY_FIELDS = { + subject_id_field: subject_id_dtype, + time_field: time_dtype, + code_field: code_dtype, + numeric_value_field: numeric_value_dtype, +} + +OPTIONAL_FIELDS = { + "categorical_value": pa.string(), + "text_value": pa.string(), +} + +# This returns a function that will create a data schema with the mandatory fields and any custom fields you +# specify. Types are guaranteed to match optional field types if the names align. +data_schema = CustomizableSchemaFntr(MANDATORY_FIELDS, OPTIONAL_FIELDS) + + +def convert_and_validate_schema_fntr( + do_cast_types: bool | dict[str, bool] = True, + do_add_missing_fields: bool = True, +) -> Callable[[ + df: DF_T, + schema: pa.Schema, + if isinstance(df, pa.Table): + # handle pa.Table + + + + +def get_and_validate_data_schema(df: pl.LazyFrame, stage_cfg: DictConfig) -> pa.Table: + do_retype = stage_cfg.get("do_retype", True) + schema = df.collect_schema() + errors = [] + for col, dtype in MEDS_DATA_MANDATORY_TYPES.items(): + if col in schema and schema[col] != dtype: + if do_retype: + df = df.with_columns(pl.col(col).cast(dtype, strict=False)) + else: + errors.append(f"MEDS Data '{col}' column must be of type {dtype}. Got {schema[col]}.") + elif col not in schema: + if col in ("numeric_value", "time") and do_retype: + df = df.with_columns(pl.lit(None, dtype=dtype).alias(col)) + else: + errors.append(f"MEDS Data DataFrame must have a '{col}' column of type {dtype}.") + + if errors: + raise ValueError("\n".join(errors)) + + additional_cols = [col for col in schema if col not in MEDS_DATA_MANDATORY_TYPES] + + if additional_cols: + extra_schema = df.head(0).select(additional_cols).collect().to_arrow().schema + measurement_properties = list(zip(extra_schema.names, extra_schema.types)) + df = df.select(*MEDS_DATA_MANDATORY_TYPES.keys(), *additional_cols) + else: + df = df.select(*MEDS_DATA_MANDATORY_TYPES.keys()) + measurement_properties = [] + + validated_schema = data_schema(measurement_properties) + return df.collect().to_arrow().cast(validated_schema) diff --git a/src/meds/polars_support.py b/src/meds/polars_support.py new file mode 100644 index 0000000..7e225f9 --- /dev/null +++ b/src/meds/polars_support.py @@ -0,0 +1,135 @@ +try: + import polars as pl + + if pl.__version__ < "1.0.0": + raise ImportError("polars version must be >= 1.0.0 for these utilities") + + import pyarrow as pa + + DF_TYPES = (pl.DataFrame, pl.LazyFrame) + + def _convert_and_validate_schema( + df: Union[*PL_DF_TYPES], + mandatory_columns: dict[str, pa.DataType], + optional_columns: dict[str, pa.DataType], + do_allow_extra_columns: bool, + do_cast_types: dict[str, bool], + do_add_missing_mandatory_fields: bool, + do_reorder_columns: bool, + schema_validator: Callable[[list[tuple[str, pa.DataType]]], pa.Schema] | None = None, + ) -> pa.Table: + """ + This function converts a DataFrame to an Arrow Table and validates that it has the correct schema. + + Args: + df: The polars DataFrame or LazyFrame to convert and validate. + mandatory_columns: A dictionary of mandatory columns and their types. + optional_columns: A dictionary of optional columns and their types. Optional columns need not be + present in the DataFrame but, if they are present, they must have the correct type. + do_allow_extra_columns: Whether to allow extra columns in the DataFrame. + do_cast_types: Whether it is permissible to cast individual columns to the correct types. This + parameter must be specified as a dictionary mapping column name to whether it is permissible + to cast that column. + do_add_missing_mandatory_fields: Whether it is permissible to add missing mandatory fields to the + DataFrame with null values. If `False`, any missing values will result in an error. + do_reorder_columns: Whether it is permissible + schema_validator: A function that takes a list of tuples of all additional (beyond the mandatory) + column names and types and returns a PyArrow Schema object for the table, if a valid schema + exists with the passed columns. + """ + + # If it is not a lazyframe, make it one. + df = df.lazy() + + schema = df.collect_schema() + + typed_pa_df = pa.Table.from_pylist( + [], schema=pa.schema(list(mandatory_columns.items()) + list(optional_columns.items())) + ) + target_pl_schema = pl.from_arrow(typed_pa_df).collect_schema() + + errors = [] + for col in mandatory_columns: + target_dtype = target_pl_schema[col] + if col in schema: + if target_dtype != schema[col]: + if do_cast_types[col]: + df = df.with_columns(pl.col(col).cast(target_dtype)) + else: + errors.append( + f"Column '{col}' must be of type {target_dtype}. Got {schema[col]} instead." + ) + elif do_add_missing_mandatory_fields: + df = df.with_columns(pl.lit(None, dtype=target_dtype).alias(col)) + else: + errors.append(f"Missing mandatory column '{col}' of type {target_dtype}.") + + for col in optional_columns: + if col not in schema: + continue + + target_dtype = target_pl_schema[col] + if target_dtype != schema[col]: + if do_cast_types[col]: + df = df.with_columns(pl.col(col).cast(target_dtype)) + else: + errors.append( + f"Optional column '{col}' must be of type {target_dtype} if included. " + f"Got {schema[col]} instead." + ) + + type_specific_columns = set(mandatory_columns.keys()) | set(optional_columns.keys()) + additional_cols = [col for col in schema if col not in type_specific_columns] + + if additional_cols and not do_allow_extra_columns: + errors.append(f"Found unexpected columns: {additional_cols}") + + if errors: + raise ValueError("\n".join(errors)) + + default_pa_schema = df.head(0).collect().to_arrow().schema + + optional_properties = [] + for col in schema: + if col in mandatory_columns: + continue + + if col in optional_columns: + optional_properties.append((col, optional_columns[col])) + else: + optional_properties.append((col, default_pa_schema[col])) + + if schema_validator is not None: + validated_schema = schema_validator(optional_properties) + else: + validated_schema = pa.schema(list(mandatory_columns.items()) + optional_properties) + + schema_order = validated_schema.names + + extra_cols = set(df.columns) - set(schema_order) + if extra_cols: + raise ValueError(f"Found unexpected columns: {extra_cols}") + + if schema_order != df.columns: + if do_reorder_columns: + df = df.select(schema_order) + else: + raise ValueError(f"Column order must be {schema_order}. Got {df.columns} instead.") + + return df.collect().to_arrow().cast(validated_schema) + + +except ImportError: + DF_TYPES = tuple() + + def _convert_and_validate_schema( + df: Any, + mandatory_columns: dict[str, pa.DataType], + optional_columns: dict[str, pa.DataType], + do_allow_extra_columns: bool, + do_cast_types: dict[str, bool], + do_add_missing_mandatory_fields: bool, + do_reorder_columns: bool, + schema_validator: Callable[[list[tuple[str, pa.DataType]]], pa.Schema] | None = None, + ) -> pa.Table: + raise NotImplementedError("polars is not installed") diff --git a/src/meds/schema.py b/src/meds/schema.py index 6fdafee..053a428 100644 --- a/src/meds/schema.py +++ b/src/meds/schema.py @@ -23,14 +23,7 @@ # which vastly simplifies many data analysis pipelines. # We define some codes for particularly important events -birth_code = "MEDS_BIRTH" -death_code = "MEDS_DEATH" -subject_id_field = "subject_id" -time_field = "time" -code_field = "code" - -subject_id_dtype = pa.int64() def data_schema(custom_properties=[]): diff --git a/src/meds/shared_constants.py b/src/meds/shared_constants.py new file mode 100644 index 0000000..afe353a --- /dev/null +++ b/src/meds/shared_constants.py @@ -0,0 +1,16 @@ +"""Shared constants for the MEDS schema.""" + +# Field names and types that are shared across the MEDS schema. +subject_id_field = "subject_id" +time_field = "time" +code_field = "code" +numeric_value_field = "numeric_value" + +subject_id_dtype = pa.int64() +time_dtype = pa.timestamp("us") +code_dtype = pa.string() +numeric_value_dtype = pa.float32() + +# Canonical codes for select events. +birth_code = "MEDS_BIRTH" +death_code = "MEDS_DEATH" diff --git a/src/meds/utils.py b/src/meds/utils.py new file mode 100644 index 0000000..179bdb0 --- /dev/null +++ b/src/meds/utils.py @@ -0,0 +1,125 @@ +"""Utilities for specifying and working with MEDS schemas.""" + +import pyarrow as pa +from typing import Callable + +SCHEMA_DICT_T = dict[str, pa.DataType] + +from .polars_support import ( + DF_TYPES as PL_DF_TYPES, + _convert_and_validate_schema as pl_convert_and_validate_schema, +) + +from .pyarrow_support import ( + DF_TYPES as PA_DF_TYPES, + _convert_and_validate_schema as pa_convert_and_validate_schema, +) + +PA_DF_TYPES = (pa.Table,) + +__DF_VALIDATORS = [ + (PL_DF_TYPES, pl_convert_and_validate_schema), + #(PD_DF_TYPES, pd_convert_and_validate_schema), +] + +DF_T = TypeVar("DF_T", *(PA_DF_TYPES + PL_DF_TYPES)) + +def convert_and_validate_schema( + df: DF_T, + schema: pa.Schema | CustomizableSchemaFntr, + do_cast_types: bool | dict[str, bool] = True, + do_add_missing_mandatory_fields: bool = True, + do_reorder_columns: bool = True, +) -> pa.Table: + """TODO + """ + + match schema: + case CustomizableSchemaFntr(): + mandatory_columns = schema.mandatory_fields + optional_columns = schema.optional_fields + do_allow_extra_columns = True + case pa.Schema(): + mandatory_columns = {k: dt for k, dt in zip(schema.names, schema.types)} + optional_columns = {} + do_allow_extra_columns = False + case _: + raise ValueError("Schema must be a CustomizableSchemaFntr or a pa.Schema.") + + all_columns = set(mandatory_columns.keys()) | set(optional_columns.keys()) + if isinstance(do_cast_types, bool): + do_cast_types = {col: do_cast_types for col in all_columns} + elif not isinstance(do_cast_types, dict): + raise ValueError("do_cast_types must be a bool or a dict.") + elif not all(col in all_columns for col in do_cast_types): + missing_cols = sorted(list(set(all_columns) - set(do_cast_types))) + raise ValueError( + "If it is a dict, do_cast_types must have a key for every column in the schema. Missing " + f"columns: {', '.join(missing_cols)}." + ) + elif not all(type(v) is bool for v in do_cast_types.values()): + invalid_types = {col: v for col, v in do_cast_types.items() if type(v) is not bool} + raise ValueError(f"do_cast_types values must be bools. Got invalid types: {invalid_types}.") + + kwargs = { + "mandatory_columns": mandatory_columns, + "optional_columns": optional_columns, + "do_allow_extra_columns": do_allow_extra_columns, + "do_cast_types": do_cast_types, + "do_add_missing_mandatory_fields": do_add_missing_mandatory_fields, + "do_reorder_columns": do_reorder_columns + } + + for df_types, validator_fn in __DF_VALIDATORS: + if isinstance(df, df_types): + return validator_fn(df, **kwargs) + + if not isinstance(df, PA_DF_TYPES): + raise ValueError(f"DataFrame must be one of the allowed types. Got {type(df)}.") + + return pa_convert_and_validate_schema(df, **kwargs) + + +class CustomizableSchemaFntr(object): + def __init__(self, mandatory_fields: SCHEMA_DICT_T, optional_fields: SCHEMA_DICT_T): + """Returns a member object that can be called to create a schema with custom properties. + + The returned schema will contain, in entry order, all the fields specified in `mandatory_fields` + followed by any custom properties specified in the function call. Any custom properties must not share + a name with a mandatory field, and must have the correct type if they share a name with an optional + field. + + Args: + mandatory_fields: The mandatory fields for the schema. + optional_fields: The optional fields for the schema. + + Returns: + A function that can be used to create a schema with custom properties. + + Raises: + ValueError: If a custom property conflicts with a mandatory field or is an optional field but has the + wrong type. + + Examples: + >>> raise NotImplementedError("doctests should fail") + """ + self.mandatory_fields = mandatory_fields + self.optional_fields = optional_fields + + def __call__( + self, custom_properties: list[tuple[str, pa.DataType]] | SCHEMA_DICT_T | None = None + ) -> pa.Schema: + """Returns the final, cutomized schema for the specified format.""" + + for field, dtype in custom_properties: + if field in self.mandatory_fields: + raise ValueError(f"Custom property {field} conflicts with a mandatory field.") + if field in self.optional_fields and dtype != self.optional_fields[field]: + raise ValueError(f"Custom property {field} must be of type {optional_fields[field]}.") + + if custom_properties is None: + custom_properties = [] + elif isinstance(custom_properties, dict): + custom_properties = list(custom_properties.items()) + + return pa.schema(list(self.mandatory_fields.items()) + custom_properties)