Source code for mesa.discrete_space.grid

"""Grid-based cell space implementations with different connection patterns.

Provides several grid types for organizing cells:
- OrthogonalMooreGrid: 8 neighbors in 2D, (3^n)-1 in nD
- OrthogonalVonNeumannGrid: 4 neighbors in 2D, 2n in nD
- HexGrid: 6 neighbors in hexagonal pattern (2D only)

Each grid type supports optional wrapping (torus) and cell capacity limits.
Choose based on how movement and connectivity should work in your model -
Moore for unrestricted movement, Von Neumann for orthogonal-only movement,
or Hex for more uniform distances.
"""

from __future__ import annotations

import copyreg
import math
from collections.abc import Sequence
from itertools import product
from random import Random
from typing import Any, TypeVar

import numpy as np
from scipy.spatial import KDTree

from mesa.discrete_space import Cell, DiscreteSpace
from mesa.discrete_space.property_layer import (
    HasPropertyLayers,
    create_property_accessors,
)

T = TypeVar("T", bound=Cell)


[docs] def pickle_gridcell(obj): """Helper function for pickling GridCell instances.""" # we have the base class and the state via __getstate__ args = obj.__class__.__bases__[0], obj.__getstate__() return unpickle_gridcell, args
[docs] def unpickle_gridcell(parent, fields): """Helper function for unpickling GridCell instances.""" # since the class is dynamically created, we recreate it here cell_klass = type( "GridCell", (parent,), {"_mesa_properties": set(), "__slots__": ()}, ) instance = cell_klass( (0, 0) ) # we use a default coordinate and overwrite it with the correct value next # __gestate__ returns a tuple with dict and slots, but slots contains the dict so we can just use the # second item only for k, v in fields[1].items(): setattr(instance, k, v) return instance
[docs] class Grid(DiscreteSpace[T], HasPropertyLayers): """Base class for all grid classes. Attributes: dimensions (Sequence[int]): the dimensions of the grid torus (bool): whether the grid is a torus capacity (int): the capacity of a grid cell random (Random): the random number generator _try_random (bool): whether to get empty cell be repeatedly trying random cell Notes: width and height are accessible via properties, higher dimensions can be retrieved via dimensions """ @property def width(self) -> int: """Convenience access to the width of the grid.""" return self.dimensions[0] @property def height(self) -> int: """Convenience access to the height of the grid.""" return self.dimensions[1] def __init__( self, dimensions: Sequence[int], torus: bool = False, capacity: float | None = None, random: Random | None = None, cell_klass: type[T] = Cell, ) -> None: """Initialise the grid class. Args: dimensions: the dimensions of the space torus: whether the space wraps capacity: capacity of the grid cell random: a random number generator cell_klass: the base class to use for the cells """ super().__init__(capacity=capacity, random=random, cell_klass=cell_klass) self.torus = torus self.dimensions = dimensions self._try_random = True self._ndims = len(dimensions) self._validate_parameters() self.cell_klass = type( "GridCell", (self.cell_klass,), {"_mesa_properties": set(), "__slots__": ()}, ) # we register the pickle_gridcell helper function copyreg.pickle(self.cell_klass, pickle_gridcell) coordinates = product(*(range(dim) for dim in self.dimensions)) self._cells = { coord: self.cell_klass(coord, capacity=capacity, random=self.random) for coord in coordinates } self._celllist = list(self._cells.values()) self._connect_cells() self.create_property_layer("empty", default_value=True, dtype=bool)
[docs] def find_nearest_cell(self, position: np.ndarray) -> T: """Find the cell containing the given position. Args: position: Physical position [x, y] Returns: Cell: The cell containing the position Raises: KeyError: If position is outside grid bounds and not a torus """ # Floor to get cell coordinate coord = tuple(np.floor(position).astype(int)) # Handle torus wrapping if self.torus: coord = tuple(c % d for c, d in zip(coord, self.dimensions)) # Check bounds for non-torus grids elif not all(0 <= c < d for c, d in zip(coord, self.dimensions)): raise ValueError( f"Position {position} is outside grid bounds. " f"Dimensions: {self.dimensions}" ) return self._cells[coord]
def _connect_cells(self) -> None: if self._ndims == 2: self._connect_cells_2d() else: self._connect_cells_nd() def _connect_cells_2d(self) -> None: ... def _connect_cells_nd(self) -> None: ... def _validate_parameters(self): if not all(isinstance(dim, int) and dim > 0 for dim in self.dimensions): raise ValueError("Dimensions must be a list of positive integers.") if not isinstance(self.torus, bool): raise ValueError("Torus must be a boolean.") if self.capacity is not None and not isinstance(self.capacity, float | int): raise ValueError("Capacity must be a number or None.")
[docs] def select_random_empty_cell(self) -> T: # noqa # Use a heuristic: try random sampling first for performance (O(1)) # FIXME:: basically if grid is close to 99% full, creating empty list can be faster # FIXME:: note however that the old results don't apply because in this implementation # FIXME:: because empties list needs to be rebuild each time # This method is based on Agents.jl's random_empty() implementation. See # https://github.com/JuliaDynamics/Agents.jl/pull/541. For the discussion, see # https://github.com/mesa/mesa/issues/1052 and # https://github.com/mesa/mesa/pull/1565. The cutoff value provided # is the break-even comparison with the time taken in the else branching point. random = self.random cells = self._celllist if self._try_random: # Limit attempts to avoid infinite loops on full grids for _ in range(50): cell = random.choice(cells) if cell.is_empty: return cell empty_coords = np.argwhere(self.empty.data) random_coord = self.random.choice(empty_coords) return self._cells[tuple(random_coord)]
def _connect_single_cell_nd(self, cell: T, offsets: list[tuple[int, ...]]) -> None: coord = cell.coordinate for d_coord in offsets: n_coord = tuple(c + dc for c, dc in zip(coord, d_coord)) if self.torus: n_coord = tuple(nc % d for nc, d in zip(n_coord, self.dimensions)) if all(0 <= nc < d for nc, d in zip(n_coord, self.dimensions)): cell.connect(self._cells[n_coord], d_coord) def _connect_single_cell_2d(self, cell: T, offsets: list[tuple[int, int]]) -> None: i, j = cell.coordinate height, width = self.dimensions for di, dj in offsets: ni, nj = (i + di, j + dj) if self.torus: ni, nj = ni % height, nj % width if 0 <= ni < height and 0 <= nj < width: cell.connect(self._cells[ni, nj], (di, dj)) def __getstate__(self) -> dict[str, Any]: """Custom __getstate__ for handling dynamic GridCell class and PropertyDescriptors.""" state = super().__getstate__() state = {k: v for k, v in state.items() if k != "cell_klass"} return state def __setstate__(self, state: dict[str, Any]) -> None: """Custom __setstate__ for handling dynamic GridCell class and PropertyDescriptors.""" super().__setstate__(state) for layer in self._mesa_property_layers.values(): setattr( self.cell_klass, layer.name, create_property_accessors( layer.data, docstring=f"accessor for {layer.name}" ), )
[docs] class OrthogonalMooreGrid(Grid[T]): """Grid where cells are connected to their 8 neighbors. Example for two dimensions: directions = [ (-1, -1), (-1, 0), (-1, 1), ( 0, -1), ( 0, 1), ( 1, -1), ( 1, 0), ( 1, 1), ] """ def _connect_cells_2d(self) -> None: # fmt: off offsets = [ (-1, -1), (-1, 0), (-1, 1), ( 0, -1), ( 0, 1), ( 1, -1), ( 1, 0), ( 1, 1), ] # fmt: on for cell in self.all_cells: self._connect_single_cell_2d(cell, offsets) def _connect_cells_nd(self) -> None: offsets = list(product([-1, 0, 1], repeat=len(self.dimensions))) offsets.remove((0,) * len(self.dimensions)) # Remove the central cell for cell in self.all_cells: self._connect_single_cell_nd(cell, offsets)
[docs] class OrthogonalVonNeumannGrid(Grid[T]): """Grid where cells are connected to their 4 neighbors. Example for two dimensions: directions = [ (0, -1), (-1, 0), ( 1, 0), (0, 1), ] """ def _connect_cells_2d(self) -> None: # fmt: off offsets = [ (-1, 0), (0, -1), (0, 1), ( 1, 0), ] # fmt: on for cell in self.all_cells: self._connect_single_cell_2d(cell, offsets) def _connect_cells_nd(self) -> None: offsets: list[tuple[int, ...]] = [] dimensions = len(self.dimensions) for dim in range(dimensions): for delta in [ -1, 1, ]: # Move one step in each direction for the current dimension offset = [0] * dimensions offset[dim] = delta offsets.append(tuple(offset)) for cell in self.all_cells: self._connect_single_cell_nd(cell, offsets)
[docs] class HexGrid(Grid[T]): """A Grid with hexagonal tilling of the space. Note: When torus=True, both width and height must be even. Raises: ValueError: If torus=True and either width or height is odd. """ def __init__( self, dimensions: Sequence[int], torus: bool = False, capacity: float | None = None, random: Random | None = None, cell_klass: type[T] = Cell, ) -> None: """Initialize the hex grid. Args: dimensions: the dimensions of the space torus: whether the space wraps capacity: capacity of the grid cell random: a random number generator cell_klass: the base class to use for the cells """ super().__init__( dimensions=dimensions, torus=torus, capacity=capacity, random=random, cell_klass=cell_klass, ) self._init_hex_geometry() def _init_hex_geometry(self) -> None: """Calculate physical positions for all cells and build KD-Tree. Refer https://www.redblobgames.com/grids/hexagons/#hex-to-pixel for more detail """ positions = [] self._kdtree_coords = [] size = 1.0 for coord, cell in self._cells.items(): col, row = coord x = size * math.sqrt(3) * (col + 0.5 * (row % 2)) y = size * 1.5 * row position = np.array([x, y]) cell.position = position positions.append(position) self._kdtree_coords.append(coord) self._kdtree = KDTree(np.array(positions))
[docs] def find_nearest_cell(self, position: np.ndarray) -> T: """Find the hex cell at the given position.""" position = np.asarray(position) if self.torus: width_pixels = self.dimensions[0] * math.sqrt(3) height_pixels = self.dimensions[1] * 1.5 position = np.array( [position[0] % width_pixels, position[1] % height_pixels] ) _, index = self._kdtree.query(position) coord = self._kdtree_coords[index] return self._cells[coord]
def _connect_cells_2d(self) -> None: # fmt: off even_offsets = [ (-1, -1), (0, -1), ( -1, 0), ( 1, 0), ( -1, 1), (0, 1), ] odd_offsets = [ (0, -1), (1, -1), ( -1, 0), ( 1, 0), ( 0, 1), ( 1, 1), ] # fmt: on for cell in self.all_cells: i = cell.coordinate[1] offsets = even_offsets if i % 2 else odd_offsets self._connect_single_cell_2d(cell, offsets=offsets) def _connect_cells_nd(self) -> None: raise NotImplementedError("HexGrids are only defined for 2 dimensions") def _validate_parameters(self): super()._validate_parameters() if len(self.dimensions) != 2: raise ValueError("HexGrid must have exactly 2 dimensions.") if self.torus and (self.width % 2 != 0 or self.height % 2 != 0): raise ValueError( "HexGrid with torus=True requires both width and height to be even." )