Source code for batchrunner

"""batchrunner for running a factorial experiment design over a model.

To take advantage of parallel execution of experiments, `batch_run` uses
multiprocessing if ``number_processes`` is larger than 1. It is strongly advised
to only run in parallel using a normal python file (so don't try to do it in a
jupyter notebook). This is because Jupyter notebooks have a different execution
model that can cause issues with Python's multiprocessing module, especially on
Windows. The main problems include the lack of a traditional __main__ entry
point, serialization issues, and potential deadlocks.

Moreover, best practice when using multiprocessing is to
put the code inside an ``if __name__ == '__main__':`` code black as shown below::

    from mesa.batchrunner import batch_run

    params = {"width": 10, "height": 10, "N": range(10, 500, 10)}

    if __name__ == '__main__':
        results = batch_run(
            MoneyModel,
            parameters=params,
            iterations=5,
            max_steps=100,
            number_processes=None,
            data_collection_period=1,
            display_progress=True,
        )

"""

import bisect
import inspect
import itertools
import multiprocessing
import warnings
from collections.abc import Iterable, Mapping, Sequence
from functools import partial
from multiprocessing import Pool
from typing import Any

import numpy as np
from tqdm.auto import tqdm

from mesa.model import Model

SeedLike = int | np.integer | Sequence[int] | np.random.SeedSequence


