Skip to content
This repository has been archived by the owner on May 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #220 from Aarhus-Psychiatry-Research/HLasse/Plot-b…
Browse files Browse the repository at this point in the history
…ar-chart-in-matplotlib

Migrate from Altair to matplotlib
  • Loading branch information
HLasse authored Oct 7, 2022
2 parents 4723533 + 37f95ca commit 7e016c1
Show file tree
Hide file tree
Showing 12 changed files with 427 additions and 241 deletions.
3 changes: 1 addition & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,14 @@ wandb = ">=0.12.21,<0.13.0"
psycopmlutils = {git = "https://github.com/Aarhus-Psychiatry-Research/psycop-ml-utils.git", rev = "main"}
tabulate = ">=0.8.10,<0.9.0"
optuna = ">=2.10.1,<2.11.0"
altair = ">=4.2.0,<4.3.0"
altair-saver = ">=0.5.0,<0.6.0"
hydra-optuna-sweeper = ">=1.2.0,<1.3.0"
hydra-joblib-launcher = ">=1.2.0, <1.3.0"
selenium = ">=4.2.0,<4.6.0"

# pandoc = "^2.2" not compatible with PEP517, which is required for poetry install. Disabled for now: https://github.com/boisgera/pandoc/pull/49
# interpret = ">=0.2.7,<0.3.0" Disabled becuase of errors with github actions.
# See https://github.com/Aarhus-Psychiatry-Research/psycop-t2d/pull/194 for thoughts on root cause
seaborn = ">=0.12.0, <0.12.1"


[tool.poetry.dev-dependencies]
Expand Down
27 changes: 18 additions & 9 deletions src/psycopt2d/evaluation.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""Functions for evaluating a model's prredictions."""
from collections.abc import Iterable

import altair as alt
import numpy as np
import pandas as pd
from sklearn.metrics import f1_score, roc_auc_score
Expand All @@ -12,15 +11,19 @@
from psycopt2d.tables.performance_by_threshold import (
generate_performance_by_positive_rate_table,
)
from psycopt2d.utils import AUC_LOGGING_FILE_PATH, positive_rate_to_pred_probs
from psycopt2d.utils import (
AUC_LOGGING_FILE_PATH,
PROJECT_ROOT,
positive_rate_to_pred_probs,
)
from psycopt2d.visualization import (
plot_auc_by_time_from_first_visit,
plot_feature_importances,
plot_metric_by_time_until_diagnosis,
plot_performance_by_calendar_time,
)
from psycopt2d.visualization.altair_utils import log_altair_to_wandb
from psycopt2d.visualization.sens_over_time import plot_sensitivity_by_time_to_outcome
from psycopt2d.visualization.utils import log_image_to_wandb


