Source code for utopya.eval.plots.attractor

"""This module provides plotting functions to visualize the attractive set of
a dynamical system.

.. todo::

    Migrate these to the more generic DAG-based interface.
"""

import copy
import itertools
import logging
from typing import Callable, Dict, Sequence, Tuple, Union

import matplotlib as mpl
import numpy as np
import xarray as xr
from matplotlib.collections import PatchCollection
from matplotlib.patches import Circle, Rectangle
from scipy.signal import find_peaks

# FIXME use local import
import utopya.eval.plots._attractor as utdp

from ...tools import recursive_update
from .. import DataManager, UniverseGroup
from . import MultiversePlotCreator, PlotHelper, is_plot_func
from ._mpl import HandlerEllipse
from ._utils import calc_pxmap_rectangles

log = logging.getLogger(__name__)

# -----------------------------------------------------------------------------


[docs]@is_plot_func(creator="multiverse") def bifurcation_diagram( dm: DataManager, *, hlpr: PlotHelper, mv_data: xr.Dataset, dim: str = None, dims: Tuple[str, str] = None, analysis_steps: Sequence[Union[str, Tuple[str, str]]], custom_analysis_funcs: Dict[str, Callable] = None, analysis_kwargs: dict = None, visualization_kwargs: dict = None, to_plot: dict = None, **kwargs, ) -> None: """Plots a bifurcation diagram for one or two parameter dimensions (arguments ``dim`` or ``dims``). Args: dm (DataManager): The data manager from which to retrieve the data hlpr (PlotHelper): The PlotHelper that instantiates the figure and takes care of plot aesthetics (labels, title, ...) and saving mv_data (xarray.Dataset): The extracted multidimensional dataset dim (str, optional): The required parameter dimension of the 1d bifurcation diagram. dims (str, optional): The required parameter dimensions (x, y) of the 2d-bifurcation diagram. analysis_steps (Sequence): The analysis steps that are to be made until one is conclusive. Applied per universe. - If seq of str: The str will also be used as attractor key for plotting if the test is conclusive. - If seq of Tuple(str, str): The first str defines the attractor key for plotting, the second str is a key within custom_analysis_funcs. Default analysis_funcs are: - endpoint: utopya.dataprocessing.find_endpoint - fixpoint: utopya.dataprocessing.find_fixpoint - multistability: utdp.find_multistability - oscillation: utdp.find_oscillation - scatter: resolve_scatter custom_analysis_funcs (dict): A collection of custom analysis functions that will overwrite the default analysis funcs (recursive update). analysis_kwargs (dict, optional): The entries need to match the analysis_steps. The subentry (dict) is passed on to the analysis function. visualization_kwargs (dict, optional): The entries need to match the analysis_steps. The subentry (dict) is used to configure a rectangle to visualize the conclusive analysis step. Is passed to matplotlib.patches.rectangle. xy, width, height, and angle are ignored and set automatically. Required in 2d bifurcation diagram. to_plot (dict, optional): The configuration for the data to plot. The entries of this key need to match the data_vars selected in mv_data. It is used to visualize the state of the attractor additionally to the visualization kwargs. Only for 1d-bifurcation diagram. sub_keys: - ``label`` (str, optional): label in plot - ``plot_kwargs`` (dict, optional): passed to scatter for every universe - color (str, recommended): unique color for every data_variable accross universes **kwargs: Collection of optional dicts passed to different functions - plot_coords_kwargs (dict): Passed to ax.scatter to mark the universe's center in the bifurcation diagram - rectangle_map_kwargs (dict): Passed to utopya.eval.plots._utils.calc_pxmap_rectangles - legend_kwargs (dict): Passed to ax.legend """ def resolve_analysis_steps( analysis_steps: Sequence[Union[str, Tuple[str, str]]] ) -> Sequence[Tuple[str, str]]: """Resolve instance of str to Tuple[str, str] in sequence Args: analysis_steps (Sequence[Union[str, Tuple[str, str]]]): The original sequence Returns: analysis_steps (Sequence[Tuple[str, str]]): The sequence of attractor_key and analysis_func pairs. """ for i, analysis_step in enumerate(analysis_steps): # get key and func for the analysis step if isinstance(analysis_step, str): analysis_steps[i] = [analysis_step, analysis_step] return analysis_steps def resolve_to_plot_kwargs(to_plot: dict) -> dict: """Resolves the to_plot dict, e.g. adding labels if not explicitly specified. Args: to_plot (dict): The to_plot dict to parse Returns: dara_vars_plot_kwargs (dict): A dict with the 'plot_kwargs' for every data_var in to_plot. """ if not to_plot: return {} data_vars_plot_kwargs = {} for k, v in to_plot.items(): plot_kwargs = v.get("plot_kwargs", {}) if not plot_kwargs.get("label"): plot_kwargs["label"] = v.get("label", k) data_vars_plot_kwargs[k] = plot_kwargs return data_vars_plot_kwargs def create_legend_handles( *, visualization_kwargs: dict, data_vars_plot_kwargs: dict ): """Creates legend handles Processes entries in data_vars_plot_kwargs (from to_plot) and visualization_kwargs. Args: visualization_kwargs (dict): The visualization kwargs data_vars_plot_kwargs (dict): The resolved entries in to_plot Returns: Tuple[list, list]: Tuple of legend handles and legend labels lists as required by ax.legend(handles, labels) data_vars_plot_kwargs: The updated entries data_vars_plot_kwargs as required by plot_attractor. """ # Some defaults circle_kwargs = dict(xy=(0.5, 0.5), radius=0.25, edgecolor="none") rect_kwargs = dict( xy=(0.0, 0.0), height=0.75, width=1.0, edgecolor="none" ) # Lists to be populated for matplotlib legend legend_handles = [] legend_labels = [] for k, kwargs in data_vars_plot_kwargs.items(): label = kwargs.pop("label", kwargs) kwargs["linewidth"] = kwargs.get("linewidth", 0.0) data_vars_plot_kwargs[k] = kwargs # Determine color if "color" in kwargs: color = kwargs["color"] elif "cmap" in kwargs: cmap = mpl.cm.get_cmap(kwargs["cmap"]) color = cmap(1.0) else: log.warning(f"No color defined for data_var '{k}'!") color = None # Create and add the handle and the label legend_handles.append(Circle(**circle_kwargs, facecolor=color)) legend_labels.append(label) for k, kwargs in visualization_kwargs.items(): if "to_plot" in kwargs: for dvar_name, dvar_kwargs in kwargs["to_plot"].items(): # Make sure a linewidth is set dvar_kwargs["linewidth"] = dvar_kwargs.get( "linewidth", 0.0 ) kwargs["to_plot"][dvar_name] = dvar_kwargs # Create and add the handle and the label legend_handles.append( Rectangle(**rect_kwargs, **dvar_kwargs) ) legend_labels.append(dvar_kwargs.get("label", dvar_name)) else: kwargs["linewidth"] = kwargs.get("linewidth", 0.0) data_vars_plot_kwargs[k] = kwargs legend_handles.append(Rectangle(**rect_kwargs, **kwargs)) legend_labels.append(kwargs.get("label", k)) return [legend_handles, legend_labels], data_vars_plot_kwargs def apply_analysis_steps( data: xr.Dataset, analysis_steps: Sequence[Union[str, Tuple[str, str]]], *, analysis_funcs: dict, analysis_kwargs: dict, ): """Perform the sequence of analysis steps until the first conclusive. Args: data (xarray.Dataset): The data to analyse. analysis_steps (Sequence[Union[str, Tuple[str, str]]]): The analysis steps that are to be made until one is conclusive. Applied per universe. analysis_funcs (dict): The entries need to match the analysis_steps. Map of the analysis_steps to their Callables analysis_kwargs (dict): The entries need to match the analysis_steps. The subentry (dict) is passed on to the corresponding analysis function. Returns: analysis_key (str): The key of the conclusive analysis step. attractor (xarray.Dataset): The data corresponding to this analysis. """ for analysis_key, analysis_func in analysis_steps: analysis_func_kwargs = analysis_kwargs.get(analysis_func, {}) # resolve the analysis function from its name if isinstance(analysis_func, str): if analysis_func in analysis_funcs: analysis_func = analysis_funcs[analysis_func] else: # Try to get it from dataprocessing ... might fail. analysis_func = getattr( utopya.dataprocessing, analysis_func ) # Perfom the analysis step conclusive, attractor = analysis_func(data, **analysis_func_kwargs) # Return if conclusive if conclusive: return analysis_key, attractor # Return non-conclusive return None, None def resolve_rectangle(coord: dict, rectangles: xr.Dataset) -> Rectangle: """Resolve the rectangle patch at this coordinate Args: coord (dict): The bifurcation parameter's coordinate rectangles (xarray.Dataset): The rectangles that cover the 2D space spanned by the coordiantes. The `coord` should be one entry of rectangles.coords. Raises: ValueError: Coordinate not available in rectangles. Returns: Rectangle: A rectangle around a universe with coord and shape defined by rectangles. """ try: rectangle = rectangles.sel(coord) except Exception as exc: raise ValueError( "The requested paramspace coordinate(s) {} are " "not coordinates of rectangles {}. Plot failed." "".format(coord, rectangles.coords) ) from exc rect_spec = rectangle["rect_spec"] return Rectangle(*rect_spec.item()) def append_vis_patch( attrator_key: str, attractor: xr.Dataset, vis_patches: dict, vis_kwargs: dict, **resolve_rectangle_args, ): """Append visualization patch Performs postprocess for - attractor key 'fixpoint' and 'endpoint' if: 'to_plot' in entry of vis_kwargs. Then finds the data_var with highest valued datapoint. Args: attractor_key (str): Key according to which to decode the attractor attractor (xarray.Dataset): The Dataset with the encoded attractor information. See possible encodings vis_patches (dict): The map of attractor_key to List[Rectangle] where to append the new patch vis_kwargs (dict): The visualization kwargs **resolve_rectangle_args: Passed on to resolve_rectangle Raises: ValueError: Bad postprocess key Returns: vis_patches (dict): The new map of attractor_key to List[Rectangle] Deleted Parameters: resolve_rectangle_args (dict): Args as required by resolve_rectangle """ # Depending on the kind of attractor, add different patches kwargs = vis_kwargs.get(attractor_key) if kwargs is None: return vis_patches # Postprocess fixpoint and to_plot, append rectangle if ( attractor_key == "fixpoint" or attractor_key == "endpoint" ) and "to_plot" in kwargs: max_value = -np.inf for data_var_name, data_var in attractor.data_vars.items(): if data_var.max() > max_value: max_value = data_var.max() max_name = data_var_name rect = resolve_rectangle(**resolve_rectangle_args) attractor_var_key = attractor_key + "_" + max_name vis_patches[attractor_var_key].append(rect) # Append rectangle elif vis_patches.get(attractor_key) is not None: rect = resolve_rectangle(coord=coord, rectangles=rects) vis_patches[attractor_key].append(rect) return vis_patches def append_plot_attractor( attractor_key: str, attractor: xr.Dataset, *, coord: float = None, scatter_kwargs: list, **plot_kwargs, ): """Resolves how to plot attractor of specified type at specific bifurcation parameter value. Args: attractor_key (str): Key according to which to decode the attractor attractor (xarray.Dataset): The Dataset with the encoded attractor information. See possible encodings coord (float, optional): The bifurcation parameter's coordinate, if None its derived from the attractors coordinates scatter_kwargs (list): The list of scatter datasets where to append the new scatter. plot_kwargs (dict, optional): The kwargs used to specify ax.scatter where the entries match the attractor.data_vars Possible encodings, i.e. values for ``attractor_key`` (of type :py:class:`xarray.Dataset` or a compatible type): - ``fixpoint``: dataset with dimensions ``()`` - ``scatter``: dataset with dimensions ``(time: >=1)`` - ``multistability``: dataset with dimensions ``(<initial_state>: >= 1)`` - ``oscillation``: dataset with dimensions ``(osc: 2)``, the minimum and maximum .. note:: The attractor must contain the bifurcation parameter coordinate Raises: KeyError: Unknown ``attractor_key`` or no bifurcation coordinate ValueError: Attractor encoding mismatched with the given ``attractor_key`` Returns: dict: The new list of scatter datasets """ # Get the bifurcation parameter coordinate if not coord: try: coord = attractor[dim] except KeyError as err: raise ValueError( f"No bifurcation parameter coordinate '{dim}' " "could be found! Either have it as a " "coordinate in 'attractor' or pass it to " "'plot_attractor' explicitly." ) from err # Resolve the scatter kwargs depending on attractor key if attractor_key in ("fixpoint", "endpoint", "multistability"): for data_var_name, data_var in attractor.data_vars.items(): data_var = data_var.where(data_var != np.nan, drop=True) entries = 1 if data_var.shape: entries = len(data_var) scatter_kwargs.append( dict( x=[coord] * entries, y=data_var, **plot_kwargs.get(data_var_name, {}), ) ) elif attractor_key == "scatter": for data_var_name, data_var in attractor.data_vars.items(): if "cmap" in plot_kwargs.get(data_var_name, {}): scatter_kwargs.append( dict( x=[coord] * len(data_var.data), y=data_var, c=attractor["time"], **plot_kwargs.get(data_var_name, {}), ) ) else: scatter_kwargs.append( dict( x=[coord] * len(data_var.data), y=data_var, **plot_kwargs.get(data_var_name, {}), ) ) elif attractor_key == "oscillation": for data_var_name, data_var in attractor.data_vars.items(): scatter_kwargs.append( dict( x=[coord] * len(data_var.data), y=data_var, **plot_kwargs.get(data_var_name, {}), ) ) elif attractor_key: raise ValueError( f"Invalid attractor-key '{attrator_key}'! " "Available keys: 'endpoint', fixpoint'," " 'multistability', 'scatter', 'oscillation'." ) return scatter_kwargs def resolve_scatter( data: xr.Dataset, *, spin_up_time: int = 0, **kwargs ) -> tuple: """A mock analysis function to plot all times larger than a spin up time. Args: data (xarray.Dataset): The dataset to analyse spin_up_time (int, optional): The spin-up-time **kwargs: Ignored """ return True, data.where(data.time >= spin_up_time, drop=True) # ......................................................................... # Check argument values if not dim and not dims: raise ValueError( "No dim (str) or dims (Tuple[str, str]) specified. " "Use dim for a 1d-bifurcation diagram and dims for a " "2d-bifurcation diagram." ) if dim and dims: raise ValueError( "dim='{}' and dims='{}' specified. " "Use either dim for a 1d-bifurcation diagram or dims " "for a 2d-bifurcation diagram." "".format(dim, dims) ) if dims is not None and len(dims) != 2: raise ValueError( f"Argument dims should be of length 2, but was: {dims}" ) # TODO In the future, consider not using `dim` below here but handling it # via the length of `dims`. # Default values if visualization_kwargs is None: visualization_kwargs = {} if analysis_kwargs is None: analysis_kwargs = {} # Resolve legend handles and visualization kwargs data_vars_plot_kwargs = resolve_to_plot_kwargs(to_plot) legend_handles, data_vars_plot_kwargs = create_legend_handles( visualization_kwargs=visualization_kwargs, data_vars_plot_kwargs=data_vars_plot_kwargs, ) # Define default analysis functions analysis_funcs = dict( endpoint=utdp.find_endpoint, fixpoint=utdp.find_fixpoint, multistability=utdp.find_multistability, oscillation=utdp.find_oscillation, scatter=resolve_scatter, ) # If given, update if custom_analysis_funcs: log.debug("Updating with custom analysis functions ...") analysis_funcs = recursive_update( analysis_funcs, custom_analysis_funcs ) analysis_steps = resolve_analysis_steps(analysis_steps) # Obtain the rectangles covering space spanned by the coordinates rectangle_map_kwargs = kwargs.get("rectangle_map_kwargs", {}) if dim: rects, limits = calc_pxmap_rectangles( x_coords=mv_data[dim].values, y_coords=None, **rectangle_map_kwargs ) elif dims: rects, limits = calc_pxmap_rectangles( x_coords=mv_data[dims[0]].values, y_coords=mv_data[dims[1]].values, **rectangle_map_kwargs, ) # Obtain the list of param_coords to iterate if dim: param_iter = mv_data[dim].values elif dims: param_iter = itertools.product( mv_data[dims[0]].values, mv_data[dims[1]].values ) # Map of analysis_key to list[mpatch.Rectangle] vis_patches = {} for analysis_key, _ in analysis_steps: if not visualization_kwargs.get(analysis_key): continue if "to_plot" in visualization_kwargs[analysis_key]: for var_key, _ in visualization_kwargs[analysis_key][ "to_plot" ].items(): analysis_var_key = analysis_key + "_" + var_key vis_patches[analysis_var_key] = [] else: vis_patches[analysis_key] = [] # The List[dict] passed to ax.scatter scatter_kwargs = [] scatter_coords_kwargs = [] # Iterate the parameter coordinates for param_coord in param_iter: # Resolve the param_coord to dict if dim: param_coord = {dim: param_coord} elif dims: param_coord = {dims[0]: param_coord[0], dims[1]: param_coord[1]} # Plot coord if dim and kwargs.get("plot_coords_kwargs"): plot_coords_kwargs = kwargs.get("plot_coords_kwargs") scatter_coords_kwargs.append( { "x": param_coord[dim], "y": plot_coords_kwargs.pop("y", 0.0), **plot_coords_kwargs, } ) if dims and kwargs.get("plot_coords_kwargs"): scatter_coords_kwargs.append( { "x": param_coord[dims[0]], "y": param_coord[dims[1]], **kwargs.get("plot_coords_kwargs"), } ) # Select the data and analyse data = mv_data.sel(param_coord) attractor_key, attractor = apply_analysis_steps( data, analysis_steps, analysis_funcs=analysis_funcs, analysis_kwargs=analysis_kwargs, ) # If conclusive, append a rectangular patch to the attractor_key's # patch collection if attractor_key: # Determine coordinate value if dim: rect_map_kwargs = kwargs.get("rectangle_map_kwargs", {}) y = rect_map_kwargs.get("default_pos", (0.0, 0.0))[1] coord = dict(x=param_coord[dim], y=y) elif dims: coord = dict(x=param_coord[dims[0]], y=param_coord[dims[1]]) vis_patches = append_vis_patch( attractor_key, attractor, vis_patches, visualization_kwargs, coord=coord, rectangles=rects, ) # For 1d case ... if dim and to_plot: scatter_kwargs = append_plot_attractor( attractor_key, attractor, coord=param_coord[dim], scatter_kwargs=scatter_kwargs, **data_vars_plot_kwargs, ) # Draw collection of visualization patches for analysis_key, _ in analysis_steps: if not visualization_kwargs.get(analysis_key): continue if analysis_key in visualization_kwargs: vis_kwargs = visualization_kwargs[analysis_key] if "to_plot" in vis_kwargs: for var_key, var_kwargs in vis_kwargs["to_plot"].items(): attractor_var_key = analysis_key + "_" + var_key pc = PatchCollection( vis_patches[attractor_var_key], **var_kwargs ) hlpr.ax.add_collection(pc) else: pc = PatchCollection(vis_patches[analysis_key], **vis_kwargs) hlpr.ax.add_collection(pc) # else: nothing to do # Scatter the universe's coordinates for kws in scatter_coords_kwargs: hlpr.ax.scatter(**kws) # Scatter the attractor for kws in scatter_kwargs: hlpr.ax.scatter(**kws) # Provide PlotHelper defaults hlpr.provide_defaults("set_limits", **limits) if dim: hlpr.provide_defaults("set_labels", x=dim, y="state") elif dims: hlpr.provide_defaults("set_labels", x=dims[0], y=dims[1]) if legend_handles: hlpr.ax.legend( legend_handles[0], legend_handles[1], handler_map={Circle: HandlerEllipse()}, **kwargs.get("legend_kwargs", {}), )