Skip to content

Commit

Permalink
Merged PR 172: Give both device and patient information for all windo…
Browse files Browse the repository at this point in the history
…ws. And enforce windows not sharing two devices or two patients.

This branch introduces a "device_patient_tuples" source type to the validated sources data structure for the iterator.

It involves reading the device_patient table and looking for overlaps between it and the requested data from the definition.

It also now brings that raw interval information into each batch request ensuring that get_data is only called within the regions that are mapped one-to-one with devices and patients. Areas where we have data but it would belong to a different patient within a window are replaced with NaN values.
  • Loading branch information
William Dixon authored and bgreer101 committed Oct 20, 2023
1 parent 9340c2c commit 92ae2fa
Show file tree
Hide file tree
Showing 8 changed files with 203 additions and 17 deletions.
28 changes: 28 additions & 0 deletions sdk/atriumdb/intervals/difference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
def list_difference(first, second):
result = []
i, j = 0, 0

while i < len(first) and j < len(second):
# Check for non-overlapping intervals and add to result
if first[i][1] <= second[j][0]:
result.append(first[i])
i += 1
continue
if second[j][1] <= first[i][0]:
j += 1
continue

# Find overlapping intervals and update first list
if first[i][0] < second[j][0]:
result.append([first[i][0], second[j][0]])
if first[i][1] <= second[j][1]:
i += 1
else:
first[i][0] = second[j][1]
j += 1

while i < len(first):
result.append(first[i])
i += 1

return result
2 changes: 1 addition & 1 deletion sdk/atriumdb/intervals/union.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def intervals_union(a, b):


def intervals_union_list(interval_list):
interval_list = [interval for interval in interval_list if interval.size > 0]
interval_list = [interval for interval in interval_list if len(interval) > 0]
if len(interval_list) == 0:
return np.array([], dtype=np.int64)

Expand Down
39 changes: 28 additions & 11 deletions sdk/atriumdb/windowing/dataset_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ def _extract_batch_info(self):
for source_id, time_ranges in sources.items():

# For all time ranges in that source
for cur_window_start_time, range_end_time in time_ranges:
for range_start_time, range_end_time in time_ranges:
cur_window_start_time = range_start_time
true_range_end_time = range_end_time

# Adjust the range_end_time to ensure it's aligned to the window_slide_ns
total_range_duration = max(range_end_time - cur_window_start_time, self.window_duration_ns)
Expand All @@ -169,10 +171,14 @@ def _extract_batch_info(self):
# If the batch reaches its maximum size, finalize this batch and start a new one
if self.max_batch_size and (current_index - batch_index_start) >= self.max_batch_size:
batch_info.append(
[source_type,
source_id,
batch_start_time,
cur_window_start_time + self.window_duration_ns])
[
source_type,
source_id,
batch_start_time,
cur_window_start_time + self.window_duration_ns,
range_start_time,
true_range_end_time,
])

batch_first_index.append(batch_index_start)

Expand All @@ -189,10 +195,14 @@ def _extract_batch_info(self):
# create a batch for them.
if (current_index - batch_index_start) > 0:
batch_info.append(
[source_type,
source_id,
batch_start_time,
cur_window_start_time + self.window_duration_ns])
[
source_type,
source_id,
batch_start_time,
cur_window_start_time + self.window_duration_ns,
range_start_time,
true_range_end_time,
])
batch_first_index.append(batch_index_start)

# Append the final window count to the batch_first_index list for future batch size math.
Expand All @@ -213,7 +223,8 @@ def _load_batch_matrix(self, idx: int):
batch_size = self.row_size + (batch_num_windows - 1) * self.slide_size

# Get the matrix
source_type, source_id, batch_start_time, batch_end_time = self.batch_info[batch_index]
source_type, source_id, batch_start_time, batch_end_time, range_start_time, range_end_time = \
self.batch_info[batch_index]

batch_matrix = np.full((len(self.measures), batch_size), np.nan)

Expand All @@ -230,6 +241,10 @@ def _load_batch_matrix(self, idx: int):
patient_id = source_id
self.current_patient_id = source_id
self.current_device_id = None
elif source_type == "device_patient_tuples":
device_id, patient_id = source_id
self.current_patient_id = patient_id
self.current_device_id = device_id
else:
raise ValueError(f"Source type must be either device_ids or patient_ids, not {source_type}")

Expand All @@ -249,8 +264,10 @@ def _load_batch_matrix(self, idx: int):
measure_filled_value_array = np.full(measure_filled_time_array.shape, np.nan)

# Fetch data for this measure and window from the SDK
data_start_time = max(range_start_time, batch_start_time)
data_end_time = min(range_end_time, batch_end_time)
_, measure_sdk_times, measure_sdk_values = self.sdk.get_data(
measure_id, batch_start_time, batch_end_time, device_id=device_id, patient_id=patient_id)
measure_id, data_start_time, data_end_time, device_id=device_id, patient_id=patient_id)

