Source code for utopya.eval.plots.snsplot

"""Implements seaborn-based plotting functions"""

import copy
import logging
from typing import Hashable, List, Sequence, Tuple, Union

import pandas as pd
import seaborn as sns
import xarray as xr
from dantro.exceptions import PlottingError
from dantro.plot.funcs.generic import (
    determine_encoding,
    figure_leak_prevention,
    make_facet_grid_plot,
)

from .. import PlotHelper, is_plot_func

log = logging.getLogger(__name__)


# .. Seaborn's figure-level plot functions ....................................
SNS_PLOT_FUNCS = {
    "relplot": sns.relplot,
    "displot": sns.displot,
    "catplot": sns.catplot,
    "lmplot": sns.lmplot,
    "clustermap": sns.clustermap,
    "pairplot": sns.pairplot,
    "jointplot": sns.jointplot,
}

SNS_FACETGRID_KINDS = (
    "relplot",
    "displot",
    "catplot",
    "lmplot",
)

# .. Encodings for seaborn's figure-level plot functions ......................
# TODO Check if all are correct
SNS_ENCODINGS = {
    # FacetGrid: Distributions
    "displot": ("col", "row", "hue"),
    "catplot": ("y", "hue", "col", "row"),
    # FacetGrid: Relational
    "relplot": ("x", "y", "hue", "col", "row", "style", "size"),
    "lmplot": ("x", "y", "hue", "col", "row"),
    # Others
    "clustermap": ("hue", "col", "row"),
    "pairplot": ("hue",),
    "jointplot": (
        "x",
        "y",
        "hue",
    ),
}


# -----------------------------------------------------------------------------


