diff --git a/src/meds/__init__.py b/src/meds/__init__.py index 8853647..3d8d36c 100644 --- a/src/meds/__init__.py +++ b/src/meds/__init__.py @@ -1,22 +1,22 @@ from meds._version import __version__ # noqa from .schema import ( - data, label, Label, train_split, tuning_split, held_out_split, patient_split, code_metadata, - dataset_metadata, CodeMetadata, DatasetMetadata, birth_code, death_code + data_schema, label_schema, Label, train_split, tuning_split, held_out_split, patient_split_schema, + code_metadata_schema, dataset_metadata_schema, CodeMetadata, DatasetMetadata, birth_code, death_code ) # List all objects that we want to export _exported_objects = { - 'data': data, - 'label': label, + 'data_schema': data_schema, + 'label_schema': label_schema, 'Label': Label, 'train_split': train_split, 'tuning_split': tuning_split, 'held_out_split': held_out_split, - 'patient_split': patient_split, - 'code_metadata': code_metadata, - 'dataset_metadata': dataset_metadata, + 'patient_split_schema': patient_split_schema, + 'code_metadata_schema': code_metadata_schema, + 'dataset_metadata_schema': dataset_metadata_schema, 'CodeMetadata': CodeMetadata, 'DatasetMetadata': DatasetMetadata, 'birth_code': birth_code, diff --git a/src/meds/schema.py b/src/meds/schema.py index a7b67b4..5b263f4 100644 --- a/src/meds/schema.py +++ b/src/meds/schema.py @@ -26,7 +26,7 @@ birth_code = "MEDS_BIRTH" death_code = "MEDS_DEATH" -def data(custom_properties=[]): +def data_schema(custom_properties=[]): return pa.schema( [ ("patient_id", pa.int64()), @@ -45,7 +45,7 @@ def data(custom_properties=[]): # including the prediction time. Exclusive prediction times are not currently supported, but if you have a use # case for them please add a GitHub issue. -label = pa.schema( +label_schema = pa.schema( [ ("patient_id", pa.int64()), # The patient who is being labeled. @@ -83,7 +83,7 @@ def data(custom_properties=[]): tuning_split = "tuning" # For ML hyperparameter tuning. Also often called "validation" or "dev". held_out_split = "held_out" # For final ML evaluation. Also often called "test". -patient_split = pa.schema( +patient_split_schema = pa.schema( [ ("patient_id", pa.int64()), ("split", pa.string()), @@ -96,7 +96,7 @@ def data(custom_properties=[]): # This is a JSON schema. -dataset_metadata = { +dataset_metadata_schema = { "type": "object", "properties": { "dataset_name": {"type": "string"}, @@ -126,7 +126,7 @@ def data(custom_properties=[]): # The code metadata schema. # This is a parquet schema. -def code_metadata(custom_per_code_properties=[]): +def code_metadata_schema(custom_per_code_properties=[]): return pa.schema( [ ("code", pa.string()), diff --git a/tests/test_schema.py b/tests/test_schema.py index 168c4e9..b945909 100644 --- a/tests/test_schema.py +++ b/tests/test_schema.py @@ -5,7 +5,8 @@ import pytest from meds import ( - data, label, dataset_metadata, patient_split, code_metadata, train_split, tuning_split, held_out_split + data_schema, label_schema, dataset_metadata_schema, patient_split_schema, code_metadata_schema, + train_split, tuning_split, held_out_split ) def test_data_schema(): @@ -23,7 +24,7 @@ def test_data_schema(): } ] - schema = data([("text_value", pa.string())]) + schema = data_schema([("text_value", pa.string())]) table = pa.Table.from_pylist(raw_data, schema=schema) assert table.schema.equals(schema), "Patient schema does not match" @@ -41,7 +42,7 @@ def test_code_metadata_schema(): } ] - schema = code_metadata() + schema = code_metadata_schema() table = pa.Table.from_pylist(code_metadata, schema=schema) assert table.schema.equals(schema), "Code metadata schema does not match" @@ -58,8 +59,8 @@ def test_patient_split_schema(): {"patient_id": 123, "split": "special"}, ] - table = pa.Table.from_pylist(patient_split_data, schema=patient_split) - assert table.schema.equals(patient_split), "Patient split schema does not match" + table = pa.Table.from_pylist(patient_split_data, schema=patient_split_schema) + assert table.schema.equals(patient_split_schema), "Patient split schema does not match" def test_label_schema(): """ @@ -73,8 +74,8 @@ def test_label_schema(): "boolean_value": True } ] - label_table = pa.Table.from_pylist(label_data, schema=label) - assert label_table.schema.equals(label), "Label schema does not match" + label_table = pa.Table.from_pylist(label_data, schema=label_schema) + assert label_table.schema.equals(label_schema), "Label schema does not match" label_data = [ { @@ -83,8 +84,8 @@ def test_label_schema(): "integer_value": 4 } ] - label_table = pa.Table.from_pylist(label_data, schema=label) - assert label_table.schema.equals(label), "Label schema does not match" + label_table = pa.Table.from_pylist(label_data, schema=label_schema) + assert label_table.schema.equals(label_schema), "Label schema does not match" label_data = [ { @@ -93,8 +94,8 @@ def test_label_schema(): "float_value": 0.4 } ] - label_table = pa.Table.from_pylist(label_data, schema=label) - assert label_table.schema.equals(label), "Label schema does not match" + label_table = pa.Table.from_pylist(label_data, schema=label_schema) + assert label_table.schema.equals(label_schema), "Label schema does not match" label_data = [ { @@ -103,8 +104,8 @@ def test_label_schema(): "categorical_value": "text" } ] - label_table = pa.Table.from_pylist(label_data, schema=label) - assert label_table.schema.equals(label), "Label schema does not match" + label_table = pa.Table.from_pylist(label_data, schema=label_schema) + assert label_table.schema.equals(label_schema), "Label schema does not match" def test_dataset_metadata_schema(): """ @@ -117,5 +118,5 @@ def test_dataset_metadata_schema(): "etl_version": "1.0", } - jsonschema.validate(instance=metadata, schema=dataset_metadata) + jsonschema.validate(instance=metadata, schema=dataset_metadata_schema) assert True, "Dataset metadata schema validation failed"