253 lines
8.9 KiB
Python
253 lines
8.9 KiB
Python
from __future__ import annotations
|
|
from collections import defaultdict
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
import matplotlib as mpl
|
|
|
|
from seaborn._marks.base import (
|
|
Mark,
|
|
Mappable,
|
|
MappableBool,
|
|
MappableColor,
|
|
MappableFloat,
|
|
MappableStyle,
|
|
resolve_properties,
|
|
resolve_color,
|
|
document_properties
|
|
)
|
|
|
|
from typing import TYPE_CHECKING
|
|
if TYPE_CHECKING:
|
|
from typing import Any
|
|
from matplotlib.artist import Artist
|
|
from seaborn._core.scales import Scale
|
|
|
|
|
|
class BarBase(Mark):
|
|
|
|
def _make_patches(self, data, scales, orient):
|
|
|
|
transform = scales[orient]._matplotlib_scale.get_transform()
|
|
forward = transform.transform
|
|
reverse = transform.inverted().transform
|
|
|
|
other = {"x": "y", "y": "x"}[orient]
|
|
|
|
pos = reverse(forward(data[orient]) - data["width"] / 2)
|
|
width = reverse(forward(data[orient]) + data["width"] / 2) - pos
|
|
|
|
val = (data[other] - data["baseline"]).to_numpy()
|
|
base = data["baseline"].to_numpy()
|
|
|
|
kws = self._resolve_properties(data, scales)
|
|
if orient == "x":
|
|
kws.update(x=pos, y=base, w=width, h=val)
|
|
else:
|
|
kws.update(x=base, y=pos, w=val, h=width)
|
|
|
|
kws.pop("width", None)
|
|
kws.pop("baseline", None)
|
|
|
|
val_dim = {"x": "h", "y": "w"}[orient]
|
|
bars, vals = [], []
|
|
|
|
for i in range(len(data)):
|
|
|
|
row = {k: v[i] for k, v in kws.items()}
|
|
|
|
# Skip bars with no value. It's possible we'll want to make this
|
|
# an option (i.e so you have an artist for animating or annotating),
|
|
# but let's keep things simple for now.
|
|
if not np.nan_to_num(row[val_dim]):
|
|
continue
|
|
|
|
bar = mpl.patches.Rectangle(
|
|
xy=(row["x"], row["y"]),
|
|
width=row["w"],
|
|
height=row["h"],
|
|
facecolor=row["facecolor"],
|
|
edgecolor=row["edgecolor"],
|
|
linestyle=row["edgestyle"],
|
|
linewidth=row["edgewidth"],
|
|
**self.artist_kws,
|
|
)
|
|
bars.append(bar)
|
|
vals.append(row[val_dim])
|
|
|
|
return bars, vals
|
|
|
|
def _resolve_properties(self, data, scales):
|
|
|
|
resolved = resolve_properties(self, data, scales)
|
|
|
|
resolved["facecolor"] = resolve_color(self, data, "", scales)
|
|
resolved["edgecolor"] = resolve_color(self, data, "edge", scales)
|
|
|
|
fc = resolved["facecolor"]
|
|
if isinstance(fc, tuple):
|
|
resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"]
|
|
else:
|
|
fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem?
|
|
resolved["facecolor"] = fc
|
|
|
|
return resolved
|
|
|
|
def _legend_artist(
|
|
self, variables: list[str], value: Any, scales: dict[str, Scale],
|
|
) -> Artist:
|
|
# TODO return some sensible default?
|
|
key = {v: value for v in variables}
|
|
key = self._resolve_properties(key, scales)
|
|
artist = mpl.patches.Patch(
|
|
facecolor=key["facecolor"],
|
|
edgecolor=key["edgecolor"],
|
|
linewidth=key["edgewidth"],
|
|
linestyle=key["edgestyle"],
|
|
)
|
|
return artist
|
|
|
|
|
|
@document_properties
|
|
@dataclass
|
|
class Bar(BarBase):
|
|
"""
|
|
A bar mark drawn between baseline and data values.
|
|
|
|
See also
|
|
--------
|
|
Bars : A faster bar mark with defaults more suitable for histograms.
|
|
|
|
Examples
|
|
--------
|
|
.. include:: ../docstrings/objects.Bar.rst
|
|
|
|
"""
|
|
color: MappableColor = Mappable("C0", grouping=False)
|
|
alpha: MappableFloat = Mappable(.7, grouping=False)
|
|
fill: MappableBool = Mappable(True, grouping=False)
|
|
edgecolor: MappableColor = Mappable(depend="color", grouping=False)
|
|
edgealpha: MappableFloat = Mappable(1, grouping=False)
|
|
edgewidth: MappableFloat = Mappable(rc="patch.linewidth", grouping=False)
|
|
edgestyle: MappableStyle = Mappable("-", grouping=False)
|
|
# pattern: MappableString = Mappable(None) # TODO no Property yet
|
|
|
|
width: MappableFloat = Mappable(.8, grouping=False)
|
|
baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?
|
|
|
|
def _plot(self, split_gen, scales, orient):
|
|
|
|
val_idx = ["y", "x"].index(orient)
|
|
|
|
for _, data, ax in split_gen():
|
|
|
|
bars, vals = self._make_patches(data, scales, orient)
|
|
|
|
for bar in bars:
|
|
|
|
# Because we are clipping the artist (see below), the edges end up
|
|
# looking half as wide as they actually are. I don't love this clumsy
|
|
# workaround, which is going to cause surprises if you work with the
|
|
# artists directly. We may need to revisit after feedback.
|
|
bar.set_linewidth(bar.get_linewidth() * 2)
|
|
linestyle = bar.get_linestyle()
|
|
if linestyle[1]:
|
|
linestyle = (linestyle[0], tuple(x / 2 for x in linestyle[1]))
|
|
bar.set_linestyle(linestyle)
|
|
|
|
# This is a bit of a hack to handle the fact that the edge lines are
|
|
# centered on the actual extents of the bar, and overlap when bars are
|
|
# stacked or dodged. We may discover that this causes problems and needs
|
|
# to be revisited at some point. Also it should be faster to clip with
|
|
# a bbox than a path, but I cant't work out how to get the intersection
|
|
# with the axes bbox.
|
|
bar.set_clip_path(bar.get_path(), bar.get_transform() + ax.transData)
|
|
if self.artist_kws.get("clip_on", True):
|
|
# It seems the above hack undoes the default axes clipping
|
|
bar.set_clip_box(ax.bbox)
|
|
bar.sticky_edges[val_idx][:] = (0, np.inf)
|
|
ax.add_patch(bar)
|
|
|
|
# Add a container which is useful for, e.g. Axes.bar_label
|
|
orientation = {"x": "vertical", "y": "horizontal"}[orient]
|
|
container_kws = dict(datavalues=vals, orientation=orientation)
|
|
container = mpl.container.BarContainer(bars, **container_kws)
|
|
ax.add_container(container)
|
|
|
|
|
|
@document_properties
|
|
@dataclass
|
|
class Bars(BarBase):
|
|
"""
|
|
A faster bar mark with defaults more suitable for histograms.
|
|
|
|
See also
|
|
--------
|
|
Bar : A bar mark drawn between baseline and data values.
|
|
|
|
Examples
|
|
--------
|
|
.. include:: ../docstrings/objects.Bars.rst
|
|
|
|
"""
|
|
color: MappableColor = Mappable("C0", grouping=False)
|
|
alpha: MappableFloat = Mappable(.7, grouping=False)
|
|
fill: MappableBool = Mappable(True, grouping=False)
|
|
edgecolor: MappableColor = Mappable(rc="patch.edgecolor", grouping=False)
|
|
edgealpha: MappableFloat = Mappable(1, grouping=False)
|
|
edgewidth: MappableFloat = Mappable(auto=True, grouping=False)
|
|
edgestyle: MappableStyle = Mappable("-", grouping=False)
|
|
# pattern: MappableString = Mappable(None) # TODO no Property yet
|
|
|
|
width: MappableFloat = Mappable(1, grouping=False)
|
|
baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable?
|
|
|
|
def _plot(self, split_gen, scales, orient):
|
|
|
|
ori_idx = ["x", "y"].index(orient)
|
|
val_idx = ["y", "x"].index(orient)
|
|
|
|
patches = defaultdict(list)
|
|
for _, data, ax in split_gen():
|
|
bars, _ = self._make_patches(data, scales, orient)
|
|
patches[ax].extend(bars)
|
|
|
|
collections = {}
|
|
for ax, ax_patches in patches.items():
|
|
|
|
col = mpl.collections.PatchCollection(ax_patches, match_original=True)
|
|
col.sticky_edges[val_idx][:] = (0, np.inf)
|
|
ax.add_collection(col, autolim=False)
|
|
collections[ax] = col
|
|
|
|
# Workaround for matplotlib autoscaling bug
|
|
# https://github.com/matplotlib/matplotlib/issues/11898
|
|
# https://github.com/matplotlib/matplotlib/issues/23129
|
|
xys = np.vstack([path.vertices for path in col.get_paths()])
|
|
ax.update_datalim(xys)
|
|
|
|
if "edgewidth" not in scales and isinstance(self.edgewidth, Mappable):
|
|
|
|
for ax in collections:
|
|
ax.autoscale_view()
|
|
|
|
def get_dimensions(collection):
|
|
edges, widths = [], []
|
|
for verts in (path.vertices for path in collection.get_paths()):
|
|
edges.append(min(verts[:, ori_idx]))
|
|
widths.append(np.ptp(verts[:, ori_idx]))
|
|
return np.array(edges), np.array(widths)
|
|
|
|
min_width = np.inf
|
|
for ax, col in collections.items():
|
|
edges, widths = get_dimensions(col)
|
|
points = 72 / ax.figure.dpi * abs(
|
|
ax.transData.transform([edges + widths] * 2)
|
|
- ax.transData.transform([edges] * 2)
|
|
)
|
|
min_width = min(min_width, min(points[:, ori_idx]))
|
|
|
|
linewidth = min(.1 * min_width, mpl.rcParams["patch.linewidth"])
|
|
for _, col in collections.items():
|
|
col.set_linewidth(linewidth)
|