"""Space rendering module for Mesa visualizations.
This module provides functionality to render Mesa model spaces with different
backends, supporting various space types and visualization components.
"""
from __future__ import annotations
import warnings
from collections.abc import Callable
from typing import TYPE_CHECKING, Literal
if TYPE_CHECKING:
from mesa.visualization.components import PropertyLayerStyle
import altair as alt
import pandas as pd
import mesa
from mesa.discrete_space import (
OrthogonalMooreGrid,
OrthogonalVonNeumannGrid,
VoronoiGrid,
)
from mesa.experimental.continuous_space import ContinuousSpace
from mesa.visualization.backends import AltairBackend, MatplotlibBackend
from mesa.visualization.space_drawers import (
ContinuousSpaceDrawer,
HexSpaceDrawer,
NetworkSpaceDrawer,
OrthogonalSpaceDrawer,
VoronoiSpaceDrawer,
)
OrthogonalGrid = OrthogonalMooreGrid | OrthogonalVonNeumannGrid
HexGrid = mesa.discrete_space.HexGrid
Network = mesa.discrete_space.Network
[docs]
class SpaceRenderer:
"""Renders Mesa spaces using different visualization backends.
Supports multiple space types and backends for flexible visualization
of agent-based models.
"""
def __init__(
self,
model: mesa.Model,
backend: Literal["matplotlib", "altair"] | None = "matplotlib",
):
"""Initialize the space renderer.
Args:
model (mesa.Model): The Mesa model to render.
backend (Literal["matplotlib", "altair"] | None): The visualization backend to use.
"""
self.space = getattr(model, "grid", getattr(model, "space", None))
self.space_drawer = self._get_space_drawer()
self.space_mesh = None
self.agent_mesh = None
self.property_layer_mesh = None
self.draw_agent_kwargs = {}
self.draw_space_kwargs = {}
self.agent_portrayal = None
self.property_layer_portrayal = None
self.post_process_func = None
# Keep track of whether post-processing has been applied
# to avoid multiple applications on the same axis.
self._post_process_applied = False
self.backend = backend
if backend == "matplotlib":
self.backend_renderer = MatplotlibBackend(
self.space_drawer,
)
elif backend == "altair":
self.backend_renderer = AltairBackend(
self.space_drawer,
)
else:
raise ValueError(f"Unsupported backend: {backend}")
self.backend_renderer.initialize_canvas()
def _get_space_drawer(self):
"""Get appropriate space drawer based on space type.
Returns:
Space drawer instance for the model's space type.
Raises:
ValueError: If the space type is not supported.
"""
if isinstance(self.space, HexGrid):
return HexSpaceDrawer(self.space)
elif isinstance(self.space, OrthogonalGrid):
return OrthogonalSpaceDrawer(self.space)
elif isinstance(
self.space,
mesa.experimental.continuous_space.ContinuousSpace,
):
return ContinuousSpaceDrawer(self.space)
elif isinstance(self.space, VoronoiGrid):
return VoronoiSpaceDrawer(self.space)
elif isinstance(self.space, Network):
return NetworkSpaceDrawer(self.space)
raise ValueError(
f"Unsupported space type: {type(self.space).__name__}. "
"Supported types are OrthogonalGrid, HexGrid, ContinuousSpace, VoronoiGrid, and Network."
)
def _map_coordinates(self, arguments):
"""Map agent coordinates to appropriate space coordinates.
Args:
arguments (dict): Dictionary containing agent data with coordinates.
Returns:
dict: Dictionary with mapped coordinates appropriate for the space type.
"""
mapped_arguments = arguments.copy()
if isinstance(
self.space, OrthogonalGrid | VoronoiGrid | ContinuousSpace | Network
):
# Use the coordinates directly for Orthogonal grids, Voronoi grids and Continuous spaces
mapped_arguments["loc"] = arguments["loc"].astype(float)
elif isinstance(self.space, HexGrid):
# Map rectangular coordinates to hexagonal grid coordinates
loc = arguments["loc"].astype(float)
if loc.size > 0:
# Calculate hexagon centers
loc[:, 0] = loc[:, 0] * self.space_drawer.x_spacing + (
(loc[:, 1] - 1) % 2
) * (self.space_drawer.x_spacing / 2)
loc[:, 1] = loc[:, 1] * self.space_drawer.y_spacing
mapped_arguments["loc"] = loc
return mapped_arguments
[docs]
def setup_structure(self, **kwargs) -> SpaceRenderer:
"""Setup the space structure without drawing.
Args:
**kwargs: Additional keyword arguments for the setup function. For ContinuousSpace,
you may pass ``viz_dims=(i, j)`` to select which two dimensions are projected
to x/y.
Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
Returns:
SpaceRenderer: The current instance for method chaining.
"""
self.draw_space_kwargs = kwargs
self.space_mesh = None
return self
[docs]
def setup_agents(self, agent_portrayal: Callable, **kwargs) -> SpaceRenderer:
"""Setup agents on the space without drawing.
Args:
agent_portrayal (Callable): Function that takes an agent and returns AgentPortrayalStyle.
**kwargs: Additional keyword arguments for the setup function.
Checkout respective `SpaceDrawer` class on details how to pass **kwargs.
Returns:
SpaceRenderer: The current instance for method chaining.
"""
self.agent_portrayal = agent_portrayal
self.draw_agent_kwargs = kwargs
self.agent_mesh = None
return self
[docs]
def setup_property_layer(
self, property_layer_portrayal: Callable | dict | PropertyLayerStyle
) -> SpaceRenderer:
"""Setup property layers on the space without drawing.
Args:
property_layer_portrayal (Callable | dict | PropertyLayerStyle): A PropertyLayerStyle,
a function that produces a PropertyLayerStyle instance, or a dictionary specifying portrayal parameters.
Returns:
SpaceRenderer: The current instance for method chaining.
"""
self.property_layer_portrayal = property_layer_portrayal
self.property_layer_mesh = None
return self
[docs]
def draw_structure(self, **kwargs):
"""Draw the space structure.
Args:
**kwargs: (Deprecated) Additional keyword arguments for drawing.
Use setup_structure() instead.
Returns:
The visual representation of the space structure.
"""
if kwargs:
warnings.warn(
"Passing kwargs to draw_structure() is deprecated. "
"Use setup_structure(**kwargs) before calling draw_structure().",
PendingDeprecationWarning,
stacklevel=2,
)
self.draw_space_kwargs.update(kwargs)
# Network-specific: the space instance is replaced on every model reset.
# If the drawer still references the old space its layout positions belong
# to the previous (now stale) graph. Rebuild before drawing edges so that
# the structure is always consistent with the current space.
if (
isinstance(self.space, Network)
and self.space_drawer.space is not self.space
):
self.space_drawer = self._get_space_drawer()
self.backend_renderer.space_drawer = self.space_drawer
self.space_mesh = self.backend_renderer.draw_structure(**self.draw_space_kwargs)
return self.space_mesh
[docs]
def draw_agents(self, agent_portrayal=None, **kwargs):
"""Draw agents on the space.
Args:
agent_portrayal: (Deprecated) Function that takes an agent and returns AgentPortrayalStyle.
Use setup_agents() instead.
**kwargs: (Deprecated) Additional keyword arguments for drawing.
Returns:
The visual representation of the agents.
"""
if agent_portrayal is not None:
warnings.warn(
"Passing agent_portrayal to draw_agents() is deprecated and will be removed in Mesa 4.0. "
"Use setup_agents(agent_portrayal, **kwargs) before calling draw_agents()."
"See https://mesa.readthedocs.io/latest/migration_guide.html#passing-portrayal-arguments-to-draw-methods",
FutureWarning,
stacklevel=2,
)
self.agent_portrayal = agent_portrayal
if kwargs:
warnings.warn(
"Passing kwargs to draw_agents() is deprecated. "
"Use setup_agents(**kwargs) before calling draw_agents().",
PendingDeprecationWarning,
stacklevel=2,
)
self.draw_agent_kwargs.update(kwargs)
# Prepare data for agent plotting
arguments = self.backend_renderer.collect_agent_data(
self.space, self.agent_portrayal, default_size=self.space_drawer.s_default
)
arguments = self._map_coordinates(arguments)
self.agent_mesh = self.backend_renderer.draw_agents(
arguments, **self.draw_agent_kwargs
)
return self.agent_mesh
[docs]
def draw_property_layer(self, property_layer_portrayal=None):
"""Draw property layers on the space.
Args:
property_layer_portrayal: (Deprecated) A PropertyLayerStyle, a function that produces
a PropertyLayerStyle instance, or a dictionary specifying portrayal parameters.
Use setup_property_layer() instead.
Returns:
The visual representation of the property layers.
Raises:
Exception: If no property layers are found on the space.
"""
if property_layer_portrayal is not None:
warnings.warn(
"Passing property_layer_portrayal to draw_property_layer() is deprecated and will be removed in Mesa 4.0. "
"Use setup_property_layer(property_layer_portrayal) before calling draw_property_layer()."
"See https://mesa.readthedocs.io/latest/migration_guide.html#passing-portrayal-arguments-to-draw-methods",
FutureWarning,
stacklevel=2,
)
self.property_layer_portrayal = property_layer_portrayal
# Import here to avoid circular imports
from mesa.visualization.components import PropertyLayerStyle # noqa: PLC0415
def _dict_to_callable(portrayal_dict):
"""Convert legacy dict portrayal to callable.
Args:
portrayal_dict (dict): Dictionary with portrayal parameters.
Returns:
Callable: Function that returns PropertyLayerStyle.
"""
def style_callable(layer_object):
layer_name = layer_object
params = portrayal_dict.get(layer_name)
warnings.warn(
(
"The property_layer_portrayal dict is deprecated and will be removed in Mesa 4.0. "
"Please use a callable that returns a PropertyLayerStyle instance instead. "
"For more information, refer to the migration guide: "
"https://mesa.readthedocs.io/latest/migration_guide.html#defining-portrayal-components"
),
FutureWarning,
stacklevel=2,
)
if params is None:
return None
return PropertyLayerStyle(
color=params.get("color"),
colormap=params.get("colormap"),
alpha=params.get("alpha", PropertyLayerStyle.alpha),
vmin=params.get("vmin"),
vmax=params.get("vmax"),
colorbar=params.get("colorbar", PropertyLayerStyle.colorbar),
)
return style_callable
property_layers = self.space.property_layers
# Convert portrayal to callable if needed
if isinstance(self.property_layer_portrayal, dict):
self.property_layer_portrayal = _dict_to_callable(
self.property_layer_portrayal
)
elif isinstance(self.property_layer_portrayal, PropertyLayerStyle):
# Capture the style instance to avoid circular reference
style = self.property_layer_portrayal
self.property_layer_portrayal = lambda _: style
# else: already a callable, use as-is
number_of_props = sum([1 for layer in property_layers if layer != "empty"])
if number_of_props < 1:
raise Exception("No property layers were found on the space.")
self.property_layer_mesh = self.backend_renderer.draw_property_layer(
self.space, property_layers, self.property_layer_portrayal
)
return self.property_layer_mesh
[docs]
def render(self, agent_portrayal=None, property_layer_portrayal=None, **kwargs):
"""Render the complete space with structure, agents, and property layers.
Args:
agent_portrayal: (Deprecated) Function for agent portrayal. Use setup_agents() instead.
property_layer_portrayal: (Deprecated) Function for property layer portrayal. Use setup_property_layer() instead.
**kwargs: (Deprecated) Additional keyword arguments.
"""
if (
agent_portrayal is not None
or property_layer_portrayal is not None
or kwargs
):
warnings.warn(
"Passing parameters to render() is deprecated. "
"Use setup_structure(), setup_agents(), and setup_property_layer() before calling render().",
PendingDeprecationWarning,
stacklevel=2,
)
if agent_portrayal is not None:
self.agent_portrayal = agent_portrayal
if property_layer_portrayal is not None:
self.property_layer_portrayal = property_layer_portrayal
deprecated_kwargs_map = {
"space_kwargs": self.draw_space_kwargs,
"agent_kwargs": self.draw_agent_kwargs,
}
for key, target_dict in deprecated_kwargs_map.items():
if key in kwargs:
value = kwargs.pop(key)
if isinstance(value, dict):
target_dict.update(value)
# Update with any remaining kwargs (now that the dangerous ones are removed)
self.draw_space_kwargs.update(kwargs)
if self.space_mesh is None:
self.draw_structure()
if self.agent_mesh is None and self.agent_portrayal is not None:
self.draw_agents()
if (
self.property_layer_mesh is None
and self.property_layer_portrayal is not None
):
self.draw_property_layer()
return self
@property
def canvas(self):
"""Get the current canvas object.
Returns:
The backend-specific canvas object.
"""
if self.backend == "matplotlib":
ax = self.backend_renderer.ax
if ax is None:
self.backend_renderer.initialize_canvas()
return ax
elif self.backend == "altair":
structure = self.space_mesh if self.space_mesh else None
agents = self.agent_mesh if self.agent_mesh else None
prop_base, prop_cbar = self.property_layer_mesh or (None, None)
if self.space_mesh:
structure = self.draw_structure()
if self.agent_mesh:
agents = self.draw_agents()
if self.property_layer_mesh:
prop_base, prop_cbar = self.draw_property_layer()
spatial_charts_list = [
chart for chart in [structure, prop_base, agents] if chart
]
main_spatial = None
if spatial_charts_list:
main_spatial = (
spatial_charts_list[0]
if len(spatial_charts_list) == 1
else alt.layer(*spatial_charts_list)
)
# Determine final chart by combining with color bar if present
final_chart = None
if main_spatial and prop_cbar:
final_chart = alt.vconcat(main_spatial, prop_cbar).configure_view(
stroke=None
)
elif main_spatial: # Only main_spatial, no prop_cbar
final_chart = main_spatial
elif prop_cbar: # Only prop_cbar, no main_spatial
final_chart = prop_cbar
final_chart = final_chart.configure_view(grid=False)
if final_chart is None:
# If no charts are available, return an empty chart
final_chart = (
alt.Chart(pd.DataFrame())
.mark_point()
.properties(width=450, height=350)
)
final_chart = final_chart.configure_view(stroke="black", strokeWidth=1.5)
return final_chart
@property
def post_process(self):
"""Get the current post-processing function.
Returns:
Callable | None: The post-processing function, or None if not set.
"""
return self.post_process_func
@post_process.setter
def post_process(self, func: Callable | None):
"""Set the post-processing function.
Args:
func (Callable | None): Function to apply post-processing to the canvas.
Should accept the canvas object as its first argument.
"""
self.post_process_func = func