Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

expand ax.scatter kwargs that can be used #2445

Merged
merged 9 commits into from
Oct 31, 2024
16 changes: 8 additions & 8 deletions docs/overview.md
Original file line number Diff line number Diff line change
Expand Up @@ -177,20 +177,20 @@ def agent_portrayal(agent):

model_params = {
"N": {
"type": "SliderInt",
"value": 50,
"label": "Number of agents:",
"min": 10,
"max": 100,
"step": 1,
"type": "SliderInt",
"value": 50,
"label": "Number of agents:",
"min": 10,
"max": 100,
"step": 1,
}
}

page = SolaraViz(
MyModel,
[
make_space_component(agent_portrayal),
make_plot_component("mean_age")
make_space_component(agent_portrayal),
make_plot_component("mean_age")
],
model_params=model_params
)
Expand Down
24 changes: 12 additions & 12 deletions mesa/examples/basic/boltzmann_wealth_model/app.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,10 @@
from mesa.examples.basic.boltzmann_wealth_model.model import BoltzmannWealthModel
from mesa.visualization import (
SolaraViz,
make_plot_component,
make_space_component,
)
from mesa.visualization import SolaraViz, make_plot_component, make_space_component

Check warning on line 2 in mesa/examples/basic/boltzmann_wealth_model/app.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/boltzmann_wealth_model/app.py#L2

Added line #L2 was not covered by tests


def agent_portrayal(agent):
size = 10
color = "tab:red"
if agent.wealth > 0:
size = 50
color = "tab:blue"
return {"size": size, "color": color}
color = agent.wealth # we are using a colormap to translate wealth to color
return {"color": color}

Check warning on line 7 in mesa/examples/basic/boltzmann_wealth_model/app.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/boltzmann_wealth_model/app.py#L6-L7

Added lines #L6 - L7 were not covered by tests


model_params = {
Expand All @@ -28,6 +20,11 @@
"height": 10,
}


def post_process(ax):
ax.get_figure().colorbar(ax.collections[0], label="wealth", ax=ax)

Check warning on line 25 in mesa/examples/basic/boltzmann_wealth_model/app.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/boltzmann_wealth_model/app.py#L24-L25

Added lines #L24 - L25 were not covered by tests


# Create initial model instance
model1 = BoltzmannWealthModel(50, 10, 10)

Expand All @@ -36,7 +33,10 @@
# Under the hood these are just classes that receive the model instance.
# You can also author your own visualization elements, which can also be functions
# that receive the model instance and return a valid solara component.
SpaceGraph = make_space_component(agent_portrayal)

SpaceGraph = make_space_component(

Check warning on line 37 in mesa/examples/basic/boltzmann_wealth_model/app.py

View check run for this annotation

Codecov / codecov/patch

mesa/examples/basic/boltzmann_wealth_model/app.py#L37

Added line #L37 was not covered by tests
agent_portrayal, cmap="viridis", vmin=0, vmax=10, post_process=post_process
)
GiniPlot = make_plot_component("Gini")

# Create the SolaraViz page. This will automatically create a server and display the
Expand Down
34 changes: 25 additions & 9 deletions mesa/visualization/components/matplotlib.py
Original file line number Diff line number Diff line change
Expand Up @@ -309,6 +309,7 @@ def draw_orthogonal_grid(
agent_portrayal: Callable,
ax: Axes | None = None,
draw_grid: bool = True,
**kwargs,
):
"""Visualize a orthogonal grid.

Expand All @@ -317,6 +318,7 @@ def draw_orthogonal_grid(
agent_portrayal: a callable that is called with the agent and returns a dict
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
draw_grid: whether to draw the grid
kwargs: additional keyword arguments passed to ax.scatter

Returns:
Returns the Axes object with the plot drawn onto it.
Expand All @@ -333,7 +335,7 @@ def draw_orthogonal_grid(
arguments = collect_agent_data(space, agent_portrayal, size=s_default)

# plot the agents
_scatter(ax, arguments)
_scatter(ax, arguments, **kwargs)

# further styling
ax.set_xlim(-0.5, space.width - 0.5)
Expand All @@ -354,6 +356,7 @@ def draw_hex_grid(
agent_portrayal: Callable,
ax: Axes | None = None,
draw_grid: bool = True,
**kwargs,
):
"""Visualize a hex grid.

