"""This module provides plotting functions to visualize cellular automata."""
import copy
import logging
import warnings
from math import ceil, sqrt
from typing import Callable, Dict, Optional, Sequence, Tuple, Union
import matplotlib as mpl
import matplotlib.collections
import matplotlib.image
import matplotlib.patches
import matplotlib.transforms
import numpy as np
import xarray as xr
from dantro.abc import AbstractDataContainer
from dantro.plot import ColorManager
from dantro.plot.funcs.generic import make_facet_grid_plot
from matplotlib.colors import ListedColormap
from ...tools import ensure_dict, recursive_update
from .. import DataManager, UniverseGroup
from . import PlotHelper, UniversePlotCreator, is_plot_func
log = logging.getLogger(__name__)
# Increase log threshold for animation module
logging.getLogger("matplotlib.animation").setLevel(logging.WARNING)
# -----------------------------------------------------------------------------
[docs]def _prepare_hexgrid_data(
data: Union[np.ndarray, xr.DataArray], *, x: str = None, y: str = None
) -> Tuple[xr.DataArray, str, str]:
"""Prepares data for :py:func:`imshow_hexagonal` by checking the given data
and specified dimension names are consistent.
"""
if not isinstance(data, xr.DataArray):
data = xr.DataArray(data)
if data.ndim != 2:
raise ValueError(
"Need 2-dimensional data for hexagonal grid plot, but got "
f"{data.ndim}-dimensional data!\n{str(data)}"
)
# May not have dimension arguments given, e.g. because this was a data
# array. Assume that it's fine to just use the available two.
# For simplicity, we don't allow to give only one of them.
if not x and not y:
x, y = data.dims
elif x == y:
raise ValueError(
"Dimension names `x` and `y` need to be different, "
f"but are both '{x}'!"
)
elif (x is None) != (y is None):
raise ValueError(
"Need either both `x` and `y` dimension names or neither!"
)
# Make sure data dimensions are ordered correctly
data = data.transpose(x, y)
return data, x, y
[docs]def _flatten_hexgrid_data(data: xr.DataArray) -> np.ndarray:
"""Flattens hexgrid data in a specific way.
For consistency, this function should be used when calling
:py:meth:`~matplotlib.collections.PolyCollection.set_array` on collections
that represent hexagonal grid data.
"""
return data.data.T.flatten()
[docs]def _plot_ca_property(
prop_name: str,
*,
hlpr: PlotHelper,
data: xr.DataArray,
default_imshow_kwargs: dict,
imshow_hexagonal_extra_kwargs: dict = None,
default_cbar_kwargs: dict = None,
grid_structure: str = None,
limits: Tuple[float, float] = None,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
cmap: Union[str, dict] = None,
norm: Union[str, dict] = None,
add_colorbar: bool = True,
set_axis_off: bool = True,
title: str = None,
imshow_kwargs: dict = None,
cbar_labels: dict = None,
cbar_label_kwargs: dict = None,
cbar_tick_params: dict = None,
no_cbar_markings: bool = False,
**cbar_kwargs,
) -> mpl.image.AxesImage:
"""Helper function, used in :py:func:`caplot` and :py:func:`state` to plot
a property on the given axis. Returns the created axes image object.
.. note::
The arguments here are those within the individual entries of the
``to_plot`` argument for the above plotting functions.
Args:
prop_name (str): The property to plot
hlpr (PlotHelper): The plot helper
data (xarray.DataArray): The array-like data to plot as an image
default_imshow_kwargs (dict): Default arguments for the imshow call,
updated with individually-specified ``imshow_kwargs``.
imshow_hexagonal_extra_kwargs (dict): Default arguments for hexagonal
grids, ignored otherwise.
This updates the ``default_imshow_kwargs`` and is in turn updated
with individually-specified ``imshow_kwargs``.
default_cbar_kwargs (dict): Default arguments for the colorbar
creation, updated with ``cbar_kwargs``.
grid_structure (str, optional): Can be used to explicitly set the grid
structure in cases where ``data.attrs['grid_structure']`` is not
available or holds an invalid entry.
This decides whether to use :py:meth:`matplotlib.axes.Axes.imshow`
or :py:func:`imshow_hexagonal`.
Note that the ``grid_properties`` need to be passed via the
``imshow_kwargs`` argument below.
limits (Tuple[float, float], optional): The data limits to use in the
form ``(vmin, vmax)``. Individual entries can also be None.
*Deprecated!* Use ``vmin`` and ``vmax`` instead.
vmin (float, optional): The lower limit to use for the colorbar range.
vmax (float, optional): The upper limit to use for the colorbar range.
cmap (Union[str, dict], optional): The colormap to use. If a dict is
given, defines a (discrete) ``ListedColormap`` from the values.
Handled by :py:class:`~dantro.plot.utils.color_mngr.ColorManager`.
norm (Union[str, dict], optional): The normalization function to use.
Handled by :py:class:`~dantro.plot.utils.color_mngr.ColorManager`.
add_colorbar (bool, optional): If false, will not draw a colorbar.
Default is true.
set_axis_off (bool, optional): If true (default), will set the axis to
invisible.
title (str, optional): The subplot figure title
imshow_kwargs (dict, optional): Depending on grid structure, is passed
on either to :py:meth:`~matplotlib.axes.Axes.imshow` or to
:py:func:`.imshow_hexagonal`.
cbar_labels (dict, optional): Passed to
:py:class:`~dantro.plot.utils.color_mngr.ColorManager` to set up
the label names alongside the given ``cmap`` and ``norm``.
cbar_label_kwargs (dict, optional): Passed to
:py:meth:`~dantro.plot.utils.color_mngr.ColorManager.create_cbar`
for controlling the aesthetics of colorbar labels.
cbar_tick_params (dict, optional): Passed to
:py:meth:`~dantro.plot.utils.color_mngr.ColorManager.create_cbar`
for controlling the aesthetics of colorbar ticks.
no_cbar_markings (bool, optional): Whether to suppress colorbar
markings (ticks and tick labels).
**cbar_kwargs: Passed to
:py:meth:`~dantro.plot.utils.color_mngr.ColorManager.create_cbar`
Returns:
matplotlib.image.AxesImage:
The created axes image representing the CA property.
Raises:
ValueError: on invalid grid structure; supported structures are
``square`` and ``hexagonal``
"""
# Handle deprecations
if "draw_cbar" in cbar_kwargs:
cbar_kwargs.pop("draw_cbar")
_msg = (
"The `draw_cbar` argument is deprecated and will be removed. "
"Use `add_colorbar` instead."
)
warnings.warn(_msg, DeprecationWarning)
log.caution(_msg)
if limits is not None:
if vmin is None and vmax is None:
_msg = (
"The `limits` argument is deprecated and will be removed. "
"Use `vmin` and `vmax` instead."
)
warnings.warn(_msg, DeprecationWarning)
log.caution(_msg)
vmin, vmax = limits
else:
raise ValueError(
"Got the deprecated `limits` argument but also `vmin` and/or "
"`vmax`! Remove the `limits` argument and use only `vmin` and "
"`vmax` instead."
)
# Set up the ColorManager
cm = ColorManager(
cmap=cmap, norm=norm, vmin=vmin, vmax=vmax, labels=cbar_labels
)
# Determine grid structure
grid_structure = (
grid_structure
if grid_structure
else data.attrs.get("grid_structure", "square")
)
# Prepare imshow_kwargs, successively updating defaults.
# Also need to be able to pass custom arguments to imshow_hexagonal, which
# has a wider interface than regular imshow ...
_imshow_kwargs = ensure_dict(default_imshow_kwargs)
if grid_structure == "hexagonal":
_imshow_kwargs = recursive_update(
_imshow_kwargs, ensure_dict(imshow_hexagonal_extra_kwargs)
)
_imshow_kwargs = recursive_update(
_imshow_kwargs, ensure_dict(imshow_kwargs)
)
# Create imshow(-like) object on the currently selected axis
if grid_structure == "square" or grid_structure is None:
im = hlpr.ax.imshow(
data.T,
cmap=cm.cmap,
norm=cm.norm,
animated=True,
rasterized=True,
origin="lower",
aspect="equal",
**_imshow_kwargs,
)
elif grid_structure == "hexagonal":
im = imshow_hexagonal(
data=data,
ax=hlpr.ax,
cmap=cm.cmap,
norm=cm.norm,
animated=True,
rasterized=True,
**_imshow_kwargs,
)
else:
raise ValueError(
f"Unsupported grid structure '{grid_structure}'!\n"
"Choose from: square, hexagonal"
)
# Remove main axis labels and ticks and provide some default options
if set_axis_off:
hlpr.ax.axis("off")
hlpr.provide_defaults("set_title", title=(title if title else prop_name))
# .. Colorbar .............................................................
if not add_colorbar:
return im
# else: draw the colorbar
# Determine which artist to use; for hexagonal grids, need to attach the
# PolyCollection, because it holds the data array.
artist = im
if grid_structure == "hexagonal":
artist = im.hexagons
# Parse colorbar kwargs, setting some default values
default_cbar_kwargs = ensure_dict(default_cbar_kwargs)
cbar_kwargs = recursive_update(
copy.deepcopy(default_cbar_kwargs), cbar_kwargs
)
cbar_kwargs["fraction"] = cbar_kwargs.get("fraction", 0.05)
cbar_kwargs["pad"] = cbar_kwargs.get("pad", 0.02)
# Draw the colorbar, then store it in the AxesImage to have it accesible
cbar = cm.create_cbar(
artist,
fig=hlpr.fig,
ax=hlpr.ax,
label_kwargs=cbar_label_kwargs,
tick_params=cbar_tick_params,
**cbar_kwargs,
)
im.cbar = cbar
# May want to remove markings
if no_cbar_markings:
cbar.set_ticks([])
cbar.ax.set_xticklabels([])
cbar.ax.set_yticklabels([])
return im
# -----------------------------------------------------------------------------
# -- Plot functions -----------------------------------------------------------
# -----------------------------------------------------------------------------
[docs]def imshow_hexagonal(
data: Union[xr.DataArray, np.ndarray],
*,
ax: "matplotlib.axes.Axes" = None,
x: str = None,
y: str = None,
grid_properties: dict = {},
update_grid_properties: dict = {},
grid_properties_keys: dict = {},
extent: tuple = None,
scale: float = 1.01,
draw_centers: bool = False,
draw_center_radius: float = 0.1,
draw_center_kwargs: dict = {},
hide_ticks: bool = None,
cmap: str = None,
norm: str = None,
vmin: float = None,
vmax: float = None,
collection_kwargs: dict = {},
**im_kwargs,
) -> mpl.image.AxesImage:
"""Visualizes data using a grid of hexagons (⬢ or ⬣).
Owing to the many ways in which a hexagonal grid can be visualized,
this function requires more information than
:py:meth:`~matplotlib.axes.Axes.imshow`. These so called *grid properties*
need to be passed via the ``grid_properties`` argument or directly
alongside the data via ``data.attrs`` (for :py:class:`~xarray.DataArray`).
The following grid properties are available:
coordinate_mode (str):
In which way the data is stored. Currently only supports
``offset`` mode, i.e. with offset row and column coordinates.
Coordinates of individual cells need not be given, nor can they
be given: Assuming a regular hexagonal grid, all coordinates and
sizes are completely deduced from the shape of the given data and
the other grid parameters like ``pointy_top`` and ``offset_mode``.
pointy_top (bool):
Whether the hexagons have a pointy top (⬢) or a flat top (⬣).
More precisely, with a pointy top, there is only a single vertex
at the top and bottom of the hexagon (i.e. along the ``y``
dimension and the top/bottom of the resulting plot).
offset_mode (str):
Whether ``even`` or ``odd`` rows or columns are offset towards
higher values. In other words:
For pointy tops, offset every second *row* toward the right.
For flat tops, offset every second *column* towards the top.
Offset distance is half a cell's width and half a cell's height,
respectively.
space_size (Tuple[float, float], optional):
The size of the space in ``(width, height)`` that the hexagonal
grid cells cover.
If given, will make the assumption that the available number of
hexagons in each dimension reach from one end of the space to the
other, even if that means that the hexagons become distorted along
the dimensions.
If *not* given, will assume regular hexagons and use arbitrary
units to cover the space; these hexagons will *not* be distorted.
space_offset (Tuple[float, float], optional):
Translates the space by ``(offset_x, offset_y)``.
Also applies to the case where ``space_size`` was not given.
Effectively, this refers to the coordinates of the bottom left-hand
corner of the space.
space_boundary (str, optional):
Whether the space (regardless of explicitly given or deduced)
describes the ``outer`` or ``inner`` boundary of the hexagonal
grid. The ``outer`` boundary (default) goes through the outermost
vertices of the outermost cells.
The ``inner`` boundary goes through some hexagon center, cutting
off pointy tops and protruding parts of the hexagons such that the
whole space is covered by hexagonal cells.
For example, grid properties may look like this:
.. code-block:: yaml
grid_properties:
# -- Required:
coordinate_mode: offset
pointy_top: true
offset_mode: even
# -- Optional:
space_size: [8, 8]
space_offset: [-4, -4]
space_boundary: outer
With some 2D dummy grid data of shape ``(21, 24)``, the corresponding
output would be as follows.
The darker cells denote the boundary and the corners; the lighter cells
correspond to a "vertical" line in the third column of the grid.
.. image:: ../_static/_gen/imshow_hex/small_with_space_outer.pdf
:target: ../_static/_gen/imshow_hex/small_with_space_outer.pdf
:width: 100%
By setting the space boundary parameter to ``inner``, the domain size
remains the same, but the boundary cells are partially cut off:
.. image:: ../_static/_gen/imshow_hex/small_with_space_inner.pdf
:target: ../_static/_gen/imshow_hex/small_with_space_inner.pdf
:width: 100%
Removing all optional parameters, specifically the ``space_size``, the
size of the domain is arbitrary, so no labels are drawn.
In addition, we can also mark the hexagon centers:
.. image:: ../_static/_gen/imshow_hex/small_centers_marked.pdf
:target: ../_static/_gen/imshow_hex/small_centers_marked.pdf
:width: 100%
Changing to flat tops and ``odd`` offset mode results in a figure with a
different aspect ratio, while the hexagons remain regular.
.. image:: ../_static/_gen/imshow_hex/small_flat_top_odd.pdf
:target: ../_static/_gen/imshow_hex/small_flat_top_odd.pdf
:width: 100%
When specifying the domain size again, the hexagons need to be scaled
non-uniformly to cover the domain:
.. image:: ../_static/_gen/imshow_hex/small_flat_top_odd_with_space.pdf
:target: ../_static/_gen/imshow_hex/small_flat_top_odd_with_space.pdf
:width: 100%
.. hint::
For an excellent introduction to hexagonal grid representations, see
`this article <https://www.redblobgames.com/grids/hexagons/>`_.
.. admonition:: See also
* :py:func:`caplot` integrates this function.
* :ref:`plot_funcs_ca_hex` documents usage and shows more examples.
Args:
data (Union[xarray.DataArray, numpy.ndarray]): 2D array-like data that
holds the grid information that is to be plotted.
If the data is given as :py:class:`~xarray.DataArray`, its
``attrs`` are used to *update* the given ``grid_properties``.
ax (matplotlib.axes.Axes, optional): The axes to draw to; if not given,
will use the current axes.
x (str, optional): Name of the data dimension that is to be represented
on the x-axis of the plot. If not given, will use the first data
dimension.
y (str, optional): Name of the data dimension that is to be represented
on the x-axis of the plot. If not given, will use the second data
dimension.
grid_properties (dict, optional): The grid properties dict, which needs
to specify the above properties in order to determine how to
represent the data. This dict is first updated with potentially
available ``data.attrs`` and subsequently updated with the
``update_grid_properties``.
update_grid_properties (dict, optional): Updates the grid properties,
see above.
grid_properties_keys (dict, optional): A mapping that can be used if
the given data has grid properties given under different names.
For instance, ``{"space_size": "size"}`` would read the ``size``
entry instead of ``space_size``.
extent (tuple, optional): A custom space *extent*, denoting the edges
``(left, right, bottom, top)`` of the domain *in data units*.
scale (float, optional): A scaling factor for the size of the hexagons.
The default value is very slightly larger than 1 to reduce aliasing
artefacts on exactly overlapping hexagon edges. Scaling is uniform.
draw_centers (bool, optional): Whether to additionally draw the center
points of all hexagons.
draw_center_radius (float, optional): The relative radius of the
center points in units of ``min(cell_width, cell_height) / 2``.
draw_center_kwargs (dict, optional): Additional arguments that are
passed to the :py:class:`~matplotlib.collections.PatchCollection`
used to draw the center points.
hide_ticks (bool, optional): Whether to hide the ticks and tick labels.
If None, will hide ticks if no ``space_size`` grid property was
given (in which case the units are assumed irrelevant).
cmap (str, optional): The colormap to use
norm (str, optional): The normalization to use
vmin (float, optional): The minimum value of the color range to use
vmax (float, optional): The maximum value of the color range to use
collection_kwargs (dict, optional): Passed on to the
:py:class:`~matplotlib.collections.PolyCollection` that is used to
represent the hexagons.
**im_kwargs: Passed on to :py:class:`matplotlib.image.AxesImage` that
is created from the whole axis.
Can be used to set ``interpolation`` or similar options.
Returns:
matplotlib.image.AxesImage:
The imshow-like object containing the hexagonal grid.
"""
# .. constants ............................................................
sqrt_3 = np.sqrt(3) # 2 * sin(60°)
# Regular unit hexagon vertices in clockwise direction.
# Define one with a pointy and one with a flat top (easier than rotating).
unit_hexagon_pointy_top = np.array(
[
[0, 1],
[sqrt_3 / 2, +1 / 2],
[sqrt_3 / 2, -1 / 2],
[0, -1],
[-sqrt_3 / 2, -1 / 2],
[-sqrt_3 / 2, +1 / 2],
]
)
unit_hexagon_flat_top = np.array(
[
[1, 0],
[+1 / 2, sqrt_3 / 2],
[-1 / 2, sqrt_3 / 2],
[-1, 0],
[-1 / 2, -sqrt_3 / 2],
[+1 / 2, -sqrt_3 / 2],
]
)
# .. Prepare data .........................................................
# Bring data into a uniform shape: 2D xr.DataArray
data, x, y = _prepare_hexgrid_data(data, x=x, y=y)
# Aggregate grid properties
grid_properties = ensure_dict(copy.deepcopy(grid_properties))
grid_properties.update(data.attrs)
grid_properties.update(ensure_dict(copy.deepcopy(update_grid_properties)))
if not grid_properties:
raise ValueError(
"Could not determine grid properties! "
"Either pass them explicitly via the `grid_properties` or "
"`update_grid_properties` arguments or, if `data` is given as "
"an xr.DataArray, add them to `data.attrs`."
)
# .. Get hexgrid information ..............................................
GRID_PROP_KEYS = (
"coordinate_mode",
"pointy_top",
"offset_mode",
"space_size",
"space_offset",
"space_boundary",
)
_keys = {k: grid_properties_keys.get(k, k) for k in GRID_PROP_KEYS}
# Extract attribute values and give a useful error message if that fails
try:
coordinate_mode = grid_properties[_keys["coordinate_mode"]]
pointy_top = grid_properties[_keys["pointy_top"]]
offset_mode = grid_properties[_keys["offset_mode"]]
space_size = grid_properties.get(_keys["space_size"])
space_given = space_size is not None
space_offset = grid_properties.get(_keys["space_offset"], (0.0, 0.0))
boundary = grid_properties.get(_keys["space_boundary"], "outer")
except KeyError as err:
_gp = "\n".join(f" {k:18}: {v}" for k, v in grid_properties.items())
_km = "\n".join(f" {k:16} -> {v}" for k, v in _keys.items())
raise ValueError(
f"Missing grid property {err} for imshow_hexagonal! Make sure the "
"required metadata is available and the key mapping is correct.\n"
f"Data attributes:\n{data.attrs}"
f"\n\nAggregated grid properties (after updates):\n{_gp}"
f"\n\nKey mapping (old -> new):\n{_km}"
) from err
# May have an explicitly given extent, in which case the space size and
# offset given by the grid properties needs to be overwritten.
if extent:
_l, _r, _b, _t = extent
space_size = (abs(_r - _l), abs(_t - _b))
space_offset = (min(_l, _r), min(_b, _t))
# Check validity
COORDINATE_MODES = ("offset",)
if coordinate_mode not in COORDINATE_MODES:
raise ValueError(
f"Invalid coordinate mode '{coordinate_mode}'! "
"Hexagonal grid property `coordinate_mode` needs to be one of: "
f"{', '.join(COORDINATE_MODES)}"
)
if offset_mode not in ("even", "odd"):
raise ValueError(
f"Invalid offset mode '{offset_mode}'! Hexagonal grid property "
"`offset_mode` needs to be 'even' or 'odd'."
)
if boundary not in ("outer", "inner"):
raise ValueError(
f"Invalid space boundary '{boundary}'! Hexagonal grid property "
"`space_boundary` needs to be 'outer' or 'inner'."
)
# .. Calculations .........................................................
# Determine number of cells in x and y direction (of the final plot)
n_x = data.sizes[x]
n_y = data.sizes[y]
ids_x = np.arange(n_x)
ids_y = np.arange(n_y)
# Issue a warning if there are too few cells.
if n_x < 2 or n_y < 2:
warnings.warn(
"Plotting a hexagonal grid with fewer than two cells in any "
"dimension may lead to unexpected results!",
UserWarning,
)
# Decide on which unit hexagon to use
if pointy_top:
unit_hexagon = unit_hexagon_pointy_top
else:
unit_hexagon = unit_hexagon_flat_top
# Depending on whether the space is known, need to go different ways:
# - If known, deduce s and scale hexagon accordingly.
# - If not known, use s = 1 and compute the size of the space instead.
if space_given:
# With a space given, the size of the hexagon may be different in the
# x and y direction in order to cover the whole space with the
# available number of cells.
#
# However, a scaled hexagon no longer has a characteristic size `s`,
# but needs two parameters to describe the scaling: `lx` and `ly`.
# These are slightly different than `s`, but depending on which pointy
# end aligns with which dimension, one of them is equal to what would
# be `s` in a regular hexagon.
# To reduce pitfalls of using `s` further down, we deliberately do NOT
# define it here.
# First, compute the boundaries (lower, upper) along each dimension
space_x = (space_offset[0], space_offset[0] + space_size[0])
space_y = (space_offset[1], space_offset[1] + space_size[1])
space_width, space_height = space_size
# Now deduce the scaling factors, depending on pointyness and position
# of the boundary:
if pointy_top:
if boundary == "outer":
offs_corr = 1 if n_x > 1 else 0 # offset correction
lx = 2 * space_width / (sqrt_3 * (2 * n_x + offs_corr))
ly = 2 * space_height / (3 * n_y + 1)
else:
lx = 2 * space_width / (sqrt_3 * (2 * n_x - 1))
ly = 2 * space_height / (3 * n_y - 1)
cell_width = sqrt_3 * lx # sic! lx is a scaling factor
cell_height = 2 * ly
else:
# flat top
if boundary == "outer":
offs_corr = 1 if n_y > 1 else 0
lx = 2 * space_width / (3 * n_x + 1)
ly = 2 * space_height / (sqrt_3 * (2 * n_y + offs_corr))
else:
lx = 2 * space_width / (3 * n_x - 1)
ly = 2 * space_height / (sqrt_3 * (2 * n_y - 1))
cell_width = 2 * lx
cell_height = sqrt_3 * ly # sic! ly is a scaling factor
# Scale and transform the hexagon accordingly.
# This may lead to elongation along x or y dimensions.
# Note the use of lx and ly here, not cell_width and cell_height.
# This is because lx/ly == 1 if the hexagon is regular (the available
# space's aspect ratio is 1:sqrt(3)/2 ).
# In contrast, cell_width / cell_height will be sqrt(3)/2 == 1.15470…
# for a unit hexagon, thus not suitable as scaling factors.
hexagon = scale * np.array([lx, ly]) * unit_hexagon
else:
# Space was not given, can choose s and deduce space as we desire.
#
# Also, the hexagon will be regular: a _uniformly_ scaled unit hexagon.
# Thus, we do not need to worry about scaling factors like above, but
# the side length `s` alone suffices.
s = 1
hexagon = scale * s * unit_hexagon
if pointy_top:
cell_width = s * sqrt_3
cell_height = 2 * s
if boundary == "outer":
offs_corr = 0.5 if n_x > 1 else 0
space_x = (0.0, (n_x + offs_corr) * cell_width)
space_y = (0.0, 3 / 2 * n_y * s + s / 2)
else:
space_x = (cell_width / 2, n_x * cell_width)
space_y = (s / 2, 3 / 2 * n_y * s)
else:
# flat top
cell_width = 2 * s
cell_height = s * sqrt_3
if boundary == "outer":
offs_corr = 0.5 if n_y > 1 else 0
space_x = (0.0, 3 / 2 * n_x * s + s / 2)
space_y = (0.0, (n_y + offs_corr) * cell_height)
else:
space_x = (s / 2, 3 / 2 * n_x * s)
space_y = (cell_height / 2, n_y * cell_height)
space_size = (space_x[1], space_y[1])
space_width, space_height = space_size
space_x = (space_offset[0] + space_x[0], space_offset[0] + space_x[1])
space_y = (space_offset[1] + space_y[0], space_offset[1] + space_y[1])
# .. Compute cell positions . . . . . . . . . . . . . . . . . . . . . . . .
# Temporary position values -- without row/col offsets!
if pointy_top:
if boundary == "outer":
_pos_x = np.linspace(
space_x[0] + cell_width / 2,
space_x[1] - cell_width, # due to row offset
n_x,
)
_pos_y = np.linspace(
space_y[0] + cell_height / 2,
space_y[1] - cell_height / 2,
n_y,
)
else:
_pos_x = np.linspace(
space_x[0],
space_x[1] - cell_width / 2, # due to row offset
n_x,
)
_pos_y = np.linspace(
space_y[0] + cell_height / 4,
space_y[1] - cell_height / 4,
n_y,
)
else:
# flat top
if boundary == "outer":
_pos_x = np.linspace(
space_x[0] + cell_width / 2,
space_x[1] - cell_width / 2,
n_x,
)
_pos_y = np.linspace(
space_y[0] + cell_height / 2,
space_y[1] - cell_height, # due to col offset
n_y,
)
else:
_pos_x = np.linspace(
space_x[0] + cell_width / 4,
space_x[1] - cell_width / 4,
n_x,
)
_pos_y = np.linspace(
space_y[0],
space_y[1] - cell_height / 2, # due to col offset
n_y,
)
# Bring into the form that's required for imshow
x_offsets, y_offsets = np.meshgrid(_pos_x, _pos_y)
# Add the offset towards higher values
if pointy_top:
if n_x > 1:
offset = cell_width / 2
if offset_mode == "even":
x_offsets[ids_y % 2 == 0, ...] += offset
else:
x_offsets[ids_y % 2 == 1, ...] += offset
else:
if n_y > 1:
offset = cell_height / 2
if offset_mode == "even":
y_offsets[..., ids_x % 2 == 0] += offset
else:
y_offsets[..., ids_x % 2 == 1] += offset
# .. Create the PolyCollection ............................................
# At this point, need the following information to generate the collection
# of hexagons:
# - the appropriately transformed hexagon
# - x and y offsets (2D arrays, flattened and combined)
collection_kwargs = ensure_dict(collection_kwargs)
# Here we go ...
pcoll = mpl.collections.PolyCollection(
[hexagon],
offsets=np.transpose([x_offsets.flatten(), y_offsets.flatten()]),
transOffset=mpl.transforms.AffineDeltaTransform(ax.transData),
#
# Pass collection-related kwargs
linewidths=collection_kwargs.get("linewidths", 0),
**collection_kwargs,
)
# NOTE There also is a RegularPolyCollection, but that is a massive pain
# because it expects the sizes to be given in units of the canvas
# (area in *points squared*), which depends on the representation and
# not on the data.
# The PolyCollection does not have that issue because the polygons
# need to be drawn "by hand". This way, all information can be
# supplied in units of data space when using the data transformation
# of the offsets.
# Set the data (in a consistently flattened form)
pcoll.set_array(_flatten_hexgrid_data(data))
# Set cmap stuff, norm, limits
pcoll.set_cmap(cmap)
pcoll.set_norm(norm)
pcoll.set_clim(vmin, vmax)
# .. Add to axis ..........................................................
if ax is None:
import matplotlib.pyplot as plt
ax = plt.gca()
ax.add_collection(pcoll)
# Use same length scale in x and y and set space limits
ax.set_aspect("equal")
ax.set_xlim(*space_x)
ax.set_ylim(*space_y)
# Allow marking the center points of the hexagons
#
# Again need to use a manually created collection such that positions and
# sizes can be given in data units.
# The radius factor is a heuristic value for an "effective" `s` parameter
# for the shorter side of the hexagon.
if draw_centers:
circle = mpl.patches.Circle(
(0, 0),
radius=min(cell_width, cell_height) / 2 * draw_center_radius,
)
draw_center_kwargs = ensure_dict(draw_center_kwargs)
ccoll = mpl.collections.PatchCollection(
[circle],
offsets=np.transpose([x_offsets.flatten(), y_offsets.flatten()]),
transOffset=mpl.transforms.AffineDeltaTransform(ax.transData),
#
linewidths=draw_center_kwargs.pop("linewidths", 0),
**draw_center_kwargs,
)
ax.add_collection(ccoll)
# If space was not known, don't show axis labels
if hide_ticks or (not space_given and hide_ticks is None):
ax.tick_params(
axis="both",
left=False,
top=False,
right=False,
bottom=False,
labelleft=False,
labeltop=False,
labelright=False,
labelbottom=False,
)
# Create axes image to have the same result object as imshow does,
# including interpolation features etc
im = mpl.image.AxesImage(ax, **im_kwargs)
im.hexagons = pcoll
# Do some post-processing
im.set_cmap(cmap)
im.set_clim(vmin, vmax)
# Some callbacks
# def on_changed(collection):
# hbar.set_cmap(collection.get_cmap())
# hbar.set_clim(collection.get_clim())
# vbar.set_cmap(collection.get_cmap())
# vbar.set_clim(collection.get_clim())
# pcoll.callbacks.connect("changed", on_changed)
return im
@make_facet_grid_plot(
map_as="dataarray",
register_as_kind="imshow_hexagonal",
encodings=("x", "y"),
supported_hue_styles=(),
parse_cmap_and_norm_kwargs=True,
)
def imshow_hexagonal_facet_grid(
data: xr.DataArray,
*,
hlpr: PlotHelper,
_is_facetgrid: bool,
x: str = None,
y: str = None,
extend: str = None,
levels: int = None,
add_labels: bool = True,
add_colorbar: bool = True,
cbar_kwargs: dict = None,
**kwargs,
) -> mpl.image.AxesImage:
"""Wrapper around :py:func:`imshow_hexagonal` that makes it work as a
standalone, DAG-supporting and *faceting* plotting function.
Uses :py:class:`~dantro.plot.funcs.generic.make_facet_grid_plot` wrapper.
For more arguments, see the respective docstrings.
Args:
data (xr.DataArray): The to-be-plotted data as prepared by the wrapper.
hlpr (PlotHelper): The plot helper
_is_facetgrid (bool): *Internally used variable* that denotes whether
the invocation is part of a facet grid plot.
x (str, optional): Which data dimension to represent on the x-axis
y (str, optional): Which data dimension to represent on the y-axis
extend (str, optional): Whether to extend the colorbar
levels (int, optional): Number of discrete colormap levels to use;
*currently not supported!*
add_labels (bool, optional): Whether to add labels to the x and y axis.
add_colorbar (bool, optional): Whether to add a colorbar.
cbar_kwargs (dict, optional): Colorbar kwargs that are *only used if
no facet grid* is created.
**Note:** This interface is subject to change, aim being that the
arguments can be supplied in the same way for faceting and
non-faceting invocations of this function.
**kwargs: Passed on to :py:func:`imshow_hexagonal`
"""
if levels:
raise NotImplementedError("`levels` argument not yet supported")
im = imshow_hexagonal(
data,
ax=hlpr.ax,
x=x,
y=y,
**kwargs,
)
if not _is_facetgrid:
if add_labels:
hlpr.ax.set_xlabel(x)
hlpr.ax.set_ylabel(y)
if add_colorbar:
# TODO This should read information from the FacetGrid's
# cbar_kwargs, which are also parsed there... problem being
# that there is no FacetGrid object available here.
# However, the arguments should be parsed in the same way!
cbar = hlpr.fig.colorbar(
im,
ax=hlpr.ax,
extend=extend,
**ensure_dict(cbar_kwargs),
)
return im
# .............................................................................
[docs]@is_plot_func(use_dag=True, supports_animation=True)
def caplot(
*,
hlpr: PlotHelper,
data: dict,
to_plot: Dict[str, dict],
from_dataset: xr.Dataset = None,
frames: str = "time",
frames_isel: Union[int, Sequence] = None,
grid_structure: str = None,
aspect: float = 1.0,
aspect_pad: float = 0.1,
size: float = None,
col_wrap: Union[int, str, bool] = "auto",
imshow_hexagonal_extra_kwargs: dict = None,
default_imshow_kwargs: dict = None,
default_cbar_kwargs: dict = dict(fraction=0.04, aspect=20),
suptitle_fstr: str = "{} = {}",
suptitle_kwargs: dict = None,
):
"""Plots an animated series of one or many 2D Cellular Automata states.
The data used for plotting is assembled from ``data`` using the keys that
are specified in ``to_plot``. Alternatively, the ``from_dataset`` argument
can be used to pass a dataset which contains all the required data.
The keys in ``to_plot`` should match the names of data variables.
The values in ``to_plot`` specify the individual subplots' properties like
the color map that is to be used or the minimum or maximum values.
For plotting square grids, :py:meth:`matplotlib.axes.Axes.imshow` is used
and generates output like this:
.. image:: ../_static/_gen/caplot/snapshot_square.pdf
:target: ../_static/_gen/caplot/snapshot_square.pdf
:width: 100%
For a grid with hexagonal cells, :py:func:`imshow_hexagonal` is used; more
details on how the cells are mapped to the space can be found there.
The output (for the same dummy data as used above) may look like this:
.. image:: ../_static/_gen/caplot/snapshot_hex.pdf
:target: ../_static/_gen/caplot/snapshot_hex.pdf
:width: 100%
Finally, this plot function is specialized to generate animations along the
``frames`` dimension of the data, e.g. ``time``:
.. raw:: html
<video width="720" src="../_static/_gen/caplot/anim_square.mp4" controls></video>
.. raw:: html
<video width="720" src="../_static/_gen/caplot/anim_hex.mp4" controls></video>
.. admonition:: Requirements on the CA data
The selected data (keys in ``to_plot`` that correspond to DAG results
in ``data``) should have two *spatial* dimensions and one data
dimension that goes along the ``frames`` dimension.
All coordinates should be identical, otherwise the behavior is not
defined or alignment might fail.
**For hexagonal grid structure**, note the requirements given in
:py:func:`imshow_hexagonal`.
.. admonition:: See also
* :ref:`plot_funcs_ca`
* :ref:`plot_funcs_ca_hex`
Args:
hlpr (PlotHelper): The plot helper instance
data (dict): The selected data
to_plot (Dict[str, dict]): Which data to plot and how. The keys of
this dict refer to an item within the selected ``data`` or the
given dataset.
Each of these keys is expected to hold yet another dict,
supporting the following configuration options (all optional):
- ``title`` (str, optional):
The title for this sub-plot.
- ``cmap`` (Union[str, list, dict], optional):
Which colormap to use. This argument is handled by the
:py:class:`~dantro.plot.utils.color_mngr.ColorManager`,
providing many ways in which to define the colormap.
For instance, by passing mapping from labels to colors, a
discrete colormap is created: The keys will be the labels and
the values will be their colors. Association happens in the
order of entries, with values being inferred from ``limits``,
if given. For more information and examples, see the docstring
of the :py:class:`~dantro.plot.utils.color_mngr.ColorManager`.
- ``norm`` (Union[str, dict], optional):
The normalization function to use, also handled by the
:py:class:`~dantro.plot.utils.color_mngr.ColorManager`.
- ``vmin`` (float, optional):
The fixed lower data limit for this property; if not given,
uses auto-scaling, which may lead to jumps in the animation.
Can also be ``'min'`` in which case the *global* minimum of the
available data is used.
- ``vmax`` (float, optional):
Same as ``vmin``, but with the maximum and allowing ``'max'``
argument for choosing the global maximum.
- ``limits`` (Union[tuple, list], optional):
*Deprecated!* Use ``vmin`` and ``vmax`` instead.
- ``label`` (str, optional):
The *colorbar* label.
- ``imshow_kwargs`` (dict, optional):
Passed on to the imshow invocation, i.e. to
:py:meth:`~matplotlib.axes.Axes.imshow` or
:py:meth:`imshow_hexagonal`.
- ``**kwargs``:
Further arguments control the colorbar appearance, their
labels, ticks, and other plot specifics. For more detailed
information about available arguments, see
:py:func:`._plot_ca_property`, which takes care of plotting
the individual ``to_plot`` entries.
from_dataset (xarray.Dataset, optional): If given, will use this object
instead of assembling a dataset from ``data`` and ``to_plot`` keys.
frames (str, optional): Name of the animated dimension, typically the
time dimension.
frames_isel (Union[int, Sequence], optional): The index selector for
the frames dimension. Can be a single integer but also a range
expression.
grid_structure (str, optional): The underlying grid structure, can be
``square``, ``hexagonal``, or ``triangular`` (not implemented).
If None, will try to read it from the individual properties' data
attribute ``grid_structure``.
aspect (float, optional): The aspect ratio (width/height) of the
subplots; should match the aspect ratio of the data.
aspect_pad (float, optional): A factor that is added to the calcuation
of the subplots width. This can be used to create a horizontal
padding between subplots.
size (float, optional): The height in inches of a subplot.
Is used to determine the subplot size with the width being
calculated by ``size * (aspect + aspect_pad)``.
col_wrap (Union[int, str, bool], optional): Controls column wrapping.
If ``auto``, will compute a column wrapping that leads to an
(approximately) square subplots layout (not taking into account
the subplots aspect ratio, only the grid layout). This will start
producing a column wrapping with four or more properties to plot.
default_imshow_kwargs (dict, optional): The default parameters passed
to the underlying imshow plotting function. These are updated by
the values given via ``to_plot``.
default_cbar_kwargs (dict, optional): The default parameters for the
colorbar that is added to applicable subplots. These are updated
by the parameters given under ``to_plot``.
suptitle_fstr (str, optional): A format string used to create the
suptitle string. Passed arguments are ``frames`` and the currently
selected frames *coordinate* (not the index!).
If this evaluates to False, will not create a figure suptitle.
suptitle_kwargs (dict, optional): Passed on to ``fig.suptitle``.
"""
# Helper functions ........................................................
def get_grid_structure(d: xr.DataArray) -> str:
"""Retrieves the grid structure from data attributes"""
grid_structure = d.attrs.get("grid_structure")
if isinstance(grid_structure, np.ndarray):
grid_structure = grid_structure.item()
return grid_structure
def prepare_data(data: dict, *, prop_name: str) -> xr.DataArray:
"""Prepares data for (later) creating an :py:class:`xarray.Dataset`"""
d = data[prop_name]
if isinstance(d, AbstractDataContainer):
d = d.data
return d
def select_data(ds: xr.Dataset, name: str, isel: dict) -> xr.DataArray:
"""Selects a slice of data for plotting using
:py:meth:`~xarray.DataArray.isel` on the data variable ``name``."""
return ds[name].isel(isel)
def set_suptitle(data: xr.DataArray):
"""Sets the suptitle"""
if not suptitle_fstr:
return
hlpr.fig.suptitle(
suptitle_fstr.format(frames, data.coords[frames].item()),
**(suptitle_kwargs if suptitle_kwargs else {}),
)
# Prepare the data ........................................................
# Work on a copy of the plot spec
to_plot = copy.deepcopy(to_plot)
# Bring data into xr.Dataset form
log.note("Preparing data for CA plot ...")
if from_dataset:
log.remark("Using explicitly passed dataset ...")
ds = from_dataset
else:
log.remark("Constructing dataset ...")
ds = xr.Dataset({p: prepare_data(data, prop_name=p) for p in to_plot})
# Check that frames dimension is available
if not frames or frames not in ds.coords:
_avail = ", ".join(ds.coords.keys())
raise ValueError(
f"Invalid `frames` coordinate dimension '{frames}'! "
f"Available coordinates: {_avail}"
)
# Apply selection along frames dimension
if frames_isel is not None:
if isinstance(frames_isel, int):
frames_isel = [frames_isel]
_selector = {frames: frames_isel}
log.remark("Applying index selection %s ...", _selector)
ds = ds.isel(_selector, drop=False)
# TODO x-y mapping?
# TODO Automate aspect computation from x and y dimensions
# Depending on length of coordinate dimension, ensure that animation mode
# is enabled or disabled.
num_frames = ds.coords[frames].size
if num_frames > 1:
hlpr.enable_animation()
else:
hlpr.disable_animation()
# If not given, retrieve the structure from the data variable's attributes.
if not grid_structure:
structures = {p: get_grid_structure(ds[p]) for p in to_plot}
if len(set(structures.values())) != 1:
raise ValueError(
"Mismatch in grid structure; all grid structures need to be "
f"the same but were: {structures}\n"
"This may have resulted from data attributes being lost in a "
"data transformation. If so, one alternative to re-adding the "
"data attributes (via the `update_with_attrs_from` operation) "
"is to specify `grid_structure` explicitly. "
"For hexagonal grids, grid properties can also be passed via "
"`imshow_kwargs.grid_properties`."
)
grid_structure = next(iter(structures.values()))
# Evaluate limits argument for all properties
# NOTE That `limits` is deprecated in _plot_ca_property. Once it is
# removed from there, remove evaluation of `limits` here as well
for prop_name, spec in to_plot.items():
if spec.get("limits"):
vmin, vmax = spec["limits"]
if vmin == "min":
vmin = ds[prop_name].min().item()
if vmax == "max":
vmax = ds[prop_name].max().item()
spec["limits"] = (vmin, vmax)
if spec.get("vmin") == "min":
spec["vmin"] = ds[prop_name].min().item()
if spec.get("vmax") == "max":
spec["vmax"] = ds[prop_name].max().item()
# Inform about the data that is to be plotted
log.note(
"Performing CA plot for %d data variable%s ...",
len(to_plot),
"" if len(to_plot) == 1 else "s",
)
log.remark(" Data variables: %s", ", ".join(to_plot))
log.remark(
" Dimensions: %s",
", ".join(f"{k}: {s}" for k, s in ds.sizes.items()),
)
log.remark(" Grid structure: %s", grid_structure)
# Some final checks ...
if len(ds.sizes) != 3:
raise ValueError(
"Dataset shape needs to be 3-dimensional, but was: "
f"{dict(ds.sizes)}! Full dataset:\n{ds}"
)
# Prepare the figure ......................................................
# Evaluate column wrapping
if col_wrap and not (col_wrap == "auto" and len(to_plot) < 4):
if col_wrap == "auto":
col_wrap = ceil(sqrt(len(to_plot)))
log.remark(" Column wrapping: %s", col_wrap)
ncols = col_wrap
nrows = ceil(len(to_plot) / col_wrap)
axis_map = {
p: dict(col=i % col_wrap, row=i // col_wrap)
for i, p in enumerate(to_plot)
}
else:
ncols = len(to_plot)
nrows = 1
axis_map = {p: dict(col=col, row=0) for col, p in enumerate(to_plot)}
# Determine the figsize from the size argument (height) and data aspect.
if not size:
size = mpl.rcParams["figure.figsize"][1]
figsize = (size * (aspect + aspect_pad), size)
# Create the figure and set all axes as invisible. This is needed because
# col_wrap may lead to some subplots being completely empty.
hlpr.setup_figure(
figsize=figsize,
scale_figsize_with_subplots_shape=True,
ncols=ncols,
nrows=nrows,
)
for ax in hlpr.axes.flat:
ax.set_visible(False)
# Do the single plot for all data variables, looping through subfigures.
# This creates the first imshow objects, which are being kept track of
# such that they can later be updated.
# This also sets the axes as visible again, but only for those that have
# a property assigned to them.
ims = dict()
for i, (prop_name, props) in enumerate(to_plot.items()):
hlpr.select_axis(**axis_map[prop_name])
hlpr.ax.set_visible(True)
# Select the appropriate data, then plot the data variable
data = select_data(ds, prop_name, {frames: 0})
ims[prop_name] = _plot_ca_property(
prop_name,
hlpr=hlpr,
data=data,
grid_structure=grid_structure,
default_imshow_kwargs=default_imshow_kwargs,
imshow_hexagonal_extra_kwargs=imshow_hexagonal_extra_kwargs,
default_cbar_kwargs=default_cbar_kwargs,
**props,
)
# Use this data for setting the figure suptitle
if i == 0:
set_suptitle(data)
# End of single frame CA state plot function ..............................
# The above variables are all available below, but the update function is
# supposed to start plotting anew, starting from frame 0.
def update_data():
"""Updates the data of the imshow objects"""
def need_autoscale(
*, limits=None, vmin=None, vmax=None, cmap=None, **_
) -> bool:
"""Returns True if there no bounds (limits, vmin, vmax) are given
or if a discrete colormap is created from a dict"""
return (
vmin is None
and vmax is None
and not limits
and not isinstance(cmap, dict)
)
log.note("Plotting animation with %d frames ...", num_frames)
# Determine whether a autoscaling is needed for a property:
needs_autoscale = {
n: need_autoscale(**spec) for n, spec in to_plot.items()
}
# Frame iteration
for frame_idx in range(num_frames):
log.debug("Plotting frame %d ...", frame_idx)
for i, (prop_name, props) in enumerate(to_plot.items()):
hlpr.select_axis(**axis_map[prop_name])
frame_data = select_data(ds, prop_name, {frames: frame_idx})
# Depending on grid structure, update data or plot anew
if grid_structure == "hexagonal":
# Update imshow_hexagonal data according as would otherwise
# happen inside that function
frame_data, _, _ = _prepare_hexgrid_data(
frame_data, x=props.get("x"), y=props.get("y")
)
ims[prop_name].hexagons.set_array(
_flatten_hexgrid_data(frame_data)
)
if needs_autoscale[prop_name]:
ims[prop_name].hexagons.autoscale()
else:
# Update imshow data without creating a new object
ims[prop_name].set_data(frame_data.T)
if needs_autoscale[prop_name]:
ims[prop_name].autoscale()
# Use the first subplot's data for setting the figure suptitle
if i == 0:
set_suptitle(frame_data)
# Done with this frame; yield control to the animation framework.
yield
# Register this update method with the helper, which takes care of the rest
hlpr.register_animation_update(update_data)
# .............................................................................
# DEPRECATED CA state plot
[docs]@is_plot_func(creator="universe", supports_animation=True)
def state(
dm: DataManager,
*,
uni: UniverseGroup,
hlpr: PlotHelper,
model_name: str,
to_plot: dict,
time_idx: int,
default_imshow_kwargs: dict = None,
**_kwargs,
):
r"""Plots the state of the cellular automaton as a 2D heat map.
This plot function can be used for a single plot, but also supports
animation.
Which properties of the state to plot can be defined in ``to_plot``.
Args:
dm (DataManager): The DataManager that holds all loaded data
uni (UniverseGroup): The currently selected universe, parsed by the
:py:class:`~utopya.eval.plotcreators.UniversePlotCreator`.
hlpr (PlotHelper): The plot helper
model_name (str): The name of the model of which the data is to be
plotted
to_plot (dict): Which data to plot and how. The keys of this dict
refer to a path within the data and can include forward slashes to
navigate to data of submodels. Each of these keys is expected to
hold yet another dict, supporting the following configuration
options (all optional):
- cmap (str or dict): The colormap to use. If it is a dict, a
discrete colormap is assumed. The keys will be the labels
and the values the color. Association happens in the order
of entries.
- title (str): The title for this sub-plot
- limits (2-tuple, list): The fixed heat map limits of this
property; if not given, limits will be auto-scaled.
- \**imshow_kwargs: passed on to imshow invocation
time_idx (int): Which time index to plot the data of. Is ignored when
creating an animation.
default_imshow_kwargs (dict, optional): The default parameters passed
to the underlying imshow plotting function. These are updated by
the values given via ``to_plot``.
Raises:
ValueError: Shape mismatch of data selected by ``to_plot``
AttributeError: Got unsupported arguments (referring to the old data
transformation framework)
"""
if _kwargs:
raise AttributeError(
"This plot no longer supports preprocessing or transformation but "
f"got one of the following arguments: {_kwargs}\n"
"Use the new CA plot function (.ca.caplot) which supports the "
"data transformation framework."
)
log.warning(
"The .ca.state plot is deprecated and should no longer be used!\n"
"Please use the .ca.caplot function, which is almost identical in its "
"interface and uses the data transformation framework for data "
"selection and pre-processing."
)
# Helper functions ........................................................
def prepare_data(
prop_name: str, *, all_data: dict, time_idx: int
) -> np.ndarray:
"""Prepares the data for plotting"""
return all_data[prop_name][time_idx]
# Prepare the data ........................................................
# Get the group that all datasets are in
grp = uni["data"][model_name]
# Collect all data
all_data = {p: grp[p] for p in to_plot.keys()}
shapes = [d.shape for p, d in all_data.items()]
if any([shape != shapes[0] for shape in shapes]):
raise ValueError(
"Shape mismatch of properties {}: {}! Cannot plot."
"".format(", ".join(to_plot.keys()), shapes)
)
# Can now be sure they all have the same shape,
# so its fine to take the first shape to extract the number of steps
num_steps = shapes[0][0]
structure = prepare_data(
list(to_plot.keys())[0], all_data=all_data, time_idx=0
).attrs.get("grid_structure", "square")
if structure != "square":
raise ValueError(
"Legacy CA plot no longer supports non-square grid structure "
f"'{structure}'! Use the modern caplot (`.plot.ca`) instead."
)
# Prepare the figure ......................................................
# Prepare the figure to have as many columns as there are properties
hlpr.setup_figure(
ncols=len(to_plot), scale_figsize_with_subplots_shape=True
)
# Store the imshow objects such that only the data has to be updated in a
# following iteration step. Keys will be the property names.
ims = dict()
# Do the single plot for all properties, looping through subfigures
for col_no, (prop_name, props) in enumerate(to_plot.items()):
# Select the axis
hlpr.select_axis(col_no, 0)
# Get the data for this time step
data = prepare_data(prop_name, all_data=all_data, time_idx=time_idx)
# In the first time step create a new imshow object
ims[prop_name] = _plot_ca_property(
prop_name,
data=data,
hlpr=hlpr,
default_imshow_kwargs=default_imshow_kwargs,
**props,
)
# End of single frame CA state plot function ..............................
# The above variables are all available below, but the update function is
# supposed to start plotting anew starting from frame 0.
def update_data():
"""Updates the data of the imshow objects"""
log.info(
"Plotting animation with %d frames of %d %s each ...",
num_steps,
len(to_plot),
"property" if len(to_plot) == 1 else "properties",
)
for time_idx in range(num_steps):
log.debug("Plotting frame for time index %d ...", time_idx)
# Loop through the columns
for col_no, (prop_name, props) in enumerate(to_plot.items()):
hlpr.select_axis(col_no, 0)
data = prepare_data(
prop_name, all_data=all_data, time_idx=time_idx
)
# Update imshow data without creating a new object
ims[prop_name].set_data(data.T)
# If no limits are provided, autoscale the new limits in
# the case of continuous colormaps. A discrete colormap,
# that is provided as a dict, should never have to autoscale.
if not isinstance(props.get("cmap"), dict):
if not props.get("limits"):
ims[prop_name].autoscale()
# Done with this frame; yield control to the animation framework
# which will grab the frame...
yield
log.info("Animation finished.")
# Register this update method with the helper, which takes care of the rest
hlpr.register_animation_update(update_data)