"""batchrunner for running a factorial experiment design over a model.
To take advantage of parallel execution of experiments, `batch_run` uses
multiprocessing if ``number_processes`` is larger than 1. It is strongly advised
to only run in parallel using a normal python file (so don't try to do it in a
jupyter notebook). Moreover, best practice when using multiprocessing is to
put the code inside an ``if __name__ == '__main__':`` code black as shown below::
from mesa.batchrunner import batch_run
params = {"width": 10, "height": 10, "N": range(10, 500, 10)}
if __name__ == '__main__':
results = batch_run(
MoneyModel,
parameters=params,
iterations=5,
max_steps=100,
number_processes=None,
data_collection_period=1,
display_progress=True,
)
"""
import itertools
import multiprocessing
from collections.abc import Iterable, Mapping
from functools import partial
from multiprocessing import Pool
from typing import Any
from tqdm.auto import tqdm
from mesa.model import Model
multiprocessing.set_start_method("spawn", force=True)
[docs]
def batch_run(
model_cls: type[Model],
parameters: Mapping[str, Any | Iterable[Any]],
# We still retain the Optional[int] because users may set it to None (i.e. use all CPUs)
number_processes: int | None = 1,
iterations: int = 1,
data_collection_period: int = -1,
max_steps: int = 1000,
display_progress: bool = True,
) -> list[dict[str, Any]]:
"""Batch run a mesa model with a set of parameter values.
Args:
model_cls (Type[Model]): The model class to batch-run
parameters (Mapping[str, Union[Any, Iterable[Any]]]): Dictionary with model parameters over which to run the model. You can either pass single values or iterables.
number_processes (int, optional): Number of processes used, by default 1. Set this to None if you want to use all CPUs.
iterations (int, optional): Number of iterations for each parameter combination, by default 1
data_collection_period (int, optional): Number of steps after which data gets collected, by default -1 (end of episode)
max_steps (int, optional): Maximum number of model steps after which the model halts, by default 1000
display_progress (bool, optional): Display batch run process, by default True
Returns:
List[Dict[str, Any]]
Notes:
batch_run assumes the model has a `datacollector` attribute that has a DataCollector object initialized.
"""
runs_list = []
run_id = 0
for iteration in range(iterations):
for kwargs in _make_model_kwargs(parameters):
runs_list.append((run_id, iteration, kwargs))
run_id += 1
process_func = partial(
_model_run_func,
model_cls,
max_steps=max_steps,
data_collection_period=data_collection_period,
)
results: list[dict[str, Any]] = []
with tqdm(total=len(runs_list), disable=not display_progress) as pbar:
if number_processes == 1:
for run in runs_list:
data = process_func(run)
results.extend(data)
pbar.update()
else:
with Pool(number_processes) as p:
for data in p.imap_unordered(process_func, runs_list):
results.extend(data)
pbar.update()
return results
def _make_model_kwargs(
parameters: Mapping[str, Any | Iterable[Any]],
) -> list[dict[str, Any]]:
"""Create model kwargs from parameters dictionary.
Parameters
----------
parameters : Mapping[str, Union[Any, Iterable[Any]]]
Single or multiple values for each model parameter name.
Allowed values for each parameter:
- A single value (e.g., `32`, `"relu"`).
- A non-empty iterable (e.g., `[0.01, 0.1]`, `["relu", "sigmoid"]`).
Not allowed:
- Empty lists or empty iterables (e.g., `[]`, `()`, etc.). These should be removed manually.
Returns:
-------
List[Dict[str, Any]]
A list of all kwargs combinations.
"""
parameter_list = []
for param, values in parameters.items():
if isinstance(values, str):
# The values is a single string, so we shouldn't iterate over it.
all_values = [(param, values)]
elif isinstance(values, list | tuple | set) and len(values) == 0:
# If it's an empty iterable, raise an error
raise ValueError(
f"Parameter '{param}' contains an empty iterable, which is not allowed."
)
else:
try:
all_values = [(param, value) for value in values]
except TypeError:
all_values = [(param, values)]
parameter_list.append(all_values)
all_kwargs = itertools.product(*parameter_list)
kwargs_list = [dict(kwargs) for kwargs in all_kwargs]
return kwargs_list
def _model_run_func(
model_cls: type[Model],
run: tuple[int, int, dict[str, Any]],
max_steps: int,
data_collection_period: int,
) -> list[dict[str, Any]]:
"""Run a single model run and collect model and agent data.
Parameters
----------
model_cls : Type[Model]
The model class to batch-run
run: Tuple[int, int, Dict[str, Any]]
The run id, iteration number, and kwargs for this run
max_steps : int
Maximum number of model steps after which the model halts, by default 1000
data_collection_period : int
Number of steps after which data gets collected
Returns:
-------
List[Dict[str, Any]]
Return model_data, agent_data from the reporters
"""
run_id, iteration, kwargs = run
model = model_cls(**kwargs)
while model.running and model.steps <= max_steps:
model.step()
data = []
steps = list(range(0, model.steps, data_collection_period))
if not steps or steps[-1] != model.steps - 1:
steps.append(model.steps - 1)
for step in steps:
model_data, all_agents_data = _collect_data(model, step)
# If there are agent_reporters, then create an entry for each agent
if all_agents_data:
stepdata = [
{
"RunId": run_id,
"iteration": iteration,
"Step": step,
**kwargs,
**model_data,
**agent_data,
}
for agent_data in all_agents_data
]
# If there is only model data, then create a single entry for the step
else:
stepdata = [
{
"RunId": run_id,
"iteration": iteration,
"Step": step,
**kwargs,
**model_data,
}
]
data.extend(stepdata)
return data
def _collect_data(
model: Model,
step: int,
) -> tuple[dict[str, Any], list[dict[str, Any]]]:
"""Collect model and agent data from a model using mesas datacollector."""
if not hasattr(model, "datacollector"):
raise AttributeError(
"The model does not have a datacollector attribute. Please add a DataCollector to your model."
)
dc = model.datacollector
model_data = {param: values[step] for param, values in dc.model_vars.items()}
all_agents_data = []
raw_agent_data = dc._agent_records.get(step, [])
for data in raw_agent_data:
agent_dict = {"AgentID": data[1]}
agent_dict.update(zip(dc.agent_reporters, data[2:]))
all_agents_data.append(agent_dict)
return model_data, all_agents_data