# Batch Matrix
# Convert times to indices on the matrix using vectorized operations
Expand Down
1 change: 1 addition & 0 deletions sdk/atriumdb/windowing/definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,7 @@ def add_measure(self, measure_tag, freq=None, units=None):
:param units: Units for the measure.
Only required if freq is provided.
:type units: str, optional
:raises ValueError: Raised when the measure tag is already present or when only one of freq and units is provided.
**Examples**:
Expand Down
69 changes: 69 additions & 0 deletions sdk/atriumdb/windowing/map_definition_sources.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from typing import List, Tuple

from atriumdb.intervals.difference import list_difference
from atriumdb.intervals.intersection import list_intersection
from atriumdb.intervals.union import intervals_union_list


def map_validated_sources(sources: dict, sdk) -> dict:
# Initialize the new sources dictionary with a new key "device_patient_tuples"
mapped_sources = {"device_patient_tuples": {}}

# Extract patient_ids and device_ids dictionaries from the sources dictionary
patient_ids = sources.get('patient_ids', {})
device_ids = sources.get('device_ids', {})

# Function to process ids (either patient_ids or device_ids) and update the mapped_sources dictionary
def process_ids(ids_dict, id_type):
for src_id, time_ranges in ids_dict.items():
union_ranges = []
for time_range in time_ranges:
start_time, end_time = time_range
# Fetch device_patient_data based on id_type
device_patient_data = sdk.get_device_patient_data(
patient_id_list=[src_id] if id_type == 'patient_ids' else None,
device_id_list=[src_id] if id_type == 'device_ids' else None,
start_time=start_time, end_time=end_time)
# Aggregate the time ranges based on the device and patient IDs
aggregated_ranges = aggregate_time_ranges(device_patient_data)
for (device_id, patient_id), ranges in aggregated_ranges.items():
intersected_ranges = list_intersection(ranges, [time_range])
if intersected_ranges:
key = (device_id, patient_id)
if key not in mapped_sources["device_patient_tuples"]:
mapped_sources["device_patient_tuples"][key] = intersected_ranges
else:
mapped_sources["device_patient_tuples"][key].extend(intersected_ranges)
# Update the union_ranges list for the current src_id
union_ranges.extend(intersected_ranges)

# Calculate the union of ranges and update the mapped_sources dictionary with differences for the current src_id
union_ranges = intervals_union_list(union_ranges).tolist()
for time_range in time_ranges:
difference_ranges = list_difference([time_range], union_ranges)
if difference_ranges:
if id_type not in mapped_sources:
mapped_sources[id_type] = {src_id: difference_ranges}
else:
mapped_sources[id_type][src_id] = difference_ranges

# Process patient_ids and device_ids separately
process_ids(patient_ids, 'patient_ids')
process_ids(device_ids, 'device_ids')

return mapped_sources


def aggregate_time_ranges(device_patient_data: List[Tuple[int, int, int, int]]):
result = {}
for device_id, patient_id, start_time, end_time in device_patient_data:
key = (device_id, patient_id)
if key not in result:
result[key] = []
result[key].append([start_time, end_time])

# Sort the time ranges for each unique (device_id, patient_id) pair
for key in result:
result[key].sort(key=lambda x: x[0])

return result
7 changes: 6 additions & 1 deletion sdk/atriumdb/windowing/verify_definition.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,14 @@
# along with this program. If not, see <https://www.gnu.org/licenses/>.

import warnings
from typing import List, Tuple, Dict

import yaml
import numpy as np

from atriumdb.intervals.union import intervals_union_list
from atriumdb.windowing.definition import DatasetDefinition
from atriumdb.windowing.map_definition_sources import map_validated_sources


def verify_definition(definition, sdk, gap_tolerance=None):
Expand All @@ -39,7 +42,9 @@ def verify_definition(definition, sdk, gap_tolerance=None):
# Validate sources
validated_sources = _validate_sources(definition, sdk, validated_measure_list, gap_tolerance=gap_tolerance)

return validated_measure_list, validated_sources
mapped_sources = map_validated_sources(validated_sources, sdk)

return validated_measure_list, mapped_sources


def _validate_measures(definition: DatasetDefinition, sdk):
Expand Down
55 changes: 51 additions & 4 deletions sdk/tests/test_iterator.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,16 +32,56 @@ def _test_iterator(db_type, dataset_location, connection_params):
sdk = AtriumSDK.create_dataset(
dataset_location=dataset_location, database_type=db_type, connection_params=connection_params)

