-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merged PR 172: Give both device and patient information for all windo…
…ws. And enforce windows not sharing two devices or two patients. This branch introduces a "device_patient_tuples" source type to the validated sources data structure for the iterator. It involves reading the device_patient table and looking for overlaps between it and the requested data from the definition. It also now brings that raw interval information into each batch request ensuring that get_data is only called within the regions that are mapped one-to-one with devices and patients. Areas where we have data but it would belong to a different patient within a window are replaced with NaN values.
- Loading branch information
Showing
8 changed files
with
203 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
def list_difference(first, second): | ||
result = [] | ||
i, j = 0, 0 | ||
|
||
while i < len(first) and j < len(second): | ||
# Check for non-overlapping intervals and add to result | ||
if first[i][1] <= second[j][0]: | ||
result.append(first[i]) | ||
i += 1 | ||
continue | ||
if second[j][1] <= first[i][0]: | ||
j += 1 | ||
continue | ||
|
||
# Find overlapping intervals and update first list | ||
if first[i][0] < second[j][0]: | ||
result.append([first[i][0], second[j][0]]) | ||
if first[i][1] <= second[j][1]: | ||
i += 1 | ||
else: | ||
first[i][0] = second[j][1] | ||
j += 1 | ||
|
||
while i < len(first): | ||
result.append(first[i]) | ||
i += 1 | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
from typing import List, Tuple | ||
|
||
from atriumdb.intervals.difference import list_difference | ||
from atriumdb.intervals.intersection import list_intersection | ||
from atriumdb.intervals.union import intervals_union_list | ||
|
||
|
||
def map_validated_sources(sources: dict, sdk) -> dict: | ||
# Initialize the new sources dictionary with a new key "device_patient_tuples" | ||
mapped_sources = {"device_patient_tuples": {}} | ||
|
||
# Extract patient_ids and device_ids dictionaries from the sources dictionary | ||
patient_ids = sources.get('patient_ids', {}) | ||
device_ids = sources.get('device_ids', {}) | ||
|
||
# Function to process ids (either patient_ids or device_ids) and update the mapped_sources dictionary | ||
def process_ids(ids_dict, id_type): | ||
for src_id, time_ranges in ids_dict.items(): | ||
union_ranges = [] | ||
for time_range in time_ranges: | ||
start_time, end_time = time_range | ||
# Fetch device_patient_data based on id_type | ||
device_patient_data = sdk.get_device_patient_data( | ||
patient_id_list=[src_id] if id_type == 'patient_ids' else None, | ||
device_id_list=[src_id] if id_type == 'device_ids' else None, | ||
start_time=start_time, end_time=end_time) | ||
# Aggregate the time ranges based on the device and patient IDs | ||
aggregated_ranges = aggregate_time_ranges(device_patient_data) | ||
for (device_id, patient_id), ranges in aggregated_ranges.items(): | ||
intersected_ranges = list_intersection(ranges, [time_range]) | ||
if intersected_ranges: | ||
key = (device_id, patient_id) | ||
if key not in mapped_sources["device_patient_tuples"]: | ||
mapped_sources["device_patient_tuples"][key] = intersected_ranges | ||
else: | ||
mapped_sources["device_patient_tuples"][key].extend(intersected_ranges) | ||
# Update the union_ranges list for the current src_id | ||
union_ranges.extend(intersected_ranges) | ||
|
||
# Calculate the union of ranges and update the mapped_sources dictionary with differences for the current src_id | ||
union_ranges = intervals_union_list(union_ranges).tolist() | ||
for time_range in time_ranges: | ||
difference_ranges = list_difference([time_range], union_ranges) | ||
if difference_ranges: | ||
if id_type not in mapped_sources: | ||
mapped_sources[id_type] = {src_id: difference_ranges} | ||
else: | ||
mapped_sources[id_type][src_id] = difference_ranges | ||
|
||
# Process patient_ids and device_ids separately | ||
process_ids(patient_ids, 'patient_ids') | ||
process_ids(device_ids, 'device_ids') | ||
|
||
return mapped_sources | ||
|
||
|
||
def aggregate_time_ranges(device_patient_data: List[Tuple[int, int, int, int]]): | ||
result = {} | ||
for device_id, patient_id, start_time, end_time in device_patient_data: | ||
key = (device_id, patient_id) | ||
if key not in result: | ||
result[key] = [] | ||
result[key].append([start_time, end_time]) | ||
|
||
# Sort the time ranges for each unique (device_id, patient_id) pair | ||
for key in result: | ||
result[key].sort(key=lambda x: x[0]) | ||
|
||
return result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import pytest | ||
|
||
from atriumdb.intervals.difference import list_difference | ||
|
||
|
||
def test_list_difference(): | ||
assert list_difference([[0, 2], [3, 8], [10, 20]], [[2, 5], [6, 12], [20, 22]]) == [[0, 2], [5, 6], [12, 20]] | ||
assert list_difference([[0, 5]], [[2, 4]]) == [[0, 2], [4, 5]] | ||
assert list_difference([[1, 3], [4, 6], [7, 9]], [[2, 5], [8, 10]]) == [[1, 2], [5, 6], [7, 8]] | ||
assert list_difference([[0, 2]], [[3, 5]]) == [[0, 2]] | ||
assert list_difference([[1, 2]], [[2, 3]]) == [[1, 2]] | ||
assert list_difference([[0, 10]], [[1, 9]]) == [[0, 1], [9, 10]] | ||
assert list_difference([], [[0, 1], [2, 3]]) == [] | ||
assert list_difference([[0, 1], [2, 3]], []) == [[0, 1], [2, 3]] | ||
assert list_difference([[1, 2]], [[1, 2]]) == [] | ||
assert list_difference([[0, 1]], [[0, 1], [2, 3]]) == [] | ||
assert list_difference([[1, 2]], [[0, 1]]) == [[1, 2]] | ||
assert list_difference([[0, 5]], [[1, 2], [3, 4]]) == [[0, 1], [2, 3], [4, 5]] | ||
assert list_difference([[0, 1], [3, 4], [6, 7]], [[1, 3], [4, 6]]) == [[0, 1], [3, 4], [6, 7]] |