Skip to content

Commit

Permalink
owstepwisecoxregression: make use of cox regression learners
Browse files Browse the repository at this point in the history
  • Loading branch information
JakaKokosar committed Oct 28, 2021
1 parent f023f5a commit 42057c1
Showing 1 changed file with 45 additions and 38 deletions.
83 changes: 45 additions & 38 deletions orangecontrib/survival_analysis/widgets/owstepwisecoxregression.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from AnyQt.QtGui import QColor
from AnyQt.QtCore import Qt, QPointF, pyqtSignal as Signal
from lifelines import CoxPHFitter

from Orange.widgets import gui
from Orange.widgets.settings import ContextSetting, DomainContextHandler, Setting, SettingProvider
Expand All @@ -16,6 +15,8 @@
from Orange.data import Table, Domain, DiscreteVariable, ContinuousVariable
from Orange.data.pandas_compat import table_to_frame

from orangecontrib.survival_analysis.widgets.owcoxregression import CoxRegressionLearner, CoxRegressionModel


class CustomInfiniteLine(pg.InfiniteLine):
def __init__(self, parent, *args, **kwargs):
Expand Down Expand Up @@ -79,22 +80,19 @@ def get_viewbox_y_range(self):

class Result(NamedTuple):
log2p: int
covariate_to_coef: dict
model: CoxRegressionModel
removed_covariates: list


def worker(df: pd.DataFrame, initial_covariates: set, time_var: str, event_var: str, state: TaskState):
def worker(df: pd.DataFrame, learner, initial_covariates: set, time_var: str, event_var: str, state: TaskState):
progress_steps = iter(np.linspace(0, 100, len(initial_covariates)))

def fit_cox_models(remaining_covariates: set, combinations_to_check: List[Tuple[str, ...]]):
results = []
for covariates in combinations_to_check:
cph = CoxPHFitter().fit(
df[[time_var, event_var] + list(covariates)], duration_col=time_var, event_col=event_var
)
covariate_to_coef = cph.summary.to_dict('dict')['coef']
log2p = -np.log2(cph.log_likelihood_ratio_test().p_value)
result = Result(log2p, covariate_to_coef, [cov for cov in remaining_covariates - set(covariates)])
cph_model = learner(df[[time_var, event_var] + list(covariates)], time_var, event_var)
log2p = cph_model.ll_ratio_log2p()
result = Result(log2p, cph_model, [cov for cov in remaining_covariates - set(covariates)])
results.append(result)
return results

