201 lines
6.5 KiB
Python
201 lines
6.5 KiB
Python
from __future__ import annotations
|
|
from dataclasses import dataclass
|
|
|
|
import numpy as np
|
|
import matplotlib as mpl
|
|
|
|
from seaborn._marks.base import (
|
|
Mark,
|
|
Mappable,
|
|
MappableBool,
|
|
MappableFloat,
|
|
MappableString,
|
|
MappableColor,
|
|
MappableStyle,
|
|
resolve_properties,
|
|
resolve_color,
|
|
document_properties,
|
|
)
|
|
|
|
from typing import TYPE_CHECKING
|
|
if TYPE_CHECKING:
|
|
from typing import Any
|
|
from matplotlib.artist import Artist
|
|
from seaborn._core.scales import Scale
|
|
|
|
|
|
class DotBase(Mark):
|
|
|
|
def _resolve_paths(self, data):
|
|
|
|
paths = []
|
|
path_cache = {}
|
|
marker = data["marker"]
|
|
|
|
def get_transformed_path(m):
|
|
return m.get_path().transformed(m.get_transform())
|
|
|
|
if isinstance(marker, mpl.markers.MarkerStyle):
|
|
return get_transformed_path(marker)
|
|
|
|
for m in marker:
|
|
if m not in path_cache:
|
|
path_cache[m] = get_transformed_path(m)
|
|
paths.append(path_cache[m])
|
|
return paths
|
|
|
|
def _resolve_properties(self, data, scales):
|
|
|
|
resolved = resolve_properties(self, data, scales)
|
|
resolved["path"] = self._resolve_paths(resolved)
|
|
resolved["size"] = resolved["pointsize"] ** 2
|
|
|
|
if isinstance(data, dict): # Properties for single dot
|
|
filled_marker = resolved["marker"].is_filled()
|
|
else:
|
|
filled_marker = [m.is_filled() for m in resolved["marker"]]
|
|
|
|
resolved["fill"] = resolved["fill"] * filled_marker
|
|
|
|
return resolved
|
|
|
|
def _plot(self, split_gen, scales, orient):
|
|
|
|
# TODO Not backcompat with allowed (but nonfunctional) univariate plots
|
|
# (That should be solved upstream by defaulting to "" for unset x/y?)
|
|
# (Be mindful of xmin/xmax, etc!)
|
|
|
|
for _, data, ax in split_gen():
|
|
|
|
offsets = np.column_stack([data["x"], data["y"]])
|
|
data = self._resolve_properties(data, scales)
|
|
|
|
points = mpl.collections.PathCollection(
|
|
offsets=offsets,
|
|
paths=data["path"],
|
|
sizes=data["size"],
|
|
facecolors=data["facecolor"],
|
|
edgecolors=data["edgecolor"],
|
|
linewidths=data["linewidth"],
|
|
linestyles=data["edgestyle"],
|
|
transOffset=ax.transData,
|
|
transform=mpl.transforms.IdentityTransform(),
|
|
**self.artist_kws,
|
|
)
|
|
ax.add_collection(points)
|
|
|
|
def _legend_artist(
|
|
self, variables: list[str], value: Any, scales: dict[str, Scale],
|
|
) -> Artist:
|
|
|
|
key = {v: value for v in variables}
|
|
res = self._resolve_properties(key, scales)
|
|
|
|
return mpl.collections.PathCollection(
|
|
paths=[res["path"]],
|
|
sizes=[res["size"]],
|
|
facecolors=[res["facecolor"]],
|
|
edgecolors=[res["edgecolor"]],
|
|
linewidths=[res["linewidth"]],
|
|
linestyles=[res["edgestyle"]],
|
|
transform=mpl.transforms.IdentityTransform(),
|
|
**self.artist_kws,
|
|
)
|
|
|
|
|
|
@document_properties
|
|
@dataclass
|
|
class Dot(DotBase):
|
|
"""
|
|
A mark suitable for dot plots or less-dense scatterplots.
|
|
|
|
See also
|
|
--------
|
|
Dots : A dot mark defined by strokes to better handle overplotting.
|
|
|
|
Examples
|
|
--------
|
|
.. include:: ../docstrings/objects.Dot.rst
|
|
|
|
"""
|
|
marker: MappableString = Mappable("o", grouping=False)
|
|
pointsize: MappableFloat = Mappable(6, grouping=False) # TODO rcParam?
|
|
stroke: MappableFloat = Mappable(.75, grouping=False) # TODO rcParam?
|
|
color: MappableColor = Mappable("C0", grouping=False)
|
|
alpha: MappableFloat = Mappable(1, grouping=False)
|
|
fill: MappableBool = Mappable(True, grouping=False)
|
|
edgecolor: MappableColor = Mappable(depend="color", grouping=False)
|
|
edgealpha: MappableFloat = Mappable(depend="alpha", grouping=False)
|
|
edgewidth: MappableFloat = Mappable(.5, grouping=False) # TODO rcParam?
|
|
edgestyle: MappableStyle = Mappable("-", grouping=False)
|
|
|
|
def _resolve_properties(self, data, scales):
|
|
|
|
resolved = super()._resolve_properties(data, scales)
|
|
filled = resolved["fill"]
|
|
|
|
main_stroke = resolved["stroke"]
|
|
edge_stroke = resolved["edgewidth"]
|
|
resolved["linewidth"] = np.where(filled, edge_stroke, main_stroke)
|
|
|
|
main_color = resolve_color(self, data, "", scales)
|
|
edge_color = resolve_color(self, data, "edge", scales)
|
|
|
|
if not np.isscalar(filled):
|
|
# Expand dims to use in np.where with rgba arrays
|
|
filled = filled[:, None]
|
|
resolved["edgecolor"] = np.where(filled, edge_color, main_color)
|
|
|
|
filled = np.squeeze(filled)
|
|
if isinstance(main_color, tuple):
|
|
# TODO handle this in resolve_color
|
|
main_color = tuple([*main_color[:3], main_color[3] * filled])
|
|
else:
|
|
main_color = np.c_[main_color[:, :3], main_color[:, 3] * filled]
|
|
resolved["facecolor"] = main_color
|
|
|
|
return resolved
|
|
|
|
|
|
@document_properties
|
|
@dataclass
|
|
class Dots(DotBase):
|
|
"""
|
|
A dot mark defined by strokes to better handle overplotting.
|
|
|
|
See also
|
|
--------
|
|
Dot : A mark suitable for dot plots or less-dense scatterplots.
|
|
|
|
Examples
|
|
--------
|
|
.. include:: ../docstrings/objects.Dots.rst
|
|
|
|
"""
|
|
# TODO retype marker as MappableMarker
|
|
marker: MappableString = Mappable(rc="scatter.marker", grouping=False)
|
|
pointsize: MappableFloat = Mappable(4, grouping=False) # TODO rcParam?
|
|
stroke: MappableFloat = Mappable(.75, grouping=False) # TODO rcParam?
|
|
color: MappableColor = Mappable("C0", grouping=False)
|
|
alpha: MappableFloat = Mappable(1, grouping=False) # TODO auto alpha?
|
|
fill: MappableBool = Mappable(True, grouping=False)
|
|
fillcolor: MappableColor = Mappable(depend="color", grouping=False)
|
|
fillalpha: MappableFloat = Mappable(.2, grouping=False)
|
|
|
|
def _resolve_properties(self, data, scales):
|
|
|
|
resolved = super()._resolve_properties(data, scales)
|
|
resolved["linewidth"] = resolved.pop("stroke")
|
|
resolved["facecolor"] = resolve_color(self, data, "fill", scales)
|
|
resolved["edgecolor"] = resolve_color(self, data, "", scales)
|
|
resolved.setdefault("edgestyle", (0, None))
|
|
|
|
fc = resolved["facecolor"]
|
|
if isinstance(fc, tuple):
|
|
resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"]
|
|
else:
|
|
fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem?
|
|
resolved["facecolor"] = fc
|
|
|
|
return resolved
|