"""Mesa visualization module for creating interactive model visualizations.
This module provides components to create browser- and Jupyter notebook-based visualizations of
Mesa models, allowing users to watch models run step-by-step and interact with model parameters.
Key features:
- SolaraViz: Main component for creating visualizations, supporting grid displays and plots
- ModelController: Handles model execution controls (step, play, pause, reset)
- UserInputs: Generates UI elements for adjusting model parameters
The module uses Solara for rendering in Jupyter notebooks or as standalone web applications.
It supports various types of visualizations including matplotlib plots, agent grids, and
custom visualization components.
Usage:
1. Define an agent_portrayal function to specify how agents should be displayed
2. Set up model_params to define adjustable parameters
3. Create a SolaraViz instance with your model, parameters, and desired measures
4. Display the visualization in a Jupyter notebook or run as a Solara app
See the Visualization Tutorial and example models for more details.
"""
from __future__ import annotations
import asyncio
import inspect
from collections.abc import Callable
from typing import TYPE_CHECKING, Literal
import reacton.core
import solara
import mesa.visualization.components.altair_components as components_altair
from mesa.experimental.devs.simulator import Simulator
from mesa.mesa_logging import create_module_logger, function_logger
from mesa.visualization.user_param import Slider
from mesa.visualization.utils import force_update, update_counter
if TYPE_CHECKING:
from mesa.model import Model
_mesa_logger = create_module_logger()
@solara.component
@function_logger(__name__)
def SolaraViz(
model: Model | solara.Reactive[Model],
components: list[reacton.core.Component]
| list[Callable[[Model], reacton.core.Component]]
| Literal["default"] = "default",
*,
play_interval: int = 100,
render_interval: int = 1,
simulator: Simulator | None = None,
model_params=None,
name: str | None = None,
):
"""Solara visualization component.
This component provides a visualization interface for a given model using Solara.
It supports various visualization components and allows for interactive model
stepping and parameter adjustments.
Args:
model (Model | solara.Reactive[Model]): A Model instance or a reactive Model.
This is the main model to be visualized. If a non-reactive model is provided,
it will be converted to a reactive model.
components (list[solara.component] | Literal["default"], optional): List of solara
components or functions that return a solara component.
These components are used to render different parts of the model visualization.
Defaults to "default", which uses the default Altair space visualization.
play_interval (int, optional): Interval for playing the model steps in milliseconds.
This controls the speed of the model's automatic stepping. Defaults to 100 ms.
render_interval (int, optional): Controls how often plots are updated during a simulation,
allowing users to skip intermediate steps and update graphs less frequently.
simulator: A simulator that controls the model (optional)
model_params (dict, optional): Parameters for (re-)instantiating a model.
Can include user-adjustable parameters and fixed parameters. Defaults to None.
name (str | None, optional): Name of the visualization. Defaults to the models class name.
Returns:
solara.component: A Solara component that renders the visualization interface for the model.
Example:
>>> model = MyModel()
>>> page = SolaraViz(model)
>>> page
Notes:
- The `model` argument can be either a direct model instance or a reactive model. If a direct
model instance is provided, it will be converted to a reactive model using `solara.use_reactive`.
- The `play_interval` argument controls the speed of the model's automatic stepping. A lower
value results in faster stepping, while a higher value results in slower stepping.
- The `render_interval` argument determines how often plots are updated during simulation. Higher values
reduce update frequency,resulting in faster execution.
"""
if components == "default":
components = [
components_altair.make_altair_space(
agent_portrayal=None, propertylayer_portrayal=None, post_process=None
)
]
if model_params is None:
model_params = {}
# Convert model to reactive
if not isinstance(model, solara.Reactive):
model = solara.use_reactive(model) # noqa: SH102, RUF100
# set up reactive model_parameters shared by ModelCreator and ModelController
reactive_model_parameters = solara.use_reactive({})
reactive_play_interval = solara.use_reactive(play_interval)
reactive_render_interval = solara.use_reactive(render_interval)
with solara.AppBar():
solara.AppBarTitle(name if name else model.value.__class__.__name__)
with solara.Sidebar(), solara.Column():
with solara.Card("Controls"):
solara.SliderInt(
label="Play Interval (ms)",
value=reactive_play_interval,
on_value=lambda v: reactive_play_interval.set(v),
min=1,
max=500,
step=10,
)
solara.SliderInt(
label="Render Interval (steps)",
value=reactive_render_interval,
on_value=lambda v: reactive_render_interval.set(v),
min=1,
max=100,
step=2,
)
if not isinstance(simulator, Simulator):
ModelController(
model,
model_parameters=reactive_model_parameters,
play_interval=reactive_play_interval,
render_interval=reactive_render_interval,
)
else:
SimulatorController(
model,
simulator,
model_parameters=reactive_model_parameters,
play_interval=reactive_play_interval,
render_interval=reactive_render_interval,
)
with solara.Card("Model Parameters"):
ModelCreator(
model, model_params, model_parameters=reactive_model_parameters
)
with solara.Card("Information"):
ShowSteps(model.value)
ComponentsView(components, model.value)
def _wrap_component(
component: reacton.core.Component | Callable[[Model], reacton.core.Component],
) -> reacton.core.Component:
"""Wrap a component in an auto-updated Solara component if needed."""
if isinstance(component, reacton.core.Component):
return component
@solara.component
def WrappedComponent(model):
update_counter.get()
return component(model)
return WrappedComponent
@solara.component
def ComponentsView(
components: list[reacton.core.Component]
| list[Callable[[Model], reacton.core.Component]],
model: Model,
):
"""Display a list of components.
Args:
components: List of components to display
model: Model instance to pass to each component
"""
wrapped_components = [_wrap_component(component) for component in components]
items = [component(model) for component in wrapped_components]
grid_layout_initial = make_initial_grid_layout(num_components=len(items))
grid_layout, set_grid_layout = solara.use_state(grid_layout_initial)
solara.GridDraggable(
items=items,
grid_layout=grid_layout,
resizable=True,
draggable=True,
on_grid_layout=set_grid_layout,
)
JupyterViz = SolaraViz
@solara.component
def ModelController(
model: solara.Reactive[Model],
*,
model_parameters: dict | solara.Reactive[dict] = None,
play_interval: int | solara.Reactive[int] = 100,
render_interval: int | solara.Reactive[int] = 1,
):
"""Create controls for model execution (step, play, pause, reset).
Args:
model: Reactive model instance
model_parameters: Reactive parameters for (re-)instantiating a model.
play_interval: Interval for playing the model steps in milliseconds.
render_interval: Controls how often the plots are updated during simulation steps.Higher value reduce update frequency.
"""
playing = solara.use_reactive(False)
running = solara.use_reactive(True)
if model_parameters is None:
model_parameters = {}
model_parameters = solara.use_reactive(model_parameters)
async def step():
while playing.value and running.value:
await asyncio.sleep(play_interval.value / 1000)
do_step()
solara.lab.use_task(
step, dependencies=[playing.value, running.value], prefer_threaded=False
)
@function_logger(__name__)
def do_step():
"""Advance the model by the number of steps specified by the render_interval slider."""
for _ in range(render_interval.value):
model.value.step()
running.value = model.value.running
force_update()
@function_logger(__name__)
def do_reset():
"""Reset the model to its initial state."""
playing.value = False
running.value = True
_mesa_logger.log(
10,
f"creating new {model.value.__class__} instance with {model_parameters.value}",
)
model.value = model.value = model.value.__class__(**model_parameters.value)
@function_logger(__name__)
def do_play_pause():
"""Toggle play/pause."""
playing.value = not playing.value
with solara.Row(justify="space-between"):
solara.Button(label="Reset", color="primary", on_click=do_reset)
solara.Button(
label="▶" if not playing.value else "❚❚",
color="primary",
on_click=do_play_pause,
disabled=not running.value,
)
solara.Button(
label="Step",
color="primary",
on_click=do_step,
disabled=playing.value or not running.value,
)
@solara.component
def SimulatorController(
model: solara.Reactive[Model],
simulator,
*,
model_parameters: dict | solara.Reactive[dict] = None,
play_interval: int | solara.Reactive[int] = 100,
render_interval: int | solara.Reactive[int] = 1,
):
"""Create controls for model execution (step, play, pause, reset).
Args:
model: Reactive model instance
simulator: Simulator instance
model_parameters: Reactive parameters for (re-)instantiating a model.
play_interval: Interval for playing the model steps in milliseconds.
render_interval: Controls how often the plots are updated during simulation steps.Higher values reduce update frequency.
Notes:
The `step button` increments the step by the value specified in the `render_interval` slider.
This behavior ensures synchronization between simulation steps and plot updates.
"""
playing = solara.use_reactive(False)
running = solara.use_reactive(True)
if model_parameters is None:
model_parameters = {}
model_parameters = solara.use_reactive(model_parameters)
async def step():
while playing.value and running.value:
await asyncio.sleep(play_interval.value / 1000)
do_step()
solara.lab.use_task(
step, dependencies=[playing.value, running.value], prefer_threaded=False
)
def do_step():
"""Advance the model by the number of steps specified by the render_interval slider."""
simulator.run_for(render_interval.value)
running.value = model.value.running
force_update()
def do_reset():
"""Reset the model to its initial state."""
playing.value = False
running.value = True
simulator.reset()
model.value = model.value = model.value.__class__(
simulator=simulator, **model_parameters.value
)
def do_play_pause():
"""Toggle play/pause."""
playing.value = not playing.value
with solara.Row(justify="space-between"):
solara.Button(label="Reset", color="primary", on_click=do_reset)
solara.Button(
label="▶" if not playing.value else "❚❚",
color="primary",
on_click=do_play_pause,
disabled=not running.value,
)
solara.Button(
label="Step",
color="primary",
on_click=do_step,
disabled=playing.value or not running.value,
)
[docs]
def split_model_params(model_params):
"""Split model parameters into user-adjustable and fixed parameters.
Args:
model_params: Dictionary of all model parameters
Returns:
tuple: (user_adjustable_params, fixed_params)
"""
model_params_input = {}
model_params_fixed = {}
for k, v in model_params.items():
if check_param_is_fixed(v):
model_params_fixed[k] = v
else:
model_params_input[k] = v
return model_params_input, model_params_fixed
[docs]
def check_param_is_fixed(param):
"""Check if a parameter is fixed (not user-adjustable).
Args:
param: Parameter to check
Returns:
bool: True if parameter is fixed, False otherwise
"""
if isinstance(param, Slider):
return False
if not isinstance(param, dict):
return True
if "type" not in param:
return True
@solara.component
def ModelCreator(
model: solara.Reactive[Model],
user_params: dict,
*,
model_parameters: dict | solara.Reactive[dict] = None,
):
"""Solara component for creating and managing a model instance with user-defined parameters.
This component allows users to create a model instance with specified parameters and seed.
It provides an interface for adjusting model parameters and reseeding the model's random
number generator.
Args:
model: A reactive model instance. This is the main model to be created and managed.
user_params: Parameters for (re-)instantiating a model. Can include user-adjustable parameters and fixed parameters. Defaults to None.
model_parameters: reactive parameters for reinitializing the model
Returns:
solara.component: A Solara component that renders the model creation and management interface.
Example:
>>> model = solara.reactive(MyModel())
>>> model_params = {
>>> "param1": {"type": "slider", "value": 10, "min": 0, "max": 100},
>>> "param2": {"type": "slider", "value": 5, "min": 1, "max": 10},
>>> }
>>> creator = ModelCreator(model, model_params)
>>> creator
Notes:
- The `model_params` argument should be a dictionary where keys are parameter names and values either fixed values
or are dictionaries containing parameter details such as type, value, min, and max.
- The `seed` argument ensures reproducibility by setting the initial seed for the model's random number generator.
- The component provides an interface for adjusting user-defined parameters and reseeding the model.
"""
if model_parameters is None:
model_parameters = {}
model_parameters = solara.use_reactive(model_parameters)
solara.use_effect(
lambda: _check_model_params(model.value.__class__.__init__, fixed_params),
[model.value],
)
user_params, fixed_params = split_model_params(user_params)
# Use solara.use_effect to run the initialization code only once
solara.use_effect(
# set model_parameters to the default values for all parameters
lambda: model_parameters.set(
{
**fixed_params,
**{k: v.get("value") for k, v in user_params.items()},
}
),
[],
)
@function_logger(__name__)
def on_change(name, value):
model_parameters.value = {**model_parameters.value, name: value}
UserInputs(user_params, on_change=on_change)
def _check_model_params(init_func, model_params):
"""Check if model parameters are valid for the model's initialization function.
Args:
init_func: Model initialization function
model_params: Dictionary of model parameters
Raises:
ValueError: If a parameter is not valid for the model's initialization function
"""
model_parameters = inspect.signature(init_func).parameters
has_var_positional = any(
param.kind == inspect.Parameter.VAR_POSITIONAL
for param in model_parameters.values()
)
if has_var_positional:
raise ValueError(
"Mesa's visualization requires the use of keyword arguments to ensure the parameters are passed to Solara correctly. Please ensure all model parameters are of form param=value"
)
for name in model_parameters:
if (
model_parameters[name].default == inspect.Parameter.empty
and name not in model_params
and name != "self"
and name != "kwargs"
):
raise ValueError(f"Missing required model parameter: {name}")
for name in model_params:
if name not in model_parameters and "kwargs" not in model_parameters:
raise ValueError(f"Invalid model parameter: {name}")
@solara.component
def UserInputs(user_params, on_change=None):
"""Initialize user inputs for configurable model parameters.
Currently supports :class:`solara.SliderInt`, :class:`solara.SliderFloat`,
:class:`solara.Select`, and :class:`solara.Checkbox`.
Args:
user_params: Dictionary with options for the input, including label, min and max values, and other fields specific to the input type.
on_change: Function to be called with (name, value) when the value of an input changes.
"""
for name, options in user_params.items():
def change_handler(value, name=name):
on_change(name, value)
if isinstance(options, Slider):
slider_class = (
solara.SliderFloat if options.is_float_slider else solara.SliderInt
)
slider_class(
options.label,
value=options.value,
on_value=change_handler,
min=options.min,
max=options.max,
step=options.step,
)
continue
# label for the input is "label" from options or name
label = options.get("label", name)
input_type = options.get("type")
if input_type == "SliderInt":
solara.SliderInt(
label,
value=options.get("value"),
on_value=change_handler,
min=options.get("min"),
max=options.get("max"),
step=options.get("step"),
)
elif input_type == "SliderFloat":
solara.SliderFloat(
label,
value=options.get("value"),
on_value=change_handler,
min=options.get("min"),
max=options.get("max"),
step=options.get("step"),
)
elif input_type == "Select":
solara.Select(
label,
value=options.get("value"),
on_value=change_handler,
values=options.get("values"),
)
elif input_type == "Checkbox":
solara.Checkbox(
label=label,
on_value=change_handler,
value=options.get("value"),
)
elif input_type == "InputText":
solara.InputText(
label=label,
on_value=change_handler,
value=options.get("value"),
)
else:
raise ValueError(f"{input_type} is not a supported input type")
[docs]
def make_initial_grid_layout(num_components):
"""Create an initial grid layout for visualization components.
Args:
num_components: Number of components to display
Returns:
list: Initial grid layout configuration
"""
return [
{
"i": i,
"w": 6,
"h": 10,
"moved": False,
"x": 6 * (i % 2),
"y": 16 * (i - i % 2),
}
for i in range(num_components)
]
@solara.component
def ShowSteps(model):
"""Display the current step of the model."""
update_counter.get()
return solara.Text(f"Step: {model.steps}")