diff --git a/src/meds/schema.py b/src/meds/schema.py index 1c9608c..3689941 100644 --- a/src/meds/schema.py +++ b/src/meds/schema.py @@ -1,5 +1,5 @@ import datetime -from typing import Any, List, Mapping +from typing import Any, List, Mapping, Optional import pyarrow as pa from typing_extensions import NotRequired, TypedDict @@ -73,12 +73,22 @@ def patient_schema(per_event_metadata_schema=pa.null()): ("patient_id", pa.int64()), ("prediction_time", pa.timestamp("us")), ("boolean_value", pa.bool_()), + ("integer_value", pa.int64()), + ("float_value", pa.float64()), + ("categorical_value", pa.string()), ] ) # Python types for the above schema -Label = TypedDict("Label", {"patient_id": int, "prediction_time": datetime.datetime, "boolean_value": bool}) +Label = TypedDict("Label", { + "patient_id": int, + "prediction_time": datetime.datetime, + "boolean_value": Optional[bool], + "integer_value" : Optional[int], + "float_value" : Optional[float], + "categorical_value" : Optional[str], +}) ############################################################ diff --git a/tests/test_schema.py b/tests/test_schema.py index 2b9559b..ce27f9b 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -50,7 +50,36 @@ def test_label_schema(): "boolean_value": True } ] + label_table = pa.Table.from_pylist(label_data, schema=label) + assert label_table.schema.equals(label), "Label schema does not match" + label_data = [ + { + "patient_id": 123, + "prediction_time": datetime.datetime(2020, 1, 1, 12, 0, 0), + "integer_value": 4 + } + ] + label_table = pa.Table.from_pylist(label_data, schema=label) + assert label_table.schema.equals(label), "Label schema does not match" + + label_data = [ + { + "patient_id": 123, + "prediction_time": datetime.datetime(2020, 1, 1, 12, 0, 0), + "float_value": 0.4 + } + ] + label_table = pa.Table.from_pylist(label_data, schema=label) + assert label_table.schema.equals(label), "Label schema does not match" + + label_data = [ + { + "patient_id": 123, + "prediction_time": datetime.datetime(2020, 1, 1, 12, 0, 0), + "categorical_value": "text" + } + ] label_table = pa.Table.from_pylist(label_data, schema=label) assert label_table.schema.equals(label), "Label schema does not match"