[docs] def batch_run( model_cls: type[Model], parameters: Mapping[str, Any | Iterable[Any]], # We still retain the Optional[int] because users may set it to None (i.e. use all CPUs) number_processes: int | None = 1, iterations: int | None = None, data_collection_period: int = -1, max_steps: int = 1000, display_progress: bool = True, rng: SeedLike | Iterable[SeedLike] | None = None, ) -> list[dict[str, Any]]: """Batch run a mesa model with a set of parameter values. Args: model_cls (Type[Model]): The model class to batch-run parameters (Mapping[str, Union[Any, Iterable[Any]]]): Dictionary with model parameters over which to run the model. You can either pass single values or iterables. number_processes (int, optional): Number of processes used, by default 1. Set this to None if you want to use all CPUs. iterations (int, optional): Number of iterations for each parameter combination, by default 1 data_collection_period (int, optional): Number of steps after which data gets collected, by default -1 (end of episode) max_steps (int, optional): Maximum number of model steps after which the model halts, by default 1000 display_progress (bool, optional): Display batch run process, by default True rng : a valid value or iterable of values for seeding the random number generator in the model Returns: List[Dict[str, Any]] Notes: batch_run assumes the model has a `datacollector` attribute that has a DataCollector object initialized. """ multiprocessing.set_start_method("spawn", force=True) if iterations is not None and rng is not None: raise ValueError( "you cannot use both iterations and rng at the same time. Please only use rng." ) if iterations is not None: warnings.warn( "The `iterations` keyword argument is deprecated. " "Use `rng` instead (e.g. `iterations=5` is equivalent to `rng=[None] * 5`). " "See https://mesa.readthedocs.io/latest/migration_guide.html#batch-run", DeprecationWarning, stacklevel=2, ) rng = [None] * iterations if not isinstance(rng, Iterable): rng = [rng] # establish to use seed or rng as name for parameter model_parameters = inspect.signature(model_cls).parameters rng_kwarg_name = "rng" if "seed" in model_parameters: rng_kwarg_name = "seed" runs_list = [] run_id = 0 for i, rng_i in enumerate(rng): for kwargs in _make_model_kwargs(parameters): kwargs[rng_kwarg_name] = rng_i runs_list.append((run_id, i, kwargs)) run_id += 1 process_func = partial( _model_run_func, model_cls, max_steps=max_steps, data_collection_period=data_collection_period, ) results: list[dict[str, Any]] = [] with tqdm(total=len(runs_list), disable=not display_progress) as pbar: if number_processes == 1: for run in runs_list: data = process_func(run) results.extend(data) pbar.update() else: with Pool(number_processes) as p: for data in p.imap_unordered(process_func, runs_list): results.extend(data) pbar.update() return results
def _make_model_kwargs( parameters: Mapping[str, Any | Iterable[Any]], ) -> list[dict[str, Any]]: """Create model kwargs from parameters dictionary. Parameters ---------- parameters : Mapping[str, Union[Any, Iterable[Any]]] Single or multiple values for each model parameter name. Allowed values for each parameter: - A single value (e.g., `32`, `"relu"`). - A non-empty iterable (e.g., `[0.01, 0.1]`, `["relu", "sigmoid"]`). Not allowed: - Empty lists or empty iterables (e.g., `[]`, `()`, etc.). These should be removed manually. Returns: ------- List[Dict[str, Any]] A list of all kwargs combinations. """ parameter_list = [] for param, values in parameters.items(): if isinstance(values, str): # The values is a single string, so we shouldn't iterate over it. all_values = [(param, values)] elif isinstance(values, list | tuple | set) and len(values) == 0: # If it's an empty iterable, raise an error raise ValueError( f"Parameter '{param}' contains an empty iterable, which is not allowed." ) else: try: all_values = [(param, value) for value in values] except TypeError: all_values = [(param, values)] parameter_list.append(all_values) all_kwargs = itertools.product(*parameter_list) kwargs_list = [dict(kwargs) for kwargs in all_kwargs] return kwargs_list def _model_run_func( model_cls: type[Model], run: tuple[int, int, dict[str, Any]], max_steps: int, data_collection_period: int, ) -> list[dict[str, Any]]: """Run a single model run and collect model and agent data. Parameters ---------- model_cls : Type[Model] The model class to batch-run run: Tuple[int, int, Dict[str, Any]] The run id, iteration number, and kwargs for this run max_steps : int Maximum number of model steps after which the model halts, by default 1000 data_collection_period : int Number of steps after which data gets collected Returns: ------- List[Dict[str, Any]] Return model_data, agent_data from the reporters """ run_id, iteration, kwargs = run model = model_cls(**kwargs) while model.running and model.steps < max_steps: model.step() data = [] # Use the DataCollector's actual history to capture ALL data (including sub-steps) try: recorded_steps = model.datacollector._collection_steps except AttributeError: # Fallback for legacy models without _collection_steps steps = list(range(0, model.steps, data_collection_period)) if not steps or steps[-1] != model.steps - 1: steps.append(model.steps - 1) else: match data_collection_period: case -1: steps = [recorded_steps[-1]] if recorded_steps else [] case 1: steps = recorded_steps case _: steps = recorded_steps[::data_collection_period] for step in steps: model_data, all_agents_data = _collect_data(model, step) # If there are agent_reporters, then create an entry for each agent if all_agents_data: stepdata = [ { "RunId": run_id, "iteration": iteration, "Step": step, **kwargs, **model_data, **agent_data, } for agent_data in all_agents_data ] # If there is only model data, then create a single entry for the step else: stepdata = [ { "RunId": run_id, "iteration": iteration, "Step": step, **kwargs, **model_data, } ] data.extend(stepdata) return data def _collect_data( model: Model, step: int, ) -> tuple[dict[str, Any], list[dict[str, Any]]]: """Collect model and agent data from a model using mesas datacollector.""" if not hasattr(model, "datacollector"): raise AttributeError( "The model does not have a datacollector attribute. Please add a DataCollector to your model." ) dc = model.datacollector # Check if modern DataCollector with _collection_steps exists (handles time dilation) if hasattr(dc, "_collection_steps"): idx = bisect.bisect_right(dc._collection_steps, step) - 1 if ( idx >= 0 and idx < len(dc._collection_steps) and dc._collection_steps[idx] == step ): # Exact match found - use the index directly model_data = {param: values[idx] for param, values in dc.model_vars.items()} else: # Step not found in _collection_steps # Use sparse collection logic: find the nearest collected step if idx >= 0 and idx < len(dc._collection_steps): # Use the most recent collected data before this step model_data = { param: values[idx] for param, values in dc.model_vars.items() } else: # No data collected yet, use first available try: model_data = { param: values[0] for param, values in dc.model_vars.items() } except IndexError: model_data = {} else: # Legacy DataCollector without _collection_steps # Use sparse collection logic for models that collect data irregularly available_steps = sorted(dc._agent_records.keys()) if step not in available_steps: step = max((s for s in available_steps if s <= step), default=0) try: collection_index = available_steps.index(step) except ValueError: collection_index = 0 model_data = { param: values[collection_index] for param, values in dc.model_vars.items() } all_agents_data = [] # Collect agent_reporters data raw_agent_data = dc._agent_records.get(step, []) for data in raw_agent_data: agent_dict = {"AgentID": data[1]} agent_dict.update(zip(dc.agent_reporters, data[2:])) all_agents_data.append(agent_dict) # Collect agenttype_reporters data raw_agenttype_data = dc._agenttype_records.get(step, {}) for agent_type, agents_data in raw_agenttype_data.items(): for data in agents_data: agent_dict = {"AgentID": data[1]} agent_dict.update(zip(dc.agenttype_reporters[agent_type], data[2:])) all_agents_data.append(agent_dict) return model_data, all_agents_data