"""Matplotlib based solara components for visualization MESA spaces and plots."""
from __future__ import annotations
import warnings
from collections.abc import Callable
import matplotlib.pyplot as plt
import solara
from matplotlib.figure import Figure
from mesa.visualization.mpl_space_drawing import draw_space
from mesa.visualization.utils import update_counter
[docs]
def make_space_matplotlib(*args, **kwargs): # noqa: D103
warnings.warn(
"make_space_matplotlib has been renamed to make_mpl_space_component",
DeprecationWarning,
stacklevel=2,
)
return make_mpl_space_component(*args, **kwargs)
[docs]
def make_mpl_space_component(
agent_portrayal: Callable | None = None,
propertylayer_portrayal: dict | None = None,
post_process: Callable | None = None,
**space_drawing_kwargs,
) -> SpaceMatplotlib:
"""Create a Matplotlib-based space visualization component.
Args:
agent_portrayal: Function to portray agents.
propertylayer_portrayal: Dictionary of PropertyLayer portrayal specifications
post_process : a callable that will be called with the Axes instance. Allows for fine tuning plots (e.g., control ticks)
space_drawing_kwargs : additional keyword arguments to be passed on to the underlying space drawer function. See
the functions for drawing the various spaces for further details.
``agent_portrayal`` is called with an agent and should return a dict. Valid fields in this dict are "color",
"size", "marker", "zorder", alpha, linewidths, and edgecolors. Other field are ignored and will result in a user warning.
Returns:
function: A function that creates a SpaceMatplotlib component
"""
if agent_portrayal is None:
def agent_portrayal(a):
return {}
def MakeSpaceMatplotlib(model):
return SpaceMatplotlib(
model,
agent_portrayal,
propertylayer_portrayal,
post_process=post_process,
**space_drawing_kwargs,
)
return MakeSpaceMatplotlib
@solara.component
def SpaceMatplotlib(
model,
agent_portrayal,
propertylayer_portrayal,
dependencies: list[any] | None = None,
post_process: Callable | None = None,
**space_drawing_kwargs,
):
"""Create a Matplotlib-based space visualization component."""
update_counter.get()
space = getattr(model, "grid", None)
if space is None:
space = getattr(model, "space", None)
fig = Figure()
ax = fig.add_subplot()
draw_space(
space,
agent_portrayal,
propertylayer_portrayal=propertylayer_portrayal,
ax=ax,
**space_drawing_kwargs,
)
if post_process is not None:
post_process(ax)
solara.FigureMatplotlib(
fig, format="png", bbox_inches="tight", dependencies=dependencies
)
[docs]
def make_plot_measure(*args, **kwargs): # noqa: D103
warnings.warn(
"make_plot_measure has been renamed to make_plot_component",
DeprecationWarning,
stacklevel=2,
)
return make_mpl_plot_component(*args, **kwargs)
[docs]
def make_mpl_plot_component(
measure: str | dict[str, str] | list[str] | tuple[str],
post_process: Callable | None = None,
save_format="png",
):
"""Create a plotting function for a specified measure.
Args:
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
post_process: a user-specified callable to do post-processing called with the Axes instance.
save_format: save format of figure in solara backend
Returns:
function: A function that creates a PlotMatplotlib component.
"""
def MakePlotMatplotlib(model):
return PlotMatplotlib(
model, measure, post_process=post_process, save_format=save_format
)
return MakePlotMatplotlib
@solara.component
def PlotMatplotlib(
model,
measure,
dependencies: list[any] | None = None,
post_process: Callable | None = None,
save_format="png",
):
"""Create a Matplotlib-based plot for a measure or measures.
Args:
model (mesa.Model): The model instance.
measure (str | dict[str, str] | list[str] | tuple[str]): Measure(s) to plot.
dependencies (list[any] | None): Optional dependencies for the plot.
post_process: a user-specified callable to do post-processing called with the Axes instance.
save_format: format used for saving the figure.
Returns:
solara.FigureMatplotlib: A component for rendering the plot.
"""
update_counter.get()
fig = Figure()
ax = fig.subplots()
df = model.datacollector.get_model_vars_dataframe()
if isinstance(measure, str):
ax.plot(df.loc[:, measure])
ax.set_ylabel(measure)
elif isinstance(measure, dict):
for m, color in measure.items():
ax.plot(df.loc[:, m], label=m, color=color)
ax.legend(loc="best")
elif isinstance(measure, list | tuple):
for m in measure:
ax.plot(df.loc[:, m], label=m)
ax.legend(loc="best")
if post_process is not None:
post_process(ax)
ax.set_xlabel("Step")
# Set integer x axis
ax.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
solara.FigureMatplotlib(
fig, format=save_format, bbox_inches="tight", dependencies=dependencies
)