Expand All @@ -104,11 +102,11 @@ def fit_cox_models(remaining_covariates: set, combinations_to_check: List[Tuple[
covariates_to_eval = initial_covariates - removed_covariates

if len(covariates_to_eval) > 1:
gene_combinations = list(itertools.combinations(covariates_to_eval, len(covariates_to_eval) - 1))
combinations = list(itertools.combinations(covariates_to_eval, len(covariates_to_eval) - 1))
else:
gene_combinations = [tuple(covariates_to_eval)]
combinations = [tuple(covariates_to_eval)]

results = fit_cox_models(covariates_to_eval, gene_combinations)
results = fit_cox_models(covariates_to_eval, combinations)

best_result = max(results, key=lambda result: result.log2p)
if not best_result.removed_covariates:
Expand All @@ -135,6 +133,7 @@ class OWStepwiseCoxRegression(OWWidget, ConcurrentWidgetMixin):

class Inputs:
data = Input('Data', Table)
learner = Input('Cox Learner', CoxRegressionLearner)

class Outputs:
selected_data = Output('Data', Table)
Expand All @@ -143,10 +142,10 @@ def __init__(self):
OWWidget.__init__(self)
ConcurrentWidgetMixin.__init__(self)

self.learner: Optional[CoxRegressionLearner] = CoxRegressionLearner()
self.data: Optional[Table] = None
self.data_df: Optional[pd.DataFrame] = None
self.attr_name_to_variable: Optional[dict] = None

self.trace: Optional[List[Result]] = None

time_var_model = DomainModel(valid_types=(ContinuousVariable,), order=(4,))
Expand All @@ -163,6 +162,15 @@ def __init__(self):
gui.rubber(self.controlArea)
self.commit_button = gui.auto_commit(self.controlArea, self, 'auto_commit', '&Commit', box=False)

@Inputs.learner
def set_learner(self, learner: CoxRegressionLearner):
if learner:
self.learner = learner
else:
self.learner = CoxRegressionLearner()

self.on_controls_changed()

@Inputs.data
def set_data(self, data: Table):
self.closeContext()
Expand All @@ -178,31 +186,26 @@ def set_data(self, data: Table):
self.time_var = None
self.openContext(data)

if self.time_var:
self.start(
worker,
self.data_df,
set(self.attr_name_to_variable.keys()),
self.time_var.name,
self.data.domain.class_var.name,
)
self.on_controls_changed()

def on_controls_changed(self):
if self.time_var:
self.start(
worker,
self.data_df,
self.learner,
set(self.attr_name_to_variable.keys()),
self.time_var.name,
self.data.domain.class_var.name,
)

def on_selection_changed(self, selection_line):
self.current_x = selection_line.getXPos() - 1
self.current_x = selection_line.getXPos() # + 1
self.commit()

def commit(self):
self.Outputs.selected_data.send(self.stratify_data(self.data_df, self.trace[self.current_x]))
if self.current_x:
self.Outputs.selected_data.send(self.stratify_data(self.trace[self.current_x - 1]))

def on_done(self, trace):
# save results
Expand All @@ -222,27 +225,31 @@ def on_exception(self, ex):
def on_partial_result(self, result: Any) -> None:
pass

def stratify_data(
self,
df: pd.DataFrame,
result: Result,
) -> Table:
covariates = result.covariate_to_coef.keys()
def stratify_data(self, result: Result) -> Table:
model = result.model

domain = Domain(
[self.attr_name_to_variable[covariate] for covariate in model.covariates],
self.data.domain.class_var,
self.data.domain.metas,
)
data = self.data.transform(domain)

risk_score_label = 'Risk Score'
risk_group_label = 'Risk Group'
risk_score_var = ContinuousVariable(risk_score_label)
risk_group_var = DiscreteVariable(risk_group_label, values=['Low Risk', 'High Risk'])

df[risk_score_label] = df[covariates].dot([result.covariate_to_coef[covariate] for covariate in covariates])
df[risk_group_label] = (df[risk_score_label] >= df[risk_score_label].median()).astype(int)
risk_scores = model.predict(data.X)
risk_groups = (risk_scores > np.median(risk_scores)).astype(int)

attrs = [self.attr_name_to_variable[name] for name in covariates]
domain = Domain(attrs, self.data.domain.class_var, self.data.domain.metas + (risk_score_var, risk_group_var))
data = self.data.transform(domain)
data[:, risk_score_var] = np.reshape(df[risk_score_label].to_numpy(), (-1, 1))
data[:, risk_group_var] = np.reshape(df[risk_group_label].to_numpy(), (-1, 1))
return data
domain = Domain(
data.domain.attributes, data.domain.class_var, data.domain.metas + (risk_score_var, risk_group_var)
)
stratified_data = data.transform(domain)
stratified_data[:, risk_score_var] = np.reshape(risk_scores, (-1, 1))
stratified_data[:, risk_group_var] = np.reshape(risk_groups, (-1, 1))
return stratified_data

def send_report(self):
if self.data is None:
Expand All @@ -253,4 +260,4 @@ def send_report(self):
if __name__ == "__main__":
from orangewidget.utils.widgetpreview import WidgetPreview

WidgetPreview(OWStepwiseCoxRegression).run(Table('test_data_full.pkl'))
WidgetPreview(OWStepwiseCoxRegression).run(Table('test_data3.tab'))

0 comments on commit 42057c1

Please sign in to comment.