898 lines
28 KiB
Raw Normal View History

2024-10-02 22:15:59 +04:00
"""Utility functions, mostly for internal use."""
import os
import inspect
import warnings
import colorsys
from contextlib import contextmanager
from urllib.request import urlopen, urlretrieve
from types import ModuleType
import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib.colors import to_rgb
import matplotlib.pyplot as plt
from matplotlib.cbook import normalize_kwargs
from seaborn._core.typing import deprecated
from seaborn.external.version import Version
from seaborn.external.appdirs import user_cache_dir
__all__ = ["desaturate", "saturate", "set_hls_values", "move_legend",
"despine", "get_dataset_names", "get_data_home", "load_dataset"]
DATASET_SOURCE = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master"
DATASET_NAMES_URL = f"{DATASET_SOURCE}/dataset_names.txt"
def ci_to_errsize(cis, heights):
"""Convert intervals to error arguments relative to plot heights.
cis : 2 x n sequence
sequence of confidence interval limits
heights : n sequence
sequence of plot heights
errsize : 2 x n array
sequence of error size relative to height values in correct
format as argument for plt.bar
cis = np.atleast_2d(cis).reshape(2, -1)
heights = np.atleast_1d(heights)
errsize = []
for i, (low, high) in enumerate(np.transpose(cis)):
h = heights[i]
elow = h - low
ehigh = high - h
errsize.append([elow, ehigh])
errsize = np.asarray(errsize).T
return errsize
def _draw_figure(fig):
"""Force draw of a matplotlib figure, accounting for back-compat."""
# See https://github.com/matplotlib/matplotlib/issues/19197 for context
if fig.stale:
except AttributeError:
def _default_color(method, hue, color, kws, saturation=1):
"""If needed, get a default color by using the matplotlib property cycle."""
if hue is not None:
# This warning is probably user-friendly, but it's currently triggered
# in a FacetGrid context and I don't want to mess with that logic right now
# if color is not None:
# msg = "`color` is ignored when `hue` is assigned."
# warnings.warn(msg)
return None
kws = kws.copy()
kws.pop("label", None)
if color is not None:
if saturation < 1:
color = desaturate(color, saturation)
return color
elif method.__name__ == "plot":
color = normalize_kwargs(kws, mpl.lines.Line2D).get("color")
scout, = method([], [], scalex=False, scaley=False, color=color)
color = scout.get_color()
elif method.__name__ == "scatter":
# Matplotlib will raise if the size of x/y don't match s/c,
# and the latter might be in the kws dict
scout_size = max(
np.atleast_1d(kws.get(key, [])).shape[0]
for key in ["s", "c", "fc", "facecolor", "facecolors"]
scout_x = scout_y = np.full(scout_size, np.nan)
scout = method(scout_x, scout_y, **kws)
facecolors = scout.get_facecolors()
if not len(facecolors):
# Handle bug in matplotlib <= 3.2 (I think)
# This will limit the ability to use non color= kwargs to specify
# a color in versions of matplotlib with the bug, but trying to
# work out what the user wanted by re-implementing the broken logic
# of inspecting the kwargs is probably too brittle.
single_color = False
single_color = np.unique(facecolors, axis=0).shape[0] == 1
# Allow the user to specify an array of colors through various kwargs
if "c" not in kws and single_color:
color = to_rgb(facecolors[0])
elif method.__name__ == "bar":
# bar() needs masked, not empty data, to generate a patch
scout, = method([np.nan], [np.nan], **kws)
color = to_rgb(scout.get_facecolor())
# Axes.bar adds both a patch and a container
elif method.__name__ == "fill_between":
kws = normalize_kwargs(kws, mpl.collections.PolyCollection)
scout = method([], [], **kws)
facecolor = scout.get_facecolor()
color = to_rgb(facecolor[0])
if saturation < 1:
color = desaturate(color, saturation)
return color
def desaturate(color, prop):
"""Decrease the saturation channel of a color by some percent.
color : matplotlib color
hex, rgb-tuple, or html color name
prop : float
saturation channel of color will be multiplied by this value
new_color : rgb tuple
desaturated color code in RGB tuple representation
# Check inputs
if not 0 <= prop <= 1:
raise ValueError("prop must be between 0 and 1")
# Get rgb tuple rep
rgb = to_rgb(color)
# Short circuit to avoid floating point issues
if prop == 1:
return rgb
# Convert to hls
h, l, s = colorsys.rgb_to_hls(*rgb)
# Desaturate the saturation channel
s *= prop
# Convert back to rgb
new_color = colorsys.hls_to_rgb(h, l, s)
return new_color
def saturate(color):
"""Return a fully saturated color with the same hue.
color : matplotlib color
hex, rgb-tuple, or html color name
new_color : rgb tuple
saturated color code in RGB tuple representation
return set_hls_values(color, s=1)
def set_hls_values(color, h=None, l=None, s=None): # noqa
"""Independently manipulate the h, l, or s channels of a color.
color : matplotlib color
hex, rgb-tuple, or html color name
h, l, s : floats between 0 and 1, or None
new values for each channel in hls space
new_color : rgb tuple
new color code in RGB tuple representation
# Get an RGB tuple representation
rgb = to_rgb(color)
vals = list(colorsys.rgb_to_hls(*rgb))
for i, val in enumerate([h, l, s]):
if val is not None:
vals[i] = val
rgb = colorsys.hls_to_rgb(*vals)
return rgb
def axlabel(xlabel, ylabel, **kwargs):
"""Grab current axis and label it.
DEPRECATED: will be removed in a future version.
msg = "This function is deprecated and will be removed in a future version"
warnings.warn(msg, FutureWarning)
ax = plt.gca()
ax.set_xlabel(xlabel, **kwargs)
ax.set_ylabel(ylabel, **kwargs)
def remove_na(vector):
"""Helper method for removing null values from data vectors.
vector : vector object
Must implement boolean masking with [] subscript syntax.
clean_clean : same type as ``vector``
Vector of data with null values removed. May be a copy or a view.
return vector[pd.notnull(vector)]
def get_color_cycle():
"""Return the list of colors in the current matplotlib color cycle
colors : list
List of matplotlib colors in the current cycle, or dark gray if
the current color cycle is empty.
cycler = mpl.rcParams['axes.prop_cycle']
return cycler.by_key()['color'] if 'color' in cycler.keys else [".15"]
def despine(fig=None, ax=None, top=True, right=True, left=False,
bottom=False, offset=None, trim=False):
"""Remove the top and right spines from plot(s).
fig : matplotlib figure, optional
Figure to despine all axes of, defaults to the current figure.
ax : matplotlib axes, optional
Specific axes object to despine. Ignored if fig is provided.
top, right, left, bottom : boolean, optional
If True, remove that spine.
offset : int or dict, optional
Absolute distance, in points, spines should be moved away
from the axes (negative values move spines inward). A single value
applies to all spines; a dict can be used to set offset values per
trim : bool, optional
If True, limit spines to the smallest and largest major tick
on each non-despined axis.
# Get references to the axes we want
if fig is None and ax is None:
axes = plt.gcf().axes
elif fig is not None:
axes = fig.axes
elif ax is not None:
axes = [ax]
for ax_i in axes:
for side in ["top", "right", "left", "bottom"]:
# Toggle the spine objects
is_visible = not locals()[side]
if offset is not None and is_visible:
val = offset.get(side, 0)
except AttributeError:
val = offset
ax_i.spines[side].set_position(('outward', val))
# Potentially move the ticks
if left and not right:
maj_on = any(
for t in ax_i.yaxis.majorTicks
min_on = any(
for t in ax_i.yaxis.minorTicks
for t in ax_i.yaxis.majorTicks:
for t in ax_i.yaxis.minorTicks:
if bottom and not top:
maj_on = any(
for t in ax_i.xaxis.majorTicks
min_on = any(
for t in ax_i.xaxis.minorTicks
for t in ax_i.xaxis.majorTicks:
for t in ax_i.xaxis.minorTicks:
if trim:
# clip off the parts of the spines that extend past major ticks
xticks = np.asarray(ax_i.get_xticks())
if xticks.size:
firsttick = np.compress(xticks >= min(ax_i.get_xlim()),
lasttick = np.compress(xticks <= max(ax_i.get_xlim()),
ax_i.spines['bottom'].set_bounds(firsttick, lasttick)
ax_i.spines['top'].set_bounds(firsttick, lasttick)
newticks = xticks.compress(xticks <= lasttick)
newticks = newticks.compress(newticks >= firsttick)
yticks = np.asarray(ax_i.get_yticks())
if yticks.size:
firsttick = np.compress(yticks >= min(ax_i.get_ylim()),
lasttick = np.compress(yticks <= max(ax_i.get_ylim()),
ax_i.spines['left'].set_bounds(firsttick, lasttick)
ax_i.spines['right'].set_bounds(firsttick, lasttick)
newticks = yticks.compress(yticks <= lasttick)
newticks = newticks.compress(newticks >= firsttick)
def move_legend(obj, loc, **kwargs):
Recreate a plot's legend at a new location.
The name is a slight misnomer. Matplotlib legends do not expose public
control over their position parameters. So this function creates a new legend,
copying over the data from the original object, which is then removed.
obj : the object with the plot
This argument can be either a seaborn or matplotlib object:
- :class:`seaborn.FacetGrid` or :class:`seaborn.PairGrid`
- :class:`matplotlib.axes.Axes` or :class:`matplotlib.figure.Figure`
loc : str or int
Location argument, as in :meth:`matplotlib.axes.Axes.legend`.
Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.legend`.
.. include:: ../docstrings/move_legend.rst
# This is a somewhat hackish solution that will hopefully be obviated by
# upstream improvements to matplotlib legends that make them easier to
# modify after creation.
from seaborn.axisgrid import Grid # Avoid circular import
# Locate the legend object and a method to recreate the legend
if isinstance(obj, Grid):
old_legend = obj.legend
legend_func = obj.figure.legend
elif isinstance(obj, mpl.axes.Axes):
old_legend = obj.legend_
legend_func = obj.legend
elif isinstance(obj, mpl.figure.Figure):
if obj.legends:
old_legend = obj.legends[-1]
old_legend = None
legend_func = obj.legend
err = "`obj` must be a seaborn Grid or matplotlib Axes or Figure instance."
raise TypeError(err)
if old_legend is None:
err = f"{obj} has no legend attached."
raise ValueError(err)
# Extract the components of the legend we need to reuse
# Import here to avoid a circular import
from seaborn._compat import get_legend_handles
handles = get_legend_handles(old_legend)
labels = [t.get_text() for t in old_legend.get_texts()]
# Handle the case where the user is trying to override the labels
if (new_labels := kwargs.pop("labels", None)) is not None:
if len(new_labels) != len(labels):
err = "Length of new labels does not match existing legend."
raise ValueError(err)
labels = new_labels
# Extract legend properties that can be passed to the recreation method
# (Vexingly, these don't all round-trip)
legend_kws = inspect.signature(mpl.legend.Legend).parameters
props = {k: v for k, v in old_legend.properties().items() if k in legend_kws}
# Delegate default bbox_to_anchor rules to matplotlib
# Try to propagate the existing title and font properties; respect new ones too
title = props.pop("title")
if "title" in kwargs:
title_kwargs = {k: v for k, v in kwargs.items() if k.startswith("title_")}
for key, val in title_kwargs.items():
title.set(**{key[6:]: val})
# Try to respect the frame visibility
kwargs.setdefault("frameon", old_legend.legendPatch.get_visible())
# Remove the old legend and create the new one
new_legend = legend_func(handles, labels, loc=loc, **props)
new_legend.set_title(title.get_text(), title.get_fontproperties())
# Let the Grid object continue to track the correct legend object
if isinstance(obj, Grid):
obj._legend = new_legend
def _kde_support(data, bw, gridsize, cut, clip):
"""Establish support for a kernel density estimate."""
support_min = max(data.min() - bw * cut, clip[0])
support_max = min(data.max() + bw * cut, clip[1])
support = np.linspace(support_min, support_max, gridsize)
return support
def ci(a, which=95, axis=None):
"""Return a percentile range from an array of values."""
p = 50 - which / 2, 50 + which / 2
return np.nanpercentile(a, p, axis)
def get_dataset_names():
"""Report available example datasets, useful for reporting issues.
Requires an internet connection.
with urlopen(DATASET_NAMES_URL) as resp:
txt = resp.read()
dataset_names = [name.strip() for name in txt.decode().split("\n")]
return list(filter(None, dataset_names))
def get_data_home(data_home=None):
"""Return a path to the cache directory for example datasets.
This directory is used by :func:`load_dataset`.
If the ``data_home`` argument is not provided, it will use a directory
specified by the `SEABORN_DATA` environment variable (if it exists)
or otherwise default to an OS-appropriate user cache location.
if data_home is None:
data_home = os.environ.get("SEABORN_DATA", user_cache_dir("seaborn"))
data_home = os.path.expanduser(data_home)
if not os.path.exists(data_home):
return data_home
def load_dataset(name, cache=True, data_home=None, **kws):
"""Load an example dataset from the online repository (requires internet).
This function provides quick access to a small number of example datasets
that are useful for documenting seaborn or generating reproducible examples
for bug reports. It is not necessary for normal usage.
Note that some of the datasets have a small amount of preprocessing applied
to define a proper ordering for categorical variables.
Use :func:`get_dataset_names` to see a list of available datasets.
name : str
Name of the dataset (``{name}.csv`` on
cache : boolean, optional
If True, try to load from the local cache first, and save to the cache
if a download is required.
data_home : string, optional
The directory in which to cache data; see :func:`get_data_home`.
kws : keys and values, optional
Additional keyword arguments are passed to passed through to
df : :class:`pandas.DataFrame`
Tabular data, possibly with some preprocessing applied.
# A common beginner mistake is to assume that one's personal data needs
# to be passed through this function to be usable with seaborn.
# Let's provide a more helpful error than you would otherwise get.
if isinstance(name, pd.DataFrame):
err = (
"This function accepts only strings (the name of an example dataset). "
"You passed a pandas DataFrame. If you have your own dataset, "
"it is not necessary to use this function before plotting."
raise TypeError(err)
url = f"{DATASET_SOURCE}/{name}.csv"
if cache:
cache_path = os.path.join(get_data_home(data_home), os.path.basename(url))
if not os.path.exists(cache_path):
if name not in get_dataset_names():
raise ValueError(f"'{name}' is not one of the example datasets.")
urlretrieve(url, cache_path)
full_path = cache_path
full_path = url
df = pd.read_csv(full_path, **kws)
if df.iloc[-1].isnull().all():
df = df.iloc[:-1]
# Set some columns as a categorical type with ordered levels
if name == "tips":
df["day"] = pd.Categorical(df["day"], ["Thur", "Fri", "Sat", "Sun"])
df["sex"] = pd.Categorical(df["sex"], ["Male", "Female"])
df["time"] = pd.Categorical(df["time"], ["Lunch", "Dinner"])
df["smoker"] = pd.Categorical(df["smoker"], ["Yes", "No"])
elif name == "flights":
months = df["month"].str[:3]
df["month"] = pd.Categorical(months, months.unique())
elif name == "exercise":
df["time"] = pd.Categorical(df["time"], ["1 min", "15 min", "30 min"])
df["kind"] = pd.Categorical(df["kind"], ["rest", "walking", "running"])
df["diet"] = pd.Categorical(df["diet"], ["no fat", "low fat"])
elif name == "titanic":
df["class"] = pd.Categorical(df["class"], ["First", "Second", "Third"])
df["deck"] = pd.Categorical(df["deck"], list("ABCDEFG"))
elif name == "penguins":
df["sex"] = df["sex"].str.title()
elif name == "diamonds":
df["color"] = pd.Categorical(
df["color"], ["D", "E", "F", "G", "H", "I", "J"],
df["clarity"] = pd.Categorical(
df["clarity"], ["IF", "VVS1", "VVS2", "VS1", "VS2", "SI1", "SI2", "I1"],
df["cut"] = pd.Categorical(
df["cut"], ["Ideal", "Premium", "Very Good", "Good", "Fair"],
elif name == "taxis":
df["pickup"] = pd.to_datetime(df["pickup"])
df["dropoff"] = pd.to_datetime(df["dropoff"])
elif name == "seaice":
df["Date"] = pd.to_datetime(df["Date"])
elif name == "dowjones":
df["Date"] = pd.to_datetime(df["Date"])
return df
def axis_ticklabels_overlap(labels):
"""Return a boolean for whether the list of ticklabels have overlaps.
labels : list of matplotlib ticklabels
overlap : boolean
True if any of the labels overlap.
if not labels:
return False
bboxes = [l.get_window_extent() for l in labels]
overlaps = [b.count_overlaps(bboxes) for b in bboxes]
return max(overlaps) > 1
except RuntimeError:
# Issue on macos backend raises an error in the above code
return False
def axes_ticklabels_overlap(ax):
"""Return booleans for whether the x and y ticklabels on an Axes overlap.
ax : matplotlib Axes
x_overlap, y_overlap : booleans
True when the labels on that axis overlap.
return (axis_ticklabels_overlap(ax.get_xticklabels()),
def locator_to_legend_entries(locator, limits, dtype):
"""Return levels and formatted levels for brief numeric legends."""
raw_levels = locator.tick_values(*limits).astype(dtype)
# The locator can return ticks outside the limits, clip them here
raw_levels = [l for l in raw_levels if l >= limits[0] and l <= limits[1]]
class dummy_axis:
def get_view_interval(self):
return limits
if isinstance(locator, mpl.ticker.LogLocator):
formatter = mpl.ticker.LogFormatter()
formatter = mpl.ticker.ScalarFormatter()
# Avoid having an offset/scientific notation which we don't currently
# have any way of representing in the legend
formatter.axis = dummy_axis()
formatted_levels = formatter.format_ticks(raw_levels)
return raw_levels, formatted_levels
def relative_luminance(color):
"""Calculate the relative luminance of a color according to W3C standards
color : matplotlib color or sequence of matplotlib colors
Hex code, rgb-tuple, or html color name.
luminance : float(s) between 0 and 1
rgb = mpl.colors.colorConverter.to_rgba_array(color)[:, :3]
rgb = np.where(rgb <= .03928, rgb / 12.92, ((rgb + .055) / 1.055) ** 2.4)
lum = rgb.dot([.2126, .7152, .0722])
return lum.item()
except ValueError:
return lum
def to_utf8(obj):
"""Return a string representing a Python object.
Strings (i.e. type ``str``) are returned unchanged.
Byte strings (i.e. type ``bytes``) are returned as UTF-8-decoded strings.
For other objects, the method ``__str__()`` is called, and the result is
returned as a string.
obj : object
Any Python object
s : str
UTF-8-decoded string representation of ``obj``
if isinstance(obj, str):
return obj
return obj.decode(encoding="utf-8")
except AttributeError: # obj is not bytes-like
return str(obj)
def _check_argument(param, options, value, prefix=False):
"""Raise if value for param is not in options."""
if prefix and value is not None:
failure = not any(value.startswith(p) for p in options if isinstance(p, str))
failure = value not in options
if failure:
raise ValueError(
f"The value for `{param}` must be one of {options}, "
f"but {repr(value)} was passed."
return value
def _assign_default_kwargs(kws, call_func, source_func):
"""Assign default kwargs for call_func using values from source_func."""
# This exists so that axes-level functions and figure-level functions can
# both call a Plotter method while having the default kwargs be defined in
# the signature of the axes-level function.
# An alternative would be to have a decorator on the method that sets its
# defaults based on those defined in the axes-level function.
# Then the figure-level function would not need to worry about defaults.
# I am not sure which is better.
needed = inspect.signature(call_func).parameters
defaults = inspect.signature(source_func).parameters
for param in needed:
if param in defaults and param not in kws:
kws[param] = defaults[param].default
return kws
def adjust_legend_subtitles(legend):
Make invisible-handle "subtitles" entries look more like titles.
Note: This function is not part of the public API and may be changed or removed.
# Legend title not in rcParams until 3.0
font_size = plt.rcParams.get("legend.title_fontsize", None)
hpackers = legend.findobj(mpl.offsetbox.VPacker)[0].get_children()
for hpack in hpackers:
draw_area, text_area = hpack.get_children()
handles = draw_area.get_children()
if not all(artist.get_visible() for artist in handles):
for text in text_area.get_children():
if font_size is not None:
def _deprecate_ci(errorbar, ci):
Warn on usage of ci= and convert to appropriate errorbar= arg.
ci was deprecated when errorbar was added in 0.12. It should not be removed
completely for some time, but it can be moved out of function definitions
(and extracted from kwargs) after one cycle.
if ci is not deprecated and ci != "deprecated":
if ci is None:
errorbar = None
elif ci == "sd":
errorbar = "sd"
errorbar = ("ci", ci)
msg = (
"\n\nThe `ci` parameter is deprecated. "
f"Use `errorbar={repr(errorbar)}` for the same effect.\n"
warnings.warn(msg, FutureWarning, stacklevel=3)
return errorbar
def _get_transform_functions(ax, axis):
"""Return the forward and inverse transforms for a given axis."""
axis_obj = getattr(ax, f"{axis}axis")
transform = axis_obj.get_transform()
return transform.transform, transform.inverted().transform
def _disable_autolayout():
"""Context manager for preventing rc-controlled auto-layout behavior."""
# This is a workaround for an issue in matplotlib, for details see
# https://github.com/mwaskom/seaborn/issues/2914
# The only affect of this rcParam is to set the default value for
# layout= in plt.figure, so we could just do that instead.
# But then we would need to own the complexity of the transition
# from tight_layout=True -> layout="tight". This seems easier,
# but can be removed when (if) that is simpler on the matplotlib side,
# or if the layout algorithms are improved to handle figure legends.
orig_val = mpl.rcParams["figure.autolayout"]
mpl.rcParams["figure.autolayout"] = False
mpl.rcParams["figure.autolayout"] = orig_val
def _version_predates(lib: ModuleType, version: str) -> bool:
"""Helper function for checking version compatibility."""
return Version(lib.__version__) < Version(version)
def _scatter_legend_artist(**kws):
kws = normalize_kwargs(kws, mpl.collections.PathCollection)
edgecolor = kws.pop("edgecolor", None)
rc = mpl.rcParams
line_kws = {
"linestyle": "",
"marker": kws.pop("marker", "o"),
"markersize": np.sqrt(kws.pop("s", rc["lines.markersize"] ** 2)),
"markerfacecolor": kws.pop("facecolor", kws.get("color")),
"markeredgewidth": kws.pop("linewidth", 0),
if edgecolor is not None:
if edgecolor == "face":
line_kws["markeredgecolor"] = line_kws["markerfacecolor"]
line_kws["markeredgecolor"] = edgecolor
return mpl.lines.Line2D([], [], **line_kws)
def _get_patch_legend_artist(fill):
def legend_artist(**kws):
color = kws.pop("color", None)
if color is not None:
if fill:
kws["facecolor"] = color
kws["edgecolor"] = color
kws["facecolor"] = "none"
return mpl.patches.Rectangle((0, 0), 0, 0, **kws)
return legend_artist