Skip to content

Commit

Permalink
feat: support seaborn bar and count plot (#12)
Browse files Browse the repository at this point in the history
  • Loading branch information
SaaiVenkat authored Feb 20, 2024
1 parent 4b5387e commit fd622bd
Show file tree
Hide file tree
Showing 12 changed files with 957 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ cython_debug/

# Vscode config files
# .vscode/
.vscode/launch.json

# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@
let maidr = {
"type": "bar",
"title": "The Number of Tips by Day",
"selector": "TODO: Enter your bar plot selector here",
"selector": "path[clip-path='url(#p6c5d07f9e0)']",
"axes": {
"x": {
"label": "Day",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def plot():
def main():
bar_plot = plot()
bar_maidr = maidr.bar(bar_plot)
bar_maidr.save(get_filepath("example_bar_plot.html"))
bar_maidr.save(get_filepath("example_mpl_bar_plot.html"))


if __name__ == "__main__":
Expand Down
397 changes: 397 additions & 0 deletions example/bar/seaborn/example_sns_bar_plot.html

Large diffs are not rendered by default.

38 changes: 38 additions & 0 deletions example/bar/seaborn/example_sns_bar_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os

import matplotlib.pyplot as plt
import maidr
import seaborn as sns


def get_filepath(filename: str) -> str:
current_file_path = os.path.abspath(__file__)
directory = os.path.dirname(current_file_path)
return os.path.join(directory, filename)


def plot():
# Load the penguins dataset
penguins = sns.load_dataset("penguins")

# Create a bar plot showing the average body mass of penguins by species
plt.figure(figsize=(10, 6))
b_plot = sns.barplot(
x="species", y="body_mass_g", data=penguins, errorbar="sd", palette="Blues_d"
)
b_plot.bar
plt.title("Average Body Mass of Penguins by Species")
plt.xlabel("Species")
plt.ylabel("Body Mass (g)")

return b_plot


def main():
bar_plot = plot()
bar_maidr = maidr.bar(bar_plot)
bar_maidr.save(get_filepath("example_sns_bar_plot.html"))


if __name__ == "__main__":
main()
364 changes: 364 additions & 0 deletions example/count/seaborn/example_sns_count_plot.html

Large diffs are not rendered by default.

34 changes: 34 additions & 0 deletions example/count/seaborn/example_sns_count_plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os

import matplotlib.pyplot as plt
import maidr
import seaborn as sns


def get_filepath(filename: str) -> str:
current_file_path = os.path.abspath(__file__)
directory = os.path.dirname(current_file_path)
return os.path.join(directory, filename)


def plot():
# Load the Titanic dataset
titanic = sns.load_dataset("titanic")

# Create a countplot
count_plot = sns.countplot(x="class", data=titanic)

# Set the title and show the plot
plt.title("Passenger Class Distribution on the Titanic")

return count_plot


def main():
count_plot = plot()
count_maidr = maidr.count(count_plot)
count_maidr.save(get_filepath("example_sns_count_plot.html"))


if __name__ == "__main__":
main()
4 changes: 2 additions & 2 deletions maidr/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# __version__ will be automatically updated by python-semantic-release
__version__ = "0.0.1"

from .maidr import bar
from .maidr import bar, count

__all__ = ["bar"]
__all__ = ["bar", "count"]
2 changes: 1 addition & 1 deletion maidr/core/maidr.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def __create_html(self) -> str:

def __unflatten_maidr(self) -> dict | list[dict]:
"""
Unflattens the MAIDR data into a dictionary format.
Unflatten the MAIDR data into a dictionary format.
Returns
-------
Expand Down
71 changes: 59 additions & 12 deletions maidr/core/plot/bar_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,22 +109,69 @@ def __extract_data(self) -> list:
Raises
------
TypeError
ExtractionError
If the plot object is incompatible for data extraction.
"""
plot = self.plot
data = None

if isinstance(plot, BarContainer) and isinstance(plot.datavalues, Iterable):
bc_data = []
for value in plot.datavalues:
if isinstance(value, np.integer):
bc_data.append(int(value))
elif isinstance(value, np.floating):
bc_data.append(float(value))
else:
bc_data.append(value)
data = bc_data
else:
if isinstance(plot, Axes):
plot = BarPlotData.__extract_bar_container(plot)
if isinstance(plot, BarContainer):
data = BarPlotData.__extract_bar_container_data(plot)

if data is None:
raise ExtractionError(self.type, self.plot)

return data

@staticmethod
def __extract_bar_container_data(plot: BarContainer) -> list | None:
"""
Extracts numerical data from the specified BarContainer object if possible.
Parameters
----------
plot : BarContainer
The BarContainer from which to extract the data.
Returns
-------
list | None
A list containing the numerical data extracted from the BarContainer, or None
if the plot does not contain valid data values or is not a BarContainer.
"""
if not isinstance(plot.datavalues, Iterable):
return None

data = []
for value in plot.datavalues:
if isinstance(value, np.integer):
data.append(int(value))
elif isinstance(value, np.floating):
data.append(float(value))
else:
data.append(value)
return data

@staticmethod
def __extract_bar_container(plot: Axes) -> BarContainer | None:
"""
Extracts the BarContainer from the given Axes object if possible.
Parameters
----------
plot : Axes
The Axes object to search for a BarContainer.
Returns
-------
BarContainer | None
The first BarContainer found within the given Axes object, or None if no
BarContainer is present.
"""
# TODO
# If the Axes contains multiple bar plots, track and extract the correct one.
for container in plot.containers:
if isinstance(container, BarContainer):
return container
53 changes: 50 additions & 3 deletions maidr/maidr.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,21 @@
from __future__ import annotations

from matplotlib.axes import Axes
from matplotlib.container import BarContainer

from maidr.core.enum.plot_type import PlotType
from maidr.core.maidr import Maidr
from maidr.utils.figure_manager import FigureManager


def bar(plot: BarContainer) -> Maidr:
def bar(plot: Axes | BarContainer) -> Maidr:
"""
Create a Maidr object for a bar plot.
Parameters
----------
plot : BarContainer
The bar plot for which a Maidr object is to be created.
plot : Axes | BarContainer
The bar plot for which a Maidr object is to be created.
Returns
-------
Expand Down Expand Up @@ -46,5 +47,51 @@ def bar(plot: BarContainer) -> Maidr:
return FigureManager.create_maidr(fig, plot, plot_type)


def count(plot: Axes | BarContainer) -> Maidr:
"""
Create a Maidr object for a count plot.
Parameters
----------
plot : Axes | BarContainer
The count plot for which a Maidr object is to be created.
Returns
-------
Maidr
The created Maidr object representing the count plot.
Raises
------
ValueError
If the input `plot` is missing the `matplotlib.figure.Figure` and
`matplotlib.figure.Axes`.
TypeError
If the input `plot` is not a valid count plot.
See Also
--------
Maidr : The core class encapsulating the plot with its MAIDR structure.
bar : Function to create a Maidr object for matplotlib bar plots, usable as an
alternative to `count()`.
Note
----
Since a count plot is a specific case of a bar plot, this function internally uses
the `bar()` function to process the plot. The `count()` function is provided as a
convenience to align with the `seaborn.countplot()` method.
Examples
--------
>>> import seaborn as sns
>>> import maidr
>>> data = sns.load_dataset("titanic") # Load the dataset
>>> count_plot = sns.countplot(x="class", data=data) # Generate a count plot
>>> count_maidr = maidr.count(count_plot) # Convert the plot to a Maidr object
>>> count_maidr.save("maidr_count_plot.html") # Save the plot to an HTML file
"""
return bar(plot)


def close() -> None:
pass
13 changes: 9 additions & 4 deletions maidr/utils/figure_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from typing import Any

from matplotlib.axes import Axes
from matplotlib.container import BarContainer
from matplotlib.figure import Figure

Expand Down Expand Up @@ -76,13 +77,13 @@ def create_maidr(fig: Figure | None, plot: Any, plot_type: list[PlotType]) -> Ma
return Maidr(fig, maidr_data)

@staticmethod
def get_figure(artist: BarContainer | None) -> Figure | None:
def get_figure(artist: Axes | BarContainer | None) -> Figure | None:
"""
Retrieves the `Figure` object associated with a given matplotlib `Artist`.
Parameters
----------
artist : BarContainer | None
artist : Axes | BarContainer | None
The artist for which to retrieve the figure.
Returns
Expand All @@ -95,8 +96,12 @@ def get_figure(artist: BarContainer | None) -> Figure | None:
return None

fig = None
# bar container - get figure from the first occurrence of any artist
if isinstance(artist, BarContainer):
# axes - get figure from the artist
if isinstance(artist, Axes):
fig = artist.get_figure()

# bar container - get figure from the first occurrence of any child artist
elif isinstance(artist, BarContainer):
fig = next(
(
child_artist.get_figure()
Expand Down

0 comments on commit fd622bd

Please sign in to comment.