Skip to content

Commit

Permalink
Merge pull request #18 from Miking98/feat/label
Browse files Browse the repository at this point in the history
Add non-bool Label value options
  • Loading branch information
EthanSteinberg authored Apr 3, 2024
2 parents bc0de46 + cf7ffaf commit b5a62c6
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
14 changes: 12 additions & 2 deletions src/meds/schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import datetime
from typing import Any, List, Mapping
from typing import Any, List, Mapping, Optional

import pyarrow as pa
from typing_extensions import NotRequired, TypedDict
Expand Down Expand Up @@ -73,12 +73,22 @@ def patient_schema(per_event_metadata_schema=pa.null()):
("patient_id", pa.int64()),
("prediction_time", pa.timestamp("us")),
("boolean_value", pa.bool_()),
("integer_value", pa.int64()),
("float_value", pa.float64()),
("categorical_value", pa.string()),
]
)

# Python types for the above schema

Label = TypedDict("Label", {"patient_id": int, "prediction_time": datetime.datetime, "boolean_value": bool})
Label = TypedDict("Label", {
"patient_id": int,
"prediction_time": datetime.datetime,
"boolean_value": Optional[bool],
"integer_value" : Optional[int],
"float_value" : Optional[float],
"categorical_value" : Optional[str],
})

############################################################

Expand Down
29 changes: 29 additions & 0 deletions tests/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,36 @@ def test_label_schema():
"boolean_value": True
}
]
label_table = pa.Table.from_pylist(label_data, schema=label)
assert label_table.schema.equals(label), "Label schema does not match"

label_data = [
{
"patient_id": 123,
"prediction_time": datetime.datetime(2020, 1, 1, 12, 0, 0),
"integer_value": 4
}
]
label_table = pa.Table.from_pylist(label_data, schema=label)
assert label_table.schema.equals(label), "Label schema does not match"

label_data = [
{
"patient_id": 123,
"prediction_time": datetime.datetime(2020, 1, 1, 12, 0, 0),
"float_value": 0.4
}
]
label_table = pa.Table.from_pylist(label_data, schema=label)
assert label_table.schema.equals(label), "Label schema does not match"

label_data = [
{
"patient_id": 123,
"prediction_time": datetime.datetime(2020, 1, 1, 12, 0, 0),
"categorical_value": "text"
}
]
label_table = pa.Table.from_pylist(label_data, schema=label)
assert label_table.schema.equals(label), "Label schema does not match"

Expand Down

0 comments on commit b5a62c6

Please sign in to comment.