# Check for the case of partial windows
partial_freq_nano = 1_000_000_000
partial_period_nano = (10 ** 18) // partial_freq_nano
partial_device_id = sdk.insert_device(device_tag="partial_device")
partial_measure_id = sdk.insert_measure(measure_tag="partial_measure", freq=partial_freq_nano, units="mV")

start_time = 1_000_000_000
num_values = 100
end_time = start_time + (num_values * partial_period_nano)
times = np.arange(start_time, end_time, partial_period_nano)
values = (np.sin(times) * 1000).astype(np.int64)
scale_m = 1 / 1000
scale_b = 0

sdk.write_data_easy(
partial_measure_id, partial_device_id, times, values, partial_freq_nano, scale_m=scale_m, scale_b=scale_b)

# Add a patient
patient_id = sdk.sql_handler.insert_patient()

# Only map half the data
half_time = int(times[(num_values // 2) + 1])
sdk.insert_device_patient_data([(partial_device_id, patient_id, start_time, half_time)])

# get definition
definition = DatasetDefinition(measures=["partial_measure"], device_ids={partial_device_id: "all"})

window_size_nano = partial_period_nano * 25
iterator = sdk.get_iterator(definition, window_size_nano, window_size_nano, num_windows_prefetch=None)

for window_i, window in enumerate(iterator):
for (measure_tag, measure_freq_nhz, measure_units), signal_dict in window.signals.items():
first_nan_idx = get_index_of_first_nan(signal_dict['values'])
first_nan_time = int(signal_dict['times'][first_nan_idx])
if first_nan_time - partial_period_nano < half_time:
assert window.patient_id == patient_id
else:
assert window.patient_id is None

# larger test
write_mit_bih_to_dataset(sdk, max_records=2, seed=42)
# Uncomment line below to recreate test files
# create_test_definition_files(sdk)

test_parameters = [
# definition, expected_device_id_type, expected_patient_id_type
(DatasetDefinition(filename="./example_data/mitbih_seed_42_all_devices.yaml"), int, type(None)),
(DatasetDefinition(filename="./example_data/mitbih_seed_42_all_patients.yaml"), type(None), int),
(DatasetDefinition(filename="./example_data/mitbih_seed_42_all_mrns.yaml"), type(None), int),
(DatasetDefinition(filename="./example_data/mitbih_seed_42_all_tags.yaml"), int, type(None)),
(DatasetDefinition(filename="./example_data/mitbih_seed_42_all_devices.yaml"), int, int),
(DatasetDefinition(filename="./example_data/mitbih_seed_42_all_patients.yaml"), int, int),
(DatasetDefinition(filename="./example_data/mitbih_seed_42_all_mrns.yaml"), int, int),
(DatasetDefinition(filename="./example_data/mitbih_seed_42_all_tags.yaml"), int, int),
]

window_size_nano = 1_024 * 1_000_000_000
Expand Down Expand Up @@ -94,3 +134,10 @@ def create_test_definition_files(sdk):
definition = DatasetDefinition(measures=measures, device_tags=device_tags)

definition.save("./example_data/mitbih_seed_42_all_tags.yaml", force=True)


def get_index_of_first_nan(arr):
nan_index = np.argmax(np.isnan(arr))
if nan_index == 0 and not np.isnan(arr[0]):
return len(arr) - 1
return nan_index
19 changes: 19 additions & 0 deletions sdk/tests/test_list_difference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import pytest

from atriumdb.intervals.difference import list_difference


def test_list_difference():
assert list_difference([[0, 2], [3, 8], [10, 20]], [[2, 5], [6, 12], [20, 22]]) == [[0, 2], [5, 6], [12, 20]]
assert list_difference([[0, 5]], [[2, 4]]) == [[0, 2], [4, 5]]
assert list_difference([[1, 3], [4, 6], [7, 9]], [[2, 5], [8, 10]]) == [[1, 2], [5, 6], [7, 8]]
assert list_difference([[0, 2]], [[3, 5]]) == [[0, 2]]
assert list_difference([[1, 2]], [[2, 3]]) == [[1, 2]]
assert list_difference([[0, 10]], [[1, 9]]) == [[0, 1], [9, 10]]
assert list_difference([], [[0, 1], [2, 3]]) == []
assert list_difference([[0, 1], [2, 3]], []) == [[0, 1], [2, 3]]
assert list_difference([[1, 2]], [[1, 2]]) == []
assert list_difference([[0, 1]], [[0, 1], [2, 3]]) == []
assert list_difference([[1, 2]], [[0, 1]]) == [[1, 2]]
assert list_difference([[0, 5]], [[1, 2], [3, 4]]) == [[0, 1], [2, 3], [4, 5]]
assert list_difference([[0, 1], [3, 4], [6, 7]], [[1, 3], [4, 6]]) == [[0, 1], [3, 4], [6, 7]]

0 comments on commit 92ae2fa

Please sign in to comment.