Expand All @@ -362,6 +365,7 @@ def draw_hex_grid(
agent_portrayal: a callable that is called with the agent and returns a dict
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
draw_grid: whether to draw the grid
kwargs: additional keyword arguments passed to ax.scatter

Returns:
Returns the Axes object with the plot drawn onto it.
Expand Down Expand Up @@ -394,7 +398,7 @@ def draw_hex_grid(
arguments["loc"] = loc

# plot the agents
_scatter(ax, arguments)
_scatter(ax, arguments, **kwargs)

# further styling and adding of grid
ax.set_xlim(-1, space.width + 0.5)
Expand Down Expand Up @@ -443,6 +447,7 @@ def draw_network(
draw_grid: bool = True,
layout_alg=nx.spring_layout,
layout_kwargs=None,
**kwargs,
):
"""Visualize a network space.

Expand All @@ -453,6 +458,7 @@ def draw_network(
draw_grid: whether to draw the grid
layout_alg: a networkx layout algorithm or other callable with the same behavior
layout_kwargs: a dictionary of keyword arguments for the layout algorithm
kwargs: additional keyword arguments passed to ax.scatter

Returns:
Returns the Axes object with the plot drawn onto it.
Expand Down Expand Up @@ -488,7 +494,7 @@ def draw_network(
arguments["loc"] = pos[arguments["loc"]]

# plot the agents
_scatter(ax, arguments)
_scatter(ax, arguments, **kwargs)

# further styling
ax.set_axis_off()
Expand All @@ -506,14 +512,15 @@ def draw_network(


def draw_continuous_space(
space: ContinuousSpace, agent_portrayal: Callable, ax: Axes | None = None
space: ContinuousSpace, agent_portrayal: Callable, ax: Axes | None = None, **kwargs
):
"""Visualize a continuous space.

Args:
space: the space to visualize
agent_portrayal: a callable that is called with the agent and returns a dict
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
kwargs: additional keyword arguments passed to ax.scatter

Returns:
Returns the Axes object with the plot drawn onto it.
Expand All @@ -536,7 +543,7 @@ def draw_continuous_space(
arguments = collect_agent_data(space, agent_portrayal, size=s_default)

# plot the agents
_scatter(ax, arguments)
_scatter(ax, arguments, **kwargs)

# further visual styling
border_style = "solid" if not space.torus else (0, (5, 10))
Expand All @@ -552,14 +559,15 @@ def draw_continuous_space(


def draw_voroinoi_grid(
space: VoronoiGrid, agent_portrayal: Callable, ax: Axes | None = None
space: VoronoiGrid, agent_portrayal: Callable, ax: Axes | None = None, **kwargs
):
"""Visualize a voronoi grid.

Args:
space: the space to visualize
agent_portrayal: a callable that is called with the agent and returns a dict
ax: a Matplotlib Axes instance. If none is provided a new figure and ax will be created using plt.subplots
kwargs: additional keyword arguments passed to ax.scatter

Returns:
Returns the Axes object with the plot drawn onto it.
Expand Down Expand Up @@ -589,7 +597,7 @@ def draw_voroinoi_grid(
ax.set_xlim(x_min - x_padding, x_max + x_padding)
ax.set_ylim(y_min - y_padding, y_max + y_padding)

_scatter(ax, arguments)
_scatter(ax, arguments, **kwargs)

for cell in space.all_cells:
polygon = cell.properties["polygon"]
Expand All @@ -604,8 +612,15 @@ def draw_voroinoi_grid(
return ax


def _scatter(ax: Axes, arguments):
"""Helper function for plotting the agents."""
def _scatter(ax: Axes, arguments, **kwargs):
"""Helper function for plotting the agents.

Args:
ax: a Matplotlib Axes instance
arguments: the agents specific arguments for platting
kwargs: additional keyword arguments for ax.scatter

"""
loc = arguments.pop("loc")

x = loc[:, 0]
Expand All @@ -624,6 +639,7 @@ def _scatter(ax: Axes, arguments):
marker=mark,
zorder=z_order,
**{k: v[logical] for k, v in arguments.items()},
**kwargs,
)


Expand Down
Loading