2024-10-02 22:15:59 +04:00

131 lines
3.6 KiB
Python

from __future__ import annotations
from dataclasses import dataclass
from typing import ClassVar, Callable
import pandas as pd
from pandas import DataFrame
from seaborn._core.scales import Scale
from seaborn._core.groupby import GroupBy
from seaborn._stats.base import Stat
from seaborn._statistics import (
EstimateAggregator,
WeightedAggregator,
)
from seaborn._core.typing import Vector
@dataclass
class Agg(Stat):
"""
Aggregate data along the value axis using given method.
Parameters
----------
func : str or callable
Name of a :class:`pandas.Series` method or a vector -> scalar function.
See Also
--------
objects.Est : Aggregation with error bars.
Examples
--------
.. include:: ../docstrings/objects.Agg.rst
"""
func: str | Callable[[Vector], float] = "mean"
group_by_orient: ClassVar[bool] = True
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:
var = {"x": "y", "y": "x"}.get(orient)
res = (
groupby
.agg(data, {var: self.func})
.dropna(subset=[var])
.reset_index(drop=True)
)
return res
@dataclass
class Est(Stat):
"""
Calculate a point estimate and error bar interval.
For more information about the various `errorbar` choices, see the
:doc:`errorbar tutorial </tutorial/error_bars>`.
Additional variables:
- **weight**: When passed to a layer that uses this stat, a weighted estimate
will be computed. Note that use of weights currently limits the choice of
function and error bar method to `"mean"` and `"ci"`, respectively.
Parameters
----------
func : str or callable
Name of a :class:`numpy.ndarray` method or a vector -> scalar function.
errorbar : str, (str, float) tuple, or callable
Name of errorbar method (one of "ci", "pi", "se" or "sd"), or a tuple
with a method name ane a level parameter, or a function that maps from a
vector to a (min, max) interval.
n_boot : int
Number of bootstrap samples to draw for "ci" errorbars.
seed : int
Seed for the PRNG used to draw bootstrap samples.
Examples
--------
.. include:: ../docstrings/objects.Est.rst
"""
func: str | Callable[[Vector], float] = "mean"
errorbar: str | tuple[str, float] = ("ci", 95)
n_boot: int = 1000
seed: int | None = None
group_by_orient: ClassVar[bool] = True
def _process(
self, data: DataFrame, var: str, estimator: EstimateAggregator
) -> DataFrame:
# Needed because GroupBy.apply assumes func is DataFrame -> DataFrame
# which we could probably make more general to allow Series return
res = estimator(data, var)
return pd.DataFrame([res])
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:
boot_kws = {"n_boot": self.n_boot, "seed": self.seed}
if "weight" in data:
engine = WeightedAggregator(self.func, self.errorbar, **boot_kws)
else:
engine = EstimateAggregator(self.func, self.errorbar, **boot_kws)
var = {"x": "y", "y": "x"}[orient]
res = (
groupby
.apply(data, self._process, var, engine)
.dropna(subset=[var])
.reset_index(drop=True)
)
res = res.fillna({f"{var}min": res[var], f"{var}max": res[var]})
return res
@dataclass
class Rolling(Stat):
...
def __call__(self, data, groupby, orient, scales):
...