Source code for mesa.visualization.backends.altair_backend

# noqa: D100
import warnings
from collections.abc import Callable
from dataclasses import fields

import altair as alt
import numpy as np
import pandas as pd
from matplotlib.colors import to_rgb

import mesa
from mesa.discrete_space import (
    OrthogonalMooreGrid,
    OrthogonalVonNeumannGrid,
)
from mesa.visualization.backends.abstract_renderer import AbstractRenderer

OrthogonalGrid = OrthogonalMooreGrid | OrthogonalVonNeumannGrid
HexGrid = mesa.discrete_space.HexGrid
Network = mesa.discrete_space.Network


[docs] class AltairBackend(AbstractRenderer): """Altair-based renderer for Mesa spaces. This module provides an Altair-based renderer for visualizing Mesa model spaces, agents, and property layers with interactive charting capabilities. """
[docs] def initialize_canvas(self) -> None: """Initialize the Altair canvas.""" self._canvas = None
[docs] def draw_structure(self, **kwargs) -> alt.Chart: """Draw the space structure using Altair. Args: **kwargs: Additional arguments passed to the space drawer. Checkout respective `SpaceDrawer` class on details how to pass **kwargs. Returns: alt.Chart: The Altair chart representing the space structure. """ return self.space_drawer.draw_altair(**kwargs)
[docs] def collect_agent_data( self, space, agent_portrayal: Callable, default_size: float | None = None ): """Collect plotting data for all agents in the space for Altair. Args: space: The Mesa space containing agents. agent_portrayal: Callable that returns AgentPortrayalStyle for each agent. default_size: Default marker size if not specified in portrayal. Returns: dict: Dictionary containing agent plotting data arrays. """ # Initialize data collection arrays arguments = { "loc": [], "size": [], "color": [], "shape": [], "order": [], # z-order "opacity": [], "stroke": [], # Stroke color "strokeWidth": [], "filled": [], "tooltip": [], } # Import here to avoid circular import issues from mesa.visualization.components import AgentPortrayalStyle # noqa: PLC0415 style_fields = {f.name: f.default for f in fields(AgentPortrayalStyle)} class_default_size = style_fields.get("size") # Marker mapping from Matplotlib to Altair marker_to_shape_map = { "o": "circle", "s": "square", "D": "diamond", "^": "triangle-up", "v": "triangle-down", "<": "triangle-left", ">": "triangle-right", "+": "cross", "x": "cross", # Both '+' and 'x' map to cross in Altair ".": "circle", # Small point becomes circle "1": "triangle-down", "2": "triangle-up", "3": "triangle-left", "4": "triangle-right", } for agent in space.agents: portray_input = agent_portrayal(agent) aps: AgentPortrayalStyle if isinstance(portray_input, dict): warnings.warn( ( "Returning a dict from agent_portrayal is deprecated and will be removed in Mesa 4.0. " "Please return an AgentPortrayalStyle instance instead. " "For more information, refer to the migration guide: " "https://mesa.readthedocs.io/latest/migration_guide.html#defining-portrayal-components" ), FutureWarning, stacklevel=2, ) dict_data = portray_input.copy() agent_x, agent_y = self._get_agent_pos(agent, space) aps = AgentPortrayalStyle( x=agent_x, y=agent_y, size=dict_data.pop("size", style_fields.get("size")), color=dict_data.pop("color", style_fields.get("color")), marker=dict_data.pop("marker", style_fields.get("marker")), zorder=dict_data.pop("zorder", style_fields.get("zorder")), alpha=dict_data.pop("alpha", style_fields.get("alpha")), edgecolors=dict_data.pop( "edgecolors", style_fields.get("edgecolors") ), linewidths=dict_data.pop( "linewidths", style_fields.get("linewidths") ), ) if dict_data: ignored_keys = list(dict_data.keys()) warnings.warn( f"The following keys were ignored from dict portrayal: {', '.join(ignored_keys)}", UserWarning, stacklevel=2, ) else: aps = portray_input if aps.x is None and aps.y is None: aps.x, aps.y = self._get_agent_pos(agent, space) arguments["loc"].append((aps.x, aps.y)) size_to_collect = aps.size if aps.size is not None else default_size if size_to_collect is None: size_to_collect = class_default_size arguments["size"].append(size_to_collect) arguments["color"].append( aps.color if aps.color is not None else style_fields.get("color") ) arguments["tooltip"].append(aps.tooltip) # Map marker to Altair shape if defined, else use raw marker raw_marker = ( aps.marker if aps.marker is not None else style_fields.get("marker") ) shape_value = marker_to_shape_map.get(raw_marker, raw_marker) if shape_value is None: warnings.warn( f"Marker '{raw_marker}' is not supported in Altair. " "Using 'circle' as default.", UserWarning, stacklevel=2, ) shape_value = "circle" arguments["shape"].append(shape_value) arguments["order"].append( aps.zorder if aps.zorder is not None else style_fields.get("zorder") ) arguments["opacity"].append( aps.alpha if aps.alpha is not None else style_fields.get("alpha") ) arguments["stroke"].append(aps.edgecolors) arguments["strokeWidth"].append( aps.linewidths if aps.linewidths is not None else style_fields.get("linewidths") ) # FIXME: Make filled user-controllable filled_value = True arguments["filled"].append(filled_value) final_data = {} for k, v in arguments.items(): if k == "shape": # Ensure shape is an object array arr = np.empty(len(v), dtype=object) arr[:] = v final_data[k] = arr elif k in ["x", "y", "size", "order", "opacity", "strokeWidth"]: final_data[k] = np.asarray(v, dtype=float) else: final_data[k] = np.asarray(v) return final_data
[docs] def draw_agents( self, arguments, chart_width: int = 450, chart_height: int = 350, **kwargs ): """Draw agents using Altair backend. Args: arguments: Dictionary containing agent data arrays. chart_width: Width of the chart. chart_height: Height of the chart. **kwargs: Additional keyword arguments for customization. Checkout respective `SpaceDrawer` class on details how to pass **kwargs. Returns: alt.Chart: The Altair chart representing the agents, or None if no agents. """ if arguments["loc"].size == 0: return None # To get a continuous scale for color the domain should be between [0, 1] # that's why changing the the domain of strokeWidth beforehand. stroke_width = [data / 10 for data in arguments["strokeWidth"]] # Agent data preparation df_data = { "x": arguments["loc"][:, 0], "y": arguments["loc"][:, 1], "size": arguments["size"], "shape": arguments["shape"], "opacity": arguments["opacity"], "strokeWidth": stroke_width, "original_color": arguments["color"], "is_filled": arguments["filled"], "original_stroke": arguments["stroke"], } df = pd.DataFrame(df_data) # To ensure distinct shapes according to agent portrayal unique_shape_names_in_data = df["shape"].unique().tolist() fill_colors = [] stroke_colors = [] for i in range(len(df)): filled = df["is_filled"][i] main_color = df["original_color"][i] stroke_spec = ( df["original_stroke"][i] if isinstance(df["original_stroke"][i], str) else None ) if filled: fill_colors.append(main_color) stroke_colors.append(stroke_spec) else: fill_colors.append(None) stroke_colors.append(main_color) df["viz_fill_color"] = fill_colors df["viz_stroke_color"] = stroke_colors # Extract additional parameters from kwargs # FIXME: Add more parameters to kwargs title = kwargs.pop("title", "") xlabel = kwargs.pop("xlabel", "") ylabel = kwargs.pop("ylabel", "") # Tooltip list for interactivity tooltip_list = [] # Find ALL unique keys all_tooltips_key = set() column_data = {} for tooltip in arguments["tooltip"]: if tooltip: all_tooltips_key.update(tooltip.keys()) if all_tooltips_key: # pre-build columns column_data = { key: [None] * len(arguments["tooltip"]) for key in all_tooltips_key } for i, tooltip in enumerate(arguments["tooltip"]): if tooltip: for key, value in tooltip.items(): column_data[key][i] = value for key, values in column_data.items(): df[key] = values tooltip_list.extend(sorted(all_tooltips_key)) # Handle custom colormapping cmap = kwargs.pop("cmap", "viridis") vmin = kwargs.pop("vmin", None) vmax = kwargs.pop("vmax", None) color_is_numeric = pd.api.types.is_numeric_dtype(df["original_color"]) if color_is_numeric: color_min = vmin if vmin is not None else df["original_color"].min() color_max = vmax if vmax is not None else df["original_color"].max() fill_encoding = alt.Fill( "original_color:Q", scale=alt.Scale(scheme=cmap, domain=[color_min, color_max]), ) else: fill_encoding = alt.Fill( "viz_fill_color:N", scale=None, title="Color", ) # Determine space dimensions xmin, xmax, ymin, ymax = self.space_drawer.get_viz_limits() chart = ( alt.Chart(df) .mark_point() .encode( x=alt.X( "x:Q", title=xlabel, scale=alt.Scale(type="linear", domain=[xmin, xmax]), axis=None, ), y=alt.Y( "y:Q", title=ylabel, scale=alt.Scale(type="linear", domain=[ymin, ymax]), axis=None, ), size=alt.Size("size:Q", legend=None, scale=alt.Scale(domain=[0, 50])), shape=alt.Shape( "shape:N", scale=alt.Scale( domain=unique_shape_names_in_data, range=unique_shape_names_in_data, ), title="Shape", ), opacity=alt.Opacity( "opacity:Q", title="Opacity", scale=alt.Scale(domain=[0, 1], range=[0, 1]), ), fill=fill_encoding, stroke=alt.Stroke("viz_stroke_color:N", scale=None), strokeWidth=alt.StrokeWidth( "strokeWidth:Q", scale=alt.Scale(domain=[0, 1]) ), tooltip=tooltip_list, ) .properties(title=title, width=chart_width, height=chart_height) ) return chart
[docs] def draw_property_layer( self, space, property_layers: dict[str, np.ndarray], property_layer_portrayal: Callable, chart_width: int = 450, chart_height: int = 350, ): """Draw property layers using Altair backend. Args: space: The Mesa space object containing the property layers. property_layers: A dictionary mapping property_layer names to numpy arrays. property_layer_portrayal: A function that returns PropertyLayerStyle that contains the visualization parameters. chart_width: The width of the chart. chart_height: The height of the chart. Returns: alt.Chart: A tuple containing the base chart and the color bar chart. """ main_charts = [] for layer_name, layer in property_layers.items(): if layer_name == "empty": continue portrayal = property_layer_portrayal(layer_name) if portrayal is None: continue data = layer.astype(float) if layer.dtype == bool else layer # Check dimensions if (space.width, space.height) != data.shape: warnings.warn( f"Property Layer {layer_name} dimensions ({data.shape}) " f"don't match space dimensions ({space.width}, {space.height})", UserWarning, stacklevel=2, ) continue # Get portrayal parameters color = portrayal.color colormap = portrayal.colormap alpha = portrayal.alpha vmin = portrayal.vmin if portrayal.vmin is not None else np.min(data) vmax = portrayal.vmax if portrayal.vmax is not None else np.max(data) df = pd.DataFrame( { "x": np.repeat(np.arange(data.shape[0]), data.shape[1]), "y": np.tile(np.arange(data.shape[1] - 1, -1, -1), data.shape[0]), "value": data.flatten(), } ) if color: # For a single color gradient, we define the range from transparent to solid. rgb = to_rgb(color) r, g, b = (int(c * 255) for c in rgb) min_color = f"rgba({r},{g},{b},0)" max_color = f"rgba({r},{g},{b},{alpha})" opacity = 1 color_scale = alt.Scale( range=[min_color, max_color], domain=[vmin, vmax] ) elif colormap: cmap = colormap color_scale = alt.Scale(scheme=cmap, domain=[vmin, vmax]) opacity = alpha else: raise ValueError( f"Property Layer {layer_name} portrayal must include 'color' or 'colormap'." ) current_chart = ( alt.Chart(df) .mark_rect(opacity=opacity) .encode( x=alt.X("x:O", axis=None), y=alt.Y("y:O", axis=None), color=alt.Color( "value:Q", scale=color_scale, title=layer_name, legend=alt.Legend(title=layer_name, orient="bottom") if portrayal.colorbar else None, ), ) .properties(width=chart_width, height=chart_height) ) if current_chart is not None: main_charts.append(current_chart) base = alt.layer(*main_charts).resolve_scale(color="independent") return base