Source code for mesa.visualization.space_renderer

"""Space rendering module for Mesa visualizations.

This module provides functionality to render Mesa model spaces with different
backends, supporting various space types and visualization components.
"""

from __future__ import annotations

import warnings
from collections.abc import Callable
from typing import TYPE_CHECKING, Literal

if TYPE_CHECKING:
    from mesa.visualization.components import PropertyLayerStyle

import altair as alt
import numpy as np
import pandas as pd

import mesa
from mesa.discrete_space import (
    OrthogonalMooreGrid,
    OrthogonalVonNeumannGrid,
    VoronoiGrid,
)
from mesa.space import (
    ContinuousSpace,
    HexMultiGrid,
    HexSingleGrid,
    MultiGrid,
    NetworkGrid,
    SingleGrid,
    _HexGrid,
)
from mesa.visualization.backends import AltairBackend, MatplotlibBackend
from mesa.visualization.space_drawers import (
    ContinuousSpaceDrawer,
    HexSpaceDrawer,
    NetworkSpaceDrawer,
    OrthogonalSpaceDrawer,
    VoronoiSpaceDrawer,
)

OrthogonalGrid = SingleGrid | MultiGrid | OrthogonalMooreGrid | OrthogonalVonNeumannGrid
HexGrid = HexSingleGrid | HexMultiGrid | mesa.discrete_space.HexGrid
Network = NetworkGrid | mesa.discrete_space.Network


