124 lines
4.0 KiB
Python
124 lines
4.0 KiB
Python
|
from __future__ import annotations
|
||
|
from typing import Literal
|
||
|
|
||
|
import numpy as np
|
||
|
import pandas as pd
|
||
|
import matplotlib as mpl
|
||
|
from matplotlib.figure import Figure
|
||
|
from seaborn.utils import _version_predates
|
||
|
|
||
|
|
||
|
def norm_from_scale(scale, norm):
|
||
|
"""Produce a Normalize object given a Scale and min/max domain limits."""
|
||
|
# This is an internal maplotlib function that simplifies things to access
|
||
|
# It is likely to become part of the matplotlib API at some point:
|
||
|
# https://github.com/matplotlib/matplotlib/issues/20329
|
||
|
if isinstance(norm, mpl.colors.Normalize):
|
||
|
return norm
|
||
|
|
||
|
if scale is None:
|
||
|
return None
|
||
|
|
||
|
if norm is None:
|
||
|
vmin = vmax = None
|
||
|
else:
|
||
|
vmin, vmax = norm # TODO more helpful error if this fails?
|
||
|
|
||
|
class ScaledNorm(mpl.colors.Normalize):
|
||
|
|
||
|
def __call__(self, value, clip=None):
|
||
|
# From github.com/matplotlib/matplotlib/blob/v3.4.2/lib/matplotlib/colors.py
|
||
|
# See github.com/matplotlib/matplotlib/tree/v3.4.2/LICENSE
|
||
|
value, is_scalar = self.process_value(value)
|
||
|
self.autoscale_None(value)
|
||
|
if self.vmin > self.vmax:
|
||
|
raise ValueError("vmin must be less or equal to vmax")
|
||
|
if self.vmin == self.vmax:
|
||
|
return np.full_like(value, 0)
|
||
|
if clip is None:
|
||
|
clip = self.clip
|
||
|
if clip:
|
||
|
value = np.clip(value, self.vmin, self.vmax)
|
||
|
# ***** Seaborn changes start ****
|
||
|
t_value = self.transform(value).reshape(np.shape(value))
|
||
|
t_vmin, t_vmax = self.transform([self.vmin, self.vmax])
|
||
|
# ***** Seaborn changes end *****
|
||
|
if not np.isfinite([t_vmin, t_vmax]).all():
|
||
|
raise ValueError("Invalid vmin or vmax")
|
||
|
t_value -= t_vmin
|
||
|
t_value /= (t_vmax - t_vmin)
|
||
|
t_value = np.ma.masked_invalid(t_value, copy=False)
|
||
|
return t_value[0] if is_scalar else t_value
|
||
|
|
||
|
new_norm = ScaledNorm(vmin, vmax)
|
||
|
new_norm.transform = scale.get_transform().transform
|
||
|
|
||
|
return new_norm
|
||
|
|
||
|
|
||
|
def get_colormap(name):
|
||
|
"""Handle changes to matplotlib colormap interface in 3.6."""
|
||
|
try:
|
||
|
return mpl.colormaps[name]
|
||
|
except AttributeError:
|
||
|
return mpl.cm.get_cmap(name)
|
||
|
|
||
|
|
||
|
def register_colormap(name, cmap):
|
||
|
"""Handle changes to matplotlib colormap interface in 3.6."""
|
||
|
try:
|
||
|
if name not in mpl.colormaps:
|
||
|
mpl.colormaps.register(cmap, name=name)
|
||
|
except AttributeError:
|
||
|
mpl.cm.register_cmap(name, cmap)
|
||
|
|
||
|
|
||
|
def set_layout_engine(
|
||
|
fig: Figure,
|
||
|
engine: Literal["constrained", "compressed", "tight", "none"],
|
||
|
) -> None:
|
||
|
"""Handle changes to auto layout engine interface in 3.6"""
|
||
|
if hasattr(fig, "set_layout_engine"):
|
||
|
fig.set_layout_engine(engine)
|
||
|
else:
|
||
|
# _version_predates(mpl, 3.6)
|
||
|
if engine == "tight":
|
||
|
fig.set_tight_layout(True) # type: ignore # predates typing
|
||
|
elif engine == "constrained":
|
||
|
fig.set_constrained_layout(True) # type: ignore
|
||
|
elif engine == "none":
|
||
|
fig.set_tight_layout(False) # type: ignore
|
||
|
fig.set_constrained_layout(False) # type: ignore
|
||
|
|
||
|
|
||
|
def get_layout_engine(fig: Figure) -> mpl.layout_engine.LayoutEngine | None:
|
||
|
"""Handle changes to auto layout engine interface in 3.6"""
|
||
|
if hasattr(fig, "get_layout_engine"):
|
||
|
return fig.get_layout_engine()
|
||
|
else:
|
||
|
# _version_predates(mpl, 3.6)
|
||
|
return None
|
||
|
|
||
|
|
||
|
def share_axis(ax0, ax1, which):
|
||
|
"""Handle changes to post-hoc axis sharing."""
|
||
|
if _version_predates(mpl, "3.5"):
|
||
|
group = getattr(ax0, f"get_shared_{which}_axes")()
|
||
|
group.join(ax1, ax0)
|
||
|
else:
|
||
|
getattr(ax1, f"share{which}")(ax0)
|
||
|
|
||
|
|
||
|
def get_legend_handles(legend):
|
||
|
"""Handle legendHandles attribute rename."""
|
||
|
if _version_predates(mpl, "3.7"):
|
||
|
return legend.legendHandles
|
||
|
else:
|
||
|
return legend.legend_handles
|
||
|
|
||
|
|
||
|
def groupby_apply_include_groups(val):
|
||
|
if _version_predates(pd, "2.2.0"):
|
||
|
return {}
|
||
|
return {"include_groups": val}
|