-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
started some over-complicated and not yet fully validated code for pr…
…oviding schema validators and more consistent handling of optional types.
- Loading branch information
1 parent
d4be859
commit 4e92bf6
Showing
6 changed files
with
370 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
"""The data schema for the MEDS format. | ||
Please see the README for more information, including expected file organization on disk, more details on what | ||
this schema should capture, etc. | ||
The data schema. | ||
MEDS data also must satisfy two important properties: | ||
1. Data about a single subject cannot be split across parquet files. | ||
If a subject is in a dataset it must be in one and only one parquet file. | ||
2. Data about a single subject must be contiguous within a particular parquet file and sorted by time. | ||
Both of these restrictions allow the stream rolling processing | ||
(https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.rolling.html) | ||
which vastly simplifies many data analysis pipelines. | ||
No python type is provided because Python tools for processing MEDS data will often provide their own types. | ||
See https://github.com/EthanSteinberg/meds_reader/blob/0.0.6/src/meds_reader/__init__.pyi#L55 for example. | ||
""" | ||
|
||
import pyarrow as pa | ||
from .shared_codes import ( | ||
subject_id_field, time_field, code_field, subject_id_dtype, time_dtype, code_dtype, numeric_value_field, | ||
numeric_value_dtype | ||
) | ||
|
||
from .utils import CustomizableSchemaFntr | ||
|
||
MANDATORY_FIELDS = { | ||
subject_id_field: subject_id_dtype, | ||
time_field: time_dtype, | ||
code_field: code_dtype, | ||
numeric_value_field: numeric_value_dtype, | ||
} | ||
|
||
OPTIONAL_FIELDS = { | ||
"categorical_value": pa.string(), | ||
"text_value": pa.string(), | ||
} | ||
|
||
# This returns a function that will create a data schema with the mandatory fields and any custom fields you | ||
# specify. Types are guaranteed to match optional field types if the names align. | ||
data_schema = CustomizableSchemaFntr(MANDATORY_FIELDS, OPTIONAL_FIELDS) | ||
|
||
|
||
def convert_and_validate_schema_fntr( | ||
do_cast_types: bool | dict[str, bool] = True, | ||
do_add_missing_fields: bool = True, | ||
) -> Callable[[ | ||
df: DF_T, | ||
schema: pa.Schema, | ||
if isinstance(df, pa.Table): | ||
# handle pa.Table | ||
|
||
|
||
|
||
|
||
def get_and_validate_data_schema(df: pl.LazyFrame, stage_cfg: DictConfig) -> pa.Table: | ||
do_retype = stage_cfg.get("do_retype", True) | ||
schema = df.collect_schema() | ||
errors = [] | ||
for col, dtype in MEDS_DATA_MANDATORY_TYPES.items(): | ||
if col in schema and schema[col] != dtype: | ||
if do_retype: | ||
df = df.with_columns(pl.col(col).cast(dtype, strict=False)) | ||
else: | ||
errors.append(f"MEDS Data '{col}' column must be of type {dtype}. Got {schema[col]}.") | ||
elif col not in schema: | ||
if col in ("numeric_value", "time") and do_retype: | ||
df = df.with_columns(pl.lit(None, dtype=dtype).alias(col)) | ||
else: | ||
errors.append(f"MEDS Data DataFrame must have a '{col}' column of type {dtype}.") | ||
|
||
if errors: | ||
raise ValueError("\n".join(errors)) | ||
|
||
additional_cols = [col for col in schema if col not in MEDS_DATA_MANDATORY_TYPES] | ||
|
||
if additional_cols: | ||
extra_schema = df.head(0).select(additional_cols).collect().to_arrow().schema | ||
measurement_properties = list(zip(extra_schema.names, extra_schema.types)) | ||
df = df.select(*MEDS_DATA_MANDATORY_TYPES.keys(), *additional_cols) | ||
else: | ||
df = df.select(*MEDS_DATA_MANDATORY_TYPES.keys()) | ||
measurement_properties = [] | ||
|
||
validated_schema = data_schema(measurement_properties) | ||
return df.collect().to_arrow().cast(validated_schema) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
try: | ||
import polars as pl | ||
|
||
if pl.__version__ < "1.0.0": | ||
raise ImportError("polars version must be >= 1.0.0 for these utilities") | ||
|
||
import pyarrow as pa | ||
|
||
DF_TYPES = (pl.DataFrame, pl.LazyFrame) | ||
|
||
def _convert_and_validate_schema( | ||
df: Union[*PL_DF_TYPES], | ||
mandatory_columns: dict[str, pa.DataType], | ||
optional_columns: dict[str, pa.DataType], | ||
do_allow_extra_columns: bool, | ||
do_cast_types: dict[str, bool], | ||
do_add_missing_mandatory_fields: bool, | ||
do_reorder_columns: bool, | ||
schema_validator: Callable[[list[tuple[str, pa.DataType]]], pa.Schema] | None = None, | ||
) -> pa.Table: | ||
""" | ||
This function converts a DataFrame to an Arrow Table and validates that it has the correct schema. | ||
Args: | ||
df: The polars DataFrame or LazyFrame to convert and validate. | ||
mandatory_columns: A dictionary of mandatory columns and their types. | ||
optional_columns: A dictionary of optional columns and their types. Optional columns need not be | ||
present in the DataFrame but, if they are present, they must have the correct type. | ||
do_allow_extra_columns: Whether to allow extra columns in the DataFrame. | ||
do_cast_types: Whether it is permissible to cast individual columns to the correct types. This | ||
parameter must be specified as a dictionary mapping column name to whether it is permissible | ||
to cast that column. | ||
do_add_missing_mandatory_fields: Whether it is permissible to add missing mandatory fields to the | ||
DataFrame with null values. If `False`, any missing values will result in an error. | ||
do_reorder_columns: Whether it is permissible | ||
schema_validator: A function that takes a list of tuples of all additional (beyond the mandatory) | ||
column names and types and returns a PyArrow Schema object for the table, if a valid schema | ||
exists with the passed columns. | ||
""" | ||
|
||
# If it is not a lazyframe, make it one. | ||
df = df.lazy() | ||
|
||
schema = df.collect_schema() | ||
|
||
typed_pa_df = pa.Table.from_pylist( | ||
[], schema=pa.schema(list(mandatory_columns.items()) + list(optional_columns.items())) | ||
) | ||
target_pl_schema = pl.from_arrow(typed_pa_df).collect_schema() | ||
|
||
errors = [] | ||
for col in mandatory_columns: | ||
target_dtype = target_pl_schema[col] | ||
if col in schema: | ||
if target_dtype != schema[col]: | ||
if do_cast_types[col]: | ||
df = df.with_columns(pl.col(col).cast(target_dtype)) | ||
else: | ||
errors.append( | ||
f"Column '{col}' must be of type {target_dtype}. Got {schema[col]} instead." | ||
) | ||
elif do_add_missing_mandatory_fields: | ||
df = df.with_columns(pl.lit(None, dtype=target_dtype).alias(col)) | ||
else: | ||
errors.append(f"Missing mandatory column '{col}' of type {target_dtype}.") | ||
|
||
for col in optional_columns: | ||
if col not in schema: | ||
continue | ||
|
||
target_dtype = target_pl_schema[col] | ||
if target_dtype != schema[col]: | ||
if do_cast_types[col]: | ||
df = df.with_columns(pl.col(col).cast(target_dtype)) | ||
else: | ||
errors.append( | ||
f"Optional column '{col}' must be of type {target_dtype} if included. " | ||
f"Got {schema[col]} instead." | ||
) | ||
|
||
type_specific_columns = set(mandatory_columns.keys()) | set(optional_columns.keys()) | ||
additional_cols = [col for col in schema if col not in type_specific_columns] | ||
|
||
if additional_cols and not do_allow_extra_columns: | ||
errors.append(f"Found unexpected columns: {additional_cols}") | ||
|
||
if errors: | ||
raise ValueError("\n".join(errors)) | ||
|
||
default_pa_schema = df.head(0).collect().to_arrow().schema | ||
|
||
optional_properties = [] | ||
for col in schema: | ||
if col in mandatory_columns: | ||
continue | ||
|
||
if col in optional_columns: | ||
optional_properties.append((col, optional_columns[col])) | ||
else: | ||
optional_properties.append((col, default_pa_schema[col])) | ||
|
||
if schema_validator is not None: | ||
validated_schema = schema_validator(optional_properties) | ||
else: | ||
validated_schema = pa.schema(list(mandatory_columns.items()) + optional_properties) | ||
|
||
schema_order = validated_schema.names | ||
|
||
extra_cols = set(df.columns) - set(schema_order) | ||
if extra_cols: | ||
raise ValueError(f"Found unexpected columns: {extra_cols}") | ||
|
||
if schema_order != df.columns: | ||
if do_reorder_columns: | ||
df = df.select(schema_order) | ||
else: | ||
raise ValueError(f"Column order must be {schema_order}. Got {df.columns} instead.") | ||
|
||
return df.collect().to_arrow().cast(validated_schema) | ||
|
||
|
||
except ImportError: | ||
DF_TYPES = tuple() | ||
|
||
def _convert_and_validate_schema( | ||
df: Any, | ||
mandatory_columns: dict[str, pa.DataType], | ||
optional_columns: dict[str, pa.DataType], | ||
do_allow_extra_columns: bool, | ||
do_cast_types: dict[str, bool], | ||
do_add_missing_mandatory_fields: bool, | ||
do_reorder_columns: bool, | ||
schema_validator: Callable[[list[tuple[str, pa.DataType]]], pa.Schema] | None = None, | ||
) -> pa.Table: | ||
raise NotImplementedError("polars is not installed") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
"""Shared constants for the MEDS schema.""" | ||
|
||
# Field names and types that are shared across the MEDS schema. | ||
subject_id_field = "subject_id" | ||
time_field = "time" | ||
code_field = "code" | ||
numeric_value_field = "numeric_value" | ||
|
||
subject_id_dtype = pa.int64() | ||
time_dtype = pa.timestamp("us") | ||
code_dtype = pa.string() | ||
numeric_value_dtype = pa.float32() | ||
|
||
# Canonical codes for select events. | ||
birth_code = "MEDS_BIRTH" | ||
death_code = "MEDS_DEATH" |
Oops, something went wrong.