66 lines
2.6 KiB
Python
66 lines
2.6 KiB
Python
"""Base module for statistical transformations."""
|
|
from __future__ import annotations
|
|
from collections.abc import Iterable
|
|
from dataclasses import dataclass
|
|
from typing import ClassVar, Any
|
|
import warnings
|
|
|
|
from typing import TYPE_CHECKING
|
|
if TYPE_CHECKING:
|
|
from pandas import DataFrame
|
|
from seaborn._core.groupby import GroupBy
|
|
from seaborn._core.scales import Scale
|
|
|
|
|
|
@dataclass
|
|
class Stat:
|
|
"""Base class for objects that apply statistical transformations."""
|
|
|
|
# The class supports a partial-function application pattern. The object is
|
|
# initialized with desired parameters and the result is a callable that
|
|
# accepts and returns dataframes.
|
|
|
|
# The statistical transformation logic should not add any state to the instance
|
|
# beyond what is defined with the initialization parameters.
|
|
|
|
# Subclasses can declare whether the orient dimension should be used in grouping
|
|
# TODO consider whether this should be a parameter. Motivating example:
|
|
# use the same KDE class violin plots and univariate density estimation.
|
|
# In the former case, we would expect separate densities for each unique
|
|
# value on the orient axis, but we would not in the latter case.
|
|
group_by_orient: ClassVar[bool] = False
|
|
|
|
def _check_param_one_of(self, param: str, options: Iterable[Any]) -> None:
|
|
"""Raise when parameter value is not one of a specified set."""
|
|
value = getattr(self, param)
|
|
if value not in options:
|
|
*most, last = options
|
|
option_str = ", ".join(f"{x!r}" for x in most[:-1]) + f" or {last!r}"
|
|
err = " ".join([
|
|
f"The `{param}` parameter for `{self.__class__.__name__}` must be",
|
|
f"one of {option_str}; not {value!r}.",
|
|
])
|
|
raise ValueError(err)
|
|
|
|
def _check_grouping_vars(
|
|
self, param: str, data_vars: list[str], stacklevel: int = 2,
|
|
) -> None:
|
|
"""Warn if vars are named in parameter without being present in the data."""
|
|
param_vars = getattr(self, param)
|
|
undefined = set(param_vars) - set(data_vars)
|
|
if undefined:
|
|
param = f"{self.__class__.__name__}.{param}"
|
|
names = ", ".join(f"{x!r}" for x in undefined)
|
|
msg = f"Undefined variable(s) passed for {param}: {names}."
|
|
warnings.warn(msg, stacklevel=stacklevel)
|
|
|
|
def __call__(
|
|
self,
|
|
data: DataFrame,
|
|
groupby: GroupBy,
|
|
orient: str,
|
|
scales: dict[str, Scale],
|
|
) -> DataFrame:
|
|
"""Apply statistical transform to data subgroups and return combined result."""
|
|
return data
|