Skip to content

Commit

Permalink
owassurvivaldata: properly handle widget input
Browse files Browse the repository at this point in the history
Widget now respects the information about survival variables stored in the domain.
  • Loading branch information
JakaKokosar committed Dec 4, 2023
1 parent 7d424c3 commit fee9871
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions orangecontrib/survival_analysis/widgets/owassurvivaldata.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
TIME_VAR,
EVENT_VAR,
TIME_TO_EVENT_VAR,
get_survival_endpoints,
)


Expand All @@ -36,9 +37,9 @@ class Outputs:
data = Output('Data', Table)

settingsHandler = DomainContextHandler()
time_var = ContextSetting(None, schema_only=True)
event_var = ContextSetting(None, schema_only=True)
auto_commit: bool = Setting(True, schema_only=True)
time_var = ContextSetting(None)
event_var = ContextSetting(None)
auto_commit: bool = Setting(True)

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand Down Expand Up @@ -82,22 +83,35 @@ def set_data(self, data: Table) -> None:
self._data = data.transform(data.domain)
self._data.attributes = data.attributes.copy()
# look for survival data in meta and class vars only.
vars_ = [
metas = [
var
for var in data.domain.metas + data.domain.class_vars
for var in data.domain.metas
if not isinstance(var, (TimeVariable, StringVariable))
]
class_vars = [
var
for var in data.domain.class_vars
if not isinstance(var, (TimeVariable, StringVariable))
]
domain = Domain(vars_)

domain = Domain([], metas=metas, class_vars=class_vars)

self.controls.time_var.model().set_domain(domain)
self.controls.event_var.model().set_domain(domain)

time_var_model = self.controls.time_var.model()
event_var_model = self.controls.event_var.model()

self.time_var = time_var_model[0] if len(time_var_model) else None
self.event_var = event_var_model[0] if len(event_var_model) else None
# If not found in the domain then default to the first var in model.
_time_var, _event_var = get_survival_endpoints(domain)

if len(time_var_model):
self.time_var = time_var_model[0] if _time_var is None else _time_var

if len(event_var_model):
self.event_var = event_var_model[0] if _event_var is None else _event_var

# Lastly, respect saved domain context
if self.time_var is not None and self.event_var is not None:
self.openContext(domain)

Expand Down

0 comments on commit fee9871

Please sign in to comment.