This repository has been archived by the owner on May 1, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #220 from Aarhus-Psychiatry-Research/HLasse/Plot-b…
…ar-chart-in-matplotlib Migrate from Altair to matplotlib
- Loading branch information
Showing
12 changed files
with
427 additions
and
241 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,58 +1,68 @@ | ||
"""Base charts.""" | ||
from collections.abc import Iterable | ||
from typing import Optional | ||
from pathlib import Path | ||
from typing import Optional, Union | ||
|
||
import altair as alt | ||
import matplotlib.pyplot as plt | ||
import pandas as pd | ||
|
||
|
||
def plot_bar_chart( | ||
def plot_basic_chart( | ||
x_values: Iterable, | ||
y_values: Iterable, | ||
x_title: str, | ||
y_title: str, | ||
plot_type: Optional[Union[list[str], str]], | ||
sort_x: Optional[Iterable[int]] = None, | ||
sort_y: Optional[Iterable[int]] = None, | ||
) -> alt.Chart: | ||
"""Plot a basic bar chart with Altair. | ||
fig_size: Optional[tuple] = (10, 10), | ||
save_path: Optional[Path] = None, | ||
) -> Union[None, Path]: | ||
"""Plot a simple chart using matplotlib. Options for sorting the x and y | ||
axis are available. | ||
Args: | ||
x_values (Iterable): The x values of the bar chart. | ||
y_values (Iterable): The y values of the bar chart. | ||
x_title (str): title of x axis | ||
y_title (str): title of y axis | ||
plot_type (Optional[Union[List[str], str]], optional): type of plots. | ||
Options are combinations of ["bar", "hbar", "line", "scatter"] Defaults to "bar". | ||
sort_x (Optional[Iterable[int]], optional): order of values on the x-axis. Defaults to None. | ||
sort_y (Optional[Iterable[int]], optional): order of values on the y-axis. Defaults to None. | ||
fig_size (Optional[tuple], optional): figure size. Defaults to None. | ||
save_path (Optional[Path], optional): path to save figure. Defaults to None. | ||
Returns: | ||
alt.Chart: Altair chart | ||
Union[None, Path]: None if save_path is None, else path to saved figure | ||
""" | ||
if isinstance(plot_type, str): | ||
plot_type = [plot_type] | ||
|
||
df = pd.DataFrame({"x": x_values, "y": y_values, "sort": sort_x}) | ||
df = pd.DataFrame( | ||
{"x": x_values, "y": y_values, "sort_x": sort_x, "sort_y": sort_y}, | ||
) | ||
|
||
if sort_x is not None: | ||
x_axis = alt.X( | ||
"x", | ||
axis=alt.Axis(title=x_title), | ||
sort=alt.SortField(field="sort"), | ||
) | ||
else: | ||
x_axis = alt.X("x", axis=alt.Axis(title=x_title)) | ||
df = df.sort_values(by=["sort_x"]) | ||
|
||
if sort_y is not None: | ||
y_axis = alt.Y( | ||
"y", | ||
axis=alt.Axis(title=y_title), | ||
sort=alt.SortField(field="sort"), | ||
) | ||
else: | ||
y_axis = alt.Y("y", axis=alt.Axis(title=y_title)) | ||
|
||
return ( | ||
alt.Chart(df) | ||
.mark_bar() | ||
.encode( | ||
x=x_axis, | ||
y=y_axis, | ||
) | ||
) | ||
df = df.sort_values(by=["sort_y"]) | ||
|
||
plt.figure(figsize=fig_size) | ||
if "bar" in plot_type: | ||
plt.bar(df["x"], df["y"]) | ||
if "hbar" in plot_type: | ||
plt.barh(df["x"], df["y"]) | ||
if "line" in plot_type: | ||
plt.plot(df["x"], df["y"]) | ||
if "scatter" in plot_type: | ||
plt.scatter(df["x"], df["y"]) | ||
|
||
plt.xlabel(x_title) | ||
plt.ylabel(y_title) | ||
|
||
if save_path is not None: | ||
plt.savefig(save_path) | ||
plt.close() | ||
return save_path |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.