Source code for utopya.eval.plots.abm

"""Implements plot functions for agent-based models, e.g. of agents and their
movement in a domain."""

import copy
import itertools
import logging
from collections import deque
from math import sqrt
from typing import Dict, Optional, Sequence, Tuple, Union

import numpy as np
import xarray as xr
from dantro.base import BaseDataContainer
from dantro.plot import ColorManager
from dantro.tools import recursive_update
from matplotlib.collections import LineCollection, PathCollection
from matplotlib.legend import Legend
from matplotlib.legend_handler import HandlerBase, HandlerPathCollection
from matplotlib.markers import MarkerStyle
from matplotlib.patches import Patch
from matplotlib.path import Path
from matplotlib.transforms import AffineDeltaTransform

from ...tools import ensure_dict, is_iterable
from .. import PlotHelper, is_plot_func
from ._mpl import adjust_figsize_to_aspect

log = logging.getLogger(__name__)

# -----------------------------------------------------------------------------

MARKERS: Dict[str, Path] = {
    "wedge": Path(
        [(0.0, -0.1), (0.3, 0.0), (0.0, 0.1), (0.0, 0.0)],
        codes=[Path.MOVETO, Path.LINETO, Path.LINETO, Path.CLOSEPOLY],
    ),
    "fish": Path(
        [
            (-5, 0),
            (-2, +0.2),
            (+3, +2),
            (+3, 0),
            #
            (+3, -2),
            (-2, -0.2),
            (-5, 0),
        ],
        codes=[
            Path.MOVETO,
            Path.CURVE4,
            Path.CURVE4,
            Path.CURVE4,
            #
            Path.CURVE4,
            Path.CURVE4,
            Path.CURVE4,
        ],
    ),
    "fish2": Path(
        [
            (-5, 0),
            (-2, +0.2),
            (+3, +2.5),
            (+3, 0),
            #
            (+3, -2.5),
            (-2, -0.2),
            (-5, 0),
            #
            # tail
            (-3, 0),
            (-5.5, -1.5),
            (-5, 0),
            (-5.5, +1.5),
            #
            # fins left & right
            (+2, 0),
            (-2, -1.3),
            (-1, 0),
            (-2, +1.3),
        ],
        codes=[
            Path.MOVETO,
            Path.CURVE4,
            Path.CURVE4,
            Path.CURVE4,
            #
            Path.CURVE4,
            Path.CURVE4,
            Path.CURVE4,
            #
            Path.MOVETO,
            Path.LINETO,
            Path.LINETO,
            Path.LINETO,
            #
            Path.MOVETO,
            Path.LINETO,
            Path.LINETO,
            Path.LINETO,
        ],
    ),
}
"""Custom markers that can be used in :py:class:`.AgentCollection`, an
essential part of :py:func:`.abmplot`.

Available markers:

- ``wedge``: A triangle that can be used to denote orientations.
- ``fish``: A simple fish-like shape.
- ``fish2``: A more elaborate fish-like shape.

.. image:: ../_static/_gen/abmplot/markers.pdf
    :target: ../_static/_gen/abmplot/markers.pdf
    :width: 100%

.. hint::

    To have consistent orientations, note that markers should point along the
    positive x-axis. With increasing orientation values, they are rotated in
    counter-clockwise direction.
    Angles are measured in radians.

.. note::

    Suggestions for additional markers are
    `welcome <https://gitlab.com/utopia-project/utopya/-/issues>`_.
"""
# TODO add ant-like shape
# TODO tweak shapes to consist only of a single path, not two overlapping ones


