Skip to content

Commit

Permalink
started some over-complicated and not yet fully validated code for pr…
Browse files Browse the repository at this point in the history
…oviding schema validators and more consistent handling of optional types.
  • Loading branch information
mmcdermott committed Aug 19, 2024
1 parent d4be859 commit 4e92bf6
Show file tree
Hide file tree
Showing 6 changed files with 370 additions and 9 deletions.
7 changes: 5 additions & 2 deletions src/meds/__init__.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,18 @@
from meds._version import __version__ # noqa

from .shared_constants import (
subject_id_field, time_field, code_field, numeric_value_field, subject_id_dtype, time_dtype, code_dtype,
numeric_value_dtype, birth_code, death_code
)

from .schema import (
CodeMetadata,
DatasetMetadata,
Label,
birth_code,
code_field,
code_metadata_schema,
data_schema,
dataset_metadata_schema,
death_code,
held_out_split,
label_schema,
subject_id_dtype,
Expand Down
89 changes: 89 additions & 0 deletions src/meds/data_schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
"""The data schema for the MEDS format.
Please see the README for more information, including expected file organization on disk, more details on what
this schema should capture, etc.
The data schema.
MEDS data also must satisfy two important properties:
1. Data about a single subject cannot be split across parquet files.
If a subject is in a dataset it must be in one and only one parquet file.
2. Data about a single subject must be contiguous within a particular parquet file and sorted by time.
Both of these restrictions allow the stream rolling processing
(https://docs.pola.rs/api/python/stable/reference/dataframe/api/polars.DataFrame.rolling.html)
which vastly simplifies many data analysis pipelines.
No python type is provided because Python tools for processing MEDS data will often provide their own types.
See https://github.com/EthanSteinberg/meds_reader/blob/0.0.6/src/meds_reader/__init__.pyi#L55 for example.
"""

import pyarrow as pa
from .shared_codes import (
subject_id_field, time_field, code_field, subject_id_dtype, time_dtype, code_dtype, numeric_value_field,
numeric_value_dtype
)

from .utils import CustomizableSchemaFntr

MANDATORY_FIELDS = {
subject_id_field: subject_id_dtype,
time_field: time_dtype,
code_field: code_dtype,
numeric_value_field: numeric_value_dtype,
}

OPTIONAL_FIELDS = {
"categorical_value": pa.string(),
"text_value": pa.string(),
}

# This returns a function that will create a data schema with the mandatory fields and any custom fields you
# specify. Types are guaranteed to match optional field types if the names align.
data_schema = CustomizableSchemaFntr(MANDATORY_FIELDS, OPTIONAL_FIELDS)


def convert_and_validate_schema_fntr(
do_cast_types: bool | dict[str, bool] = True,
do_add_missing_fields: bool = True,
) -> Callable[[
df: DF_T,
schema: pa.Schema,
if isinstance(df, pa.Table):
# handle pa.Table




def get_and_validate_data_schema(df: pl.LazyFrame, stage_cfg: DictConfig) -> pa.Table:
do_retype = stage_cfg.get("do_retype", True)
schema = df.collect_schema()
errors = []
for col, dtype in MEDS_DATA_MANDATORY_TYPES.items():
if col in schema and schema[col] != dtype:
if do_retype:
df = df.with_columns(pl.col(col).cast(dtype, strict=False))
else:
errors.append(f"MEDS Data '{col}' column must be of type {dtype}. Got {schema[col]}.")
elif col not in schema:
if col in ("numeric_value", "time") and do_retype:
df = df.with_columns(pl.lit(None, dtype=dtype).alias(col))
else:
errors.append(f"MEDS Data DataFrame must have a '{col}' column of type {dtype}.")

if errors:
raise ValueError("\n".join(errors))

additional_cols = [col for col in schema if col not in MEDS_DATA_MANDATORY_TYPES]

if additional_cols:
extra_schema = df.head(0).select(additional_cols).collect().to_arrow().schema
measurement_properties = list(zip(extra_schema.names, extra_schema.types))
df = df.select(*MEDS_DATA_MANDATORY_TYPES.keys(), *additional_cols)
else:
df = df.select(*MEDS_DATA_MANDATORY_TYPES.keys())
measurement_properties = []

validated_schema = data_schema(measurement_properties)
return df.collect().to_arrow().cast(validated_schema)
135 changes: 135 additions & 0 deletions src/meds/polars_support.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
try:
import polars as pl

if pl.__version__ < "1.0.0":
raise ImportError("polars version must be >= 1.0.0 for these utilities")

import pyarrow as pa

DF_TYPES = (pl.DataFrame, pl.LazyFrame)

def _convert_and_validate_schema(
df: Union[*PL_DF_TYPES],
mandatory_columns: dict[str, pa.DataType],
optional_columns: dict[str, pa.DataType],
do_allow_extra_columns: bool,
do_cast_types: dict[str, bool],
do_add_missing_mandatory_fields: bool,
do_reorder_columns: bool,
schema_validator: Callable[[list[tuple[str, pa.DataType]]], pa.Schema] | None = None,
) -> pa.Table:
"""
This function converts a DataFrame to an Arrow Table and validates that it has the correct schema.
Args:
df: The polars DataFrame or LazyFrame to convert and validate.
mandatory_columns: A dictionary of mandatory columns and their types.
optional_columns: A dictionary of optional columns and their types. Optional columns need not be
present in the DataFrame but, if they are present, they must have the correct type.
do_allow_extra_columns: Whether to allow extra columns in the DataFrame.
do_cast_types: Whether it is permissible to cast individual columns to the correct types. This
parameter must be specified as a dictionary mapping column name to whether it is permissible
to cast that column.
do_add_missing_mandatory_fields: Whether it is permissible to add missing mandatory fields to the
DataFrame with null values. If `False`, any missing values will result in an error.
do_reorder_columns: Whether it is permissible
schema_validator: A function that takes a list of tuples of all additional (beyond the mandatory)
column names and types and returns a PyArrow Schema object for the table, if a valid schema
exists with the passed columns.
"""

# If it is not a lazyframe, make it one.
df = df.lazy()

schema = df.collect_schema()

typed_pa_df = pa.Table.from_pylist(
[], schema=pa.schema(list(mandatory_columns.items()) + list(optional_columns.items()))
)
target_pl_schema = pl.from_arrow(typed_pa_df).collect_schema()

errors = []
for col in mandatory_columns:
target_dtype = target_pl_schema[col]
if col in schema:
if target_dtype != schema[col]:
if do_cast_types[col]:
df = df.with_columns(pl.col(col).cast(target_dtype))
else:
errors.append(
f"Column '{col}' must be of type {target_dtype}. Got {schema[col]} instead."
)
elif do_add_missing_mandatory_fields:
df = df.with_columns(pl.lit(None, dtype=target_dtype).alias(col))
else:
errors.append(f"Missing mandatory column '{col}' of type {target_dtype}.")

for col in optional_columns:
if col not in schema:
continue

target_dtype = target_pl_schema[col]
if target_dtype != schema[col]:
if do_cast_types[col]:
df = df.with_columns(pl.col(col).cast(target_dtype))
else:
errors.append(
f"Optional column '{col}' must be of type {target_dtype} if included. "
f"Got {schema[col]} instead."
)

type_specific_columns = set(mandatory_columns.keys()) | set(optional_columns.keys())
additional_cols = [col for col in schema if col not in type_specific_columns]

if additional_cols and not do_allow_extra_columns:
errors.append(f"Found unexpected columns: {additional_cols}")

if errors:
raise ValueError("\n".join(errors))

default_pa_schema = df.head(0).collect().to_arrow().schema

optional_properties = []
for col in schema:
if col in mandatory_columns:
continue

if col in optional_columns:
optional_properties.append((col, optional_columns[col]))
else:
optional_properties.append((col, default_pa_schema[col]))

if schema_validator is not None:
validated_schema = schema_validator(optional_properties)
else:
validated_schema = pa.schema(list(mandatory_columns.items()) + optional_properties)

schema_order = validated_schema.names

extra_cols = set(df.columns) - set(schema_order)
if extra_cols:
raise ValueError(f"Found unexpected columns: {extra_cols}")

if schema_order != df.columns:
if do_reorder_columns:
df = df.select(schema_order)
else:
raise ValueError(f"Column order must be {schema_order}. Got {df.columns} instead.")

return df.collect().to_arrow().cast(validated_schema)


except ImportError:
DF_TYPES = tuple()

def _convert_and_validate_schema(
df: Any,
mandatory_columns: dict[str, pa.DataType],
optional_columns: dict[str, pa.DataType],
do_allow_extra_columns: bool,
do_cast_types: dict[str, bool],
do_add_missing_mandatory_fields: bool,
do_reorder_columns: bool,
schema_validator: Callable[[list[tuple[str, pa.DataType]]], pa.Schema] | None = None,
) -> pa.Table:
raise NotImplementedError("polars is not installed")
7 changes: 0 additions & 7 deletions src/meds/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,7 @@
# which vastly simplifies many data analysis pipelines.

# We define some codes for particularly important events
birth_code = "MEDS_BIRTH"
death_code = "MEDS_DEATH"

subject_id_field = "subject_id"
time_field = "time"
code_field = "code"

subject_id_dtype = pa.int64()


def data_schema(custom_properties=[]):
Expand Down
16 changes: 16 additions & 0 deletions src/meds/shared_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""Shared constants for the MEDS schema."""

# Field names and types that are shared across the MEDS schema.
subject_id_field = "subject_id"
time_field = "time"
code_field = "code"
numeric_value_field = "numeric_value"

subject_id_dtype = pa.int64()
time_dtype = pa.timestamp("us")
code_dtype = pa.string()
numeric_value_dtype = pa.float32()

# Canonical codes for select events.
birth_code = "MEDS_BIRTH"
death_code = "MEDS_DEATH"
Loading

0 comments on commit 4e92bf6

Please sign in to comment.