[docs] class SpaceRenderer: """Renders Mesa spaces using different visualization backends. Supports multiple space types and backends for flexible visualization of agent-based models. """ def __init__( self, model: mesa.Model, backend: Literal["matplotlib", "altair"] | None = "matplotlib", ): """Initialize the space renderer. Args: model (mesa.Model): The Mesa model to render. backend (Literal["matplotlib", "altair"] | None): The visualization backend to use. """ self.space = getattr(model, "grid", getattr(model, "space", None)) self.space_drawer = self._get_space_drawer() self.space_mesh = None self.agent_mesh = None self.propertylayer_mesh = None self.draw_agent_kwargs = {} self.draw_space_kwargs = {} self.agent_portrayal = None self.propertylayer_portrayal = None self.post_process_func = None # Keep track of whether post-processing has been applied # to avoid multiple applications on the same axis. self._post_process_applied = False self.backend = backend if backend == "matplotlib": self.backend_renderer = MatplotlibBackend( self.space_drawer, ) elif backend == "altair": self.backend_renderer = AltairBackend( self.space_drawer, ) else: raise ValueError(f"Unsupported backend: {backend}") self.backend_renderer.initialize_canvas() def _get_space_drawer(self): """Get appropriate space drawer based on space type. Returns: Space drawer instance for the model's space type. Raises: ValueError: If the space type is not supported. """ if isinstance(self.space, HexGrid | _HexGrid): return HexSpaceDrawer(self.space) elif isinstance(self.space, OrthogonalGrid): return OrthogonalSpaceDrawer(self.space) elif isinstance( self.space, ContinuousSpace | mesa.experimental.continuous_space.ContinuousSpace, ): return ContinuousSpaceDrawer(self.space) elif isinstance(self.space, VoronoiGrid): return VoronoiSpaceDrawer(self.space) elif isinstance(self.space, Network): return NetworkSpaceDrawer(self.space) raise ValueError( f"Unsupported space type: {type(self.space).__name__}. " "Supported types are OrthogonalGrid, HexGrid, ContinuousSpace, VoronoiGrid, and Network." ) def _map_coordinates(self, arguments): """Map agent coordinates to appropriate space coordinates. Args: arguments (dict): Dictionary containing agent data with coordinates. Returns: dict: Dictionary with mapped coordinates appropriate for the space type. """ mapped_arguments = arguments.copy() if isinstance(self.space, OrthogonalGrid | VoronoiGrid | ContinuousSpace): # Use the coordinates directly for Orthogonal grids, Voronoi grids and Continuous spaces mapped_arguments["loc"] = arguments["loc"].astype(float) elif isinstance(self.space, HexGrid): # Map rectangular coordinates to hexagonal grid coordinates loc = arguments["loc"].astype(float) if loc.size > 0: # Calculate hexagon centers loc[:, 0] = loc[:, 0] * self.space_drawer.x_spacing + ( (loc[:, 1] - 1) % 2 ) * (self.space_drawer.x_spacing / 2) loc[:, 1] = loc[:, 1] * self.space_drawer.y_spacing mapped_arguments["loc"] = loc elif isinstance(self.space, Network): # Map network node IDs to positions using vectorized dictionary lookup loc = arguments["loc"].astype(float) pos_dict = self.space_drawer.pos node_ids = loc[:, 0].astype(int) # Process unique node IDs once, then broadcast to all agents unique_ids, inverse_indices = np.unique(node_ids, return_inverse=True) unique_positions = np.full((len(unique_ids), 2), np.nan) missing_nodes = [] for idx, node_id in enumerate(unique_ids): if node_id in pos_dict: unique_positions[idx] = pos_dict[node_id] else: missing_nodes.append(node_id) mapped_locs = unique_positions[inverse_indices] # Warn if significant nodes missing (likely layout issue, not race condition) if missing_nodes and len(missing_nodes) > len(pos_dict) / 10: sample = missing_nodes[: min(5, len(missing_nodes))] warnings.warn( f"Many nodes {sample}{'...' if len(missing_nodes) > 5 else ''} not found " f"in position layout ({len(missing_nodes)}/{len(node_ids)} agents). " f"This may indicate the network layout needs to be updated or regenerated.", UserWarning, stacklevel=2, ) mapped_arguments["loc"] = ( mapped_locs if len(mapped_locs) > 0 else mapped_locs.reshape(0, 2) ) return mapped_arguments
[docs] def setup_structure(self, **kwargs) -> SpaceRenderer: """Setup the space structure without drawing. Args: **kwargs: Additional keyword arguments for the setup function. Checkout respective `SpaceDrawer` class on details how to pass **kwargs. Returns: SpaceRenderer: The current instance for method chaining. """ self.draw_space_kwargs = kwargs self.space_mesh = None return self
[docs] def setup_agents(self, agent_portrayal: Callable, **kwargs) -> SpaceRenderer: """Setup agents on the space without drawing. Args: agent_portrayal (Callable): Function that takes an agent and returns AgentPortrayalStyle. **kwargs: Additional keyword arguments for the setup function. Checkout respective `SpaceDrawer` class on details how to pass **kwargs. Returns: SpaceRenderer: The current instance for method chaining. """ self.agent_portrayal = agent_portrayal self.draw_agent_kwargs = kwargs self.agent_mesh = None return self
[docs] def setup_propertylayer( self, propertylayer_portrayal: Callable | dict | PropertyLayerStyle ) -> SpaceRenderer: """Setup property layers on the space without drawing. Args: propertylayer_portrayal (Callable | dict | PropertyLayerStyle): A PropertyLayerStyle, a function that produces a PropertyLayerStyle instance, or a dictionary specifying portrayal parameters. Returns: SpaceRenderer: The current instance for method chaining. """ self.propertylayer_portrayal = propertylayer_portrayal self.propertylayer_mesh = None return self
[docs] def draw_structure(self, **kwargs): """Draw the space structure. Args: **kwargs: (Deprecated) Additional keyword arguments for drawing. Use setup_structure() instead. Returns: The visual representation of the space structure. """ if kwargs: warnings.warn( "Passing kwargs to draw_structure() is deprecated. " "Use setup_structure(**kwargs) before calling draw_structure().", PendingDeprecationWarning, stacklevel=2, ) self.draw_space_kwargs.update(kwargs) self.space_mesh = self.backend_renderer.draw_structure(**self.draw_space_kwargs) return self.space_mesh
[docs] def draw_agents(self, agent_portrayal=None, **kwargs): """Draw agents on the space. Args: agent_portrayal: (Deprecated) Function that takes an agent and returns AgentPortrayalStyle. Use setup_agents() instead. **kwargs: (Deprecated) Additional keyword arguments for drawing. Returns: The visual representation of the agents. """ if agent_portrayal is not None: warnings.warn( "Passing agent_portrayal to draw_agents() is deprecated and will be removed in Mesa 4.0. " "Use setup_agents(agent_portrayal, **kwargs) before calling draw_agents()." "See https://mesa.readthedocs.io/latest/migration_guide.html#passing-portrayal-arguments-to-draw-methods", FutureWarning, stacklevel=2, ) self.agent_portrayal = agent_portrayal if kwargs: warnings.warn( "Passing kwargs to draw_agents() is deprecated. " "Use setup_agents(**kwargs) before calling draw_agents().", PendingDeprecationWarning, stacklevel=2, ) self.draw_agent_kwargs.update(kwargs) # Prepare data for agent plotting arguments = self.backend_renderer.collect_agent_data( self.space, self.agent_portrayal, default_size=self.space_drawer.s_default ) arguments = self._map_coordinates(arguments) self.agent_mesh = self.backend_renderer.draw_agents( arguments, **self.draw_agent_kwargs ) return self.agent_mesh
[docs] def draw_propertylayer(self, propertylayer_portrayal=None): """Draw property layers on the space. Args: propertylayer_portrayal: (Deprecated) A PropertyLayerStyle, a function that produces a PropertyLayerStyle instance, or a dictionary specifying portrayal parameters. Use setup_propertylayer() instead. Returns: The visual representation of the property layers. Raises: Exception: If no property layers are found on the space. """ if propertylayer_portrayal is not None: warnings.warn( "Passing propertylayer_portrayal to draw_propertylayer() is deprecated and will be removed in Mesa 4.0. " "Use setup_propertylayer(propertylayer_portrayal) before calling draw_propertylayer()." "See https://mesa.readthedocs.io/latest/migration_guide.html#passing-portrayal-arguments-to-draw-methods", FutureWarning, stacklevel=2, ) self.propertylayer_portrayal = propertylayer_portrayal # Import here to avoid circular imports from mesa.visualization.components import PropertyLayerStyle # noqa: PLC0415 def _dict_to_callable(portrayal_dict): """Convert legacy dict portrayal to callable. Args: portrayal_dict (dict): Dictionary with portrayal parameters. Returns: Callable: Function that returns PropertyLayerStyle. """ def style_callable(layer_object): layer_name = layer_object.name params = portrayal_dict.get(layer_name) warnings.warn( ( "The propertylayer_portrayal dict is deprecated and will be removed in Mesa 4.0. " "Please use a callable that returns a PropertyLayerStyle instance instead. " "For more information, refer to the migration guide: " "https://mesa.readthedocs.io/latest/migration_guide.html#defining-portrayal-components" ), FutureWarning, stacklevel=2, ) if params is None: return None return PropertyLayerStyle( color=params.get("color"), colormap=params.get("colormap"), alpha=params.get("alpha", PropertyLayerStyle.alpha), vmin=params.get("vmin"), vmax=params.get("vmax"), colorbar=params.get("colorbar", PropertyLayerStyle.colorbar), ) return style_callable # Get property layers try: # old style spaces property_layers = self.space.properties except AttributeError: # new style spaces property_layers = self.space._mesa_property_layers # Convert portrayal to callable if needed if isinstance(self.propertylayer_portrayal, dict): self.propertylayer_portrayal = _dict_to_callable( self.propertylayer_portrayal ) elif isinstance(self.propertylayer_portrayal, PropertyLayerStyle): # Capture the style instance to avoid circular reference style = self.propertylayer_portrayal self.propertylayer_portrayal = lambda _: style # else: already a callable, use as-is number_of_propertylayers = sum( [1 for layer in property_layers if layer != "empty"] ) if number_of_propertylayers < 1: raise Exception("No property layers were found on the space.") self.propertylayer_mesh = self.backend_renderer.draw_propertylayer( self.space, property_layers, self.propertylayer_portrayal ) return self.propertylayer_mesh
[docs] def render(self, agent_portrayal=None, propertylayer_portrayal=None, **kwargs): """Render the complete space with structure, agents, and property layers. Args: agent_portrayal: (Deprecated) Function for agent portrayal. Use setup_agents() instead. propertylayer_portrayal: (Deprecated) Function for property layer portrayal. Use setup_propertylayer() instead. **kwargs: (Deprecated) Additional keyword arguments. """ if agent_portrayal is not None or propertylayer_portrayal is not None or kwargs: warnings.warn( "Passing parameters to render() is deprecated. " "Use setup_structure(), setup_agents(), and setup_propertylayer() before calling render().", PendingDeprecationWarning, stacklevel=2, ) if agent_portrayal is not None: self.agent_portrayal = agent_portrayal if propertylayer_portrayal is not None: self.propertylayer_portrayal = propertylayer_portrayal deprecated_kwargs_map = { "space_kwargs": self.draw_space_kwargs, "agent_kwargs": self.draw_agent_kwargs, } for key, target_dict in deprecated_kwargs_map.items(): if key in kwargs: value = kwargs.pop(key) if isinstance(value, dict): target_dict.update(value) # Update with any remaining kwargs (now that the dangerous ones are removed) self.draw_space_kwargs.update(kwargs) if self.space_mesh is None: self.draw_structure() if self.agent_mesh is None and self.agent_portrayal is not None: self.draw_agents() if self.propertylayer_mesh is None and self.propertylayer_portrayal is not None: self.draw_propertylayer() return self
@property def canvas(self): """Get the current canvas object. Returns: The backend-specific canvas object. """ if self.backend == "matplotlib": ax = self.backend_renderer.ax if ax is None: self.backend_renderer.initialize_canvas() return ax elif self.backend == "altair": structure = self.space_mesh if self.space_mesh else None agents = self.agent_mesh if self.agent_mesh else None prop_base, prop_cbar = self.propertylayer_mesh or (None, None) if self.space_mesh: structure = self.draw_structure() if self.agent_mesh: agents = self.draw_agents() if self.propertylayer_mesh: prop_base, prop_cbar = self.draw_propertylayer() spatial_charts_list = [ chart for chart in [structure, prop_base, agents] if chart ] main_spatial = None if spatial_charts_list: main_spatial = ( spatial_charts_list[0] if len(spatial_charts_list) == 1 else alt.layer(*spatial_charts_list) ) # Determine final chart by combining with color bar if present final_chart = None if main_spatial and prop_cbar: final_chart = alt.vconcat(main_spatial, prop_cbar).configure_view( stroke=None ) elif main_spatial: # Only main_spatial, no prop_cbar final_chart = main_spatial elif prop_cbar: # Only prop_cbar, no main_spatial final_chart = prop_cbar final_chart = final_chart.configure_view(grid=False) if final_chart is None: # If no charts are available, return an empty chart final_chart = ( alt.Chart(pd.DataFrame()) .mark_point() .properties(width=450, height=350) ) final_chart = final_chart.configure_view(stroke="black", strokeWidth=1.5) return final_chart @property def post_process(self): """Get the current post-processing function. Returns: Callable | None: The post-processing function, or None if not set. """ return self.post_process_func @post_process.setter def post_process(self, func: Callable | None): """Set the post-processing function. Args: func (Callable | None): Function to apply post-processing to the canvas. Should accept the canvas object as its first argument. """ self.post_process_func = func