AIM-PIbd-32-Kurbanova-A-A/aimenv/Lib/site-packages/seaborn/axisgrid.py

2402 lines
86 KiB
Python
Raw Normal View History

2024-10-02 22:15:59 +04:00
from __future__ import annotations
from itertools import product
from inspect import signature
import warnings
from textwrap import dedent
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from ._base import VectorPlotter, variable_type, categorical_order
from ._core.data import handle_data_source
from ._compat import share_axis, get_legend_handles
from . import utils
from .utils import (
adjust_legend_subtitles,
set_hls_values,
_check_argument,
_draw_figure,
_disable_autolayout
)
from .palettes import color_palette, blend_palette
from ._docstrings import (
DocstringComponents,
_core_docs,
)
__all__ = ["FacetGrid", "PairGrid", "JointGrid", "pairplot", "jointplot"]
_param_docs = DocstringComponents.from_nested_components(
core=_core_docs["params"],
)
class _BaseGrid:
"""Base class for grids of subplots."""
def set(self, **kwargs):
"""Set attributes on each subplot Axes."""
for ax in self.axes.flat:
if ax is not None: # Handle removed axes
ax.set(**kwargs)
return self
@property
def fig(self):
"""DEPRECATED: prefer the `figure` property."""
# Grid.figure is preferred because it matches the Axes attribute name.
# But as the maintanace burden on having this property is minimal,
# let's be slow about formally deprecating it. For now just note its deprecation
# in the docstring; add a warning in version 0.13, and eventually remove it.
return self._figure
@property
def figure(self):
"""Access the :class:`matplotlib.figure.Figure` object underlying the grid."""
return self._figure
def apply(self, func, *args, **kwargs):
"""
Pass the grid to a user-supplied function and return self.
The `func` must accept an object of this type for its first
positional argument. Additional arguments are passed through.
The return value of `func` is ignored; this method returns self.
See the `pipe` method if you want the return value.
Added in v0.12.0.
"""
func(self, *args, **kwargs)
return self
def pipe(self, func, *args, **kwargs):
"""
Pass the grid to a user-supplied function and return its value.
The `func` must accept an object of this type for its first
positional argument. Additional arguments are passed through.
The return value of `func` becomes the return value of this method.
See the `apply` method if you want to return self instead.
Added in v0.12.0.
"""
return func(self, *args, **kwargs)
def savefig(self, *args, **kwargs):
"""
Save an image of the plot.
This wraps :meth:`matplotlib.figure.Figure.savefig`, using bbox_inches="tight"
by default. Parameters are passed through to the matplotlib function.
"""
kwargs = kwargs.copy()
kwargs.setdefault("bbox_inches", "tight")
self.figure.savefig(*args, **kwargs)
class Grid(_BaseGrid):
"""A grid that can have multiple subplots and an external legend."""
_margin_titles = False
_legend_out = True
def __init__(self):
self._tight_layout_rect = [0, 0, 1, 1]
self._tight_layout_pad = None
# This attribute is set externally and is a hack to handle newer functions that
# don't add proxy artists onto the Axes. We need an overall cleaner approach.
self._extract_legend_handles = False
def tight_layout(self, *args, **kwargs):
"""Call fig.tight_layout within rect that exclude the legend."""
kwargs = kwargs.copy()
kwargs.setdefault("rect", self._tight_layout_rect)
if self._tight_layout_pad is not None:
kwargs.setdefault("pad", self._tight_layout_pad)
self._figure.tight_layout(*args, **kwargs)
return self
def add_legend(self, legend_data=None, title=None, label_order=None,
adjust_subtitles=False, **kwargs):
"""Draw a legend, maybe placing it outside axes and resizing the figure.
Parameters
----------
legend_data : dict
Dictionary mapping label names (or two-element tuples where the
second element is a label name) to matplotlib artist handles. The
default reads from ``self._legend_data``.
title : string
Title for the legend. The default reads from ``self._hue_var``.
label_order : list of labels
The order that the legend entries should appear in. The default
reads from ``self.hue_names``.
adjust_subtitles : bool
If True, modify entries with invisible artists to left-align
the labels and set the font size to that of a title.
kwargs : key, value pairings
Other keyword arguments are passed to the underlying legend methods
on the Figure or Axes object.
Returns
-------
self : Grid instance
Returns self for easy chaining.
"""
# Find the data for the legend
if legend_data is None:
legend_data = self._legend_data
if label_order is None:
if self.hue_names is None:
label_order = list(legend_data.keys())
else:
label_order = list(map(utils.to_utf8, self.hue_names))
blank_handle = mpl.patches.Patch(alpha=0, linewidth=0)
handles = [legend_data.get(lab, blank_handle) for lab in label_order]
title = self._hue_var if title is None else title
title_size = mpl.rcParams["legend.title_fontsize"]
# Unpack nested labels from a hierarchical legend
labels = []
for entry in label_order:
if isinstance(entry, tuple):
_, label = entry
else:
label = entry
labels.append(label)
# Set default legend kwargs
kwargs.setdefault("scatterpoints", 1)
if self._legend_out:
kwargs.setdefault("frameon", False)
kwargs.setdefault("loc", "center right")
# Draw a full-figure legend outside the grid
figlegend = self._figure.legend(handles, labels, **kwargs)
self._legend = figlegend
figlegend.set_title(title, prop={"size": title_size})
if adjust_subtitles:
adjust_legend_subtitles(figlegend)
# Draw the plot to set the bounding boxes correctly
_draw_figure(self._figure)
# Calculate and set the new width of the figure so the legend fits
legend_width = figlegend.get_window_extent().width / self._figure.dpi
fig_width, fig_height = self._figure.get_size_inches()
self._figure.set_size_inches(fig_width + legend_width, fig_height)
# Draw the plot again to get the new transformations
_draw_figure(self._figure)
# Now calculate how much space we need on the right side
legend_width = figlegend.get_window_extent().width / self._figure.dpi
space_needed = legend_width / (fig_width + legend_width)
margin = .04 if self._margin_titles else .01
self._space_needed = margin + space_needed
right = 1 - self._space_needed
# Place the subplot axes to give space for the legend
self._figure.subplots_adjust(right=right)
self._tight_layout_rect[2] = right
else:
# Draw a legend in the first axis
ax = self.axes.flat[0]
kwargs.setdefault("loc", "best")
leg = ax.legend(handles, labels, **kwargs)
leg.set_title(title, prop={"size": title_size})
self._legend = leg
if adjust_subtitles:
adjust_legend_subtitles(leg)
return self
def _update_legend_data(self, ax):
"""Extract the legend data from an axes object and save it."""
data = {}
# Get data directly from the legend, which is necessary
# for newer functions that don't add labeled proxy artists
if ax.legend_ is not None and self._extract_legend_handles:
handles = get_legend_handles(ax.legend_)
labels = [t.get_text() for t in ax.legend_.texts]
data.update({label: handle for handle, label in zip(handles, labels)})
handles, labels = ax.get_legend_handles_labels()
data.update({label: handle for handle, label in zip(handles, labels)})
self._legend_data.update(data)
# Now clear the legend
ax.legend_ = None
def _get_palette(self, data, hue, hue_order, palette):
"""Get a list of colors for the hue variable."""
if hue is None:
palette = color_palette(n_colors=1)
else:
hue_names = categorical_order(data[hue], hue_order)
n_colors = len(hue_names)
# By default use either the current color palette or HUSL
if palette is None:
current_palette = utils.get_color_cycle()
if n_colors > len(current_palette):
colors = color_palette("husl", n_colors)
else:
colors = color_palette(n_colors=n_colors)
# Allow for palette to map from hue variable names
elif isinstance(palette, dict):
color_names = [palette[h] for h in hue_names]
colors = color_palette(color_names, n_colors)
# Otherwise act as if we just got a list of colors
else:
colors = color_palette(palette, n_colors)
palette = color_palette(colors, n_colors)
return palette
@property
def legend(self):
"""The :class:`matplotlib.legend.Legend` object, if present."""
try:
return self._legend
except AttributeError:
return None
def tick_params(self, axis='both', **kwargs):
"""Modify the ticks, tick labels, and gridlines.
Parameters
----------
axis : {'x', 'y', 'both'}
The axis on which to apply the formatting.
kwargs : keyword arguments
Additional keyword arguments to pass to
:meth:`matplotlib.axes.Axes.tick_params`.
Returns
-------
self : Grid instance
Returns self for easy chaining.
"""
for ax in self.figure.axes:
ax.tick_params(axis=axis, **kwargs)
return self
_facet_docs = dict(
data=dedent("""\
data : DataFrame
Tidy ("long-form") dataframe where each column is a variable and each
row is an observation.\
"""),
rowcol=dedent("""\
row, col : vectors or keys in ``data``
Variables that define subsets to plot on different facets.\
"""),
rowcol_order=dedent("""\
{row,col}_order : vector of strings
Specify the order in which levels of the ``row`` and/or ``col`` variables
appear in the grid of subplots.\
"""),
col_wrap=dedent("""\
col_wrap : int
"Wrap" the column variable at this width, so that the column facets
span multiple rows. Incompatible with a ``row`` facet.\
"""),
share_xy=dedent("""\
share{x,y} : bool, 'col', or 'row' optional
If true, the facets will share y axes across columns and/or x axes
across rows.\
"""),
height=dedent("""\
height : scalar
Height (in inches) of each facet. See also: ``aspect``.\
"""),
aspect=dedent("""\
aspect : scalar
Aspect ratio of each facet, so that ``aspect * height`` gives the width
of each facet in inches.\
"""),
palette=dedent("""\
palette : palette name, list, or dict
Colors to use for the different levels of the ``hue`` variable. Should
be something that can be interpreted by :func:`color_palette`, or a
dictionary mapping hue levels to matplotlib colors.\
"""),
legend_out=dedent("""\
legend_out : bool
If ``True``, the figure size will be extended, and the legend will be
drawn outside the plot on the center right.\
"""),
margin_titles=dedent("""\
margin_titles : bool
If ``True``, the titles for the row variable are drawn to the right of
the last column. This option is experimental and may not work in all
cases.\
"""),
facet_kws=dedent("""\
facet_kws : dict
Additional parameters passed to :class:`FacetGrid`.
"""),
)
class FacetGrid(Grid):
"""Multi-plot grid for plotting conditional relationships."""
def __init__(
self, data, *,
row=None, col=None, hue=None, col_wrap=None,
sharex=True, sharey=True, height=3, aspect=1, palette=None,
row_order=None, col_order=None, hue_order=None, hue_kws=None,
dropna=False, legend_out=True, despine=True,
margin_titles=False, xlim=None, ylim=None, subplot_kws=None,
gridspec_kws=None,
):
super().__init__()
data = handle_data_source(data)
# Determine the hue facet layer information
hue_var = hue
if hue is None:
hue_names = None
else:
hue_names = categorical_order(data[hue], hue_order)
colors = self._get_palette(data, hue, hue_order, palette)
# Set up the lists of names for the row and column facet variables
if row is None:
row_names = []
else:
row_names = categorical_order(data[row], row_order)
if col is None:
col_names = []
else:
col_names = categorical_order(data[col], col_order)
# Additional dict of kwarg -> list of values for mapping the hue var
hue_kws = hue_kws if hue_kws is not None else {}
# Make a boolean mask that is True anywhere there is an NA
# value in one of the faceting variables, but only if dropna is True
none_na = np.zeros(len(data), bool)
if dropna:
row_na = none_na if row is None else data[row].isnull()
col_na = none_na if col is None else data[col].isnull()
hue_na = none_na if hue is None else data[hue].isnull()
not_na = ~(row_na | col_na | hue_na)
else:
not_na = ~none_na
# Compute the grid shape
ncol = 1 if col is None else len(col_names)
nrow = 1 if row is None else len(row_names)
self._n_facets = ncol * nrow
self._col_wrap = col_wrap
if col_wrap is not None:
if row is not None:
err = "Cannot use `row` and `col_wrap` together."
raise ValueError(err)
ncol = col_wrap
nrow = int(np.ceil(len(col_names) / col_wrap))
self._ncol = ncol
self._nrow = nrow
# Calculate the base figure size
# This can get stretched later by a legend
# TODO this doesn't account for axis labels
figsize = (ncol * height * aspect, nrow * height)
# Validate some inputs
if col_wrap is not None:
margin_titles = False
# Build the subplot keyword dictionary
subplot_kws = {} if subplot_kws is None else subplot_kws.copy()
gridspec_kws = {} if gridspec_kws is None else gridspec_kws.copy()
if xlim is not None:
subplot_kws["xlim"] = xlim
if ylim is not None:
subplot_kws["ylim"] = ylim
# --- Initialize the subplot grid
with _disable_autolayout():
fig = plt.figure(figsize=figsize)
if col_wrap is None:
kwargs = dict(squeeze=False,
sharex=sharex, sharey=sharey,
subplot_kw=subplot_kws,
gridspec_kw=gridspec_kws)
axes = fig.subplots(nrow, ncol, **kwargs)
if col is None and row is None:
axes_dict = {}
elif col is None:
axes_dict = dict(zip(row_names, axes.flat))
elif row is None:
axes_dict = dict(zip(col_names, axes.flat))
else:
facet_product = product(row_names, col_names)
axes_dict = dict(zip(facet_product, axes.flat))
else:
# If wrapping the col variable we need to make the grid ourselves
if gridspec_kws:
warnings.warn("`gridspec_kws` ignored when using `col_wrap`")
n_axes = len(col_names)
axes = np.empty(n_axes, object)
axes[0] = fig.add_subplot(nrow, ncol, 1, **subplot_kws)
if sharex:
subplot_kws["sharex"] = axes[0]
if sharey:
subplot_kws["sharey"] = axes[0]
for i in range(1, n_axes):
axes[i] = fig.add_subplot(nrow, ncol, i + 1, **subplot_kws)
axes_dict = dict(zip(col_names, axes))
# --- Set up the class attributes
# Attributes that are part of the public API but accessed through
# a property so that Sphinx adds them to the auto class doc
self._figure = fig
self._axes = axes
self._axes_dict = axes_dict
self._legend = None
# Public attributes that aren't explicitly documented
# (It's not obvious that having them be public was a good idea)
self.data = data
self.row_names = row_names
self.col_names = col_names
self.hue_names = hue_names
self.hue_kws = hue_kws
# Next the private variables
self._nrow = nrow
self._row_var = row
self._ncol = ncol
self._col_var = col
self._margin_titles = margin_titles
self._margin_titles_texts = []
self._col_wrap = col_wrap
self._hue_var = hue_var
self._colors = colors
self._legend_out = legend_out
self._legend_data = {}
self._x_var = None
self._y_var = None
self._sharex = sharex
self._sharey = sharey
self._dropna = dropna
self._not_na = not_na
# --- Make the axes look good
self.set_titles()
self.tight_layout()
if despine:
self.despine()
if sharex in [True, 'col']:
for ax in self._not_bottom_axes:
for label in ax.get_xticklabels():
label.set_visible(False)
ax.xaxis.offsetText.set_visible(False)
ax.xaxis.label.set_visible(False)
if sharey in [True, 'row']:
for ax in self._not_left_axes:
for label in ax.get_yticklabels():
label.set_visible(False)
ax.yaxis.offsetText.set_visible(False)
ax.yaxis.label.set_visible(False)
__init__.__doc__ = dedent("""\
Initialize the matplotlib figure and FacetGrid object.
This class maps a dataset onto multiple axes arrayed in a grid of rows
and columns that correspond to *levels* of variables in the dataset.
The plots it produces are often called "lattice", "trellis", or
"small-multiple" graphics.
It can also represent levels of a third variable with the ``hue``
parameter, which plots different subsets of data in different colors.
This uses color to resolve elements on a third dimension, but only
draws subsets on top of each other and will not tailor the ``hue``
parameter for the specific visualization the way that axes-level
functions that accept ``hue`` will.
The basic workflow is to initialize the :class:`FacetGrid` object with
the dataset and the variables that are used to structure the grid. Then
one or more plotting functions can be applied to each subset by calling
:meth:`FacetGrid.map` or :meth:`FacetGrid.map_dataframe`. Finally, the
plot can be tweaked with other methods to do things like change the
axis labels, use different ticks, or add a legend. See the detailed
code examples below for more information.
.. warning::
When using seaborn functions that infer semantic mappings from a
dataset, care must be taken to synchronize those mappings across
facets (e.g., by defining the ``hue`` mapping with a palette dict or
setting the data type of the variables to ``category``). In most cases,
it will be better to use a figure-level function (e.g. :func:`relplot`
or :func:`catplot`) than to use :class:`FacetGrid` directly.
See the :ref:`tutorial <grid_tutorial>` for more information.
Parameters
----------
{data}
row, col, hue : strings
Variables that define subsets of the data, which will be drawn on
separate facets in the grid. See the ``{{var}}_order`` parameters to
control the order of levels of this variable.
{col_wrap}
{share_xy}
{height}
{aspect}
{palette}
{{row,col,hue}}_order : lists
Order for the levels of the faceting variables. By default, this
will be the order that the levels appear in ``data`` or, if the
variables are pandas categoricals, the category order.
hue_kws : dictionary of param -> list of values mapping
Other keyword arguments to insert into the plotting call to let
other plot attributes vary across levels of the hue variable (e.g.
the markers in a scatterplot).
{legend_out}
despine : boolean
Remove the top and right spines from the plots.
{margin_titles}
{{x, y}}lim: tuples
Limits for each of the axes on each facet (only relevant when
share{{x, y}} is True).
subplot_kws : dict
Dictionary of keyword arguments passed to matplotlib subplot(s)
methods.
gridspec_kws : dict
Dictionary of keyword arguments passed to
:class:`matplotlib.gridspec.GridSpec`
(via :meth:`matplotlib.figure.Figure.subplots`).
Ignored if ``col_wrap`` is not ``None``.
See Also
--------
PairGrid : Subplot grid for plotting pairwise relationships
relplot : Combine a relational plot and a :class:`FacetGrid`
displot : Combine a distribution plot and a :class:`FacetGrid`
catplot : Combine a categorical plot and a :class:`FacetGrid`
lmplot : Combine a regression plot and a :class:`FacetGrid`
Examples
--------
.. note::
These examples use seaborn functions to demonstrate some of the
advanced features of the class, but in most cases you will want
to use figue-level functions (e.g. :func:`displot`, :func:`relplot`)
to make the plots shown here.
.. include:: ../docstrings/FacetGrid.rst
""").format(**_facet_docs)
def facet_data(self):
"""Generator for name indices and data subsets for each facet.
Yields
------
(i, j, k), data_ijk : tuple of ints, DataFrame
The ints provide an index into the {row, col, hue}_names attribute,
and the dataframe contains a subset of the full data corresponding
to each facet. The generator yields subsets that correspond with
the self.axes.flat iterator, or self.axes[i, j] when `col_wrap`
is None.
"""
data = self.data
# Construct masks for the row variable
if self.row_names:
row_masks = [data[self._row_var] == n for n in self.row_names]
else:
row_masks = [np.repeat(True, len(self.data))]
# Construct masks for the column variable
if self.col_names:
col_masks = [data[self._col_var] == n for n in self.col_names]
else:
col_masks = [np.repeat(True, len(self.data))]
# Construct masks for the hue variable
if self.hue_names:
hue_masks = [data[self._hue_var] == n for n in self.hue_names]
else:
hue_masks = [np.repeat(True, len(self.data))]
# Here is the main generator loop
for (i, row), (j, col), (k, hue) in product(enumerate(row_masks),
enumerate(col_masks),
enumerate(hue_masks)):
data_ijk = data[row & col & hue & self._not_na]
yield (i, j, k), data_ijk
def map(self, func, *args, **kwargs):
"""Apply a plotting function to each facet's subset of the data.
Parameters
----------
func : callable
A plotting function that takes data and keyword arguments. It
must plot to the currently active matplotlib Axes and take a
`color` keyword argument. If faceting on the `hue` dimension,
it must also take a `label` keyword argument.
args : strings
Column names in self.data that identify variables with data to
plot. The data for each variable is passed to `func` in the
order the variables are specified in the call.
kwargs : keyword arguments
All keyword arguments are passed to the plotting function.
Returns
-------
self : object
Returns self.
"""
# If color was a keyword argument, grab it here
kw_color = kwargs.pop("color", None)
# How we use the function depends on where it comes from
func_module = str(getattr(func, "__module__", ""))
# Check for categorical plots without order information
if func_module == "seaborn.categorical":
if "order" not in kwargs:
warning = ("Using the {} function without specifying "
"`order` is likely to produce an incorrect "
"plot.".format(func.__name__))
warnings.warn(warning)
if len(args) == 3 and "hue_order" not in kwargs:
warning = ("Using the {} function without specifying "
"`hue_order` is likely to produce an incorrect "
"plot.".format(func.__name__))
warnings.warn(warning)
# Iterate over the data subsets
for (row_i, col_j, hue_k), data_ijk in self.facet_data():
# If this subset is null, move on
if not data_ijk.values.size:
continue
# Get the current axis
modify_state = not func_module.startswith("seaborn")
ax = self.facet_axis(row_i, col_j, modify_state)
# Decide what color to plot with
kwargs["color"] = self._facet_color(hue_k, kw_color)
# Insert the other hue aesthetics if appropriate
for kw, val_list in self.hue_kws.items():
kwargs[kw] = val_list[hue_k]
# Insert a label in the keyword arguments for the legend
if self._hue_var is not None:
kwargs["label"] = utils.to_utf8(self.hue_names[hue_k])
# Get the actual data we are going to plot with
plot_data = data_ijk[list(args)]
if self._dropna:
plot_data = plot_data.dropna()
plot_args = [v for k, v in plot_data.items()]
# Some matplotlib functions don't handle pandas objects correctly
if func_module.startswith("matplotlib"):
plot_args = [v.values for v in plot_args]
# Draw the plot
self._facet_plot(func, ax, plot_args, kwargs)
# Finalize the annotations and layout
self._finalize_grid(args[:2])
return self
def map_dataframe(self, func, *args, **kwargs):
"""Like ``.map`` but passes args as strings and inserts data in kwargs.
This method is suitable for plotting with functions that accept a
long-form DataFrame as a `data` keyword argument and access the
data in that DataFrame using string variable names.
Parameters
----------
func : callable
A plotting function that takes data and keyword arguments. Unlike
the `map` method, a function used here must "understand" Pandas
objects. It also must plot to the currently active matplotlib Axes
and take a `color` keyword argument. If faceting on the `hue`
dimension, it must also take a `label` keyword argument.
args : strings
Column names in self.data that identify variables with data to
plot. The data for each variable is passed to `func` in the
order the variables are specified in the call.
kwargs : keyword arguments
All keyword arguments are passed to the plotting function.
Returns
-------
self : object
Returns self.
"""
# If color was a keyword argument, grab it here
kw_color = kwargs.pop("color", None)
# Iterate over the data subsets
for (row_i, col_j, hue_k), data_ijk in self.facet_data():
# If this subset is null, move on
if not data_ijk.values.size:
continue
# Get the current axis
modify_state = not str(func.__module__).startswith("seaborn")
ax = self.facet_axis(row_i, col_j, modify_state)
# Decide what color to plot with
kwargs["color"] = self._facet_color(hue_k, kw_color)
# Insert the other hue aesthetics if appropriate
for kw, val_list in self.hue_kws.items():
kwargs[kw] = val_list[hue_k]
# Insert a label in the keyword arguments for the legend
if self._hue_var is not None:
kwargs["label"] = self.hue_names[hue_k]
# Stick the facet dataframe into the kwargs
if self._dropna:
data_ijk = data_ijk.dropna()
kwargs["data"] = data_ijk
# Draw the plot
self._facet_plot(func, ax, args, kwargs)
# For axis labels, prefer to use positional args for backcompat
# but also extract the x/y kwargs and use if no corresponding arg
axis_labels = [kwargs.get("x", None), kwargs.get("y", None)]
for i, val in enumerate(args[:2]):
axis_labels[i] = val
self._finalize_grid(axis_labels)
return self
def _facet_color(self, hue_index, kw_color):
color = self._colors[hue_index]
if kw_color is not None:
return kw_color
elif color is not None:
return color
def _facet_plot(self, func, ax, plot_args, plot_kwargs):
# Draw the plot
if str(func.__module__).startswith("seaborn"):
plot_kwargs = plot_kwargs.copy()
semantics = ["x", "y", "hue", "size", "style"]
for key, val in zip(semantics, plot_args):
plot_kwargs[key] = val
plot_args = []
plot_kwargs["ax"] = ax
func(*plot_args, **plot_kwargs)
# Sort out the supporting information
self._update_legend_data(ax)
def _finalize_grid(self, axlabels):
"""Finalize the annotations and layout."""
self.set_axis_labels(*axlabels)
self.tight_layout()
def facet_axis(self, row_i, col_j, modify_state=True):
"""Make the axis identified by these indices active and return it."""
# Calculate the actual indices of the axes to plot on
if self._col_wrap is not None:
ax = self.axes.flat[col_j]
else:
ax = self.axes[row_i, col_j]
# Get a reference to the axes object we want, and make it active
if modify_state:
plt.sca(ax)
return ax
def despine(self, **kwargs):
"""Remove axis spines from the facets."""
utils.despine(self._figure, **kwargs)
return self
def set_axis_labels(self, x_var=None, y_var=None, clear_inner=True, **kwargs):
"""Set axis labels on the left column and bottom row of the grid."""
if x_var is not None:
self._x_var = x_var
self.set_xlabels(x_var, clear_inner=clear_inner, **kwargs)
if y_var is not None:
self._y_var = y_var
self.set_ylabels(y_var, clear_inner=clear_inner, **kwargs)
return self
def set_xlabels(self, label=None, clear_inner=True, **kwargs):
"""Label the x axis on the bottom row of the grid."""
if label is None:
label = self._x_var
for ax in self._bottom_axes:
ax.set_xlabel(label, **kwargs)
if clear_inner:
for ax in self._not_bottom_axes:
ax.set_xlabel("")
return self
def set_ylabels(self, label=None, clear_inner=True, **kwargs):
"""Label the y axis on the left column of the grid."""
if label is None:
label = self._y_var
for ax in self._left_axes:
ax.set_ylabel(label, **kwargs)
if clear_inner:
for ax in self._not_left_axes:
ax.set_ylabel("")
return self
def set_xticklabels(self, labels=None, step=None, **kwargs):
"""Set x axis tick labels of the grid."""
for ax in self.axes.flat:
curr_ticks = ax.get_xticks()
ax.set_xticks(curr_ticks)
if labels is None:
curr_labels = [label.get_text() for label in ax.get_xticklabels()]
if step is not None:
xticks = ax.get_xticks()[::step]
curr_labels = curr_labels[::step]
ax.set_xticks(xticks)
ax.set_xticklabels(curr_labels, **kwargs)
else:
ax.set_xticklabels(labels, **kwargs)
return self
def set_yticklabels(self, labels=None, **kwargs):
"""Set y axis tick labels on the left column of the grid."""
for ax in self.axes.flat:
curr_ticks = ax.get_yticks()
ax.set_yticks(curr_ticks)
if labels is None:
curr_labels = [label.get_text() for label in ax.get_yticklabels()]
ax.set_yticklabels(curr_labels, **kwargs)
else:
ax.set_yticklabels(labels, **kwargs)
return self
def set_titles(self, template=None, row_template=None, col_template=None, **kwargs):
"""Draw titles either above each facet or on the grid margins.
Parameters
----------
template : string
Template for all titles with the formatting keys {col_var} and
{col_name} (if using a `col` faceting variable) and/or {row_var}
and {row_name} (if using a `row` faceting variable).
row_template:
Template for the row variable when titles are drawn on the grid
margins. Must have {row_var} and {row_name} formatting keys.
col_template:
Template for the column variable when titles are drawn on the grid
margins. Must have {col_var} and {col_name} formatting keys.
Returns
-------
self: object
Returns self.
"""
args = dict(row_var=self._row_var, col_var=self._col_var)
kwargs["size"] = kwargs.pop("size", mpl.rcParams["axes.labelsize"])
# Establish default templates
if row_template is None:
row_template = "{row_var} = {row_name}"
if col_template is None:
col_template = "{col_var} = {col_name}"
if template is None:
if self._row_var is None:
template = col_template
elif self._col_var is None:
template = row_template
else:
template = " | ".join([row_template, col_template])
row_template = utils.to_utf8(row_template)
col_template = utils.to_utf8(col_template)
template = utils.to_utf8(template)
if self._margin_titles:
# Remove any existing title texts
for text in self._margin_titles_texts:
text.remove()
self._margin_titles_texts = []
if self.row_names is not None:
# Draw the row titles on the right edge of the grid
for i, row_name in enumerate(self.row_names):
ax = self.axes[i, -1]
args.update(dict(row_name=row_name))
title = row_template.format(**args)
text = ax.annotate(
title, xy=(1.02, .5), xycoords="axes fraction",
rotation=270, ha="left", va="center",
**kwargs
)
self._margin_titles_texts.append(text)
if self.col_names is not None:
# Draw the column titles as normal titles
for j, col_name in enumerate(self.col_names):
args.update(dict(col_name=col_name))
title = col_template.format(**args)
self.axes[0, j].set_title(title, **kwargs)
return self
# Otherwise title each facet with all the necessary information
if (self._row_var is not None) and (self._col_var is not None):
for i, row_name in enumerate(self.row_names):
for j, col_name in enumerate(self.col_names):
args.update(dict(row_name=row_name, col_name=col_name))
title = template.format(**args)
self.axes[i, j].set_title(title, **kwargs)
elif self.row_names is not None and len(self.row_names):
for i, row_name in enumerate(self.row_names):
args.update(dict(row_name=row_name))
title = template.format(**args)
self.axes[i, 0].set_title(title, **kwargs)
elif self.col_names is not None and len(self.col_names):
for i, col_name in enumerate(self.col_names):
args.update(dict(col_name=col_name))
title = template.format(**args)
# Index the flat array so col_wrap works
self.axes.flat[i].set_title(title, **kwargs)
return self
def refline(self, *, x=None, y=None, color='.5', linestyle='--', **line_kws):
"""Add a reference line(s) to each facet.
Parameters
----------
x, y : numeric
Value(s) to draw the line(s) at.
color : :mod:`matplotlib color <matplotlib.colors>`
Specifies the color of the reference line(s). Pass ``color=None`` to
use ``hue`` mapping.
linestyle : str
Specifies the style of the reference line(s).
line_kws : key, value mappings
Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.axvline`
when ``x`` is not None and :meth:`matplotlib.axes.Axes.axhline` when ``y``
is not None.
Returns
-------
:class:`FacetGrid` instance
Returns ``self`` for easy method chaining.
"""
line_kws['color'] = color
line_kws['linestyle'] = linestyle
if x is not None:
self.map(plt.axvline, x=x, **line_kws)
if y is not None:
self.map(plt.axhline, y=y, **line_kws)
return self
# ------ Properties that are part of the public API and documented by Sphinx
@property
def axes(self):
"""An array of the :class:`matplotlib.axes.Axes` objects in the grid."""
return self._axes
@property
def ax(self):
"""The :class:`matplotlib.axes.Axes` when no faceting variables are assigned."""
if self.axes.shape == (1, 1):
return self.axes[0, 0]
else:
err = (
"Use the `.axes` attribute when facet variables are assigned."
)
raise AttributeError(err)
@property
def axes_dict(self):
"""A mapping of facet names to corresponding :class:`matplotlib.axes.Axes`.
If only one of ``row`` or ``col`` is assigned, each key is a string
representing a level of that variable. If both facet dimensions are
assigned, each key is a ``({row_level}, {col_level})`` tuple.
"""
return self._axes_dict
# ------ Private properties, that require some computation to get
@property
def _inner_axes(self):
"""Return a flat array of the inner axes."""
if self._col_wrap is None:
return self.axes[:-1, 1:].flat
else:
axes = []
n_empty = self._nrow * self._ncol - self._n_facets
for i, ax in enumerate(self.axes):
append = (
i % self._ncol
and i < (self._ncol * (self._nrow - 1))
and i < (self._ncol * (self._nrow - 1) - n_empty)
)
if append:
axes.append(ax)
return np.array(axes, object).flat
@property
def _left_axes(self):
"""Return a flat array of the left column of axes."""
if self._col_wrap is None:
return self.axes[:, 0].flat
else:
axes = []
for i, ax in enumerate(self.axes):
if not i % self._ncol:
axes.append(ax)
return np.array(axes, object).flat
@property
def _not_left_axes(self):
"""Return a flat array of axes that aren't on the left column."""
if self._col_wrap is None:
return self.axes[:, 1:].flat
else:
axes = []
for i, ax in enumerate(self.axes):
if i % self._ncol:
axes.append(ax)
return np.array(axes, object).flat
@property
def _bottom_axes(self):
"""Return a flat array of the bottom row of axes."""
if self._col_wrap is None:
return self.axes[-1, :].flat
else:
axes = []
n_empty = self._nrow * self._ncol - self._n_facets
for i, ax in enumerate(self.axes):
append = (
i >= (self._ncol * (self._nrow - 1))
or i >= (self._ncol * (self._nrow - 1) - n_empty)
)
if append:
axes.append(ax)
return np.array(axes, object).flat
@property
def _not_bottom_axes(self):
"""Return a flat array of axes that aren't on the bottom row."""
if self._col_wrap is None:
return self.axes[:-1, :].flat
else:
axes = []
n_empty = self._nrow * self._ncol - self._n_facets
for i, ax in enumerate(self.axes):
append = (
i < (self._ncol * (self._nrow - 1))
and i < (self._ncol * (self._nrow - 1) - n_empty)
)
if append:
axes.append(ax)
return np.array(axes, object).flat
class PairGrid(Grid):
"""Subplot grid for plotting pairwise relationships in a dataset.
This object maps each variable in a dataset onto a column and row in a
grid of multiple axes. Different axes-level plotting functions can be
used to draw bivariate plots in the upper and lower triangles, and the
marginal distribution of each variable can be shown on the diagonal.
Several different common plots can be generated in a single line using
:func:`pairplot`. Use :class:`PairGrid` when you need more flexibility.
See the :ref:`tutorial <grid_tutorial>` for more information.
"""
def __init__(
self, data, *, hue=None, vars=None, x_vars=None, y_vars=None,
hue_order=None, palette=None, hue_kws=None, corner=False, diag_sharey=True,
height=2.5, aspect=1, layout_pad=.5, despine=True, dropna=False,
):
"""Initialize the plot figure and PairGrid object.
Parameters
----------
data : DataFrame
Tidy (long-form) dataframe where each column is a variable and
each row is an observation.
hue : string (variable name)
Variable in ``data`` to map plot aspects to different colors. This
variable will be excluded from the default x and y variables.
vars : list of variable names
Variables within ``data`` to use, otherwise use every column with
a numeric datatype.
{x, y}_vars : lists of variable names
Variables within ``data`` to use separately for the rows and
columns of the figure; i.e. to make a non-square plot.
hue_order : list of strings
Order for the levels of the hue variable in the palette
palette : dict or seaborn color palette
Set of colors for mapping the ``hue`` variable. If a dict, keys
should be values in the ``hue`` variable.
hue_kws : dictionary of param -> list of values mapping
Other keyword arguments to insert into the plotting call to let
other plot attributes vary across levels of the hue variable (e.g.
the markers in a scatterplot).
corner : bool
If True, don't add axes to the upper (off-diagonal) triangle of the
grid, making this a "corner" plot.
height : scalar
Height (in inches) of each facet.
aspect : scalar
Aspect * height gives the width (in inches) of each facet.
layout_pad : scalar
Padding between axes; passed to ``fig.tight_layout``.
despine : boolean
Remove the top and right spines from the plots.
dropna : boolean
Drop missing values from the data before plotting.
See Also
--------
pairplot : Easily drawing common uses of :class:`PairGrid`.
FacetGrid : Subplot grid for plotting conditional relationships.
Examples
--------
.. include:: ../docstrings/PairGrid.rst
"""
super().__init__()
data = handle_data_source(data)
# Sort out the variables that define the grid
numeric_cols = self._find_numeric_cols(data)
if hue in numeric_cols:
numeric_cols.remove(hue)
if vars is not None:
x_vars = list(vars)
y_vars = list(vars)
if x_vars is None:
x_vars = numeric_cols
if y_vars is None:
y_vars = numeric_cols
if np.isscalar(x_vars):
x_vars = [x_vars]
if np.isscalar(y_vars):
y_vars = [y_vars]
self.x_vars = x_vars = list(x_vars)
self.y_vars = y_vars = list(y_vars)
self.square_grid = self.x_vars == self.y_vars
if not x_vars:
raise ValueError("No variables found for grid columns.")
if not y_vars:
raise ValueError("No variables found for grid rows.")
# Create the figure and the array of subplots
figsize = len(x_vars) * height * aspect, len(y_vars) * height
with _disable_autolayout():
fig = plt.figure(figsize=figsize)
axes = fig.subplots(len(y_vars), len(x_vars),
sharex="col", sharey="row",
squeeze=False)
# Possibly remove upper axes to make a corner grid
# Note: setting up the axes is usually the most time-intensive part
# of using the PairGrid. We are foregoing the speed improvement that
# we would get by just not setting up the hidden axes so that we can
# avoid implementing fig.subplots ourselves. But worth thinking about.
self._corner = corner
if corner:
hide_indices = np.triu_indices_from(axes, 1)
for i, j in zip(*hide_indices):
axes[i, j].remove()
axes[i, j] = None
self._figure = fig
self.axes = axes
self.data = data
# Save what we are going to do with the diagonal
self.diag_sharey = diag_sharey
self.diag_vars = None
self.diag_axes = None
self._dropna = dropna
# Label the axes
self._add_axis_labels()
# Sort out the hue variable
self._hue_var = hue
if hue is None:
self.hue_names = hue_order = ["_nolegend_"]
self.hue_vals = pd.Series(["_nolegend_"] * len(data),
index=data.index)
else:
# We need hue_order and hue_names because the former is used to control
# the order of drawing and the latter is used to control the order of
# the legend. hue_names can become string-typed while hue_order must
# retain the type of the input data. This is messy but results from
# the fact that PairGrid can implement the hue-mapping logic itself
# (and was originally written exclusively that way) but now can delegate
# to the axes-level functions, while always handling legend creation.
# See GH2307
hue_names = hue_order = categorical_order(data[hue], hue_order)
if dropna:
# Filter NA from the list of unique hue names
hue_names = list(filter(pd.notnull, hue_names))
self.hue_names = hue_names
self.hue_vals = data[hue]
# Additional dict of kwarg -> list of values for mapping the hue var
self.hue_kws = hue_kws if hue_kws is not None else {}
self._orig_palette = palette
self._hue_order = hue_order
self.palette = self._get_palette(data, hue, hue_order, palette)
self._legend_data = {}
# Make the plot look nice
for ax in axes[:-1, :].flat:
if ax is None:
continue
for label in ax.get_xticklabels():
label.set_visible(False)
ax.xaxis.offsetText.set_visible(False)
ax.xaxis.label.set_visible(False)
for ax in axes[:, 1:].flat:
if ax is None:
continue
for label in ax.get_yticklabels():
label.set_visible(False)
ax.yaxis.offsetText.set_visible(False)
ax.yaxis.label.set_visible(False)
self._tight_layout_rect = [.01, .01, .99, .99]
self._tight_layout_pad = layout_pad
self._despine = despine
if despine:
utils.despine(fig=fig)
self.tight_layout(pad=layout_pad)
def map(self, func, **kwargs):
"""Plot with the same function in every subplot.
Parameters
----------
func : callable plotting function
Must take x, y arrays as positional arguments and draw onto the
"currently active" matplotlib Axes. Also needs to accept kwargs
called ``color`` and ``label``.
"""
row_indices, col_indices = np.indices(self.axes.shape)
indices = zip(row_indices.flat, col_indices.flat)
self._map_bivariate(func, indices, **kwargs)
return self
def map_lower(self, func, **kwargs):
"""Plot with a bivariate function on the lower diagonal subplots.
Parameters
----------
func : callable plotting function
Must take x, y arrays as positional arguments and draw onto the
"currently active" matplotlib Axes. Also needs to accept kwargs
called ``color`` and ``label``.
"""
indices = zip(*np.tril_indices_from(self.axes, -1))
self._map_bivariate(func, indices, **kwargs)
return self
def map_upper(self, func, **kwargs):
"""Plot with a bivariate function on the upper diagonal subplots.
Parameters
----------
func : callable plotting function
Must take x, y arrays as positional arguments and draw onto the
"currently active" matplotlib Axes. Also needs to accept kwargs
called ``color`` and ``label``.
"""
indices = zip(*np.triu_indices_from(self.axes, 1))
self._map_bivariate(func, indices, **kwargs)
return self
def map_offdiag(self, func, **kwargs):
"""Plot with a bivariate function on the off-diagonal subplots.
Parameters
----------
func : callable plotting function
Must take x, y arrays as positional arguments and draw onto the
"currently active" matplotlib Axes. Also needs to accept kwargs
called ``color`` and ``label``.
"""
if self.square_grid:
self.map_lower(func, **kwargs)
if not self._corner:
self.map_upper(func, **kwargs)
else:
indices = []
for i, (y_var) in enumerate(self.y_vars):
for j, (x_var) in enumerate(self.x_vars):
if x_var != y_var:
indices.append((i, j))
self._map_bivariate(func, indices, **kwargs)
return self
def map_diag(self, func, **kwargs):
"""Plot with a univariate function on each diagonal subplot.
Parameters
----------
func : callable plotting function
Must take an x array as a positional argument and draw onto the
"currently active" matplotlib Axes. Also needs to accept kwargs
called ``color`` and ``label``.
"""
# Add special diagonal axes for the univariate plot
if self.diag_axes is None:
diag_vars = []
diag_axes = []
for i, y_var in enumerate(self.y_vars):
for j, x_var in enumerate(self.x_vars):
if x_var == y_var:
# Make the density axes
diag_vars.append(x_var)
ax = self.axes[i, j]
diag_ax = ax.twinx()
diag_ax.set_axis_off()
diag_axes.append(diag_ax)
# Work around matplotlib bug
# https://github.com/matplotlib/matplotlib/issues/15188
if not plt.rcParams.get("ytick.left", True):
for tick in ax.yaxis.majorTicks:
tick.tick1line.set_visible(False)
# Remove main y axis from density axes in a corner plot
if self._corner:
ax.yaxis.set_visible(False)
if self._despine:
utils.despine(ax=ax, left=True)
# TODO add optional density ticks (on the right)
# when drawing a corner plot?
if self.diag_sharey and diag_axes:
for ax in diag_axes[1:]:
share_axis(diag_axes[0], ax, "y")
self.diag_vars = diag_vars
self.diag_axes = diag_axes
if "hue" not in signature(func).parameters:
return self._map_diag_iter_hue(func, **kwargs)
# Loop over diagonal variables and axes, making one plot in each
for var, ax in zip(self.diag_vars, self.diag_axes):
plot_kwargs = kwargs.copy()
if str(func.__module__).startswith("seaborn"):
plot_kwargs["ax"] = ax
else:
plt.sca(ax)
vector = self.data[var]
if self._hue_var is not None:
hue = self.data[self._hue_var]
else:
hue = None
if self._dropna:
not_na = vector.notna()
if hue is not None:
not_na &= hue.notna()
vector = vector[not_na]
if hue is not None:
hue = hue[not_na]
plot_kwargs.setdefault("hue", hue)
plot_kwargs.setdefault("hue_order", self._hue_order)
plot_kwargs.setdefault("palette", self._orig_palette)
func(x=vector, **plot_kwargs)
ax.legend_ = None
self._add_axis_labels()
return self
def _map_diag_iter_hue(self, func, **kwargs):
"""Put marginal plot on each diagonal axes, iterating over hue."""
# Plot on each of the diagonal axes
fixed_color = kwargs.pop("color", None)
for var, ax in zip(self.diag_vars, self.diag_axes):
hue_grouped = self.data[var].groupby(self.hue_vals, observed=True)
plot_kwargs = kwargs.copy()
if str(func.__module__).startswith("seaborn"):
plot_kwargs["ax"] = ax
else:
plt.sca(ax)
for k, label_k in enumerate(self._hue_order):
# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = pd.Series([], dtype=float)
if fixed_color is None:
color = self.palette[k]
else:
color = fixed_color
if self._dropna:
data_k = utils.remove_na(data_k)
if str(func.__module__).startswith("seaborn"):
func(x=data_k, label=label_k, color=color, **plot_kwargs)
else:
func(data_k, label=label_k, color=color, **plot_kwargs)
self._add_axis_labels()
return self
def _map_bivariate(self, func, indices, **kwargs):
"""Draw a bivariate plot on the indicated axes."""
# This is a hack to handle the fact that new distribution plots don't add
# their artists onto the axes. This is probably superior in general, but
# we'll need a better way to handle it in the axisgrid functions.
from .distributions import histplot, kdeplot
if func is histplot or func is kdeplot:
self._extract_legend_handles = True
kws = kwargs.copy() # Use copy as we insert other kwargs
for i, j in indices:
x_var = self.x_vars[j]
y_var = self.y_vars[i]
ax = self.axes[i, j]
if ax is None: # i.e. we are in corner mode
continue
self._plot_bivariate(x_var, y_var, ax, func, **kws)
self._add_axis_labels()
if "hue" in signature(func).parameters:
self.hue_names = list(self._legend_data)
def _plot_bivariate(self, x_var, y_var, ax, func, **kwargs):
"""Draw a bivariate plot on the specified axes."""
if "hue" not in signature(func).parameters:
self._plot_bivariate_iter_hue(x_var, y_var, ax, func, **kwargs)
return
kwargs = kwargs.copy()
if str(func.__module__).startswith("seaborn"):
kwargs["ax"] = ax
else:
plt.sca(ax)
if x_var == y_var:
axes_vars = [x_var]
else:
axes_vars = [x_var, y_var]
if self._hue_var is not None and self._hue_var not in axes_vars:
axes_vars.append(self._hue_var)
data = self.data[axes_vars]
if self._dropna:
data = data.dropna()
x = data[x_var]
y = data[y_var]
if self._hue_var is None:
hue = None
else:
hue = data.get(self._hue_var)
if "hue" not in kwargs:
kwargs.update({
"hue": hue, "hue_order": self._hue_order, "palette": self._orig_palette,
})
func(x=x, y=y, **kwargs)
self._update_legend_data(ax)
def _plot_bivariate_iter_hue(self, x_var, y_var, ax, func, **kwargs):
"""Draw a bivariate plot while iterating over hue subsets."""
kwargs = kwargs.copy()
if str(func.__module__).startswith("seaborn"):
kwargs["ax"] = ax
else:
plt.sca(ax)
if x_var == y_var:
axes_vars = [x_var]
else:
axes_vars = [x_var, y_var]
hue_grouped = self.data.groupby(self.hue_vals, observed=True)
for k, label_k in enumerate(self._hue_order):
kws = kwargs.copy()
# Attempt to get data for this level, allowing for empty
try:
data_k = hue_grouped.get_group(label_k)
except KeyError:
data_k = pd.DataFrame(columns=axes_vars,
dtype=float)
if self._dropna:
data_k = data_k[axes_vars].dropna()
x = data_k[x_var]
y = data_k[y_var]
for kw, val_list in self.hue_kws.items():
kws[kw] = val_list[k]
kws.setdefault("color", self.palette[k])
if self._hue_var is not None:
kws["label"] = label_k
if str(func.__module__).startswith("seaborn"):
func(x=x, y=y, **kws)
else:
func(x, y, **kws)
self._update_legend_data(ax)
def _add_axis_labels(self):
"""Add labels to the left and bottom Axes."""
for ax, label in zip(self.axes[-1, :], self.x_vars):
ax.set_xlabel(label)
for ax, label in zip(self.axes[:, 0], self.y_vars):
ax.set_ylabel(label)
def _find_numeric_cols(self, data):
"""Find which variables in a DataFrame are numeric."""
numeric_cols = []
for col in data:
if variable_type(data[col]) == "numeric":
numeric_cols.append(col)
return numeric_cols
class JointGrid(_BaseGrid):
"""Grid for drawing a bivariate plot with marginal univariate plots.
Many plots can be drawn by using the figure-level interface :func:`jointplot`.
Use this class directly when you need more flexibility.
"""
def __init__(
self, data=None, *,
x=None, y=None, hue=None,
height=6, ratio=5, space=.2,
palette=None, hue_order=None, hue_norm=None,
dropna=False, xlim=None, ylim=None, marginal_ticks=False,
):
# Set up the subplot grid
f = plt.figure(figsize=(height, height))
gs = plt.GridSpec(ratio + 1, ratio + 1)
ax_joint = f.add_subplot(gs[1:, :-1])
ax_marg_x = f.add_subplot(gs[0, :-1], sharex=ax_joint)
ax_marg_y = f.add_subplot(gs[1:, -1], sharey=ax_joint)
self._figure = f
self.ax_joint = ax_joint
self.ax_marg_x = ax_marg_x
self.ax_marg_y = ax_marg_y
# Turn off tick visibility for the measure axis on the marginal plots
plt.setp(ax_marg_x.get_xticklabels(), visible=False)
plt.setp(ax_marg_y.get_yticklabels(), visible=False)
plt.setp(ax_marg_x.get_xticklabels(minor=True), visible=False)
plt.setp(ax_marg_y.get_yticklabels(minor=True), visible=False)
# Turn off the ticks on the density axis for the marginal plots
if not marginal_ticks:
plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False)
plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False)
plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False)
plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False)
plt.setp(ax_marg_x.get_yticklabels(), visible=False)
plt.setp(ax_marg_y.get_xticklabels(), visible=False)
plt.setp(ax_marg_x.get_yticklabels(minor=True), visible=False)
plt.setp(ax_marg_y.get_xticklabels(minor=True), visible=False)
ax_marg_x.yaxis.grid(False)
ax_marg_y.xaxis.grid(False)
# Process the input variables
p = VectorPlotter(data=data, variables=dict(x=x, y=y, hue=hue))
plot_data = p.plot_data.loc[:, p.plot_data.notna().any()]
# Possibly drop NA
if dropna:
plot_data = plot_data.dropna()
def get_var(var):
vector = plot_data.get(var, None)
if vector is not None:
vector = vector.rename(p.variables.get(var, None))
return vector
self.x = get_var("x")
self.y = get_var("y")
self.hue = get_var("hue")
for axis in "xy":
name = p.variables.get(axis, None)
if name is not None:
getattr(ax_joint, f"set_{axis}label")(name)
if xlim is not None:
ax_joint.set_xlim(xlim)
if ylim is not None:
ax_joint.set_ylim(ylim)
# Store the semantic mapping parameters for axes-level functions
self._hue_params = dict(palette=palette, hue_order=hue_order, hue_norm=hue_norm)
# Make the grid look nice
utils.despine(f)
if not marginal_ticks:
utils.despine(ax=ax_marg_x, left=True)
utils.despine(ax=ax_marg_y, bottom=True)
for axes in [ax_marg_x, ax_marg_y]:
for axis in [axes.xaxis, axes.yaxis]:
axis.label.set_visible(False)
f.tight_layout()
f.subplots_adjust(hspace=space, wspace=space)
def _inject_kwargs(self, func, kws, params):
"""Add params to kws if they are accepted by func."""
func_params = signature(func).parameters
for key, val in params.items():
if key in func_params:
kws.setdefault(key, val)
def plot(self, joint_func, marginal_func, **kwargs):
"""Draw the plot by passing functions for joint and marginal axes.
This method passes the ``kwargs`` dictionary to both functions. If you
need more control, call :meth:`JointGrid.plot_joint` and
:meth:`JointGrid.plot_marginals` directly with specific parameters.
Parameters
----------
joint_func, marginal_func : callables
Functions to draw the bivariate and univariate plots. See methods
referenced above for information about the required characteristics
of these functions.
kwargs
Additional keyword arguments are passed to both functions.
Returns
-------
:class:`JointGrid` instance
Returns ``self`` for easy method chaining.
"""
self.plot_marginals(marginal_func, **kwargs)
self.plot_joint(joint_func, **kwargs)
return self
def plot_joint(self, func, **kwargs):
"""Draw a bivariate plot on the joint axes of the grid.
Parameters
----------
func : plotting callable
If a seaborn function, it should accept ``x`` and ``y``. Otherwise,
it must accept ``x`` and ``y`` vectors of data as the first two
positional arguments, and it must plot on the "current" axes.
If ``hue`` was defined in the class constructor, the function must
accept ``hue`` as a parameter.
kwargs
Keyword argument are passed to the plotting function.
Returns
-------
:class:`JointGrid` instance
Returns ``self`` for easy method chaining.
"""
kwargs = kwargs.copy()
if str(func.__module__).startswith("seaborn"):
kwargs["ax"] = self.ax_joint
else:
plt.sca(self.ax_joint)
if self.hue is not None:
kwargs["hue"] = self.hue
self._inject_kwargs(func, kwargs, self._hue_params)
if str(func.__module__).startswith("seaborn"):
func(x=self.x, y=self.y, **kwargs)
else:
func(self.x, self.y, **kwargs)
return self
def plot_marginals(self, func, **kwargs):
"""Draw univariate plots on each marginal axes.
Parameters
----------
func : plotting callable
If a seaborn function, it should accept ``x`` and ``y`` and plot
when only one of them is defined. Otherwise, it must accept a vector
of data as the first positional argument and determine its orientation
using the ``vertical`` parameter, and it must plot on the "current" axes.
If ``hue`` was defined in the class constructor, it must accept ``hue``
as a parameter.
kwargs
Keyword argument are passed to the plotting function.
Returns
-------
:class:`JointGrid` instance
Returns ``self`` for easy method chaining.
"""
seaborn_func = (
str(func.__module__).startswith("seaborn")
# deprecated distplot has a legacy API, special case it
and not func.__name__ == "distplot"
)
func_params = signature(func).parameters
kwargs = kwargs.copy()
if self.hue is not None:
kwargs["hue"] = self.hue
self._inject_kwargs(func, kwargs, self._hue_params)
if "legend" in func_params:
kwargs.setdefault("legend", False)
if "orientation" in func_params:
# e.g. plt.hist
orient_kw_x = {"orientation": "vertical"}
orient_kw_y = {"orientation": "horizontal"}
elif "vertical" in func_params:
# e.g. sns.distplot (also how did this get backwards?)
orient_kw_x = {"vertical": False}
orient_kw_y = {"vertical": True}
if seaborn_func:
func(x=self.x, ax=self.ax_marg_x, **kwargs)
else:
plt.sca(self.ax_marg_x)
func(self.x, **orient_kw_x, **kwargs)
if seaborn_func:
func(y=self.y, ax=self.ax_marg_y, **kwargs)
else:
plt.sca(self.ax_marg_y)
func(self.y, **orient_kw_y, **kwargs)
self.ax_marg_x.yaxis.get_label().set_visible(False)
self.ax_marg_y.xaxis.get_label().set_visible(False)
return self
def refline(
self, *, x=None, y=None, joint=True, marginal=True,
color='.5', linestyle='--', **line_kws
):
"""Add a reference line(s) to joint and/or marginal axes.
Parameters
----------
x, y : numeric
Value(s) to draw the line(s) at.
joint, marginal : bools
Whether to add the reference line(s) to the joint/marginal axes.
color : :mod:`matplotlib color <matplotlib.colors>`
Specifies the color of the reference line(s).
linestyle : str
Specifies the style of the reference line(s).
line_kws : key, value mappings
Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.axvline`
when ``x`` is not None and :meth:`matplotlib.axes.Axes.axhline` when ``y``
is not None.
Returns
-------
:class:`JointGrid` instance
Returns ``self`` for easy method chaining.
"""
line_kws['color'] = color
line_kws['linestyle'] = linestyle
if x is not None:
if joint:
self.ax_joint.axvline(x, **line_kws)
if marginal:
self.ax_marg_x.axvline(x, **line_kws)
if y is not None:
if joint:
self.ax_joint.axhline(y, **line_kws)
if marginal:
self.ax_marg_y.axhline(y, **line_kws)
return self
def set_axis_labels(self, xlabel="", ylabel="", **kwargs):
"""Set axis labels on the bivariate axes.
Parameters
----------
xlabel, ylabel : strings
Label names for the x and y variables.
kwargs : key, value mappings
Other keyword arguments are passed to the following functions:
- :meth:`matplotlib.axes.Axes.set_xlabel`
- :meth:`matplotlib.axes.Axes.set_ylabel`
Returns
-------
:class:`JointGrid` instance
Returns ``self`` for easy method chaining.
"""
self.ax_joint.set_xlabel(xlabel, **kwargs)
self.ax_joint.set_ylabel(ylabel, **kwargs)
return self
JointGrid.__init__.__doc__ = """\
Set up the grid of subplots and store data internally for easy plotting.
Parameters
----------
{params.core.data}
{params.core.xy}
height : number
Size of each side of the figure in inches (it will be square).
ratio : number
Ratio of joint axes height to marginal axes height.
space : number
Space between the joint and marginal axes
dropna : bool
If True, remove missing observations before plotting.
{{x, y}}lim : pairs of numbers
Set axis limits to these values before plotting.
marginal_ticks : bool
If False, suppress ticks on the count/density axis of the marginal plots.
{params.core.hue}
Note: unlike in :class:`FacetGrid` or :class:`PairGrid`, the axes-level
functions must support ``hue`` to use it in :class:`JointGrid`.
{params.core.palette}
{params.core.hue_order}
{params.core.hue_norm}
See Also
--------
{seealso.jointplot}
{seealso.pairgrid}
{seealso.pairplot}
Examples
--------
.. include:: ../docstrings/JointGrid.rst
""".format(
params=_param_docs,
seealso=_core_docs["seealso"],
)
def pairplot(
data, *,
hue=None, hue_order=None, palette=None,
vars=None, x_vars=None, y_vars=None,
kind="scatter", diag_kind="auto", markers=None,
height=2.5, aspect=1, corner=False, dropna=False,
plot_kws=None, diag_kws=None, grid_kws=None, size=None,
):
"""Plot pairwise relationships in a dataset.
By default, this function will create a grid of Axes such that each numeric
variable in ``data`` will by shared across the y-axes across a single row and
the x-axes across a single column. The diagonal plots are treated
differently: a univariate distribution plot is drawn to show the marginal
distribution of the data in each column.
It is also possible to show a subset of variables or plot different
variables on the rows and columns.
This is a high-level interface for :class:`PairGrid` that is intended to
make it easy to draw a few common styles. You should use :class:`PairGrid`
directly if you need more flexibility.
Parameters
----------
data : `pandas.DataFrame`
Tidy (long-form) dataframe where each column is a variable and
each row is an observation.
hue : name of variable in ``data``
Variable in ``data`` to map plot aspects to different colors.
hue_order : list of strings
Order for the levels of the hue variable in the palette
palette : dict or seaborn color palette
Set of colors for mapping the ``hue`` variable. If a dict, keys
should be values in the ``hue`` variable.
vars : list of variable names
Variables within ``data`` to use, otherwise use every column with
a numeric datatype.
{x, y}_vars : lists of variable names
Variables within ``data`` to use separately for the rows and
columns of the figure; i.e. to make a non-square plot.
kind : {'scatter', 'kde', 'hist', 'reg'}
Kind of plot to make.
diag_kind : {'auto', 'hist', 'kde', None}
Kind of plot for the diagonal subplots. If 'auto', choose based on
whether or not ``hue`` is used.
markers : single matplotlib marker code or list
Either the marker to use for all scatterplot points or a list of markers
with a length the same as the number of levels in the hue variable so that
differently colored points will also have different scatterplot
markers.
height : scalar
Height (in inches) of each facet.
aspect : scalar
Aspect * height gives the width (in inches) of each facet.
corner : bool
If True, don't add axes to the upper (off-diagonal) triangle of the
grid, making this a "corner" plot.
dropna : boolean
Drop missing values from the data before plotting.
{plot, diag, grid}_kws : dicts
Dictionaries of keyword arguments. ``plot_kws`` are passed to the
bivariate plotting function, ``diag_kws`` are passed to the univariate
plotting function, and ``grid_kws`` are passed to the :class:`PairGrid`
constructor.
Returns
-------
grid : :class:`PairGrid`
Returns the underlying :class:`PairGrid` instance for further tweaking.
See Also
--------
PairGrid : Subplot grid for more flexible plotting of pairwise relationships.
JointGrid : Grid for plotting joint and marginal distributions of two variables.
Examples
--------
.. include:: ../docstrings/pairplot.rst
"""
# Avoid circular import
from .distributions import histplot, kdeplot
# Handle deprecations
if size is not None:
height = size
msg = ("The `size` parameter has been renamed to `height`; "
"please update your code.")
warnings.warn(msg, UserWarning)
if not isinstance(data, pd.DataFrame):
raise TypeError(
f"'data' must be pandas DataFrame object, not: {type(data)}")
plot_kws = {} if plot_kws is None else plot_kws.copy()
diag_kws = {} if diag_kws is None else diag_kws.copy()
grid_kws = {} if grid_kws is None else grid_kws.copy()
# Resolve "auto" diag kind
if diag_kind == "auto":
if hue is None:
diag_kind = "kde" if kind == "kde" else "hist"
else:
diag_kind = "hist" if kind == "hist" else "kde"
# Set up the PairGrid
grid_kws.setdefault("diag_sharey", diag_kind == "hist")
grid = PairGrid(data, vars=vars, x_vars=x_vars, y_vars=y_vars, hue=hue,
hue_order=hue_order, palette=palette, corner=corner,
height=height, aspect=aspect, dropna=dropna, **grid_kws)
# Add the markers here as PairGrid has figured out how many levels of the
# hue variable are needed and we don't want to duplicate that process
if markers is not None:
if kind == "reg":
# Needed until regplot supports style
if grid.hue_names is None:
n_markers = 1
else:
n_markers = len(grid.hue_names)
if not isinstance(markers, list):
markers = [markers] * n_markers
if len(markers) != n_markers:
raise ValueError("markers must be a singleton or a list of "
"markers for each level of the hue variable")
grid.hue_kws = {"marker": markers}
elif kind == "scatter":
if isinstance(markers, str):
plot_kws["marker"] = markers
elif hue is not None:
plot_kws["style"] = data[hue]
plot_kws["markers"] = markers
# Draw the marginal plots on the diagonal
diag_kws = diag_kws.copy()
diag_kws.setdefault("legend", False)
if diag_kind == "hist":
grid.map_diag(histplot, **diag_kws)
elif diag_kind == "kde":
diag_kws.setdefault("fill", True)
diag_kws.setdefault("warn_singular", False)
grid.map_diag(kdeplot, **diag_kws)
# Maybe plot on the off-diagonals
if diag_kind is not None:
plotter = grid.map_offdiag
else:
plotter = grid.map
if kind == "scatter":
from .relational import scatterplot # Avoid circular import
plotter(scatterplot, **plot_kws)
elif kind == "reg":
from .regression import regplot # Avoid circular import
plotter(regplot, **plot_kws)
elif kind == "kde":
from .distributions import kdeplot # Avoid circular import
plot_kws.setdefault("warn_singular", False)
plotter(kdeplot, **plot_kws)
elif kind == "hist":
from .distributions import histplot # Avoid circular import
plotter(histplot, **plot_kws)
# Add a legend
if hue is not None:
grid.add_legend()
grid.tight_layout()
return grid
def jointplot(
data=None, *, x=None, y=None, hue=None, kind="scatter",
height=6, ratio=5, space=.2, dropna=False, xlim=None, ylim=None,
color=None, palette=None, hue_order=None, hue_norm=None, marginal_ticks=False,
joint_kws=None, marginal_kws=None,
**kwargs
):
# Avoid circular imports
from .relational import scatterplot
from .regression import regplot, residplot
from .distributions import histplot, kdeplot, _freedman_diaconis_bins
if kwargs.pop("ax", None) is not None:
msg = "Ignoring `ax`; jointplot is a figure-level function."
warnings.warn(msg, UserWarning, stacklevel=2)
# Set up empty default kwarg dicts
joint_kws = {} if joint_kws is None else joint_kws.copy()
joint_kws.update(kwargs)
marginal_kws = {} if marginal_kws is None else marginal_kws.copy()
# Handle deprecations of distplot-specific kwargs
distplot_keys = [
"rug", "fit", "hist_kws", "norm_hist" "hist_kws", "rug_kws",
]
unused_keys = []
for key in distplot_keys:
if key in marginal_kws:
unused_keys.append(key)
marginal_kws.pop(key)
if unused_keys and kind != "kde":
msg = (
"The marginal plotting function has changed to `histplot`,"
" which does not accept the following argument(s): {}."
).format(", ".join(unused_keys))
warnings.warn(msg, UserWarning)
# Validate the plot kind
plot_kinds = ["scatter", "hist", "hex", "kde", "reg", "resid"]
_check_argument("kind", plot_kinds, kind)
# Raise early if using `hue` with a kind that does not support it
if hue is not None and kind in ["hex", "reg", "resid"]:
msg = f"Use of `hue` with `kind='{kind}'` is not currently supported."
raise ValueError(msg)
# Make a colormap based off the plot color
# (Currently used only for kind="hex")
if color is None:
color = "C0"
color_rgb = mpl.colors.colorConverter.to_rgb(color)
colors = [set_hls_values(color_rgb, l=val) for val in np.linspace(1, 0, 12)]
cmap = blend_palette(colors, as_cmap=True)
# Matplotlib's hexbin plot is not na-robust
if kind == "hex":
dropna = True
# Initialize the JointGrid object
grid = JointGrid(
data=data, x=x, y=y, hue=hue,
palette=palette, hue_order=hue_order, hue_norm=hue_norm,
dropna=dropna, height=height, ratio=ratio, space=space,
xlim=xlim, ylim=ylim, marginal_ticks=marginal_ticks,
)
if grid.hue is not None:
marginal_kws.setdefault("legend", False)
# Plot the data using the grid
if kind.startswith("scatter"):
joint_kws.setdefault("color", color)
grid.plot_joint(scatterplot, **joint_kws)
if grid.hue is None:
marg_func = histplot
else:
marg_func = kdeplot
marginal_kws.setdefault("warn_singular", False)
marginal_kws.setdefault("fill", True)
marginal_kws.setdefault("color", color)
grid.plot_marginals(marg_func, **marginal_kws)
elif kind.startswith("hist"):
# TODO process pair parameters for bins, etc. and pass
# to both joint and marginal plots
joint_kws.setdefault("color", color)
grid.plot_joint(histplot, **joint_kws)
marginal_kws.setdefault("kde", False)
marginal_kws.setdefault("color", color)
marg_x_kws = marginal_kws.copy()
marg_y_kws = marginal_kws.copy()
pair_keys = "bins", "binwidth", "binrange"
for key in pair_keys:
if isinstance(joint_kws.get(key), tuple):
x_val, y_val = joint_kws[key]
marg_x_kws.setdefault(key, x_val)
marg_y_kws.setdefault(key, y_val)
histplot(data=data, x=x, hue=hue, **marg_x_kws, ax=grid.ax_marg_x)
histplot(data=data, y=y, hue=hue, **marg_y_kws, ax=grid.ax_marg_y)
elif kind.startswith("kde"):
joint_kws.setdefault("color", color)
joint_kws.setdefault("warn_singular", False)
grid.plot_joint(kdeplot, **joint_kws)
marginal_kws.setdefault("color", color)
if "fill" in joint_kws:
marginal_kws.setdefault("fill", joint_kws["fill"])
grid.plot_marginals(kdeplot, **marginal_kws)
elif kind.startswith("hex"):
x_bins = min(_freedman_diaconis_bins(grid.x), 50)
y_bins = min(_freedman_diaconis_bins(grid.y), 50)
gridsize = int(np.mean([x_bins, y_bins]))
joint_kws.setdefault("gridsize", gridsize)
joint_kws.setdefault("cmap", cmap)
grid.plot_joint(plt.hexbin, **joint_kws)
marginal_kws.setdefault("kde", False)
marginal_kws.setdefault("color", color)
grid.plot_marginals(histplot, **marginal_kws)
elif kind.startswith("reg"):
marginal_kws.setdefault("color", color)
marginal_kws.setdefault("kde", True)
grid.plot_marginals(histplot, **marginal_kws)
joint_kws.setdefault("color", color)
grid.plot_joint(regplot, **joint_kws)
elif kind.startswith("resid"):
joint_kws.setdefault("color", color)
grid.plot_joint(residplot, **joint_kws)
x, y = grid.ax_joint.collections[0].get_offsets().T
marginal_kws.setdefault("color", color)
histplot(x=x, hue=hue, ax=grid.ax_marg_x, **marginal_kws)
histplot(y=y, hue=hue, ax=grid.ax_marg_y, **marginal_kws)
# Make the main axes active in the matplotlib state machine
plt.sca(grid.ax_joint)
return grid
jointplot.__doc__ = """\
Draw a plot of two variables with bivariate and univariate graphs.
This function provides a convenient interface to the :class:`JointGrid`
class, with several canned plot kinds. This is intended to be a fairly
lightweight wrapper; if you need more flexibility, you should use
:class:`JointGrid` directly.
Parameters
----------
{params.core.data}
{params.core.xy}
{params.core.hue}
kind : {{ "scatter" | "kde" | "hist" | "hex" | "reg" | "resid" }}
Kind of plot to draw. See the examples for references to the underlying functions.
height : numeric
Size of the figure (it will be square).
ratio : numeric
Ratio of joint axes height to marginal axes height.
space : numeric
Space between the joint and marginal axes
dropna : bool
If True, remove observations that are missing from ``x`` and ``y``.
{{x, y}}lim : pairs of numbers
Axis limits to set before plotting.
{params.core.color}
{params.core.palette}
{params.core.hue_order}
{params.core.hue_norm}
marginal_ticks : bool
If False, suppress ticks on the count/density axis of the marginal plots.
{{joint, marginal}}_kws : dicts
Additional keyword arguments for the plot components.
kwargs
Additional keyword arguments are passed to the function used to
draw the plot on the joint Axes, superseding items in the
``joint_kws`` dictionary.
Returns
-------
{returns.jointgrid}
See Also
--------
{seealso.jointgrid}
{seealso.pairgrid}
{seealso.pairplot}
Examples
--------
.. include:: ../docstrings/jointplot.rst
""".format(
params=_param_docs,
returns=_core_docs["returns"],
seealso=_core_docs["seealso"],
)