Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
fix: bunch of plot fixes (#427)
Browse files Browse the repository at this point in the history
  • Loading branch information
MartinBernstorff authored Mar 15, 2023
2 parents 5281789 + 2d4054c commit aaba25a
Show file tree
Hide file tree
Showing 7 changed files with 104 additions and 60 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ def create_base_plot_artifacts(self) -> list[ArtifactContainer]:
label="recall_by_calendar_time",
artifact=plot_recall_by_calendar_time(
eval_dataset=self.eval_ds,
pred_proba_percentile=[0.95, 0.97, 0.99],
pos_rate=[0.95, 0.97, 0.99],
bins=self.cfg.eval.lookahead_bins,
y_limits=(0, 0.5),
save_path=self.save_dir / "recall_by_calendar_time.png",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ def plot_basic_chart(
flip_x_axis: bool = False,
flip_y_axis: bool = False,
bar_count_values: Optional[pd.Series] = None,
bar_count_y_axis_title: str = "Number of observations",
y_limits: Optional[tuple[float, float]] = None,
fig_size: Optional[tuple[float, float]] = (5, 5),
dpi: Optional[int] = 300,
Expand All @@ -41,6 +42,7 @@ def plot_basic_chart(
flip_x_axis: Whether to flip the x axis. Defaults to False.
flip_y_axis: Whether to flip the y axis. Defaults to False.
bar_count_values: Values to use for overlaid histogram of n in bins. Defaults to None.
bar_count_y_axis_title: Title of y axis of overlaid histogram. Defaults to "Number of observations".
y_limits: y-axis limits. Defaults to None.
fig_size: figure size. Defaults to None.
dpi: dpi of figure. Defaults to 300.
Expand Down Expand Up @@ -115,17 +117,14 @@ def plot_basic_chart(
if bar_count_values is not None:
# add additional y-axis for count
bar_overlay = plt.gca().twinx()
bar_overlay.bar(df["x"], bar_count_values, color="blue")
bar_overlay.set_ylabel("Number of observations")
bar_overlay.bar(df["x"], bar_count_values, color="gainsboro", alpha=0.5)
bar_overlay.set_ylabel(bar_count_y_axis_title)

# put bar plots behind other plots
axs.set_zorder(bar_overlay.get_zorder() + 1)
axs.set_facecolor("none")
bar_overlay.set_facecolor("none")

# add counts to bars
bar_overlay.bar_label(bar_overlay.bar(df["x"], bar_count_values))

plt.tight_layout()

if save_path is not None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,8 @@ def plot_performance_by_age(
y_title="AUC",
sort_x=sort_order,
y_limits=y_limits,
plot_type=["bar"],
plot_type=["scatter", "line"],
bar_count_values=df["n_in_bin"],
bar_count_y_axis_title="Number of visits",
save_path=save_path,
)
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from collections.abc import Callable, Iterable, Sequence
from pathlib import Path
from typing import Optional, Union
from typing import Literal, Optional, Union

import numpy as np
import pandas as pd
Expand All @@ -24,8 +24,9 @@

def plot_recall_by_calendar_time(
eval_dataset: EvalDataset,
pred_proba_percentile: Union[float, Iterable[float]],
pos_rate: Union[float, Iterable[float]],
bins: Iterable[float],
bin_unit: Literal["D", "W", "M", "Q", "Y"] = "D",
y_title: str = "Sensitivity (Recall)",
y_limits: Optional[tuple[float, float]] = None,
save_path: Optional[Union[Path, str]] = None,
Expand All @@ -34,43 +35,53 @@ def plot_recall_by_calendar_time(
Args:
eval_dataset (EvalDataset): EvalDataset object
pred_proba_percentile (Union[float, Iterable[float]]): Percentile of highest predicted probabilities to mark as positive in binary classification.
pos_rate (Union[float, Iterable[float]]): Percentile of highest predicted probabilities to mark as positive in binary classification.
bins (Iterable[float], optional): Bins to use for time to outcome.
bin_unit (Literal["D", "M", "Q", "Y"], optional): Unit of time to bin by. Defaults to "D".
y_title (str): Title of y-axis. Defaults to "AUC".
save_path (str, optional): Path to save figure. Defaults to None.
y_limits (tuple[float, float], optional): Limits of y-axis. Defaults to (0.5, 1.0).
Returns:
Union[None, Path]: Path to saved figure or None if not saved.
"""
if not isinstance(pred_proba_percentile, Iterable):
pred_proba_percentile = [pred_proba_percentile]

pred_proba_percentile = list(pred_proba_percentile)
pred_proba_percentile_labels = [
str(percentile) for percentile in pred_proba_percentile
]
if not isinstance(pos_rate, Iterable):
pos_rate = [pos_rate]
pos_rate = list(pos_rate)

# Get percentiles from a series of predicted probabilities
pred_proba_percentiles = eval_dataset.y_hat_probs.rank(pct=True)
y_hat_percentiles = eval_dataset.y_hat_probs.rank(pct=True)

pos_rate_threshold = [1 - threshold for threshold in list(pos_rate)]
pos_rate_threshold_labels = [str(threshold) for threshold in list(pos_rate)]

dfs = [
create_sensitivity_by_time_to_outcome_df(
labels=eval_dataset.y,
y_hat_probs=pred_proba_percentiles,
y_hat_probs=y_hat_percentiles,
pred_proba_threshold=threshold,
outcome_timestamps=eval_dataset.outcome_timestamps,
prediction_timestamps=eval_dataset.pred_timestamps,
bins=bins,
bin_delta=bin_unit,
)
for threshold in pred_proba_percentile
for threshold in pos_rate_threshold
]

bin_delta_to_str = {
"D": "Day",
"W": "Week",
"M": "Month",
"Q": "Quarter",
"Y": "Year",
}

x_title_unit = bin_delta_to_str[bin_unit]
return plot_basic_chart(
x_values=dfs[0]["days_to_outcome_binned"],
y_values=[df["sens"] for df in dfs],
x_title="Days from event",
labels=pred_proba_percentile_labels,
x_title=f"{x_title_unit}s to event",
labels=pos_rate_threshold_labels,
y_title=y_title,
y_limits=y_limits,
flip_x_axis=True,
Expand Down Expand Up @@ -112,7 +123,7 @@ def create_performance_by_calendar_time_df(
def plot_metric_by_calendar_time(
eval_dataset: EvalDataset,
y_title: str = "AUC",
bin_period: str = "Y",
bin_period: Literal["D", "W", "M", "Q", "Y"] = "Y",
save_path: Optional[str] = None,
metric_fn: Callable = roc_auc_score,
y_limits: Optional[tuple[float, float]] = (0.5, 1.0),
Expand All @@ -139,17 +150,23 @@ def plot_metric_by_calendar_time(
)
sort_order = np.arange(len(df))

x_titles = {
"D": "Day",
"W": "Week",
"M": "Month",
"Q": "Quarter",
"Y": "Year",
}

return plot_basic_chart(
x_values=df["time_bin"],
y_values=df["metric"],
x_title="Month"
if bin_period == "M"
else "Quarter"
if bin_period == "Q"
else "Year",
x_title=x_titles[bin_period],
y_title=y_title,
sort_x=sort_order,
y_limits=y_limits,
bar_count_values=df["n_in_bin"],
bar_count_y_axis_title="Number of visits",
plot_type=["line", "scatter"],
save_path=save_path,
)
Expand Down Expand Up @@ -269,6 +286,8 @@ def plot_metric_by_cyclic_time(
y_title=y_title,
y_limits=y_limits,
plot_type=["line", "scatter"],
bar_count_values=df["n_in_bin"],
bar_count_y_axis_title="Number of visits",
save_path=save_path,
)

Expand All @@ -281,6 +300,7 @@ def create_performance_by_time_from_event_df(
metric_fn: Callable,
direction: str,
bins: Sequence[float],
bin_unit: Literal["D", "M", "Q", "Y"],
bin_continuous_input: Optional[bool] = True,
drop_na_events: Optional[bool] = True,
min_n_in_bin: int = 5,
Expand All @@ -297,12 +317,13 @@ def create_performance_by_time_from_event_df(
direction (str): Which direction to calculate time difference.
Can either be 'prediction-event' or 'event-prediction'.
bins (Iterable[float]): Bins to group by.
bin_unit (Literal["D", "M", "Q", "Y"]): Unit of time to use for bins.
bin_continuous_input (bool, optional): Whether to bin input. Defaults to True.
drop_na_events (bool, optional): Whether to drop rows where the event is NA. Defaults to True.
min_n_in_bin (int, optional): Minimum number of rows in a bin to include in output. Defaults to 10.
Returns:
pd.DataFrame: Dataframe ready for plotting
pd.DataFrame: Dataframe ready for plotting where each row represents a bin.
"""

df = pd.DataFrame(
Expand All @@ -319,19 +340,19 @@ def create_performance_by_time_from_event_df(

# Calculate difference in days between prediction and event
if direction == "event-prediction":
df["days_from_event"] = (
df["unit_from_event"] = (
df["event_timestamp"] - df["prediction_timestamp"]
) / np.timedelta64(
1,
"D",
bin_unit,
) # type: ignore

elif direction == "prediction-event":
df["days_from_event"] = (
df["unit_from_event"] = (
df["prediction_timestamp"] - df["event_timestamp"]
) / np.timedelta64(
1,
"D",
bin_unit,
) # type: ignore

else:
Expand All @@ -341,32 +362,26 @@ def create_performance_by_time_from_event_df(

# bin data
if bin_continuous_input:
# Convert df["days_from_event"] to int if possible
df["days_from_event_binned"], df["n_in_bin"] = bin_continuous_data(
df["days_from_event"],
# Convert df["unit_from_event"] to int if possible
df["unit_from_event_binned"], df["n_in_bin"] = bin_continuous_data(
df["unit_from_event"],
bins=bins,
min_n_in_bin=min_n_in_bin,
)
else:
df["days_from_event_binned"] = round_floats_to_edge(
df["days_from_event"],
df["unit_from_event_binned"] = round_floats_to_edge(
df["unit_from_event"],
bins=bins,
)

# Calc performance and prettify output
output_df = df.groupby(["days_from_event_binned"]).apply(
calc_performance,
metric_fn,
)

output_df = (
output_df.reset_index()
.rename({0: "metric"}, axis=1)
.merge(
df[["days_from_event_binned", "n_in_bin"]],
on="days_from_event_binned",
how="left",
df.groupby(["unit_from_event_binned"])
.apply(
calc_performance,
metric_fn,
)
.reset_index()
)

return output_df
Expand All @@ -375,6 +390,7 @@ def create_performance_by_time_from_event_df(
def plot_auc_by_time_from_first_visit(
eval_dataset: EvalDataset,
bins: tuple = (0, 28, 182, 365, 730, 1825),
bin_unit: Literal["D", "M", "Q", "Y"] = "D",
bin_continuous_input: Optional[bool] = True,
y_limits: Optional[tuple[float, float]] = (0.5, 1.0),
save_path: Optional[Path] = None,
Expand All @@ -384,6 +400,7 @@ def plot_auc_by_time_from_first_visit(
Args:
eval_dataset (EvalDataset): EvalDataset object
bins (list, optional): Bins to group by. Defaults to [0, 28, 182, 365, 730, 1825].
bin_unit (Literal["D", "M", "Q", "Y"], optional): Unit of time to bin by. Defaults to "D".
bin_continuous_input (bool, optional): Whether to bin input. Defaults to True.
y_limits (tuple[float, float], optional): Limits of y-axis. Defaults to (0.5, 1.0).
save_path (Path, optional): Path to save figure. Defaults to None.
Expand All @@ -404,20 +421,30 @@ def plot_auc_by_time_from_first_visit(
prediction_timestamps=eval_dataset.pred_timestamps,
direction="prediction-event",
bins=list(bins),
bin_unit=bin_unit,
bin_continuous_input=bin_continuous_input,
drop_na_events=False,
metric_fn=roc_auc_score,
)

bin_unit2str = {
"D": "Days",
"M": "Months",
"Q": "Quarters",
"Y": "Years",
}

sort_order = np.arange(len(df))
return plot_basic_chart(
x_values=df["days_from_event_binned"],
x_values=df["unit_from_event_binned"],
y_values=df["metric"],
x_title="Days from first visit",
x_title=f"{bin_unit2str[bin_unit]} from first visit",
y_title="AUC",
sort_x=sort_order,
y_limits=y_limits,
plot_type=["line", "scatter"],
bar_count_values=df["n_in_bin"],
bar_count_y_axis_title="Number of visits",
save_path=save_path,
)

Expand All @@ -432,6 +459,7 @@ def plot_metric_by_time_until_diagnosis(
-28,
-0,
),
bin_unit: Literal["D", "M", "Q", "Y"] = "D",
bin_continuous_input: bool = True,
metric_fn: Callable = f1_score,
y_title: str = "F1",
Expand All @@ -445,6 +473,7 @@ def plot_metric_by_time_until_diagnosis(
Args:
eval_dataset (EvalDataset): EvalDataset object
bins (list, optional): Bins to group by. Negative values indicate days after
bin_unit (Literal["D", "M", "Q", "Y"], optional): Unit of time to bin by. Defaults to "D".
diagnosis. Defaults to (-1825, -730, -365, -182, -28, -14, -7, -1, 0)
bin_continuous_input (bool, optional): Whether to bin input. Defaults to True.
metric_fn (Callable): Which performance metric function to use.
Expand All @@ -462,17 +491,25 @@ def plot_metric_by_time_until_diagnosis(
prediction_timestamps=eval_dataset.pred_timestamps,
direction="event-prediction",
bins=bins,
bin_unit=bin_unit,
bin_continuous_input=bin_continuous_input,
min_n_in_bin=0,
drop_na_events=True,
metric_fn=metric_fn,
)
sort_order = np.arange(len(df))

bin_unit2str = {
"D": "Days",
"M": "Months",
"Q": "Quarters",
"Y": "Years",
}

return plot_basic_chart(
x_values=df["days_from_event_binned"],
x_values=df["unit_from_event_binned"],
y_values=df["metric"],
x_title="Days to diagnosis",
x_title=f"{bin_unit2str[bin_unit]} to diagnosis",
y_title=y_title,
sort_x=sort_order,
bar_count_values=df["n_in_bin"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from collections.abc import Iterable
from functools import partial
from pathlib import Path
from typing import Optional, Union
from typing import Literal, Optional, Union

import matplotlib as mpl
import matplotlib.pyplot as plt
Expand All @@ -19,6 +19,7 @@ def create_sensitivity_by_time_to_outcome_df(
outcome_timestamps: Iterable[pd.Timestamp],
prediction_timestamps: Iterable[pd.Timestamp],
bins: Iterable = (0, 1, 7, 14, 28, 182, 365, 730, 1825),
bin_delta: Literal["D", "W", "M", "Q", "Y"] = "D",
) -> pd.DataFrame:
"""Calculate sensitivity by time to outcome.
Expand Down Expand Up @@ -61,7 +62,7 @@ def create_sensitivity_by_time_to_outcome_df(
df["outcome_timestamp"] - df["prediction_timestamp"]
) / np.timedelta64(
1,
"D",
bin_delta,
) # type: ignore

df["true_positive"] = (df["y"] == 1) & (df["y_hat"] == 1)
Expand Down
Loading

0 comments on commit aaba25a

Please sign in to comment.