diff --git a/src/meds/__init__.py b/src/meds/__init__.py index 9cde9d7..a96083c 100644 --- a/src/meds/__init__.py +++ b/src/meds/__init__.py @@ -5,6 +5,7 @@ DatasetMetadata, Label, birth_code, + code_dtype, code_field, code_metadata_schema, data_schema, @@ -12,9 +13,12 @@ death_code, held_out_split, label_schema, + numeric_value_dtype, + numeric_value_field, subject_id_dtype, subject_id_field, subject_split_schema, + time_dtype, time_field, train_split, tuning_split, @@ -35,6 +39,10 @@ "DatasetMetadata": DatasetMetadata, "birth_code": birth_code, "death_code": death_code, + "numeric_value_dtype": numeric_value_dtype, + "numeric_value_field": numeric_value_field, + "time_dtype": time_dtype, + "code_dtype": code_dtype, "subject_id_field": subject_id_field, "time_field": time_field, "code_field": code_field, diff --git a/src/meds/schema.py b/src/meds/schema.py index 6fdafee..9402076 100644 --- a/src/meds/schema.py +++ b/src/meds/schema.py @@ -29,17 +29,21 @@ subject_id_field = "subject_id" time_field = "time" code_field = "code" +numeric_value_field = "numeric_value" subject_id_dtype = pa.int64() +time_dtype = pa.timestamp("us") +code_dtype = pa.string() +numeric_value_dtype = pa.float32() def data_schema(custom_properties=[]): return pa.schema( [ (subject_id_field, subject_id_dtype), - (time_field, pa.timestamp("us")), # Static events will have a null timestamp - (code_field, pa.string()), - ("numeric_value", pa.float32()), + (time_field, time_dtype), # Static events will have a null timestamp + (code_field, code_dtype), + (numeric_value_field, numeric_value_dtype), ] + custom_properties )