2024-10-02 22:15:59 +04:00

286 lines
8.6 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import ClassVar
import numpy as np
import matplotlib as mpl
from seaborn._marks.base import (
Mark,
Mappable,
MappableFloat,
MappableString,
MappableColor,
resolve_properties,
resolve_color,
document_properties,
)
@document_properties
@dataclass
class Path(Mark):
"""
A mark connecting data points in the order they appear.
See also
--------
Line : A mark connecting data points with sorting along the orientation axis.
Paths : A faster but less-flexible mark for drawing many paths.
Examples
--------
.. include:: ../docstrings/objects.Path.rst
"""
color: MappableColor = Mappable("C0")
alpha: MappableFloat = Mappable(1)
linewidth: MappableFloat = Mappable(rc="lines.linewidth")
linestyle: MappableString = Mappable(rc="lines.linestyle")
marker: MappableString = Mappable(rc="lines.marker")
pointsize: MappableFloat = Mappable(rc="lines.markersize")
fillcolor: MappableColor = Mappable(depend="color")
edgecolor: MappableColor = Mappable(depend="color")
edgewidth: MappableFloat = Mappable(rc="lines.markeredgewidth")
_sort: ClassVar[bool] = False
def _plot(self, split_gen, scales, orient):
for keys, data, ax in split_gen(keep_na=not self._sort):
vals = resolve_properties(self, keys, scales)
vals["color"] = resolve_color(self, keys, scales=scales)
vals["fillcolor"] = resolve_color(self, keys, prefix="fill", scales=scales)
vals["edgecolor"] = resolve_color(self, keys, prefix="edge", scales=scales)
if self._sort:
data = data.sort_values(orient, kind="mergesort")
artist_kws = self.artist_kws.copy()
self._handle_capstyle(artist_kws, vals)
line = mpl.lines.Line2D(
data["x"].to_numpy(),
data["y"].to_numpy(),
color=vals["color"],
linewidth=vals["linewidth"],
linestyle=vals["linestyle"],
marker=vals["marker"],
markersize=vals["pointsize"],
markerfacecolor=vals["fillcolor"],
markeredgecolor=vals["edgecolor"],
markeredgewidth=vals["edgewidth"],
**artist_kws,
)
ax.add_line(line)
def _legend_artist(self, variables, value, scales):
keys = {v: value for v in variables}
vals = resolve_properties(self, keys, scales)
vals["color"] = resolve_color(self, keys, scales=scales)
vals["fillcolor"] = resolve_color(self, keys, prefix="fill", scales=scales)
vals["edgecolor"] = resolve_color(self, keys, prefix="edge", scales=scales)
artist_kws = self.artist_kws.copy()
self._handle_capstyle(artist_kws, vals)
return mpl.lines.Line2D(
[], [],
color=vals["color"],
linewidth=vals["linewidth"],
linestyle=vals["linestyle"],
marker=vals["marker"],
markersize=vals["pointsize"],
markerfacecolor=vals["fillcolor"],
markeredgecolor=vals["edgecolor"],
markeredgewidth=vals["edgewidth"],
**artist_kws,
)
def _handle_capstyle(self, kws, vals):
# Work around for this matplotlib issue:
# https://github.com/matplotlib/matplotlib/issues/23437
if vals["linestyle"][1] is None:
capstyle = kws.get("solid_capstyle", mpl.rcParams["lines.solid_capstyle"])
kws["dash_capstyle"] = capstyle
@document_properties
@dataclass
class Line(Path):
"""
A mark connecting data points with sorting along the orientation axis.
See also
--------
Path : A mark connecting data points in the order they appear.
Lines : A faster but less-flexible mark for drawing many lines.
Examples
--------
.. include:: ../docstrings/objects.Line.rst
"""
_sort: ClassVar[bool] = True
@document_properties
@dataclass
class Paths(Mark):
"""
A faster but less-flexible mark for drawing many paths.
See also
--------
Path : A mark connecting data points in the order they appear.
Examples
--------
.. include:: ../docstrings/objects.Paths.rst
"""
color: MappableColor = Mappable("C0")
alpha: MappableFloat = Mappable(1)
linewidth: MappableFloat = Mappable(rc="lines.linewidth")
linestyle: MappableString = Mappable(rc="lines.linestyle")
_sort: ClassVar[bool] = False
def __post_init__(self):
# LineCollection artists have a capstyle property but don't source its value
# from the rc, so we do that manually here. Unfortunately, because we add
# only one LineCollection, we have the use the same capstyle for all lines
# even when they are dashed. It's a slight inconsistency, but looks fine IMO.
self.artist_kws.setdefault("capstyle", mpl.rcParams["lines.solid_capstyle"])
def _plot(self, split_gen, scales, orient):
line_data = {}
for keys, data, ax in split_gen(keep_na=not self._sort):
if ax not in line_data:
line_data[ax] = {
"segments": [],
"colors": [],
"linewidths": [],
"linestyles": [],
}
segments = self._setup_segments(data, orient)
line_data[ax]["segments"].extend(segments)
n = len(segments)
vals = resolve_properties(self, keys, scales)
vals["color"] = resolve_color(self, keys, scales=scales)
line_data[ax]["colors"].extend([vals["color"]] * n)
line_data[ax]["linewidths"].extend([vals["linewidth"]] * n)
line_data[ax]["linestyles"].extend([vals["linestyle"]] * n)
for ax, ax_data in line_data.items():
lines = mpl.collections.LineCollection(**ax_data, **self.artist_kws)
# Handle datalim update manually
# https://github.com/matplotlib/matplotlib/issues/23129
ax.add_collection(lines, autolim=False)
if ax_data["segments"]:
xy = np.concatenate(ax_data["segments"])
ax.update_datalim(xy)
def _legend_artist(self, variables, value, scales):
key = resolve_properties(self, {v: value for v in variables}, scales)
artist_kws = self.artist_kws.copy()
capstyle = artist_kws.pop("capstyle")
artist_kws["solid_capstyle"] = capstyle
artist_kws["dash_capstyle"] = capstyle
return mpl.lines.Line2D(
[], [],
color=key["color"],
linewidth=key["linewidth"],
linestyle=key["linestyle"],
**artist_kws,
)
def _setup_segments(self, data, orient):
if self._sort:
data = data.sort_values(orient, kind="mergesort")
# Column stack to avoid block consolidation
xy = np.column_stack([data["x"], data["y"]])
return [xy]
@document_properties
@dataclass
class Lines(Paths):
"""
A faster but less-flexible mark for drawing many lines.
See also
--------
Line : A mark connecting data points with sorting along the orientation axis.
Examples
--------
.. include:: ../docstrings/objects.Lines.rst
"""
_sort: ClassVar[bool] = True
@document_properties
@dataclass
class Range(Paths):
"""
An oriented line mark drawn between min/max values.
Examples
--------
.. include:: ../docstrings/objects.Range.rst
"""
def _setup_segments(self, data, orient):
# TODO better checks on what variables we have
# TODO what if only one exist?
val = {"x": "y", "y": "x"}[orient]
if not set(data.columns) & {f"{val}min", f"{val}max"}:
agg = {f"{val}min": (val, "min"), f"{val}max": (val, "max")}
data = data.groupby(orient).agg(**agg).reset_index()
cols = [orient, f"{val}min", f"{val}max"]
data = data[cols].melt(orient, value_name=val)[["x", "y"]]
segments = [d.to_numpy() for _, d in data.groupby(orient)]
return segments
@document_properties
@dataclass
class Dash(Paths):
"""
A line mark drawn as an oriented segment for each datapoint.
Examples
--------
.. include:: ../docstrings/objects.Dash.rst
"""
width: MappableFloat = Mappable(.8, grouping=False)
def _setup_segments(self, data, orient):
ori = ["x", "y"].index(orient)
xys = data[["x", "y"]].to_numpy().astype(float)
segments = np.stack([xys, xys], axis=1)
segments[:, 0, ori] -= data["width"] / 2
segments[:, 1, ori] += data["width"] / 2
return segments