941 lines
33 KiB
Python
941 lines
33 KiB
Python
"""Plotting functions for linear models (broadly construed)."""
|
|
import copy
|
|
from textwrap import dedent
|
|
import warnings
|
|
import numpy as np
|
|
import pandas as pd
|
|
import matplotlib as mpl
|
|
import matplotlib.pyplot as plt
|
|
|
|
try:
|
|
import statsmodels
|
|
assert statsmodels
|
|
_has_statsmodels = True
|
|
except ImportError:
|
|
_has_statsmodels = False
|
|
|
|
from . import utils
|
|
from . import algorithms as algo
|
|
from .axisgrid import FacetGrid, _facet_docs
|
|
|
|
|
|
__all__ = ["lmplot", "regplot", "residplot"]
|
|
|
|
|
|
class _LinearPlotter:
|
|
"""Base class for plotting relational data in tidy format.
|
|
|
|
To get anything useful done you'll have to inherit from this, but setup
|
|
code that can be abstracted out should be put here.
|
|
|
|
"""
|
|
def establish_variables(self, data, **kws):
|
|
"""Extract variables from data or use directly."""
|
|
self.data = data
|
|
|
|
# Validate the inputs
|
|
any_strings = any([isinstance(v, str) for v in kws.values()])
|
|
if any_strings and data is None:
|
|
raise ValueError("Must pass `data` if using named variables.")
|
|
|
|
# Set the variables
|
|
for var, val in kws.items():
|
|
if isinstance(val, str):
|
|
vector = data[val]
|
|
elif isinstance(val, list):
|
|
vector = np.asarray(val)
|
|
else:
|
|
vector = val
|
|
if vector is not None and vector.shape != (1,):
|
|
vector = np.squeeze(vector)
|
|
if np.ndim(vector) > 1:
|
|
err = "regplot inputs must be 1d"
|
|
raise ValueError(err)
|
|
setattr(self, var, vector)
|
|
|
|
def dropna(self, *vars):
|
|
"""Remove observations with missing data."""
|
|
vals = [getattr(self, var) for var in vars]
|
|
vals = [v for v in vals if v is not None]
|
|
not_na = np.all(np.column_stack([pd.notnull(v) for v in vals]), axis=1)
|
|
for var in vars:
|
|
val = getattr(self, var)
|
|
if val is not None:
|
|
setattr(self, var, val[not_na])
|
|
|
|
def plot(self, ax):
|
|
raise NotImplementedError
|
|
|
|
|
|
class _RegressionPlotter(_LinearPlotter):
|
|
"""Plotter for numeric independent variables with regression model.
|
|
|
|
This does the computations and drawing for the `regplot` function, and
|
|
is thus also used indirectly by `lmplot`.
|
|
"""
|
|
def __init__(self, x, y, data=None, x_estimator=None, x_bins=None,
|
|
x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
|
|
units=None, seed=None, order=1, logistic=False, lowess=False,
|
|
robust=False, logx=False, x_partial=None, y_partial=None,
|
|
truncate=False, dropna=True, x_jitter=None, y_jitter=None,
|
|
color=None, label=None):
|
|
|
|
# Set member attributes
|
|
self.x_estimator = x_estimator
|
|
self.ci = ci
|
|
self.x_ci = ci if x_ci == "ci" else x_ci
|
|
self.n_boot = n_boot
|
|
self.seed = seed
|
|
self.scatter = scatter
|
|
self.fit_reg = fit_reg
|
|
self.order = order
|
|
self.logistic = logistic
|
|
self.lowess = lowess
|
|
self.robust = robust
|
|
self.logx = logx
|
|
self.truncate = truncate
|
|
self.x_jitter = x_jitter
|
|
self.y_jitter = y_jitter
|
|
self.color = color
|
|
self.label = label
|
|
|
|
# Validate the regression options:
|
|
if sum((order > 1, logistic, robust, lowess, logx)) > 1:
|
|
raise ValueError("Mutually exclusive regression options.")
|
|
|
|
# Extract the data vals from the arguments or passed dataframe
|
|
self.establish_variables(data, x=x, y=y, units=units,
|
|
x_partial=x_partial, y_partial=y_partial)
|
|
|
|
# Drop null observations
|
|
if dropna:
|
|
self.dropna("x", "y", "units", "x_partial", "y_partial")
|
|
|
|
# Regress nuisance variables out of the data
|
|
if self.x_partial is not None:
|
|
self.x = self.regress_out(self.x, self.x_partial)
|
|
if self.y_partial is not None:
|
|
self.y = self.regress_out(self.y, self.y_partial)
|
|
|
|
# Possibly bin the predictor variable, which implies a point estimate
|
|
if x_bins is not None:
|
|
self.x_estimator = np.mean if x_estimator is None else x_estimator
|
|
x_discrete, x_bins = self.bin_predictor(x_bins)
|
|
self.x_discrete = x_discrete
|
|
else:
|
|
self.x_discrete = self.x
|
|
|
|
# Disable regression in case of singleton inputs
|
|
if len(self.x) <= 1:
|
|
self.fit_reg = False
|
|
|
|
# Save the range of the x variable for the grid later
|
|
if self.fit_reg:
|
|
self.x_range = self.x.min(), self.x.max()
|
|
|
|
@property
|
|
def scatter_data(self):
|
|
"""Data where each observation is a point."""
|
|
x_j = self.x_jitter
|
|
if x_j is None:
|
|
x = self.x
|
|
else:
|
|
x = self.x + np.random.uniform(-x_j, x_j, len(self.x))
|
|
|
|
y_j = self.y_jitter
|
|
if y_j is None:
|
|
y = self.y
|
|
else:
|
|
y = self.y + np.random.uniform(-y_j, y_j, len(self.y))
|
|
|
|
return x, y
|
|
|
|
@property
|
|
def estimate_data(self):
|
|
"""Data with a point estimate and CI for each discrete x value."""
|
|
x, y = self.x_discrete, self.y
|
|
vals = sorted(np.unique(x))
|
|
points, cis = [], []
|
|
|
|
for val in vals:
|
|
|
|
# Get the point estimate of the y variable
|
|
_y = y[x == val]
|
|
est = self.x_estimator(_y)
|
|
points.append(est)
|
|
|
|
# Compute the confidence interval for this estimate
|
|
if self.x_ci is None:
|
|
cis.append(None)
|
|
else:
|
|
units = None
|
|
if self.x_ci == "sd":
|
|
sd = np.std(_y)
|
|
_ci = est - sd, est + sd
|
|
else:
|
|
if self.units is not None:
|
|
units = self.units[x == val]
|
|
boots = algo.bootstrap(_y,
|
|
func=self.x_estimator,
|
|
n_boot=self.n_boot,
|
|
units=units,
|
|
seed=self.seed)
|
|
_ci = utils.ci(boots, self.x_ci)
|
|
cis.append(_ci)
|
|
|
|
return vals, points, cis
|
|
|
|
def _check_statsmodels(self):
|
|
"""Check whether statsmodels is installed if any boolean options require it."""
|
|
options = "logistic", "robust", "lowess"
|
|
err = "`{}=True` requires statsmodels, an optional dependency, to be installed."
|
|
for option in options:
|
|
if getattr(self, option) and not _has_statsmodels:
|
|
raise RuntimeError(err.format(option))
|
|
|
|
def fit_regression(self, ax=None, x_range=None, grid=None):
|
|
"""Fit the regression model."""
|
|
self._check_statsmodels()
|
|
|
|
# Create the grid for the regression
|
|
if grid is None:
|
|
if self.truncate:
|
|
x_min, x_max = self.x_range
|
|
else:
|
|
if ax is None:
|
|
x_min, x_max = x_range
|
|
else:
|
|
x_min, x_max = ax.get_xlim()
|
|
grid = np.linspace(x_min, x_max, 100)
|
|
ci = self.ci
|
|
|
|
# Fit the regression
|
|
if self.order > 1:
|
|
yhat, yhat_boots = self.fit_poly(grid, self.order)
|
|
elif self.logistic:
|
|
from statsmodels.genmod.generalized_linear_model import GLM
|
|
from statsmodels.genmod.families import Binomial
|
|
yhat, yhat_boots = self.fit_statsmodels(grid, GLM,
|
|
family=Binomial())
|
|
elif self.lowess:
|
|
ci = None
|
|
grid, yhat = self.fit_lowess()
|
|
elif self.robust:
|
|
from statsmodels.robust.robust_linear_model import RLM
|
|
yhat, yhat_boots = self.fit_statsmodels(grid, RLM)
|
|
elif self.logx:
|
|
yhat, yhat_boots = self.fit_logx(grid)
|
|
else:
|
|
yhat, yhat_boots = self.fit_fast(grid)
|
|
|
|
# Compute the confidence interval at each grid point
|
|
if ci is None:
|
|
err_bands = None
|
|
else:
|
|
err_bands = utils.ci(yhat_boots, ci, axis=0)
|
|
|
|
return grid, yhat, err_bands
|
|
|
|
def fit_fast(self, grid):
|
|
"""Low-level regression and prediction using linear algebra."""
|
|
def reg_func(_x, _y):
|
|
return np.linalg.pinv(_x).dot(_y)
|
|
|
|
X, y = np.c_[np.ones(len(self.x)), self.x], self.y
|
|
grid = np.c_[np.ones(len(grid)), grid]
|
|
yhat = grid.dot(reg_func(X, y))
|
|
if self.ci is None:
|
|
return yhat, None
|
|
|
|
beta_boots = algo.bootstrap(X, y,
|
|
func=reg_func,
|
|
n_boot=self.n_boot,
|
|
units=self.units,
|
|
seed=self.seed).T
|
|
yhat_boots = grid.dot(beta_boots).T
|
|
return yhat, yhat_boots
|
|
|
|
def fit_poly(self, grid, order):
|
|
"""Regression using numpy polyfit for higher-order trends."""
|
|
def reg_func(_x, _y):
|
|
return np.polyval(np.polyfit(_x, _y, order), grid)
|
|
|
|
x, y = self.x, self.y
|
|
yhat = reg_func(x, y)
|
|
if self.ci is None:
|
|
return yhat, None
|
|
|
|
yhat_boots = algo.bootstrap(x, y,
|
|
func=reg_func,
|
|
n_boot=self.n_boot,
|
|
units=self.units,
|
|
seed=self.seed)
|
|
return yhat, yhat_boots
|
|
|
|
def fit_statsmodels(self, grid, model, **kwargs):
|
|
"""More general regression function using statsmodels objects."""
|
|
import statsmodels.tools.sm_exceptions as sme
|
|
X, y = np.c_[np.ones(len(self.x)), self.x], self.y
|
|
grid = np.c_[np.ones(len(grid)), grid]
|
|
|
|
def reg_func(_x, _y):
|
|
err_classes = (sme.PerfectSeparationError,)
|
|
try:
|
|
with warnings.catch_warnings():
|
|
if hasattr(sme, "PerfectSeparationWarning"):
|
|
# statsmodels>=0.14.0
|
|
warnings.simplefilter("error", sme.PerfectSeparationWarning)
|
|
err_classes = (*err_classes, sme.PerfectSeparationWarning)
|
|
yhat = model(_y, _x, **kwargs).fit().predict(grid)
|
|
except err_classes:
|
|
yhat = np.empty(len(grid))
|
|
yhat.fill(np.nan)
|
|
return yhat
|
|
|
|
yhat = reg_func(X, y)
|
|
if self.ci is None:
|
|
return yhat, None
|
|
|
|
yhat_boots = algo.bootstrap(X, y,
|
|
func=reg_func,
|
|
n_boot=self.n_boot,
|
|
units=self.units,
|
|
seed=self.seed)
|
|
return yhat, yhat_boots
|
|
|
|
def fit_lowess(self):
|
|
"""Fit a locally-weighted regression, which returns its own grid."""
|
|
from statsmodels.nonparametric.smoothers_lowess import lowess
|
|
grid, yhat = lowess(self.y, self.x).T
|
|
return grid, yhat
|
|
|
|
def fit_logx(self, grid):
|
|
"""Fit the model in log-space."""
|
|
X, y = np.c_[np.ones(len(self.x)), self.x], self.y
|
|
grid = np.c_[np.ones(len(grid)), np.log(grid)]
|
|
|
|
def reg_func(_x, _y):
|
|
_x = np.c_[_x[:, 0], np.log(_x[:, 1])]
|
|
return np.linalg.pinv(_x).dot(_y)
|
|
|
|
yhat = grid.dot(reg_func(X, y))
|
|
if self.ci is None:
|
|
return yhat, None
|
|
|
|
beta_boots = algo.bootstrap(X, y,
|
|
func=reg_func,
|
|
n_boot=self.n_boot,
|
|
units=self.units,
|
|
seed=self.seed).T
|
|
yhat_boots = grid.dot(beta_boots).T
|
|
return yhat, yhat_boots
|
|
|
|
def bin_predictor(self, bins):
|
|
"""Discretize a predictor by assigning value to closest bin."""
|
|
x = np.asarray(self.x)
|
|
if np.isscalar(bins):
|
|
percentiles = np.linspace(0, 100, bins + 2)[1:-1]
|
|
bins = np.percentile(x, percentiles)
|
|
else:
|
|
bins = np.ravel(bins)
|
|
|
|
dist = np.abs(np.subtract.outer(x, bins))
|
|
x_binned = bins[np.argmin(dist, axis=1)].ravel()
|
|
|
|
return x_binned, bins
|
|
|
|
def regress_out(self, a, b):
|
|
"""Regress b from a keeping a's original mean."""
|
|
a_mean = a.mean()
|
|
a = a - a_mean
|
|
b = b - b.mean()
|
|
b = np.c_[b]
|
|
a_prime = a - b.dot(np.linalg.pinv(b).dot(a))
|
|
return np.asarray(a_prime + a_mean).reshape(a.shape)
|
|
|
|
def plot(self, ax, scatter_kws, line_kws):
|
|
"""Draw the full plot."""
|
|
# Insert the plot label into the correct set of keyword arguments
|
|
if self.scatter:
|
|
scatter_kws["label"] = self.label
|
|
else:
|
|
line_kws["label"] = self.label
|
|
|
|
# Use the current color cycle state as a default
|
|
if self.color is None:
|
|
lines, = ax.plot([], [])
|
|
color = lines.get_color()
|
|
lines.remove()
|
|
else:
|
|
color = self.color
|
|
|
|
# Ensure that color is hex to avoid matplotlib weirdness
|
|
color = mpl.colors.rgb2hex(mpl.colors.colorConverter.to_rgb(color))
|
|
|
|
# Let color in keyword arguments override overall plot color
|
|
scatter_kws.setdefault("color", color)
|
|
line_kws.setdefault("color", color)
|
|
|
|
# Draw the constituent plots
|
|
if self.scatter:
|
|
self.scatterplot(ax, scatter_kws)
|
|
|
|
if self.fit_reg:
|
|
self.lineplot(ax, line_kws)
|
|
|
|
# Label the axes
|
|
if hasattr(self.x, "name"):
|
|
ax.set_xlabel(self.x.name)
|
|
if hasattr(self.y, "name"):
|
|
ax.set_ylabel(self.y.name)
|
|
|
|
def scatterplot(self, ax, kws):
|
|
"""Draw the data."""
|
|
# Treat the line-based markers specially, explicitly setting larger
|
|
# linewidth than is provided by the seaborn style defaults.
|
|
# This would ideally be handled better in matplotlib (i.e., distinguish
|
|
# between edgewidth for solid glyphs and linewidth for line glyphs
|
|
# but this should do for now.
|
|
line_markers = ["1", "2", "3", "4", "+", "x", "|", "_"]
|
|
if self.x_estimator is None:
|
|
if "marker" in kws and kws["marker"] in line_markers:
|
|
lw = mpl.rcParams["lines.linewidth"]
|
|
else:
|
|
lw = mpl.rcParams["lines.markeredgewidth"]
|
|
kws.setdefault("linewidths", lw)
|
|
|
|
if not hasattr(kws['color'], 'shape') or kws['color'].shape[1] < 4:
|
|
kws.setdefault("alpha", .8)
|
|
|
|
x, y = self.scatter_data
|
|
ax.scatter(x, y, **kws)
|
|
else:
|
|
# TODO abstraction
|
|
ci_kws = {"color": kws["color"]}
|
|
if "alpha" in kws:
|
|
ci_kws["alpha"] = kws["alpha"]
|
|
ci_kws["linewidth"] = mpl.rcParams["lines.linewidth"] * 1.75
|
|
kws.setdefault("s", 50)
|
|
|
|
xs, ys, cis = self.estimate_data
|
|
if [ci for ci in cis if ci is not None]:
|
|
for x, ci in zip(xs, cis):
|
|
ax.plot([x, x], ci, **ci_kws)
|
|
ax.scatter(xs, ys, **kws)
|
|
|
|
def lineplot(self, ax, kws):
|
|
"""Draw the model."""
|
|
# Fit the regression model
|
|
grid, yhat, err_bands = self.fit_regression(ax)
|
|
edges = grid[0], grid[-1]
|
|
|
|
# Get set default aesthetics
|
|
fill_color = kws["color"]
|
|
lw = kws.pop("lw", mpl.rcParams["lines.linewidth"] * 1.5)
|
|
kws.setdefault("linewidth", lw)
|
|
|
|
# Draw the regression line and confidence interval
|
|
line, = ax.plot(grid, yhat, **kws)
|
|
if not self.truncate:
|
|
line.sticky_edges.x[:] = edges # Prevent mpl from adding margin
|
|
if err_bands is not None:
|
|
ax.fill_between(grid, *err_bands, facecolor=fill_color, alpha=.15)
|
|
|
|
|
|
_regression_docs = dict(
|
|
|
|
model_api=dedent("""\
|
|
There are a number of mutually exclusive options for estimating the
|
|
regression model. See the :ref:`tutorial <regression_tutorial>` for more
|
|
information.\
|
|
"""),
|
|
regplot_vs_lmplot=dedent("""\
|
|
The :func:`regplot` and :func:`lmplot` functions are closely related, but
|
|
the former is an axes-level function while the latter is a figure-level
|
|
function that combines :func:`regplot` and :class:`FacetGrid`.\
|
|
"""),
|
|
x_estimator=dedent("""\
|
|
x_estimator : callable that maps vector -> scalar, optional
|
|
Apply this function to each unique value of ``x`` and plot the
|
|
resulting estimate. This is useful when ``x`` is a discrete variable.
|
|
If ``x_ci`` is given, this estimate will be bootstrapped and a
|
|
confidence interval will be drawn.\
|
|
"""),
|
|
x_bins=dedent("""\
|
|
x_bins : int or vector, optional
|
|
Bin the ``x`` variable into discrete bins and then estimate the central
|
|
tendency and a confidence interval. This binning only influences how
|
|
the scatterplot is drawn; the regression is still fit to the original
|
|
data. This parameter is interpreted either as the number of
|
|
evenly-sized (not necessary spaced) bins or the positions of the bin
|
|
centers. When this parameter is used, it implies that the default of
|
|
``x_estimator`` is ``numpy.mean``.\
|
|
"""),
|
|
x_ci=dedent("""\
|
|
x_ci : "ci", "sd", int in [0, 100] or None, optional
|
|
Size of the confidence interval used when plotting a central tendency
|
|
for discrete values of ``x``. If ``"ci"``, defer to the value of the
|
|
``ci`` parameter. If ``"sd"``, skip bootstrapping and show the
|
|
standard deviation of the observations in each bin.\
|
|
"""),
|
|
scatter=dedent("""\
|
|
scatter : bool, optional
|
|
If ``True``, draw a scatterplot with the underlying observations (or
|
|
the ``x_estimator`` values).\
|
|
"""),
|
|
fit_reg=dedent("""\
|
|
fit_reg : bool, optional
|
|
If ``True``, estimate and plot a regression model relating the ``x``
|
|
and ``y`` variables.\
|
|
"""),
|
|
ci=dedent("""\
|
|
ci : int in [0, 100] or None, optional
|
|
Size of the confidence interval for the regression estimate. This will
|
|
be drawn using translucent bands around the regression line. The
|
|
confidence interval is estimated using a bootstrap; for large
|
|
datasets, it may be advisable to avoid that computation by setting
|
|
this parameter to None.\
|
|
"""),
|
|
n_boot=dedent("""\
|
|
n_boot : int, optional
|
|
Number of bootstrap resamples used to estimate the ``ci``. The default
|
|
value attempts to balance time and stability; you may want to increase
|
|
this value for "final" versions of plots.\
|
|
"""),
|
|
units=dedent("""\
|
|
units : variable name in ``data``, optional
|
|
If the ``x`` and ``y`` observations are nested within sampling units,
|
|
those can be specified here. This will be taken into account when
|
|
computing the confidence intervals by performing a multilevel bootstrap
|
|
that resamples both units and observations (within unit). This does not
|
|
otherwise influence how the regression is estimated or drawn.\
|
|
"""),
|
|
seed=dedent("""\
|
|
seed : int, numpy.random.Generator, or numpy.random.RandomState, optional
|
|
Seed or random number generator for reproducible bootstrapping.\
|
|
"""),
|
|
order=dedent("""\
|
|
order : int, optional
|
|
If ``order`` is greater than 1, use ``numpy.polyfit`` to estimate a
|
|
polynomial regression.\
|
|
"""),
|
|
logistic=dedent("""\
|
|
logistic : bool, optional
|
|
If ``True``, assume that ``y`` is a binary variable and use
|
|
``statsmodels`` to estimate a logistic regression model. Note that this
|
|
is substantially more computationally intensive than linear regression,
|
|
so you may wish to decrease the number of bootstrap resamples
|
|
(``n_boot``) or set ``ci`` to None.\
|
|
"""),
|
|
lowess=dedent("""\
|
|
lowess : bool, optional
|
|
If ``True``, use ``statsmodels`` to estimate a nonparametric lowess
|
|
model (locally weighted linear regression). Note that confidence
|
|
intervals cannot currently be drawn for this kind of model.\
|
|
"""),
|
|
robust=dedent("""\
|
|
robust : bool, optional
|
|
If ``True``, use ``statsmodels`` to estimate a robust regression. This
|
|
will de-weight outliers. Note that this is substantially more
|
|
computationally intensive than standard linear regression, so you may
|
|
wish to decrease the number of bootstrap resamples (``n_boot``) or set
|
|
``ci`` to None.\
|
|
"""),
|
|
logx=dedent("""\
|
|
logx : bool, optional
|
|
If ``True``, estimate a linear regression of the form y ~ log(x), but
|
|
plot the scatterplot and regression model in the input space. Note that
|
|
``x`` must be positive for this to work.\
|
|
"""),
|
|
xy_partial=dedent("""\
|
|
{x,y}_partial : strings in ``data`` or matrices
|
|
Confounding variables to regress out of the ``x`` or ``y`` variables
|
|
before plotting.\
|
|
"""),
|
|
truncate=dedent("""\
|
|
truncate : bool, optional
|
|
If ``True``, the regression line is bounded by the data limits. If
|
|
``False``, it extends to the ``x`` axis limits.
|
|
"""),
|
|
xy_jitter=dedent("""\
|
|
{x,y}_jitter : floats, optional
|
|
Add uniform random noise of this size to either the ``x`` or ``y``
|
|
variables. The noise is added to a copy of the data after fitting the
|
|
regression, and only influences the look of the scatterplot. This can
|
|
be helpful when plotting variables that take discrete values.\
|
|
"""),
|
|
scatter_line_kws=dedent("""\
|
|
{scatter,line}_kws : dictionaries
|
|
Additional keyword arguments to pass to ``plt.scatter`` and
|
|
``plt.plot``.\
|
|
"""),
|
|
)
|
|
_regression_docs.update(_facet_docs)
|
|
|
|
|
|
def lmplot(
|
|
data, *,
|
|
x=None, y=None, hue=None, col=None, row=None,
|
|
palette=None, col_wrap=None, height=5, aspect=1, markers="o",
|
|
sharex=None, sharey=None, hue_order=None, col_order=None, row_order=None,
|
|
legend=True, legend_out=None, x_estimator=None, x_bins=None,
|
|
x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
|
|
units=None, seed=None, order=1, logistic=False, lowess=False,
|
|
robust=False, logx=False, x_partial=None, y_partial=None,
|
|
truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None,
|
|
line_kws=None, facet_kws=None,
|
|
):
|
|
|
|
if facet_kws is None:
|
|
facet_kws = {}
|
|
|
|
def facet_kw_deprecation(key, val):
|
|
msg = (
|
|
f"{key} is deprecated from the `lmplot` function signature. "
|
|
"Please update your code to pass it using `facet_kws`."
|
|
)
|
|
if val is not None:
|
|
warnings.warn(msg, UserWarning)
|
|
facet_kws[key] = val
|
|
|
|
facet_kw_deprecation("sharex", sharex)
|
|
facet_kw_deprecation("sharey", sharey)
|
|
facet_kw_deprecation("legend_out", legend_out)
|
|
|
|
if data is None:
|
|
raise TypeError("Missing required keyword argument `data`.")
|
|
|
|
# Reduce the dataframe to only needed columns
|
|
need_cols = [x, y, hue, col, row, units, x_partial, y_partial]
|
|
cols = np.unique([a for a in need_cols if a is not None]).tolist()
|
|
data = data[cols]
|
|
|
|
# Initialize the grid
|
|
facets = FacetGrid(
|
|
data, row=row, col=col, hue=hue,
|
|
palette=palette,
|
|
row_order=row_order, col_order=col_order, hue_order=hue_order,
|
|
height=height, aspect=aspect, col_wrap=col_wrap,
|
|
**facet_kws,
|
|
)
|
|
|
|
# Add the markers here as FacetGrid has figured out how many levels of the
|
|
# hue variable are needed and we don't want to duplicate that process
|
|
if facets.hue_names is None:
|
|
n_markers = 1
|
|
else:
|
|
n_markers = len(facets.hue_names)
|
|
if not isinstance(markers, list):
|
|
markers = [markers] * n_markers
|
|
if len(markers) != n_markers:
|
|
raise ValueError("markers must be a singleton or a list of markers "
|
|
"for each level of the hue variable")
|
|
facets.hue_kws = {"marker": markers}
|
|
|
|
def update_datalim(data, x, y, ax, **kws):
|
|
xys = data[[x, y]].to_numpy().astype(float)
|
|
ax.update_datalim(xys, updatey=False)
|
|
ax.autoscale_view(scaley=False)
|
|
|
|
facets.map_dataframe(update_datalim, x=x, y=y)
|
|
|
|
# Draw the regression plot on each facet
|
|
regplot_kws = dict(
|
|
x_estimator=x_estimator, x_bins=x_bins, x_ci=x_ci,
|
|
scatter=scatter, fit_reg=fit_reg, ci=ci, n_boot=n_boot, units=units,
|
|
seed=seed, order=order, logistic=logistic, lowess=lowess,
|
|
robust=robust, logx=logx, x_partial=x_partial, y_partial=y_partial,
|
|
truncate=truncate, x_jitter=x_jitter, y_jitter=y_jitter,
|
|
scatter_kws=scatter_kws, line_kws=line_kws,
|
|
)
|
|
facets.map_dataframe(regplot, x=x, y=y, **regplot_kws)
|
|
facets.set_axis_labels(x, y)
|
|
|
|
# Add a legend
|
|
if legend and (hue is not None) and (hue not in [col, row]):
|
|
facets.add_legend()
|
|
return facets
|
|
|
|
|
|
lmplot.__doc__ = dedent("""\
|
|
Plot data and regression model fits across a FacetGrid.
|
|
|
|
This function combines :func:`regplot` and :class:`FacetGrid`. It is
|
|
intended as a convenient interface to fit regression models across
|
|
conditional subsets of a dataset.
|
|
|
|
When thinking about how to assign variables to different facets, a general
|
|
rule is that it makes sense to use ``hue`` for the most important
|
|
comparison, followed by ``col`` and ``row``. However, always think about
|
|
your particular dataset and the goals of the visualization you are
|
|
creating.
|
|
|
|
{model_api}
|
|
|
|
The parameters to this function span most of the options in
|
|
:class:`FacetGrid`, although there may be occasional cases where you will
|
|
want to use that class and :func:`regplot` directly.
|
|
|
|
Parameters
|
|
----------
|
|
{data}
|
|
x, y : strings, optional
|
|
Input variables; these should be column names in ``data``.
|
|
hue, col, row : strings
|
|
Variables that define subsets of the data, which will be drawn on
|
|
separate facets in the grid. See the ``*_order`` parameters to control
|
|
the order of levels of this variable.
|
|
{palette}
|
|
{col_wrap}
|
|
{height}
|
|
{aspect}
|
|
markers : matplotlib marker code or list of marker codes, optional
|
|
Markers for the scatterplot. If a list, each marker in the list will be
|
|
used for each level of the ``hue`` variable.
|
|
{share_xy}
|
|
|
|
.. deprecated:: 0.12.0
|
|
Pass using the `facet_kws` dictionary.
|
|
|
|
{{hue,col,row}}_order : lists, optional
|
|
Order for the levels of the faceting variables. By default, this will
|
|
be the order that the levels appear in ``data`` or, if the variables
|
|
are pandas categoricals, the category order.
|
|
legend : bool, optional
|
|
If ``True`` and there is a ``hue`` variable, add a legend.
|
|
{legend_out}
|
|
|
|
.. deprecated:: 0.12.0
|
|
Pass using the `facet_kws` dictionary.
|
|
|
|
{x_estimator}
|
|
{x_bins}
|
|
{x_ci}
|
|
{scatter}
|
|
{fit_reg}
|
|
{ci}
|
|
{n_boot}
|
|
{units}
|
|
{seed}
|
|
{order}
|
|
{logistic}
|
|
{lowess}
|
|
{robust}
|
|
{logx}
|
|
{xy_partial}
|
|
{truncate}
|
|
{xy_jitter}
|
|
{scatter_line_kws}
|
|
facet_kws : dict
|
|
Dictionary of keyword arguments for :class:`FacetGrid`.
|
|
|
|
See Also
|
|
--------
|
|
regplot : Plot data and a conditional model fit.
|
|
FacetGrid : Subplot grid for plotting conditional relationships.
|
|
pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
|
|
``kind="reg"``).
|
|
|
|
Notes
|
|
-----
|
|
|
|
{regplot_vs_lmplot}
|
|
|
|
Examples
|
|
--------
|
|
|
|
.. include:: ../docstrings/lmplot.rst
|
|
|
|
""").format(**_regression_docs)
|
|
|
|
|
|
def regplot(
|
|
data=None, *, x=None, y=None,
|
|
x_estimator=None, x_bins=None, x_ci="ci",
|
|
scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None,
|
|
seed=None, order=1, logistic=False, lowess=False, robust=False,
|
|
logx=False, x_partial=None, y_partial=None,
|
|
truncate=True, dropna=True, x_jitter=None, y_jitter=None,
|
|
label=None, color=None, marker="o",
|
|
scatter_kws=None, line_kws=None, ax=None
|
|
):
|
|
|
|
plotter = _RegressionPlotter(x, y, data, x_estimator, x_bins, x_ci,
|
|
scatter, fit_reg, ci, n_boot, units, seed,
|
|
order, logistic, lowess, robust, logx,
|
|
x_partial, y_partial, truncate, dropna,
|
|
x_jitter, y_jitter, color, label)
|
|
|
|
if ax is None:
|
|
ax = plt.gca()
|
|
|
|
scatter_kws = {} if scatter_kws is None else copy.copy(scatter_kws)
|
|
scatter_kws["marker"] = marker
|
|
line_kws = {} if line_kws is None else copy.copy(line_kws)
|
|
plotter.plot(ax, scatter_kws, line_kws)
|
|
return ax
|
|
|
|
|
|
regplot.__doc__ = dedent("""\
|
|
Plot data and a linear regression model fit.
|
|
|
|
{model_api}
|
|
|
|
Parameters
|
|
----------
|
|
x, y: string, series, or vector array
|
|
Input variables. If strings, these should correspond with column names
|
|
in ``data``. When pandas objects are used, axes will be labeled with
|
|
the series name.
|
|
{data}
|
|
{x_estimator}
|
|
{x_bins}
|
|
{x_ci}
|
|
{scatter}
|
|
{fit_reg}
|
|
{ci}
|
|
{n_boot}
|
|
{units}
|
|
{seed}
|
|
{order}
|
|
{logistic}
|
|
{lowess}
|
|
{robust}
|
|
{logx}
|
|
{xy_partial}
|
|
{truncate}
|
|
{xy_jitter}
|
|
label : string
|
|
Label to apply to either the scatterplot or regression line (if
|
|
``scatter`` is ``False``) for use in a legend.
|
|
color : matplotlib color
|
|
Color to apply to all plot elements; will be superseded by colors
|
|
passed in ``scatter_kws`` or ``line_kws``.
|
|
marker : matplotlib marker code
|
|
Marker to use for the scatterplot glyphs.
|
|
{scatter_line_kws}
|
|
ax : matplotlib Axes, optional
|
|
Axes object to draw the plot onto, otherwise uses the current Axes.
|
|
|
|
Returns
|
|
-------
|
|
ax : matplotlib Axes
|
|
The Axes object containing the plot.
|
|
|
|
See Also
|
|
--------
|
|
lmplot : Combine :func:`regplot` and :class:`FacetGrid` to plot multiple
|
|
linear relationships in a dataset.
|
|
jointplot : Combine :func:`regplot` and :class:`JointGrid` (when used with
|
|
``kind="reg"``).
|
|
pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
|
|
``kind="reg"``).
|
|
residplot : Plot the residuals of a linear regression model.
|
|
|
|
Notes
|
|
-----
|
|
|
|
{regplot_vs_lmplot}
|
|
|
|
|
|
It's also easy to combine :func:`regplot` and :class:`JointGrid` or
|
|
:class:`PairGrid` through the :func:`jointplot` and :func:`pairplot`
|
|
functions, although these do not directly accept all of :func:`regplot`'s
|
|
parameters.
|
|
|
|
Examples
|
|
--------
|
|
|
|
.. include:: ../docstrings/regplot.rst
|
|
|
|
""").format(**_regression_docs)
|
|
|
|
|
|
def residplot(
|
|
data=None, *, x=None, y=None,
|
|
x_partial=None, y_partial=None, lowess=False,
|
|
order=1, robust=False, dropna=True, label=None, color=None,
|
|
scatter_kws=None, line_kws=None, ax=None
|
|
):
|
|
"""Plot the residuals of a linear regression.
|
|
|
|
This function will regress y on x (possibly as a robust or polynomial
|
|
regression) and then draw a scatterplot of the residuals. You can
|
|
optionally fit a lowess smoother to the residual plot, which can
|
|
help in determining if there is structure to the residuals.
|
|
|
|
Parameters
|
|
----------
|
|
data : DataFrame, optional
|
|
DataFrame to use if `x` and `y` are column names.
|
|
x : vector or string
|
|
Data or column name in `data` for the predictor variable.
|
|
y : vector or string
|
|
Data or column name in `data` for the response variable.
|
|
{x, y}_partial : vectors or string(s) , optional
|
|
These variables are treated as confounding and are removed from
|
|
the `x` or `y` variables before plotting.
|
|
lowess : boolean, optional
|
|
Fit a lowess smoother to the residual scatterplot.
|
|
order : int, optional
|
|
Order of the polynomial to fit when calculating the residuals.
|
|
robust : boolean, optional
|
|
Fit a robust linear regression when calculating the residuals.
|
|
dropna : boolean, optional
|
|
If True, ignore observations with missing data when fitting and
|
|
plotting.
|
|
label : string, optional
|
|
Label that will be used in any plot legends.
|
|
color : matplotlib color, optional
|
|
Color to use for all elements of the plot.
|
|
{scatter, line}_kws : dictionaries, optional
|
|
Additional keyword arguments passed to scatter() and plot() for drawing
|
|
the components of the plot.
|
|
ax : matplotlib axis, optional
|
|
Plot into this axis, otherwise grab the current axis or make a new
|
|
one if not existing.
|
|
|
|
Returns
|
|
-------
|
|
ax: matplotlib axes
|
|
Axes with the regression plot.
|
|
|
|
See Also
|
|
--------
|
|
regplot : Plot a simple linear regression model.
|
|
jointplot : Draw a :func:`residplot` with univariate marginal distributions
|
|
(when used with ``kind="resid"``).
|
|
|
|
Examples
|
|
--------
|
|
|
|
.. include:: ../docstrings/residplot.rst
|
|
|
|
"""
|
|
plotter = _RegressionPlotter(x, y, data, ci=None,
|
|
order=order, robust=robust,
|
|
x_partial=x_partial, y_partial=y_partial,
|
|
dropna=dropna, color=color, label=label)
|
|
|
|
if ax is None:
|
|
ax = plt.gca()
|
|
|
|
# Calculate the residual from a linear regression
|
|
_, yhat, _ = plotter.fit_regression(grid=plotter.x)
|
|
plotter.y = plotter.y - yhat
|
|
|
|
# Set the regression option on the plotter
|
|
if lowess:
|
|
plotter.lowess = True
|
|
else:
|
|
plotter.fit_reg = False
|
|
|
|
# Plot a horizontal line at 0
|
|
ax.axhline(0, ls=":", c=".2")
|
|
|
|
# Draw the scatterplot
|
|
scatter_kws = {} if scatter_kws is None else scatter_kws.copy()
|
|
line_kws = {} if line_kws is None else line_kws.copy()
|
|
plotter.plot(ax, scatter_kws, line_kws)
|
|
return ax
|