Skip to content

Commit

Permalink
owkaplanmeier, owcohorts: replace get_column_view with get_column
Browse files Browse the repository at this point in the history
  • Loading branch information
JakaKokosar committed Dec 6, 2022
1 parent 35c2a26 commit 03ad8c6
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 21 deletions.
4 changes: 2 additions & 2 deletions orangecontrib/survival_analysis/widgets/owchorts.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,8 @@ def check_unique_values(split_by, values):
cutoff = partial(check_unique_values, np.mean)
elif splitting_criteria == SplittingCriteria.LogRankTest:
time_var, event_var = get_survival_endpoints(data.domain)
durations, _ = data.get_column_view(time_var)
events, _ = data.get_column_view(event_var)
durations = data.get_column(time_var)
events = data.get_column(event_var)
cutoff = partial(cutoff_by_log_rank_optimization, durations, events, callback)
else:
raise ValueError('Unknown splitting criteria')
Expand Down
10 changes: 5 additions & 5 deletions orangecontrib/survival_analysis/widgets/owkaplanmeier.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,11 +658,11 @@ def generate_plot_curves(self) -> List[EstimatedFunctionCurve]:
return []

data = self.data
time, _ = data.get_column_view(self.time_var)
events, _ = data.get_column_view(self.event_var)
time = data.get_column(self.time_var)
events = data.get_column(self.event_var)

if self.group_var:
groups, _ = data.get_column_view(self.group_var.name)
groups = data.get_column(self.group_var.name)
group_indexes = [index for index, _ in enumerate(self.group_var.values)]
colors = [self._get_discrete_var_color(index) for index in group_indexes]
masks = groups == np.reshape(group_indexes, (-1, 1))
Expand All @@ -688,7 +688,7 @@ def commit(self):

data = self.data

time, _ = data.get_column_view(self.time_var)
time = data.get_column(self.time_var)
if self.group_var is None:
time_interval = self.graph.selection[0].x
start, end = time_interval[0], time_interval[-1]
Expand All @@ -697,7 +697,7 @@ def commit(self):
)
else:
selection = []
group, _ = data.get_column_view(self.group_var.name)
group = data.get_column(self.group_var.name)
for group_id, time_interval in self.graph.selection.items():
start, end = time_interval.x[0], time_interval.x[-1]
selection += (
Expand Down
24 changes: 10 additions & 14 deletions orangecontrib/survival_analysis/widgets/tests/test_owkaplanmeier.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,8 @@ def setUp(self) -> None:
vb.resize(200, 200)

def simulate_mouse_drag(self, start: tuple, end: tuple):
start = self.widget.graph.view_box.mapViewToScene(
pg.Point(start[0], start[1])
).toPoint()
end = self.widget.graph.view_box.mapViewToScene(
pg.Point(end[0], end[1])
).toPoint()
start = self.widget.graph.view_box.mapViewToScene(pg.Point(start[0], start[1]))
end = self.widget.graph.view_box.mapViewToScene(pg.Point(end[0], end[1]))

mouse_move(self.widget.graph, start)
# this is somehow not respected in KaplanMeierViewBox.mouseDragEvent
Expand Down Expand Up @@ -166,13 +162,13 @@ def test_curve_highlight(self):
self.widget.group_var = self.widget.data.domain['Group']
self.widget.on_group_changed()

pos = self.widget.graph.view_box.mapViewToScene(pg.Point(1.5, 0.5)).toPoint()
pos = self.widget.graph.view_box.mapViewToScene(pg.Point(1.5, 0.5))
mouse_move(self.widget.graph, pos)
# We need to wait for events to process
QTest.qWait(100)
self.assertTrue(self.widget.graph.highlighted_curve == 1)

pos = self.widget.graph.view_box.mapViewToScene(pg.Point(1.5, 0.85)).toPoint()
pos = self.widget.graph.view_box.mapViewToScene(pg.Point(1.5, 0.85))
mouse_move(self.widget.graph, pos)
QTest.qWait(100)
self.assertTrue(self.widget.graph.highlighted_curve == 0)
Expand All @@ -197,9 +193,9 @@ def test_selection(self):

# check output data
selected_data = self.get_output(self.widget.Outputs.selected_data)
selected_groups = selected_data.get_column_view('Group')[0]
selected_groups = selected_data.get_column('Group')
self.assertEqual(12, selected_groups.size)
selected_groups = set(selected_data.get_column_view('Group')[0])
selected_groups = set(selected_data.get_column('Group'))
self.assertEqual(2, len(selected_groups))
self.assertIn(0, selected_groups)
self.assertIn(1, selected_groups)
Expand All @@ -223,9 +219,9 @@ def test_selection(self):

# check output data
selected_data = self.get_output(self.widget.Outputs.selected_data)
selected_groups = selected_data.get_column_view('Group')[0]
selected_groups = selected_data.get_column('Group')
self.assertEqual(6, selected_groups.size)
selected_groups = set(selected_data.get_column_view('Group')[0])
selected_groups = set(selected_data.get_column('Group'))
self.assertEqual(1, len(selected_groups))
self.assertIn(0, selected_groups)
self.assertNotIn(1, selected_groups)
Expand All @@ -249,9 +245,9 @@ def test_selection(self):

# check output data
selected_data = self.get_output(self.widget.Outputs.selected_data)
selected_groups = selected_data.get_column_view('Group')[0]
selected_groups = selected_data.get_column('Group')
self.assertEqual(6, selected_groups.size)
selected_groups = set(selected_data.get_column_view('Group')[0])
selected_groups = set(selected_data.get_column('Group'))
self.assertEqual(1, len(selected_groups))
self.assertIn(1, selected_groups)
self.assertNotIn(0, selected_groups)
Expand Down

0 comments on commit 03ad8c6

Please sign in to comment.