[docs]@is_plot_func(use_dag=True, required_dag_tags=("data",)) def snsplot( *, data: dict, hlpr: PlotHelper, sns_kind: str, free_indices: Tuple[str, ...], optional_free_indices: Tuple[str, ...] = (), auto_encoding: Union[bool, dict] = None, reset_index: bool = False, to_dataframe_kwargs: dict = None, dropna: bool = False, dropna_kwargs: dict = None, sample: Union[bool, int] = False, sample_kwargs: dict = None, **plot_kwargs, ) -> None: """Interface to seaborn's figure-level plot functions. Plot functions are selected via the ``sns_kind`` argument: - ``relplot``: :py:func:`seaborn.relplot` - ``displot``: :py:func:`seaborn.displot` - ``catplot``: :py:func:`seaborn.catplot` - ``lmplot``: :py:func:`seaborn.lmplot` - ``clustermap``: :py:func:`seaborn.clustermap` *(not faceting)* - ``pairplot``: :py:func:`seaborn.pairplot` *(not faceting)* - ``jointplot``: :py:func:`seaborn.jointplot` *(not faceting)* Args: data (dict): The data transformation framework results, expecting a single entry ``data`` which can be a :py:class:`pandas.DataFrame` or an :py:class:`xarray.DataArray` or :py:class:`xarray.Dataset`. hlpr (PlotHelper): The plot helper instance sns_kind (str): Which seaborn plot to use, see list above. free_indices (Tuple[str]): Which index names *not* to associate with a layout encoding; seaborn uses these to calculate the distribution statistics. optional_free_indices (Tuple[str], optional): These indices will be added to the free indices *if they are part of the data frame*. Otherwise, they are silently ignored. auto_encoding (Union[bool, dict], optional): Auto-encoding options. reset_index (bool, optional): Whether to reset indices such that only the ``free_indices`` remain as indices and all others are converted into columns. to_dataframe_kwargs (dict, optional): For xarray data types, this is used to convert the given data into a pandas.DataFrame. sample (bool, optional): If True, will sample a subset from the final dataframe, controlled by ``sample_kwargs`` sample_kwargs (dict, optional): Passed to :py:meth:`pandas.DataFrame.sample`. **plot_kwargs: Passed on to the selected plotting function. """ df = data["data"] # For xarray types, attempt conversion if isinstance(df, (xr.Dataset, xr.DataArray)): tdf_kwargs = to_dataframe_kwargs if to_dataframe_kwargs else {} log.note("Attempting conversion to pd.DataFrame ...") log.remark( " Arguments: %s", ", ".join(f"{k}: {v}" for k, v in tdf_kwargs.items()), ) df = df.to_dataframe(**tdf_kwargs) # Re-index to get long-form data # See: https://seaborn.pydata.org/tutorial/data_structure.html log.note("Evaluating data frame ...") log.remark(" Length: %d", len(df)) log.remark(" Shape: %s", df.shape) log.remark(" Size: %d", df.size) try: log.remark(" Columns: %s", ", ".join(df.columns)) except: # TODO Make more specific or even avoid try-except log.remark(" Columns: (none)") try: log.remark(" Indices: %s", ", ".join(df.index.names)) except: # TODO Make more specific or even avoid try-except log.remark(" Indices: (no named indices)") log.remark(" Free indices: %s", ", ".join(free_indices)) log.remark(" Optionally free: %s", ", ".join(optional_free_indices)) # TODO Add an option to make all indices free, excluding some ... # Apply optionally free indices free_indices += [n for n in optional_free_indices if n in df.index.names] # For some kinds, it makes sense to re-index such that only the free # indices are used as columns if reset_index: reset_for = [n for n in df.index.names if n not in free_indices] if reset_for: df = df.reset_index(level=reset_for) log.remark(" Reset index for: %s", ", ".join(reset_for)) # Might want to drop null values if dropna: dropna_kwargs = dropna_kwargs if dropna_kwargs else {} log.note("Dropping null values ...") log.remark( " Arguments: %s", ", ".join(f"{k}: {v}" for k, v in dropna_kwargs.items()), ) df = df.dropna(**dropna_kwargs) log.remark(" Length after drop: %d", len(df)) # Sampling if sample: if not sample_kwargs: sample_kwargs = {} if isinstance(sample, int) and sample < len(df): sample_kwargs["n"] = sample if sample_kwargs: log.note("Sampling from data frame ...") log.remark( " Arguments: %s", ", ".join(f"{k}: {v}" for k, v in sample_kwargs.items()), ) len_before = len(df) try: df = df.sample(**sample_kwargs) except Exception as exc: log.error( " Sampling failed with %s: %s", type(exc).__name__, exc ) else: log.remark( " Sampling succeeded. New length: %d (%d)", len(df), len(df) - len_before, ) else: log.note("Sampling skipped (no arguments applicable).") # ... further preprocessing ... # Interface with auto-encoding # Need to pop any given `kind` argument (valid input to sns.pairplot) kind = plot_kwargs.pop("kind", None) plot_kwargs = determine_encoding( { n: s for n, s in zip( df.index.names, getattr(df.index, "levshape", [len(df.index)]) ) if n not in free_indices }, kind=sns_kind, auto_encoding=auto_encoding, default_encodings=SNS_ENCODINGS, plot_kwargs=plot_kwargs, ) if kind is not None: plot_kwargs["kind"] = kind # Depending on plot kinds, determine some further arguments if kind in SNS_FACETGRID_KINDS: # Provide a best guess for the `x` encoding, if it is not given if "x" not in plot_kwargs and len(df.columns) == 1: x = str(df.columns[0]) log.note("Using '%s' for x-axis encoding.", x) plot_kwargs["x"] = x # Retrieve the plot function try: plot_func = SNS_PLOT_FUNCS[sns_kind] except KeyError: _avail = ", ".join(SNS_PLOT_FUNCS) raise ValueError( f"Invalid plot kind '{sns_kind}'! Available: {_avail}" ) # Actual plotting . . . . . . . . . . . . . . . . . . . . . . . . . . . . . # Close the existing figure; the seaborn functions create their own hlpr.close_figure() # Let seaborn do the plotting log.note("Now invoking sns.%s ...", sns_kind) try: with figure_leak_prevention(): fg = plot_func(data=df, **plot_kwargs) except Exception as exc: raise PlottingError( f"sns.{sns_kind} failed! Got {type(exc).__name__}: {exc}\n\n" f"Data was:\n{df}\n\n" f"Plot function arguments were:\n {plot_kwargs}" ) from exc # Attach the created figure, including a workaround for `col_wrap`, in # which case `fg.axes` is one-dimensional (for whatever reason) if isinstance(fg, sns.JointGrid): fig = fg.fig axes = [[fg.ax_joint]] # TODO consider registering all axes else: # Assume it's FacetGrid-like fig = fg.fig axes = fg.axes if axes.ndim != 2: axes = axes.reshape((fg._nrow, fg._ncol)) hlpr.attach_figure_and_axes(fig=fig, axes=axes)
# TODO Animation?!