[docs] class AgentCollection(PathCollection): """A collection of agent markers and (optionally) their position history in the form of a tail. Combines :py:class:`~matplotlib.collections.PathCollection` with a :py:class:`~matplotlib.collections.LineCollection` for the tails. """
[docs] def __init__( self, xy: Union[np.ndarray, Sequence[Union[Path, Tuple[float, float]]]], *, marker: Union[Path, Patch, str, dict] = "auto", default_marker: Union[Path, Patch, str, dict] = "o", default_marker_oriented: Union[Path, Patch, str, dict] = "wedge", sizes: Union[float, Sequence[float]] = None, size_scale: float = 1.0, normalize_marker_sizes: bool = True, orientations: Union[float, Sequence[float]] = None, base_orientation: float = 0.0, tail_length: int = 0, tail_kwargs: dict = None, tail_decay: float = 0.0, tail_max_segment_length: float = None, use_separate_paths: bool = True, **kwargs, ): """Set up an AgentCollection, visualizing agents and (optionally) their position history using a tail. .. hint:: If you need a higher performance for this plot, consider not using the *decaying* tails, for which the line segments need to be drawn individually. Args: xy (Sequence[Tuple[float, float]]): The positions of the agents. The number of agents is extracted from the length of this sequence or ``(x, y)`` positions or ``(N, 2)`` array. marker (Union[matplotlib.path.Path, matplotlib.patches.Patch, str, dict], optional): The marker to use to denote agent positions. The following input types are possible: - A :py:class:`~matplotlib.path.Path` is used as it is - A :py:class:`dict` is used as keyword arguments to the :py:class:`~matplotlib.path.Path`, where ``vertices`` can be a sequence of 2-tuples that defines the vertices of the path. - A :py:class:`~matplotlib.patches.Patch` already defines a path, which is extracted and used as marker. All other properties are ignored. - :py:class:`str` is looked up in :py:data:`.MARKERS`; if no such entry is available, the remaining options are evaluated. - All other input is used to construct a :py:class:`~matplotlib.markers.MarkerStyle` from which the path is extracted. default_marker (Union[matplotlib.path.Path, matplotlib.patches.Patch, str, dict], optional): The marker to use in case that ``marker`` is None and agents are *not* initialized with an orientation. default_marker_oriented (Union[matplotlib.path.Path, matplotlib.patches.Patch, str, dict], optional): The marker to use in case that ``marker`` is None and agents *are* initialized with an orientation. sizes (Union[float, Sequence[float]], optional): Agent sizes in units of area (of bounding box) squared. If a scalar value is given, all agents will have the same size. size_scale (float, optional): A scaling factor for the sizes. normalize_marker_sizes (bool, optional): If True, will normalize the size of the given ``marker``. This is done by computing the area of the bounding box and scaling the path such that the bounding box has an area of one. orientations (Union[float, Sequence[float]], optional): Single or individual orientations of the agents. Zero corresponds to pointing along the positive x-axis, with increasing orientation values leading to a counter-clockwise rotation. Angles need to be given in radians. base_orientation (float, optional): Rotates the given marker by this value such that all ``orientations`` have a constant offset. tail_length (int, optional): For non-zero integers, will draw a tail behind the agents that traces their position history. This is only relevant if the collection is used as part of an animation where the position history can be aggregated upon each call to :py:meth:`.set_offsets`. tail_kwargs (dict, optional): Keyword arguments that are passed to the :py:class:`~matplotlib.collections.LineCollection` which is used to represent the tails. tail_decay (float, optional): A decay factor. If non-zero, will multiply each consecutive tail segments ``alpha`` value with ``(1 - tail_decay)``. tail_max_segment_length (float, optional): If set, will *not* draw tail segments that are longer than this value in either the x- or the y-direction. This can be useful if agents move in a periodic space where positions may jump. use_separate_paths (bool, optional): If False, only a single path is used to represent all agents in the underlying :py:class:`~matplotlib.collections.PathCollection`. .. warning:: Setting this to False should *only* be done if all agents can be thought of as identical, i.e. have the same size, color, marker, and orientation — otherwise this leads to subtle and seemingly random breakage in the form of markers not being drawn. **kwargs: Passed on to :py:class:`~matplotlib.collections.PathCollection` """ xy = np.array(xy) if xy.ndim != 2 or xy.shape[1] != 2: raise ValueError( "Agent positions need to be of shape `(N, 2)` but were " f"{xy.shape}! Given data was:\n{str(xy)}" ) # Attributes self._normalize_marker_sizes = normalize_marker_sizes self._size_scale = size_scale self._size_factors = None self._has_orientation = orientations is not None self._R = np.empty((2, 2)) # re-usable rotation matrix self._base_vertices = None self._base_orientation = base_orientation self._orientations = None self._tail_length = tail_length self._tails = None self._tail_kwargs = ensure_dict(tail_kwargs) self._tail_decay = tail_decay self._tail_max_segment_length = tail_max_segment_length self._offsets_history = None # Decide whether they need to be oriented, in which case we need many # separate Path objects. Otherwise, a single Path suffices. if marker == "auto": marker = ( default_marker_oriented if self.has_orientation else default_marker ) self._markerpath = self._prepare_marker_path(marker) self._has_markers = bool(marker) if use_separate_paths or self.has_orientation: paths = [copy.deepcopy(self.markerpath) for _ in range(len(xy))] else: # A bit simpler, only need a single marker # FIXME This causes issues in some scenarios, e.g. by mapping # the color only to the first path or not drawing # markers at all! # Perhaps always use separate markerpaths?! paths = [self.markerpath] # Compute size factors to have reasonably sized markers (and not # markers in data units, which typically does not make a lot of sense) self._size_factors = self._compute_size_factors(paths) # Now setup the underlying PathCollection super().__init__(paths, offsets=xy, sizes=self._size_factors, **kwargs) # Optionally, set sizes if sizes is not None: self.set_markersizes(sizes) # Optionally, set orientations if self.has_orientation: # Store copies of the base vertices which will be used for # rotating the markers. These base vertices are used for the # rotation in order to not accumulate rounding errors etc. self._base_vertices = [p.vertices.copy() for p in self.get_paths()] # Base vertices should not be writeable, but the patches' vertices # need to be writeable. Take care of that here, reduces repetitions for _bv in self._base_vertices: _bv.setflags(write=False) # for safety # Can apply orientations now self.set_orientations(orientations) # Optionally, prepare for drawing tails if self.tail_length: self._offsets_history = deque([], self.tail_length + 1) self._add_to_offsets_history(self.get_offsets()) if not self._tail_decay: self._segments = np.empty((len(self), self.tail_length + 1, 2)) else: self._segments = np.empty((len(self) * self.tail_length, 2, 2)) self._segments.fill(np.nan) self.draw_tails()
[docs] def __len__(self) -> int: """Number of agents associated with this AgentCollection""" return len(self.get_offsets())
@property def has_orientation(self) -> bool: """Whether agent orientation is represented""" return self._has_orientation @property def has_markers(self) -> bool: """Whether a marker will be shown; if disabled, can skip setting array data and thus increase performance somewhat ...""" return self._has_markers @property def tail_length(self) -> int: """Length of the tail of each agent, i.e. the number of *historic* offset positions that are connected into a line.""" return self._tail_length @property def tails(self) -> Optional[LineCollection]: """The :py:class:`~matplotlib.collections.LineCollection` that represents the tails of the agents — or None if tails are not activated. .. note:: This collection still needs to be registered *once* with the axes via :py:meth:`~matplotlib.axes.Axes.add_collection`. """ return self._tails @property def markerpath(self) -> Optional[Path]: """Returns the used markerpath""" return self._markerpath
[docs] def draw_tails(self, **kwargs) -> LineCollection: """Draws the tails *from scratch*, removing the currently added collection if there is one. .. note:: If calling this, the collection representing the tails needs to be re-added to the desired axes. To update tail positions according to the offsets history, use :py:meth:`.update_tails`. """ if not self.tail_length: raise ValueError( "Cannot add tails to AgentCollection after initialization!" ) if kwargs: kwargs = recursive_update(copy.deepcopy(self._tail_kwargs), kwargs) else: kwargs = self._tail_kwargs if self._tails is not None: try: self._tails.remove() except: pass self._tails = LineCollection(self._build_line_segments(), **kwargs) if self._tail_decay != 0: alpha = self.tails.get_alpha() if not isinstance(alpha, float): alpha = 1.0 decaying_alpha = itertools.accumulate( [alpha] * self.tail_length, lambda a, _: a * (1.0 - self._tail_decay), ) self.tails.set_alpha(list(decaying_alpha) * len(self)) return self.tails
[docs] def update_tails(self) -> LineCollection: """Updates the collection that holds the tails using the offsets history. """ if not self.tail_length: raise ValueError( "Cannot add tails to AgentCollection after initialization!" ) self.tails.set_segments(self._build_line_segments()) return self.tails
[docs] def set(self, *args, **kwargs): """Sets AgentCollection artist properties""" return super().set(*args, **kwargs)
[docs] def set_markersizes(self, sizes): """Sets the marker sizes, taking into account the size scaling. .. note:: This behaves differently from the inherited ``set_sizes`` method and thus has a different name. The difference is that :py:meth:`~matplotlib.collections.PathCollection.set_sizes` is also called during saving of a plot and would thus lead to repeated application of the size scaling, which is not desired. """ if not is_iterable(sizes): sizes = np.array([sizes] * len(self)) super().set_sizes(self._size_factors * sizes)
[docs] def set_offsets(self, offsets) -> None: """Sets the offsets, i.e. path positions of all agents. Invokes :py:meth:`~matplotlib.collections.PathCollection.set_offsets` and afterwards stores the offset history, using it for :py:meth:`.update_tails`. """ super().set_offsets(offsets) if self.tail_length: self._add_to_offsets_history(offsets) self.update_tails()
[docs] def set_orientations( self, orientations: Union[float, Sequence[float]] ) -> None: """Set the orientations of all agents. If a scalar is given, will set all orientations to that value. If any kind of sequence is given, will associate individual orientations with agents. .. note:: This can only be used if the collection was *initialized* with ``orientations``, see :py:meth:`.__init__`. """ if not self.has_orientation: raise ValueError( "AgentCollection was not initialized with orientations; " "cannot change them after initialization!" ) if not is_iterable(orientations): orientations = [orientations] * len(self._base_vertices) paths = self.get_paths() if not (len(paths) == len(orientations) == len(self._base_vertices)): raise ValueError( f"There was a length mismatch between the number of paths in " f"this collection ({len(paths)}), the given values for the " f"new orientations ({len(orientations)}), and the number of " f"stored base vertices ({len(self._base_vertices)})!" ) for path, _o, _bv in zip(paths, orientations, self._base_vertices): self._rotate_path(path, orientation=_o, base_vertices=_bv) self._orientations = orientations
[docs] def get_orientations(self) -> Sequence[float]: """Returns the current orientations of all agents. .. note:: Changing the returned object does not rotate the agents! Use :py:meth:`.set_orientations` to achieve that. """ if not self.has_orientation: raise ValueError( "AgentCollection was not initialized with orientations!" ) return self._orientations
# .........................................................................
[docs] def _prepare_marker_path( self, marker: Union[Path, Patch, str, list, dict] ) -> Path: """Generates the :py:class:`~matplotlib.path.Path` needed for the agent marker using a variety of setup methods, depending on argument type: - :py:class:`~matplotlib.path.Path` is used as it is - :py:class:`dict` is used to call :py:class:`~matplotlib.path.Path` - :py:class:`~matplotlib.patches.Patch` already defines a path, which is extracted and used. - :py:class:`str` is looked up in :py:data:`.MARKERS`; if no such entry is available, the remaining options are evaluated. - Other types construct a :py:class:`~matplotlib.markers.MarkerStyle` from which the path is extracted. After constructing the path, applies the base orientation by rotating it by that value. """ if isinstance(marker, Path): path = marker elif isinstance(marker, dict): path = Path(**marker) elif isinstance(marker, Patch): path = copy.deepcopy(marker.get_path()) elif isinstance(marker, str) and marker in MARKERS: path = copy.deepcopy(MARKERS[marker]) else: path = MarkerStyle(marker).get_path() path.vertices.setflags(write=True) self._rotate_path( path, base_vertices=path.vertices.copy(), orientation=self._base_orientation, ) return path
[docs] def _compute_size_factors(self, paths: Sequence[Path]) -> np.ndarray: """Given a sequence of Paths, computes the corresponding size factors which are used in :py:meth:`.set_markersizes` to scale the individual marker sizes. The resulting sizes are multiplied with the given size scale. If marker size normalization is activated, size factors may vary between the markers. """ def bbox_area(p) -> float: bbox = p.get_extents() return bbox.height * bbox.width if not self._normalize_marker_sizes: return self._size_scale * np.ones((len(paths),)) return self._size_scale * np.array([1 / bbox_area(p) for p in paths])
[docs] def _add_to_offsets_history(self, new_offsets: np.ndarray) -> None: """Adds the given offset to the left side of the deque, pushing old entries off the right side of the deque. If the tail segment length is limited, will mask the offset positions for individual agents if they differ by more than that value from their previous position. *Note* that in this case the history is no longer strictly correct; it is already evaluated for use in :py:meth:`.draw_tails` in order to increase performance and avoid evaluating this over and over again. """ max_delta = self._tail_max_segment_length if max_delta is not None and len(self._offsets_history) >= 1: # Need to mask offset jumps larger than the maximum delta. # This will lead to tail segments not being drawn. new_offsets = new_offsets.copy() abs_delta = np.abs(self._offsets_history[0] - new_offsets) new_offsets[abs_delta > max_delta] = np.nan self._offsets_history.appendleft(new_offsets)
[docs] def _build_line_segments(self) -> np.ndarray: """Creates the line segments for the tails using the offsets history""" TL = self._tail_length if not self._tail_decay: # Can simply use the offsets history (with changed slicing) for i, _offsets in enumerate(self._offsets_history): self._segments[:, TL - i, :] = _offsets else: # Have a more complicated segments structure: # axis 0: [0, TL) segments of first agent's tail, # [TL, 2TL) segments of second agent's tail, # etc., with lower indices being for more recent history. # axis 1: start and end point # axis 2: x and y coordinates (of start and end point) hist = list(self._offsets_history) for i, (_p0, _p1) in enumerate(zip(hist[:-1], hist[1:])): # _p0 and _p1: offsets of *all* agents at adjacent times self._segments[i::TL, 0, :] = _p0 self._segments[i::TL, 1, :] = _p1 return self._segments
[docs] def _rotate_path( self, path: Path, *, base_vertices: np.ndarray, orientation: float ) -> None: """Sets the given path's vertices by rotating (a copy of) the base vertices by the given ``orientation`` (in radians).""" _cos, _sin = np.cos(orientation), np.sin(orientation) self._R[0, 0] = _cos self._R[1, 0] = _sin self._R[0, 1] = -_sin self._R[1, 1] = _cos path._vertices[...] = np.dot(base_vertices, self._R.T) path._orientation = orientation
[docs] class HandlerAgentCollection(HandlerPathCollection): """Legend handler for :py:class:`.AgentCollection` instances."""
[docs] def create_collection( self, orig_handle, sizes, offsets, offset_transform ) -> AgentCollection: """Returns an AgentCollection from which legend artists are created.""" ac = AgentCollection( offsets, marker=orig_handle.markerpath, sizes=sizes, # FIXME too small? offset_transform=offset_transform, ) # TODO What about the colormap etc? return ac
[docs] def create_artists( self, legend, orig_handle, xdescent, ydescent, width, height, fontsize, trans, ): """Creates the artists themselves. For original docstring see :py:meth:`matplotlib.legend_handler.HandlerBase.create_artists`.""" xdata, xdata_marker = self.get_xdata( legend, xdescent, ydescent, width, height, fontsize ) ydata = self.get_ydata( legend, xdescent, ydescent, width, height, fontsize ) sizes = self.get_sizes( legend, orig_handle, xdescent, ydescent, width, height, fontsize ) p = self.create_collection( orig_handle, sizes, offsets=list(zip(xdata_marker, ydata)), offset_transform=trans, ) self.update_prop(p, orig_handle, legend) p.set_offset_transform(trans) return [p]
Legend.update_default_handler_map({AgentCollection: HandlerAgentCollection()}) # -----------------------------------------------------------------------------
[docs] def _get_data(d: Union[xr.DataArray, xr.Dataset], var: str) -> xr.DataArray: """Retrieve data by key, e.g. for hue, size, ...""" try: return d[var] except Exception as exc: if isinstance(d, xr.DataArray): if d.name == var: return d raise ValueError( "If using xr.DataArray for agent data, make sure to have the " f"array's `name` attribute set (in this case to: '{var}')!\n" f"Got: {str(d)}" ) from exc raise ValueError( f"Failed retrieving agent data variable '{var}' from:\n{d}\n" f"Got a {type(exc).__name__}: {exc}\n" f"Make sure that the data variable '{var}' exists " f"or pass an xr.DataArray with name '{var}' instead." ) from exc
[docs] def _parse_vmin_vmax( specs: dict, d: Union[xr.DataArray, xr.Dataset], *, vmin_key: str = "vmin", vmax_key: str = "vmax", ) -> dict: """Adapts ``specs`` dict entries ``vmin_key`` and ``vmax_key``, replacing their *values* with the minimum or maximum value of given data ``d``.""" if specs.get(vmin_key) == "min": specs[vmin_key] = d.min().item() if specs.get(vmax_key) == "max": specs[vmax_key] = d.max().item() return specs
[docs] def _set_domain( *, ax: "matplotlib.axes.Axes", layers: dict, mode: str = "auto", extent: Tuple[float, float, float, float] = None, height: float = None, aspect: Union[float, int, str] = "auto", update_only: bool = False, pad: float = 0.05, ) -> Optional[Tuple[float, float]]: """Sets the axis limits and computes the area of the domain Args: ax (matplotlib.axes.Axes): The axes to adjust layers (dict): Layer information prepared by :py:func:`.abmplot` mode (str, optional): Domain mode, can be ``auto``, ``manual``, ``fixed`` and ``follow``. In fixed mode, all available data in ``layers`` is inspected to derive domain bounds. In ``follow`` mode, the currently used collection positions are used to determine the boundaries. In ``auto`` mode, will use a ``fixed`` domain if *no* ``extent`` was given, otherwise ``manual`` mode is used. extent (Tuple[float, float, float, float], optional): A manually set extent in form ``(left, right, bottom, top)``. Alternatively, a 2-tuple will be interpreted as ``(0, right, 0, top)``. height (float, optional): The height of the domain in data units aspect (Union[float, int, str], optional): The aspect ratio of the domain, the width being computed via ``height * aspect``. If ``auto``, will update_only (bool, optional): If true, will only make changes to the domain bounds if necessary, e.g. in ``follow`` mode. pad (float, optional): Relative padding added to each side. Returns: Optional[Tuple[float, float]]: The domain area (in data units squared) and the aspect ratio, i.e. ``width/height``. The return value will be None if ``update_only`` was set. """ # Determine mode if mode == "auto": mode = "fixed" if extent is None else "manual" # May already have a new extent specified if extent is not None: try: _l, _r, _b, _t = extent except: try: _l, _b = 0, 0 _r, _t = extent except: raise else: # Get the "previous" values, just to have something set _l, _r = ax.get_xlim() _b, _t = ax.get_ylim() # Depending on mode, the values may further change if mode is None or mode == "manual": if update_only: return elif isinstance(mode, str): # Decide which data source to use colls = [lyr["coll"] for lyr in layers.values()] have_colls = all([c is not None for c in colls]) if mode == "fixed" and update_only: # No need to update return if mode == "fixed" or (mode == "follow" and not have_colls): # Get all the data # NOTE Will need adaption if plotting on multiple axes xdata = [lyr["data"][lyr["x"]] for lyr in layers.values()] ydata = [lyr["data"][lyr["y"]] for lyr in layers.values()] elif mode == "follow": # Get the existing offset data from the collection xdata = [coll.get_offsets()[:, 0] for coll in colls] ydata = [coll.get_offsets()[:, 1] for coll in colls] else: raise ValueError( f"Got invalid `mode` argument with value '{mode}'! " "Possible modes: fixed, follow, manual, auto" ) # Determine bounds _l = min([x.min().item() for x in xdata]) _r = max([x.max().item() for x in xdata]) _b = min([y.min().item() for y in ydata]) _t = max([y.max().item() for y in ydata]) else: raise TypeError( f"Got invalid `mode` argument of type {type(mode)}, expected str!" ) # Ensure the domain has a fixed aspect ratio if aspect == "auto": if mode == "follow": aspect = 1.0 else: aspect = abs(_r - _l) / abs(_t - _b) if aspect is not None: current_width = abs(_r - _l) current_height = abs(_t - _b) if height: target_width = height * aspect target_height = height else: target_width = max(current_width, current_height * aspect) target_height = max(current_height, current_width / aspect) if current_width < target_width: d = target_width - current_width _l -= d / 2 _r += d / 2 if current_height < target_height: d = target_height - current_height _b -= d / 2 _t += d / 2 # Apply padding if pad is not None: _w = abs(_r - _l) _h = abs(_t - _b) _l -= _w * pad _r += _w * pad _b -= _h * pad _t += _h * pad # Finally, compute domain area and actually set the limits domain_area = abs(_r - _l) * abs(_t - _b) actual_aspect = abs(_r - _l) / abs(_t - _b) ax.set_xlim(_l, _r) ax.set_ylim(_b, _t) return domain_area, actual_aspect
# -----------------------------------------------------------------------------
[docs] def draw_agents( d: Union[xr.Dataset, xr.DataArray], *, x: str, y: str, ax: "matplotlib.axes.Axes" = None, _collection: Optional[AgentCollection] = None, _domain_area: float = 1, marker: Union[Path, Patch, str, list, dict] = "auto", hue: str = None, orientation: Union[str, float, Sequence[float]] = None, size: Union[str, float, Sequence[float]] = None, size_norm: Union[str, dict, "matplotlib.colors.Normalize"] = None, size_vmin: float = None, size_vmax: float = None, size_scale: float = 0.0005, label: str = None, cmap: Union[str, dict] = None, norm: Union[str, dict] = None, vmin: Union[float, str] = None, vmax: Union[float, str] = None, cbar_label: str = None, cbar_labels: dict = None, add_colorbar: bool = None, cbar_kwargs: dict = None, **coll_kwargs, ) -> AgentCollection: """Draws agent positions onto the given axis. Data can be a :py:class:`~xarray.DataArray` or :py:class:`~xarray.Dataset` with ``x``, ``y`` and the other encodings ``hue``, ``orientation`` and ``size`` being strings that denote Args: d (Union[xarray.Dataset, xarray.DataArray]): The agent data. The recommended format is that of datasets, because it can hold more data variables. If using data *arrays*, multiple coordinates can be used to set the data. Also, the ``name`` attribute needs to be set (to avoid accidentally using the wrong data). x (str): Name of the data dimension or variable that holds x positions. y (str): Name of the data dimension or variable that holds y positions. ax (matplotlib.axes.Axes, optional): The axis to plot to; will use the current axis if not given. _collection (Optional[AgentCollection], optional): If a collection exists, will update this collection instead of creating a new one. _domain_area (float, optional): The domain area, used to normalize the marker sizes for a consistent visual appearance. This is only needed when initializing the collection and will set its ``size_scale`` parameter to ``size_scale * _domain_area``. marker (Union[matplotlib.path.Path, matplotlib.patches.Patch, str, list, dict], optional): Which marker to use for the agents. See :py:class:`.AgentCollection` for available arguments. hue (str, optional): The name of the data dimension or variable that holds the values from which marker colors are determined. orientation (Union[str, float, Sequence[float]], optional): The name of the data variable or dimension that holds the orientation data (in radians). If scalar or sequence, will pass those explicit values through. size (Union[str, float, Sequence[float]], optional): The name of the data variable or dimension that holds the values from which the marker size is determined. size_norm (Union[str, dict, matplotlib.colors.Normalize], optional): A normalization for the sizes. The :py:class:`~dantro.plot.utils.color_mngr.ColorManager` is used to set up this normalization, thus offering all the corresponding capabilities. Actual sizes are computed by `xarray.plot.utils._parse_size``. size_vmin (float, optional): The value that is mapped to the smallest size of the markers. size_vmax (float, optional): The value that is mapped to the largest size of the markers. size_scale (float, optional): A scaling factor that is applied to the size of all markers; use this to control overall size. label (str, optional): How to label the collection that is created. Will be used as default colorbar label and for the legend. cmap (Union[str, dict], optional): The colormap to use for the markers. Supports :py:class:`~dantro.plot.utils.color_mngr.ColorManager`. norm (Union[str, dict], optional): Norm to use for the marker colors. Supports :py:class:`~dantro.plot.utils.color_mngr.ColorManager`. vmin (Union[float, str], optional): The value of the ``hue`` dimension that is mapped to the beginning of the colormap. vmax (Union[float, str], optional): The value of the ``hue`` dimension that is mapped to the end of the colormap. cbar_labels (dict, optional): Colorbar labels, parsed by the :py:class:`~dantro.plot.utils.color_mngr.ColorManager`. add_colorbar (bool, optional): Whether to add a colorbar or not. Only used upon creating the collection. cbar_kwargs (dict, optional): Passed on to :py:meth:`~dantro.plot.utils.color_mngr.ColorManager.create_cbar` **coll_kwargs: Passed on to :py:class:`.AgentCollection`; see there for further available arguments. """ import matplotlib.pyplot as plt def parse_sizes(d: np.ndarray, size_norm) -> np.ndarray: from xarray.plot.utils import _MARKERSIZE_RANGE, _parse_size # Use ColorManager to evaluate the size_norm argument size_norm = ColorManager( norm=size_norm, vmin=size_vmin, vmax=size_vmax ).norm # Compute and apply the size mapping. # Need a correction factor to let mapped sizes appear approximately # equally large as unmapped ones. With _parse_sizes returning values # within _MARKERSIZE_RANGE (typically [18, 72]), the correction is # simply to divide by a value of *approximately* that magnitude. size_mapping = _parse_size(d, size_norm) corr_factor = 1 / np.array(_MARKERSIZE_RANGE).mean() return corr_factor * size_mapping.loc[d.values.ravel()].values def get_orientations(d: Union[xr.DataArray, xr.Dataset]): """If ``orientation`` is a string, looks up the orientation data, otherwise simply passes the value through, letting the AgentManager handle the argument.""" if isinstance(orientation, str): return _get_data(d, orientation) return orientation def get_xy(d: Union[xr.DataArray, xr.Dataset]) -> np.ndarray: """Retrieve and aggregate x and y position data""" xpos = _get_data(d, x) ypos = _get_data(d, y) stacked = np.stack([xpos, ypos], -1) return stacked # equivalent to np.dstack([xpos, ypos])[0], but that one is slower # ......................................................................... if isinstance(d, BaseDataContainer): d = d.data if not isinstance(d, (xr.Dataset, xr.DataArray)): raise TypeError(f"Expected xr.Dataset or xr.DataArray, got {type(d)}!") if ax is None: ax = plt.gca() # Depending on the collection existing or not, update it or create it. coll = _collection if coll is not None: # Only need to update the existing collection's offsets, orientations, # array data, sizes ... coll.set_offsets(get_xy(d)) if orientation: coll.set_orientations(get_orientations(d)) if hue and coll.has_markers: coll.set_array(_get_data(d, hue)) if vmin is None or vmax is None: coll.autoscale() coll.set_clim(vmin, vmax) if size and isinstance(size, str): # TODO Make this more efficient? coll.set_markersizes(parse_sizes(_get_data(d, size), size_norm)) return coll # else: Still need to create everything ................................... cm = ColorManager( cmap=cmap, norm=norm, vmin=vmin, vmax=vmax, labels=cbar_labels ) # Prepare size data ... if size: if "sizes" in coll_kwargs: raise ValueError( "Cannot pass both `sizes` and `size` arguments! " "Use `size` to pass marker sizes via a data dimension, a " "sequence of sizes, or a scalar size (same for all agents)." ) if isinstance(size, str): coll_kwargs["sizes"] = parse_sizes(_get_data(d, size), size_norm) else: # Let the collection take care of this coll_kwargs["sizes"] = size # Create collection # Need to pass orientations here already, because it determines whether # this will be a heterogeneous AgentCollection (with individual paths) or # a homogeneous one (with a single path, copied for each agent). coll = AgentCollection( get_xy(d), offset_transform=AffineDeltaTransform(ax.transData), orientations=get_orientations(d) if orientation else None, label=label, cmap=cm.cmap, norm=cm.norm, size_scale=(size_scale * _domain_area), marker=marker, **coll_kwargs, ) if hue and coll.has_markers: coll.set_array(_get_data(d, hue)) coll.autoscale() coll.set_clim(vmin, vmax) # Add collections to the current axis and ensure equal aspect ratio ax.set_aspect("equal") ax.add_collection(coll) if coll.tails is not None: ax.add_collection(coll.tails) # Create the colorbar cbar = None if add_colorbar or (add_colorbar is None and (hue and marker)): cbar = cm.create_cbar( coll, fig=ax.get_figure(), ax=ax, label=cbar_label, **ensure_dict(cbar_kwargs), ) # NOTE matplotlib already adds it as `coll.colorbar` attribute return coll
# -----------------------------------------------------------------------------
[docs] @is_plot_func(use_dag=True, supports_animation=True) def abmplot( *, data: Dict[str, Union[xr.Dataset, xr.DataArray]], hlpr: PlotHelper, to_plot: Dict[str, dict], frames: str = None, frames_isel: Union[int, slice] = None, domain: Union[str, dict, Tuple[float, float, float, float]] = "fixed", adjust_figsize_to_domain: bool = True, figsize_aspect_offset: float = 0.2, suptitle_fstr: str = "{dim:} = {val:5d}", suptitle_kwargs: dict = dict(fontfamily="monospace"), add_legend: bool = None, **shared_kwargs, ): """Plots agents in a domain, animated over time. Features: - Plot multiple "layers" of agents onto the same axis. - Encode ``hue``, ``size`` and ``orientation`` of all agents. - Keep track of position history through an animation and draw the history as (constant or decaying) tails. - Automatically determine the domain bounds, allowing to dynamically follow the agents through the domain. - Generate marker paths ad-hoc or use pre-defined markers that are relevant in the context of visualizing agent-based models, e.g. ``fish``. Example output: .. raw:: html <video width="720" src="../_static/_gen/abmplot/fish.mp4" controls></video> .. admonition:: Corresponding plot configuration :class: dropdown The following configuration was used to generate the above example animation. .. literalinclude:: ../../tests/cfg/plots/abm_plots.yml :language: yaml :start-after: ### START -- doc_fish :end-before: ### END ---- doc_fish The used dummy data (``circle_walk…``) is an :py:class:`xarray.Dataset` with data variables ``x``, ``y``, ``orientation``, each one spanning dimensions ``time`` and ``agents``. Data variables do not have coordinates in this case, but it would be possible to supply some. .. admonition:: Development roadmap The following extensions of this plot function are *planned*: - Subplot support, allowing to manually distribute `to_plot` entries onto individual subplots. - Faceting support, allowing to automatically distribute *all* data layers onto rows and columns and thus performing a dimensionality reduction on them. .. admonition:: See also - :py:func:`draw_agents` - :ref:`plot_funcs_abm` Args: data (Dict[str, Union[xarray.Dataset, xarray.DataArray]]): A dict holding the data that is to be plotted. hlpr (PlotHelper): The plot helper instance. to_plot (Dict[str, dict]): Individual plot specifications. Each entry refers to one "layer" of agents being plotted and its key needs to match the corresponding entry in ``data``. See :py:func:`.draw_agents` for available arguments. Note that the ``shared_kwargs`` also end up in the entries here. frames (str, optional): Which data dimension to use for the animation. frames_isel (Union[int, slice], optional): An index-selector that is applied on the ``frames`` dimension (*without* dropping empty dimensions resulting from this operation). Scalar values will be wrapped into a list of size one. domain (Union[str, dict, Tuple[float, float, float, float]], optional): Specification of the domain on which the agents are plotted. If a 4-tuple is given, it defines the position of the *fixed* ``(left, right, bottom, top)`` borders of the domain; a 2-tuple will be interpreted as ``(0, right, 0, top)``. If a string is given, it determines the *mode* by which the domain is determined. It can be ``fixed`` (determine bounds from all agent positions, then keep them constant) or ``follow`` (domain dynamically follows agents, keeping them centered). If a dict is given, it is passed to :py:func:`._set_domain`, and offers more options like ``pad`` or ``aspect``. adjust_figsize_to_domain (bool, optional): If True, will adjust the figure size to bring it closer to the aspect ratio of the domain, thus reducing whitespace. figsize_aspect_offset (float, optional): A constant offset that is added to the aspect ratio of the domain to arrive at the target aspect ratio of the figure used for ``adjust_figsize_to_domain``. suptitle_fstr (str, optional): The format string to use for the figure's super-title. If empty or None, no suptitle is added. The format string can include the following keys: ``dim`` (frames dimension name), ``val`` (current coordinate value, will be an index if no coordinates are available), ``coords`` (a dict containing all layers' iterator states), and ``frame_no`` (enumerates the frame number). suptitle_kwargs (dict, optional): Passed to :py:meth:`~matplotlib.figure.Figure.suptitle`. add_legend (bool, optional): Whether to add a legend. ⚠️ This is currently *experimental* and does not work as expected. **shared_kwargs: Shared keyword arguments for all plot specifications. Whatever is specified here is used as a basis for each individual layer in ``to_plot``, i.e. *updated* by the entries in ``to_plot``. See :py:func:`.draw_agents` for available arguments. """ def prepare_layer_specs( name: str, *, shared_kwargs: dict, **plot_kwargs ) -> dict: """Parses arguments for each layer into a dict""" lyr = dict( name=name, data=None, plot_kwargs=None, ax=None, coll=None, frames=None, frames_iter=None, num_frames=None, ) # Prepare plot kwargs plot_kwargs = recursive_update( copy.deepcopy(shared_kwargs), copy.deepcopy(plot_kwargs) ) plot_kwargs["label"] = plot_kwargs.get("label", name) # Store it, exposing some parameters on the upper level lyr["plot_kwargs"] = plot_kwargs lyr["x"] = plot_kwargs["x"] lyr["y"] = plot_kwargs["y"] # Prepare data try: d = data[name] except KeyError as err: _avail = ", ".join(data.keys()) raise ValueError( f"Missing data named '{name}'! Available: {_avail}\n" "Check that the name is correctly defined. If you are using " "the DAG framework for data selection, make sure the tag is " "computed by including it into `compute_only` or setting the " "`force_compute` flag for the tagged node." ) from err # Extract parameters that are not meant to be passed on to the plot frames = plot_kwargs.pop("frames", None) frames_isel = plot_kwargs.pop("frames_isel", None) # Prepare the frames iterator if frames_isel is not None: if isinstance(frames_isel, int): selector = {frames: [frames_isel]} else: selector = {frames: frames_isel} try: d = d.isel(selector, drop=False) except Exception as exc: raise ValueError( f"Failed applying index-selection for layer '{name}'! " "Did you intend to apply a selection to this layer's " "data? If not, set `frames_isel` to None.\n" f" frames: {frames}\n" f" frames_isel: {frames_isel}\n" f" → selector: {selector}\n" f" data:\n{d}" ) from exc frames_iter = None if frames is not None: frames_iter = d.groupby(frames, squeeze=False) # ... and store it all to the layer specs lyr["frames"] = frames lyr["frames_iter"] = frames_iter lyr["num_frames"] = len(frames_iter) if frames_iter else 0 lyr["data"] = d # (Mutably) evaluate vmin & vmax on hue and size dimensions if plot_kwargs.get("hue"): _parse_vmin_vmax(plot_kwargs, _get_data(d, plot_kwargs["hue"])) if plot_kwargs.get("size"): _parse_vmin_vmax( plot_kwargs, _get_data(d, plot_kwargs["size"]), vmin_key="size_vmin", vmax_key="size_vmax", ) return lyr def set_suptitle( fstr: str, *, shared_dim: str, frame_coords: dict, **fstr_kwargs ) -> Tuple[str, float]: """Sets the suptitle, filling it with information about the currently plotted frame, e.g. its coordinate.""" # Need to extract a shared coordinate value shared_coord = None coords = [v for v in frame_coords.values() if v is not None] if coords and all([c == coords[0] for c in coords]): shared_coord = coords[0] # TODO intelligently set fstr? if fstr and shared_dim: fstr_kwargs["dim"] = shared_dim fstr_kwargs["val"] = shared_coord fstr_kwargs["coords"] = frame_coords try: hlpr.fig.suptitle( fstr.format(**fstr_kwargs), **ensure_dict(suptitle_kwargs), ) except Exception as exc: _fstr_kwargs = "\n".join( f" {k}: {v}" for k, v in fstr_kwargs.items() ) raise ValueError( "Failed setting suptitle, probably due to a formatting " "issue! Check the format string and the available " "formatting values are compatible.\n" f" suptitle_fstr: {repr(fstr)}\n" f" fstr_kwargs:\n{_fstr_kwargs}" ) from exc return shared_dim, shared_coord def parse_frame_coords(*iters, names: list) -> dict: """Extracts the current frame coordinate values from the state of the iterators. This is necessary because coordinate values may also be none if some layers are not iterated over. """ def parse_coord(frit: Optional[tuple]): if frit is None: return None c, _ = frit return c coords = { names[i]: parse_coord(frit[0]) for i, frit in enumerate(iters) } # TODO check for equal values in case there *is* a coordinate?! return coords def parse_domain_kwargs( domain: Union[None, str, list, tuple, dict] ) -> dict: if domain is None: return dict() elif isinstance(domain, str): return dict(mode=domain) elif isinstance(domain, (list, tuple)): return dict(extent=domain, pad=0) elif isinstance(domain, dict): return domain raise TypeError( f"Got invalid type {type(domain)} for `domain` argument! " "Valid types are: None, str, tuple, list, dict" ) # ......................................................................... # Work on a copy of the plot spec, to be on the safe side. to_plot = copy.deepcopy(to_plot) # Bring data into expected form and log.note("Preparing data for ABM plot ...") shared_kwargs = dict( frames=frames, frames_isel=frames_isel, **shared_kwargs, ) layers = { k: prepare_layer_specs(k, shared_kwargs=shared_kwargs, **spec) for k, spec in to_plot.items() } layer_names = list(layers) # Check frames iteration num_frames = [layers[lyr]["num_frames"] for lyr in layer_names] if not all([n == 0 or n == num_frames[0] for n in num_frames]): _info = "\n".join( f" {layer_names[i]:12}: {n:4d}" for i, n in enumerate(num_frames) ) raise ValueError( "There was a mismatch in the length of the given frame iterators! " "Iterators need either be of the same length along the `frames` " "dimension or zero (if there is no `frames` dimension for that " f"layer).\nGiven iterator lengths were:\n{_info}" ) else: num_frames = num_frames[0] # May need to switch in to / out of animation mode if num_frames > 1: hlpr.enable_animation() else: hlpr.disable_animation() # Construct the aggregated frames iterators object _frits = [layers[lyr]["frames_iter"] for lyr in layer_names] mock_frit = lambda: [None] * max(1, num_frames) frames_iters = zip( *[frit if frit is not None else mock_frit() for frit in _frits] ) # Determine the name of a potentially existing shared frames dimension shared_frames_dim = None frame_dims = [layers[lyr]["frames"] for lyr in layer_names] unique_frame_dim_names = set([f for f in frame_dims if f]) if unique_frame_dim_names and len(unique_frame_dim_names) == 1: shared_frames_dim = unique_frame_dim_names.pop() # Extract domain information and set the x- and y-axis limits. domain_area, _aspect = _set_domain( ax=hlpr.ax, layers=layers, **parse_domain_kwargs(domain) ) # Inform about the plot ... log.remark(" Data variables: %s", ", ".join(to_plot)) log.remark( " Frames dimension(s): %s", ( shared_frames_dim if shared_frames_dim else ", ".join(f"{f}" for f in frame_dims) ), ) log.remark( " Number of frames: %s", num_frames if num_frames > 1 else "(snapshot)", ) log.remark( " Domain area: %.4g (aspect: %.3g)", domain_area, _aspect ) if adjust_figsize_to_domain: adjust_figsize_to_aspect(_aspect + figsize_aspect_offset, fig=hlpr.fig) _fw, _fh = hlpr.fig.get_size_inches() log.remark(" Figure size adjusted: [%.3g, %.3g]", _fw, _fh) # .. Define single frame and animation update functions ................... # For simplicity, these use objects from the outer scope def plot_frame(frame_data: tuple, *, frame_no: int): """Plots a single frame from the given frame data""" frame_coords = parse_frame_coords(frame_data, names=layer_names) set_suptitle( suptitle_fstr, shared_dim=shared_frames_dim, frame_coords=frame_coords, frame_no=frame_no, ) for lyr_no, coords_and_data in enumerate(frame_data): lyr_spec = layers[layer_names[lyr_no]] # Get the data if coords_and_data is None: # No iterator available; only need to draw once if lyr_spec["coll"] is not None: continue xy = lyr_spec["data"] else: # always draw, using the data from the iterator _, xy = coords_and_data # May need to squeeze out size-1 dimensions that are created from # the .groupby iteration if lyr_spec["frames"]: xy = xy.squeeze(dim=lyr_spec["frames"], drop=True) # Now draw lyr_spec["coll"] = draw_agents( xy, ax=hlpr.ax, **lyr_spec["plot_kwargs"], _collection=lyr_spec["coll"], _domain_area=domain_area, ) # May want to update the domain _set_domain( ax=hlpr.ax, layers=layers, update_only=True, **parse_domain_kwargs(domain), ) hlpr.invoke_enabled() # FIXME should not have to do this here!? def update(): # First frame was already plotted, grab it directly yield # Plot the rest for frame_no, frame_data in enumerate(frames_iters): plot_frame(frame_data, frame_no=frame_no + 1) # Done with this frame; yield control to the animation framework # which will grab the frame... yield if frame_no > 1: log.info("Animation finished.") hlpr.register_animation_update(update) # .. Actual plotting ...................................................... # Plot only the first frame here; all further frames are plotted via the # helper's animation framework. # What is done here is meant for the case of the snapshot and advances the # iterator by one. The captured iterator in the update function above # will then continue from the second frame. plot_frame(frames_iters.__next__(), frame_no=0) # Create the legend if add_legend: log.caution("Adding a legend via abmplot will probably be buggy!") hlpr.ax.legend()