def evaluate_model(
Expand Down Expand Up @@ -51,7 +54,10 @@ def evaluate_model(
y_hat_prob_col_name (str): Column name containing pred_proba output
run (wandb_run): WandB run to log to.
"""
y = eval_dataset[y_col_name]
SAVE_DIR = PROJECT_ROOT / ".tmp" # pylint: disable=invalid-name
if not SAVE_DIR.exists():
SAVE_DIR.mkdir()
y = eval_dataset[y_col_name] # pylint: disable=invalid-name
y_hat_probs = eval_dataset[y_hat_prob_col_name]
auc = round(roc_auc_score(y, y_hat_probs), 3)
outcome_timestamps = eval_dataset[cfg.data.outcome_timestamp_col_name]
Expand All @@ -66,8 +72,6 @@ def evaluate_model(
positive_rate_thresholds=cfg.evaluation.positive_rate_thresholds,
)

alt.data_transformers.disable_max_rows()

print(f"AUC: {auc}")

# Log to wandb
Expand Down Expand Up @@ -116,6 +120,7 @@ def evaluate_model(
column_names=feature_names,
feature_importances=pipe["model"].feature_importances_,
top_n_feature_importances=cfg.evaluation.top_n_feature_importances,
save_path=SAVE_DIR / "feature_importances.png",
)
plots.update(
{"feature_importance": feature_importances_plot},
Expand All @@ -136,6 +141,7 @@ def evaluate_model(
pred_proba_thresholds=pred_proba_thresholds,
outcome_timestamps=outcome_timestamps,
prediction_timestamps=pred_timestamps,
save_path=SAVE_DIR / "sensitivity_by_time_by_threshold.png",
),
"auc_by_calendar_time": plot_performance_by_calendar_time(
labels=y,
Expand All @@ -144,12 +150,14 @@ def evaluate_model(
bin_period="M",
metric_fn=roc_auc_score,
y_title="AUC",
save_path=SAVE_DIR / "auc_by_calendar_time.png",
),
"auc_by_time_from_first_visit": plot_auc_by_time_from_first_visit(
labels=y,
y_hat_probs=y_hat_probs,
first_visit_timestamps=first_visit_timestamp,
prediction_timestamps=pred_timestamps,
save_path=SAVE_DIR / "auc_by_time_from_first_visit.png",
),
"f1_by_time_until_diagnosis": plot_metric_by_time_until_diagnosis(
labels=y,
Expand All @@ -158,10 +166,11 @@ def evaluate_model(
prediction_timestamps=pred_timestamps,
metric_fn=f1_score,
y_title="F1",
save_path=SAVE_DIR / "f1_by_time_until_diagnosis.png",
),
},
)

# Log all the figures to wandb
for chart_name, chart_obj in plots.items():
log_altair_to_wandb(chart=chart_obj, chart_name=chart_name, run=run)
## Log all the figures to wandb
for chart_name, chart_path in plots.items():
log_image_to_wandb(chart_path=chart_path, chart_name=chart_name, run=run)
7 changes: 5 additions & 2 deletions src/psycopt2d/model_performance/model_performance.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@
import pandas as pd
from sklearn.metrics import (
accuracy_score,
balanced_accuracy_score,
confusion_matrix,
f1_score,
precision_score,
recall_score,
roc_auc_score,
balanced_accuracy_score
)

from psycopt2d.model_performance.utils import (
Expand Down Expand Up @@ -368,7 +368,10 @@ def compute_metrics(
performance = {}

performance["acc-overall"] = accuracy_score(labels, predicted)
performance["balanced_accuracy-overall"] = balanced_accuracy_score(labels, predicted)
performance["balanced_accuracy-overall"] = balanced_accuracy_score(
labels,
predicted,
)
performance["f1_macro-overall"] = f1_score(labels, predicted, average="macro")
performance["f1_micro-overall"] = f1_score(labels, predicted, average="micro")
performance["precision_macro-overall"] = precision_score(
Expand Down
72 changes: 0 additions & 72 deletions src/psycopt2d/visualization/altair_utils.py

This file was deleted.

70 changes: 40 additions & 30 deletions src/psycopt2d/visualization/base_charts.py
Original file line number Diff line number Diff line change
@@ -1,58 +1,68 @@
"""Base charts."""
from collections.abc import Iterable
from typing import Optional
from pathlib import Path
from typing import Optional, Union

import altair as alt
import matplotlib.pyplot as plt
import pandas as pd


def plot_bar_chart(
def plot_basic_chart(
x_values: Iterable,
y_values: Iterable,
x_title: str,
y_title: str,
plot_type: Optional[Union[list[str], str]],
sort_x: Optional[Iterable[int]] = None,
sort_y: Optional[Iterable[int]] = None,
) -> alt.Chart:
"""Plot a basic bar chart with Altair.
fig_size: Optional[tuple] = (10, 10),
save_path: Optional[Path] = None,
) -> Union[None, Path]:
"""Plot a simple chart using matplotlib. Options for sorting the x and y
axis are available.
Args:
x_values (Iterable): The x values of the bar chart.
y_values (Iterable): The y values of the bar chart.
x_title (str): title of x axis
y_title (str): title of y axis
plot_type (Optional[Union[List[str], str]], optional): type of plots.
Options are combinations of ["bar", "hbar", "line", "scatter"] Defaults to "bar".
sort_x (Optional[Iterable[int]], optional): order of values on the x-axis. Defaults to None.
sort_y (Optional[Iterable[int]], optional): order of values on the y-axis. Defaults to None.
fig_size (Optional[tuple], optional): figure size. Defaults to None.
save_path (Optional[Path], optional): path to save figure. Defaults to None.
Returns:
alt.Chart: Altair chart
Union[None, Path]: None if save_path is None, else path to saved figure
"""
if isinstance(plot_type, str):
plot_type = [plot_type]

df = pd.DataFrame({"x": x_values, "y": y_values, "sort": sort_x})
df = pd.DataFrame(
{"x": x_values, "y": y_values, "sort_x": sort_x, "sort_y": sort_y},
)

if sort_x is not None:
x_axis = alt.X(
"x",
axis=alt.Axis(title=x_title),
sort=alt.SortField(field="sort"),
)
else:
x_axis = alt.X("x", axis=alt.Axis(title=x_title))
df = df.sort_values(by=["sort_x"])

if sort_y is not None:
y_axis = alt.Y(
"y",
axis=alt.Axis(title=y_title),
sort=alt.SortField(field="sort"),
)
else:
y_axis = alt.Y("y", axis=alt.Axis(title=y_title))

return (
alt.Chart(df)
.mark_bar()
.encode(
x=x_axis,
y=y_axis,
)
)
df = df.sort_values(by=["sort_y"])

plt.figure(figsize=fig_size)
if "bar" in plot_type:
plt.bar(df["x"], df["y"])
if "hbar" in plot_type:
plt.barh(df["x"], df["y"])
if "line" in plot_type:
plt.plot(df["x"], df["y"])
if "scatter" in plot_type:
plt.scatter(df["x"], df["y"])

plt.xlabel(x_title)
plt.ylabel(y_title)

if save_path is not None:
plt.savefig(save_path)
plt.close()
return save_path
22 changes: 13 additions & 9 deletions src/psycopt2d/visualization/feature_importance.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,20 @@
"""Generate feature importances chart."""

from collections.abc import Iterable
from typing import Union
from pathlib import Path
from typing import Optional, Union

import altair as alt
import numpy as np

from psycopt2d.visualization.base_charts import plot_bar_chart
from psycopt2d.visualization.base_charts import plot_basic_chart


def plot_feature_importances(
column_names: Iterable[str],
feature_importances: Union[list[float], np.ndarray],
top_n_feature_importances: int,
) -> alt.Chart:
save_path: Optional[Path] = None,
) -> Union[None, Path]:
"""Plots feature importances.
Sklearn's standard feature importance metric is "gain"/information gain,
Expand All @@ -26,9 +27,10 @@ def plot_feature_importances(
column_names (Iterable[str]): Column/feature names
feature_importances (Iterable[str]): Feature importances
top_n_feature_importances (int): Top n features to plot
save_path (Optional[Path], optional): Path to save the plot. Defaults to None.
Returns:
alt.Chart: Horizontal barchart of feature importances
Union[None, Path]: Path to the saved plot if save_path is not None, else None
"""

feature_importances = np.array(feature_importances)
Expand All @@ -38,10 +40,12 @@ def plot_feature_importances(
feature_names = np.array(column_names)[sorted_idx][:top_n_feature_importances]
feature_importances = feature_importances[sorted_idx][:top_n_feature_importances]

return plot_bar_chart(
x_values=feature_importances,
y_values=feature_names,
return plot_basic_chart(
x_values=feature_names,
y_values=feature_importances,
x_title="Feature importance (gain)",
y_title="Feature name",
sort_y=np.arange(len(feature_importances)),
sort_x=np.arange(len(feature_importances)),
plot_type="hbar",
save_path=save_path,
)
Loading

0 comments on commit 7e016c1

Please sign in to comment.