2024-10-02 22:15:59 +04:00

941 lines
33 KiB
Python

"""Plotting functions for linear models (broadly construed)."""
import copy
from textwrap import dedent
import warnings
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
try:
import statsmodels
assert statsmodels
_has_statsmodels = True
except ImportError:
_has_statsmodels = False
from . import utils
from . import algorithms as algo
from .axisgrid import FacetGrid, _facet_docs
__all__ = ["lmplot", "regplot", "residplot"]
class _LinearPlotter:
"""Base class for plotting relational data in tidy format.
To get anything useful done you'll have to inherit from this, but setup
code that can be abstracted out should be put here.
"""
def establish_variables(self, data, **kws):
"""Extract variables from data or use directly."""
self.data = data
# Validate the inputs
any_strings = any([isinstance(v, str) for v in kws.values()])
if any_strings and data is None:
raise ValueError("Must pass `data` if using named variables.")
# Set the variables
for var, val in kws.items():
if isinstance(val, str):
vector = data[val]
elif isinstance(val, list):
vector = np.asarray(val)
else:
vector = val
if vector is not None and vector.shape != (1,):
vector = np.squeeze(vector)
if np.ndim(vector) > 1:
err = "regplot inputs must be 1d"
raise ValueError(err)
setattr(self, var, vector)
def dropna(self, *vars):
"""Remove observations with missing data."""
vals = [getattr(self, var) for var in vars]
vals = [v for v in vals if v is not None]
not_na = np.all(np.column_stack([pd.notnull(v) for v in vals]), axis=1)
for var in vars:
val = getattr(self, var)
if val is not None:
setattr(self, var, val[not_na])
def plot(self, ax):
raise NotImplementedError
class _RegressionPlotter(_LinearPlotter):
"""Plotter for numeric independent variables with regression model.
This does the computations and drawing for the `regplot` function, and
is thus also used indirectly by `lmplot`.
"""
def __init__(self, x, y, data=None, x_estimator=None, x_bins=None,
x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
units=None, seed=None, order=1, logistic=False, lowess=False,
robust=False, logx=False, x_partial=None, y_partial=None,
truncate=False, dropna=True, x_jitter=None, y_jitter=None,
color=None, label=None):
# Set member attributes
self.x_estimator = x_estimator
self.ci = ci
self.x_ci = ci if x_ci == "ci" else x_ci
self.n_boot = n_boot
self.seed = seed
self.scatter = scatter
self.fit_reg = fit_reg
self.order = order
self.logistic = logistic
self.lowess = lowess
self.robust = robust
self.logx = logx
self.truncate = truncate
self.x_jitter = x_jitter
self.y_jitter = y_jitter
self.color = color
self.label = label
# Validate the regression options:
if sum((order > 1, logistic, robust, lowess, logx)) > 1:
raise ValueError("Mutually exclusive regression options.")
# Extract the data vals from the arguments or passed dataframe
self.establish_variables(data, x=x, y=y, units=units,
x_partial=x_partial, y_partial=y_partial)
# Drop null observations
if dropna:
self.dropna("x", "y", "units", "x_partial", "y_partial")
# Regress nuisance variables out of the data
if self.x_partial is not None:
self.x = self.regress_out(self.x, self.x_partial)
if self.y_partial is not None:
self.y = self.regress_out(self.y, self.y_partial)
# Possibly bin the predictor variable, which implies a point estimate
if x_bins is not None:
self.x_estimator = np.mean if x_estimator is None else x_estimator
x_discrete, x_bins = self.bin_predictor(x_bins)
self.x_discrete = x_discrete
else:
self.x_discrete = self.x
# Disable regression in case of singleton inputs
if len(self.x) <= 1:
self.fit_reg = False
# Save the range of the x variable for the grid later
if self.fit_reg:
self.x_range = self.x.min(), self.x.max()
@property
def scatter_data(self):
"""Data where each observation is a point."""
x_j = self.x_jitter
if x_j is None:
x = self.x
else:
x = self.x + np.random.uniform(-x_j, x_j, len(self.x))
y_j = self.y_jitter
if y_j is None:
y = self.y
else:
y = self.y + np.random.uniform(-y_j, y_j, len(self.y))
return x, y
@property
def estimate_data(self):
"""Data with a point estimate and CI for each discrete x value."""
x, y = self.x_discrete, self.y
vals = sorted(np.unique(x))
points, cis = [], []
for val in vals:
# Get the point estimate of the y variable
_y = y[x == val]
est = self.x_estimator(_y)
points.append(est)
# Compute the confidence interval for this estimate
if self.x_ci is None:
cis.append(None)
else:
units = None
if self.x_ci == "sd":
sd = np.std(_y)
_ci = est - sd, est + sd
else:
if self.units is not None:
units = self.units[x == val]
boots = algo.bootstrap(_y,
func=self.x_estimator,
n_boot=self.n_boot,
units=units,
seed=self.seed)
_ci = utils.ci(boots, self.x_ci)
cis.append(_ci)
return vals, points, cis
def _check_statsmodels(self):
"""Check whether statsmodels is installed if any boolean options require it."""
options = "logistic", "robust", "lowess"
err = "`{}=True` requires statsmodels, an optional dependency, to be installed."
for option in options:
if getattr(self, option) and not _has_statsmodels:
raise RuntimeError(err.format(option))
def fit_regression(self, ax=None, x_range=None, grid=None):
"""Fit the regression model."""
self._check_statsmodels()
# Create the grid for the regression
if grid is None:
if self.truncate:
x_min, x_max = self.x_range
else:
if ax is None:
x_min, x_max = x_range
else:
x_min, x_max = ax.get_xlim()
grid = np.linspace(x_min, x_max, 100)
ci = self.ci
# Fit the regression
if self.order > 1:
yhat, yhat_boots = self.fit_poly(grid, self.order)
elif self.logistic:
from statsmodels.genmod.generalized_linear_model import GLM
from statsmodels.genmod.families import Binomial
yhat, yhat_boots = self.fit_statsmodels(grid, GLM,
family=Binomial())
elif self.lowess:
ci = None
grid, yhat = self.fit_lowess()
elif self.robust:
from statsmodels.robust.robust_linear_model import RLM
yhat, yhat_boots = self.fit_statsmodels(grid, RLM)
elif self.logx:
yhat, yhat_boots = self.fit_logx(grid)
else:
yhat, yhat_boots = self.fit_fast(grid)
# Compute the confidence interval at each grid point
if ci is None:
err_bands = None
else:
err_bands = utils.ci(yhat_boots, ci, axis=0)
return grid, yhat, err_bands
def fit_fast(self, grid):
"""Low-level regression and prediction using linear algebra."""
def reg_func(_x, _y):
return np.linalg.pinv(_x).dot(_y)
X, y = np.c_[np.ones(len(self.x)), self.x], self.y
grid = np.c_[np.ones(len(grid)), grid]
yhat = grid.dot(reg_func(X, y))
if self.ci is None:
return yhat, None
beta_boots = algo.bootstrap(X, y,
func=reg_func,
n_boot=self.n_boot,
units=self.units,
seed=self.seed).T
yhat_boots = grid.dot(beta_boots).T
return yhat, yhat_boots
def fit_poly(self, grid, order):
"""Regression using numpy polyfit for higher-order trends."""
def reg_func(_x, _y):
return np.polyval(np.polyfit(_x, _y, order), grid)
x, y = self.x, self.y
yhat = reg_func(x, y)
if self.ci is None:
return yhat, None
yhat_boots = algo.bootstrap(x, y,
func=reg_func,
n_boot=self.n_boot,
units=self.units,
seed=self.seed)
return yhat, yhat_boots
def fit_statsmodels(self, grid, model, **kwargs):
"""More general regression function using statsmodels objects."""
import statsmodels.tools.sm_exceptions as sme
X, y = np.c_[np.ones(len(self.x)), self.x], self.y
grid = np.c_[np.ones(len(grid)), grid]
def reg_func(_x, _y):
err_classes = (sme.PerfectSeparationError,)
try:
with warnings.catch_warnings():
if hasattr(sme, "PerfectSeparationWarning"):
# statsmodels>=0.14.0
warnings.simplefilter("error", sme.PerfectSeparationWarning)
err_classes = (*err_classes, sme.PerfectSeparationWarning)
yhat = model(_y, _x, **kwargs).fit().predict(grid)
except err_classes:
yhat = np.empty(len(grid))
yhat.fill(np.nan)
return yhat
yhat = reg_func(X, y)
if self.ci is None:
return yhat, None
yhat_boots = algo.bootstrap(X, y,
func=reg_func,
n_boot=self.n_boot,
units=self.units,
seed=self.seed)
return yhat, yhat_boots
def fit_lowess(self):
"""Fit a locally-weighted regression, which returns its own grid."""
from statsmodels.nonparametric.smoothers_lowess import lowess
grid, yhat = lowess(self.y, self.x).T
return grid, yhat
def fit_logx(self, grid):
"""Fit the model in log-space."""
X, y = np.c_[np.ones(len(self.x)), self.x], self.y
grid = np.c_[np.ones(len(grid)), np.log(grid)]
def reg_func(_x, _y):
_x = np.c_[_x[:, 0], np.log(_x[:, 1])]
return np.linalg.pinv(_x).dot(_y)
yhat = grid.dot(reg_func(X, y))
if self.ci is None:
return yhat, None
beta_boots = algo.bootstrap(X, y,
func=reg_func,
n_boot=self.n_boot,
units=self.units,
seed=self.seed).T
yhat_boots = grid.dot(beta_boots).T
return yhat, yhat_boots
def bin_predictor(self, bins):
"""Discretize a predictor by assigning value to closest bin."""
x = np.asarray(self.x)
if np.isscalar(bins):
percentiles = np.linspace(0, 100, bins + 2)[1:-1]
bins = np.percentile(x, percentiles)
else:
bins = np.ravel(bins)
dist = np.abs(np.subtract.outer(x, bins))
x_binned = bins[np.argmin(dist, axis=1)].ravel()
return x_binned, bins
def regress_out(self, a, b):
"""Regress b from a keeping a's original mean."""
a_mean = a.mean()
a = a - a_mean
b = b - b.mean()
b = np.c_[b]
a_prime = a - b.dot(np.linalg.pinv(b).dot(a))
return np.asarray(a_prime + a_mean).reshape(a.shape)
def plot(self, ax, scatter_kws, line_kws):
"""Draw the full plot."""
# Insert the plot label into the correct set of keyword arguments
if self.scatter:
scatter_kws["label"] = self.label
else:
line_kws["label"] = self.label
# Use the current color cycle state as a default
if self.color is None:
lines, = ax.plot([], [])
color = lines.get_color()
lines.remove()
else:
color = self.color
# Ensure that color is hex to avoid matplotlib weirdness
color = mpl.colors.rgb2hex(mpl.colors.colorConverter.to_rgb(color))
# Let color in keyword arguments override overall plot color
scatter_kws.setdefault("color", color)
line_kws.setdefault("color", color)
# Draw the constituent plots
if self.scatter:
self.scatterplot(ax, scatter_kws)
if self.fit_reg:
self.lineplot(ax, line_kws)
# Label the axes
if hasattr(self.x, "name"):
ax.set_xlabel(self.x.name)
if hasattr(self.y, "name"):
ax.set_ylabel(self.y.name)
def scatterplot(self, ax, kws):
"""Draw the data."""
# Treat the line-based markers specially, explicitly setting larger
# linewidth than is provided by the seaborn style defaults.
# This would ideally be handled better in matplotlib (i.e., distinguish
# between edgewidth for solid glyphs and linewidth for line glyphs
# but this should do for now.
line_markers = ["1", "2", "3", "4", "+", "x", "|", "_"]
if self.x_estimator is None:
if "marker" in kws and kws["marker"] in line_markers:
lw = mpl.rcParams["lines.linewidth"]
else:
lw = mpl.rcParams["lines.markeredgewidth"]
kws.setdefault("linewidths", lw)
if not hasattr(kws['color'], 'shape') or kws['color'].shape[1] < 4:
kws.setdefault("alpha", .8)
x, y = self.scatter_data
ax.scatter(x, y, **kws)
else:
# TODO abstraction
ci_kws = {"color": kws["color"]}
if "alpha" in kws:
ci_kws["alpha"] = kws["alpha"]
ci_kws["linewidth"] = mpl.rcParams["lines.linewidth"] * 1.75
kws.setdefault("s", 50)
xs, ys, cis = self.estimate_data
if [ci for ci in cis if ci is not None]:
for x, ci in zip(xs, cis):
ax.plot([x, x], ci, **ci_kws)
ax.scatter(xs, ys, **kws)
def lineplot(self, ax, kws):
"""Draw the model."""
# Fit the regression model
grid, yhat, err_bands = self.fit_regression(ax)
edges = grid[0], grid[-1]
# Get set default aesthetics
fill_color = kws["color"]
lw = kws.pop("lw", mpl.rcParams["lines.linewidth"] * 1.5)
kws.setdefault("linewidth", lw)
# Draw the regression line and confidence interval
line, = ax.plot(grid, yhat, **kws)
if not self.truncate:
line.sticky_edges.x[:] = edges # Prevent mpl from adding margin
if err_bands is not None:
ax.fill_between(grid, *err_bands, facecolor=fill_color, alpha=.15)
_regression_docs = dict(
model_api=dedent("""\
There are a number of mutually exclusive options for estimating the
regression model. See the :ref:`tutorial <regression_tutorial>` for more
information.\
"""),
regplot_vs_lmplot=dedent("""\
The :func:`regplot` and :func:`lmplot` functions are closely related, but
the former is an axes-level function while the latter is a figure-level
function that combines :func:`regplot` and :class:`FacetGrid`.\
"""),
x_estimator=dedent("""\
x_estimator : callable that maps vector -> scalar, optional
Apply this function to each unique value of ``x`` and plot the
resulting estimate. This is useful when ``x`` is a discrete variable.
If ``x_ci`` is given, this estimate will be bootstrapped and a
confidence interval will be drawn.\
"""),
x_bins=dedent("""\
x_bins : int or vector, optional
Bin the ``x`` variable into discrete bins and then estimate the central
tendency and a confidence interval. This binning only influences how
the scatterplot is drawn; the regression is still fit to the original
data. This parameter is interpreted either as the number of
evenly-sized (not necessary spaced) bins or the positions of the bin
centers. When this parameter is used, it implies that the default of
``x_estimator`` is ``numpy.mean``.\
"""),
x_ci=dedent("""\
x_ci : "ci", "sd", int in [0, 100] or None, optional
Size of the confidence interval used when plotting a central tendency
for discrete values of ``x``. If ``"ci"``, defer to the value of the
``ci`` parameter. If ``"sd"``, skip bootstrapping and show the
standard deviation of the observations in each bin.\
"""),
scatter=dedent("""\
scatter : bool, optional
If ``True``, draw a scatterplot with the underlying observations (or
the ``x_estimator`` values).\
"""),
fit_reg=dedent("""\
fit_reg : bool, optional
If ``True``, estimate and plot a regression model relating the ``x``
and ``y`` variables.\
"""),
ci=dedent("""\
ci : int in [0, 100] or None, optional
Size of the confidence interval for the regression estimate. This will
be drawn using translucent bands around the regression line. The
confidence interval is estimated using a bootstrap; for large
datasets, it may be advisable to avoid that computation by setting
this parameter to None.\
"""),
n_boot=dedent("""\
n_boot : int, optional
Number of bootstrap resamples used to estimate the ``ci``. The default
value attempts to balance time and stability; you may want to increase
this value for "final" versions of plots.\
"""),
units=dedent("""\
units : variable name in ``data``, optional
If the ``x`` and ``y`` observations are nested within sampling units,
those can be specified here. This will be taken into account when
computing the confidence intervals by performing a multilevel bootstrap
that resamples both units and observations (within unit). This does not
otherwise influence how the regression is estimated or drawn.\
"""),
seed=dedent("""\
seed : int, numpy.random.Generator, or numpy.random.RandomState, optional
Seed or random number generator for reproducible bootstrapping.\
"""),
order=dedent("""\
order : int, optional
If ``order`` is greater than 1, use ``numpy.polyfit`` to estimate a
polynomial regression.\
"""),
logistic=dedent("""\
logistic : bool, optional
If ``True``, assume that ``y`` is a binary variable and use
``statsmodels`` to estimate a logistic regression model. Note that this
is substantially more computationally intensive than linear regression,
so you may wish to decrease the number of bootstrap resamples
(``n_boot``) or set ``ci`` to None.\
"""),
lowess=dedent("""\
lowess : bool, optional
If ``True``, use ``statsmodels`` to estimate a nonparametric lowess
model (locally weighted linear regression). Note that confidence
intervals cannot currently be drawn for this kind of model.\
"""),
robust=dedent("""\
robust : bool, optional
If ``True``, use ``statsmodels`` to estimate a robust regression. This
will de-weight outliers. Note that this is substantially more
computationally intensive than standard linear regression, so you may
wish to decrease the number of bootstrap resamples (``n_boot``) or set
``ci`` to None.\
"""),
logx=dedent("""\
logx : bool, optional
If ``True``, estimate a linear regression of the form y ~ log(x), but
plot the scatterplot and regression model in the input space. Note that
``x`` must be positive for this to work.\
"""),
xy_partial=dedent("""\
{x,y}_partial : strings in ``data`` or matrices
Confounding variables to regress out of the ``x`` or ``y`` variables
before plotting.\
"""),
truncate=dedent("""\
truncate : bool, optional
If ``True``, the regression line is bounded by the data limits. If
``False``, it extends to the ``x`` axis limits.
"""),
xy_jitter=dedent("""\
{x,y}_jitter : floats, optional
Add uniform random noise of this size to either the ``x`` or ``y``
variables. The noise is added to a copy of the data after fitting the
regression, and only influences the look of the scatterplot. This can
be helpful when plotting variables that take discrete values.\
"""),
scatter_line_kws=dedent("""\
{scatter,line}_kws : dictionaries
Additional keyword arguments to pass to ``plt.scatter`` and
``plt.plot``.\
"""),
)
_regression_docs.update(_facet_docs)
def lmplot(
data, *,
x=None, y=None, hue=None, col=None, row=None,
palette=None, col_wrap=None, height=5, aspect=1, markers="o",
sharex=None, sharey=None, hue_order=None, col_order=None, row_order=None,
legend=True, legend_out=None, x_estimator=None, x_bins=None,
x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
units=None, seed=None, order=1, logistic=False, lowess=False,
robust=False, logx=False, x_partial=None, y_partial=None,
truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None,
line_kws=None, facet_kws=None,
):
if facet_kws is None:
facet_kws = {}
def facet_kw_deprecation(key, val):
msg = (
f"{key} is deprecated from the `lmplot` function signature. "
"Please update your code to pass it using `facet_kws`."
)
if val is not None:
warnings.warn(msg, UserWarning)
facet_kws[key] = val
facet_kw_deprecation("sharex", sharex)
facet_kw_deprecation("sharey", sharey)
facet_kw_deprecation("legend_out", legend_out)
if data is None:
raise TypeError("Missing required keyword argument `data`.")
# Reduce the dataframe to only needed columns
need_cols = [x, y, hue, col, row, units, x_partial, y_partial]
cols = np.unique([a for a in need_cols if a is not None]).tolist()
data = data[cols]
# Initialize the grid
facets = FacetGrid(
data, row=row, col=col, hue=hue,
palette=palette,
row_order=row_order, col_order=col_order, hue_order=hue_order,
height=height, aspect=aspect, col_wrap=col_wrap,
**facet_kws,
)
# Add the markers here as FacetGrid has figured out how many levels of the
# hue variable are needed and we don't want to duplicate that process
if facets.hue_names is None:
n_markers = 1
else:
n_markers = len(facets.hue_names)
if not isinstance(markers, list):
markers = [markers] * n_markers
if len(markers) != n_markers:
raise ValueError("markers must be a singleton or a list of markers "
"for each level of the hue variable")
facets.hue_kws = {"marker": markers}
def update_datalim(data, x, y, ax, **kws):
xys = data[[x, y]].to_numpy().astype(float)
ax.update_datalim(xys, updatey=False)
ax.autoscale_view(scaley=False)
facets.map_dataframe(update_datalim, x=x, y=y)
# Draw the regression plot on each facet
regplot_kws = dict(
x_estimator=x_estimator, x_bins=x_bins, x_ci=x_ci,
scatter=scatter, fit_reg=fit_reg, ci=ci, n_boot=n_boot, units=units,
seed=seed, order=order, logistic=logistic, lowess=lowess,
robust=robust, logx=logx, x_partial=x_partial, y_partial=y_partial,
truncate=truncate, x_jitter=x_jitter, y_jitter=y_jitter,
scatter_kws=scatter_kws, line_kws=line_kws,
)
facets.map_dataframe(regplot, x=x, y=y, **regplot_kws)
facets.set_axis_labels(x, y)
# Add a legend
if legend and (hue is not None) and (hue not in [col, row]):
facets.add_legend()
return facets
lmplot.__doc__ = dedent("""\
Plot data and regression model fits across a FacetGrid.
This function combines :func:`regplot` and :class:`FacetGrid`. It is
intended as a convenient interface to fit regression models across
conditional subsets of a dataset.
When thinking about how to assign variables to different facets, a general
rule is that it makes sense to use ``hue`` for the most important
comparison, followed by ``col`` and ``row``. However, always think about
your particular dataset and the goals of the visualization you are
creating.
{model_api}
The parameters to this function span most of the options in
:class:`FacetGrid`, although there may be occasional cases where you will
want to use that class and :func:`regplot` directly.
Parameters
----------
{data}
x, y : strings, optional
Input variables; these should be column names in ``data``.
hue, col, row : strings
Variables that define subsets of the data, which will be drawn on
separate facets in the grid. See the ``*_order`` parameters to control
the order of levels of this variable.
{palette}
{col_wrap}
{height}
{aspect}
markers : matplotlib marker code or list of marker codes, optional
Markers for the scatterplot. If a list, each marker in the list will be
used for each level of the ``hue`` variable.
{share_xy}
.. deprecated:: 0.12.0
Pass using the `facet_kws` dictionary.
{{hue,col,row}}_order : lists, optional
Order for the levels of the faceting variables. By default, this will
be the order that the levels appear in ``data`` or, if the variables
are pandas categoricals, the category order.
legend : bool, optional
If ``True`` and there is a ``hue`` variable, add a legend.
{legend_out}
.. deprecated:: 0.12.0
Pass using the `facet_kws` dictionary.
{x_estimator}
{x_bins}
{x_ci}
{scatter}
{fit_reg}
{ci}
{n_boot}
{units}
{seed}
{order}
{logistic}
{lowess}
{robust}
{logx}
{xy_partial}
{truncate}
{xy_jitter}
{scatter_line_kws}
facet_kws : dict
Dictionary of keyword arguments for :class:`FacetGrid`.
See Also
--------
regplot : Plot data and a conditional model fit.
FacetGrid : Subplot grid for plotting conditional relationships.
pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
``kind="reg"``).
Notes
-----
{regplot_vs_lmplot}
Examples
--------
.. include:: ../docstrings/lmplot.rst
""").format(**_regression_docs)
def regplot(
data=None, *, x=None, y=None,
x_estimator=None, x_bins=None, x_ci="ci",
scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None,
seed=None, order=1, logistic=False, lowess=False, robust=False,
logx=False, x_partial=None, y_partial=None,
truncate=True, dropna=True, x_jitter=None, y_jitter=None,
label=None, color=None, marker="o",
scatter_kws=None, line_kws=None, ax=None
):
plotter = _RegressionPlotter(x, y, data, x_estimator, x_bins, x_ci,
scatter, fit_reg, ci, n_boot, units, seed,
order, logistic, lowess, robust, logx,
x_partial, y_partial, truncate, dropna,
x_jitter, y_jitter, color, label)
if ax is None:
ax = plt.gca()
scatter_kws = {} if scatter_kws is None else copy.copy(scatter_kws)
scatter_kws["marker"] = marker
line_kws = {} if line_kws is None else copy.copy(line_kws)
plotter.plot(ax, scatter_kws, line_kws)
return ax
regplot.__doc__ = dedent("""\
Plot data and a linear regression model fit.
{model_api}
Parameters
----------
x, y: string, series, or vector array
Input variables. If strings, these should correspond with column names
in ``data``. When pandas objects are used, axes will be labeled with
the series name.
{data}
{x_estimator}
{x_bins}
{x_ci}
{scatter}
{fit_reg}
{ci}
{n_boot}
{units}
{seed}
{order}
{logistic}
{lowess}
{robust}
{logx}
{xy_partial}
{truncate}
{xy_jitter}
label : string
Label to apply to either the scatterplot or regression line (if
``scatter`` is ``False``) for use in a legend.
color : matplotlib color
Color to apply to all plot elements; will be superseded by colors
passed in ``scatter_kws`` or ``line_kws``.
marker : matplotlib marker code
Marker to use for the scatterplot glyphs.
{scatter_line_kws}
ax : matplotlib Axes, optional
Axes object to draw the plot onto, otherwise uses the current Axes.
Returns
-------
ax : matplotlib Axes
The Axes object containing the plot.
See Also
--------
lmplot : Combine :func:`regplot` and :class:`FacetGrid` to plot multiple
linear relationships in a dataset.
jointplot : Combine :func:`regplot` and :class:`JointGrid` (when used with
``kind="reg"``).
pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
``kind="reg"``).
residplot : Plot the residuals of a linear regression model.
Notes
-----
{regplot_vs_lmplot}
It's also easy to combine :func:`regplot` and :class:`JointGrid` or
:class:`PairGrid` through the :func:`jointplot` and :func:`pairplot`
functions, although these do not directly accept all of :func:`regplot`'s
parameters.
Examples
--------
.. include:: ../docstrings/regplot.rst
""").format(**_regression_docs)
def residplot(
data=None, *, x=None, y=None,
x_partial=None, y_partial=None, lowess=False,
order=1, robust=False, dropna=True, label=None, color=None,
scatter_kws=None, line_kws=None, ax=None
):
"""Plot the residuals of a linear regression.
This function will regress y on x (possibly as a robust or polynomial
regression) and then draw a scatterplot of the residuals. You can
optionally fit a lowess smoother to the residual plot, which can
help in determining if there is structure to the residuals.
Parameters
----------
data : DataFrame, optional
DataFrame to use if `x` and `y` are column names.
x : vector or string
Data or column name in `data` for the predictor variable.
y : vector or string
Data or column name in `data` for the response variable.
{x, y}_partial : vectors or string(s) , optional
These variables are treated as confounding and are removed from
the `x` or `y` variables before plotting.
lowess : boolean, optional
Fit a lowess smoother to the residual scatterplot.
order : int, optional
Order of the polynomial to fit when calculating the residuals.
robust : boolean, optional
Fit a robust linear regression when calculating the residuals.
dropna : boolean, optional
If True, ignore observations with missing data when fitting and
plotting.
label : string, optional
Label that will be used in any plot legends.
color : matplotlib color, optional
Color to use for all elements of the plot.
{scatter, line}_kws : dictionaries, optional
Additional keyword arguments passed to scatter() and plot() for drawing
the components of the plot.
ax : matplotlib axis, optional
Plot into this axis, otherwise grab the current axis or make a new
one if not existing.
Returns
-------
ax: matplotlib axes
Axes with the regression plot.
See Also
--------
regplot : Plot a simple linear regression model.
jointplot : Draw a :func:`residplot` with univariate marginal distributions
(when used with ``kind="resid"``).
Examples
--------
.. include:: ../docstrings/residplot.rst
"""
plotter = _RegressionPlotter(x, y, data, ci=None,
order=order, robust=robust,
x_partial=x_partial, y_partial=y_partial,
dropna=dropna, color=color, label=label)
if ax is None:
ax = plt.gca()
# Calculate the residual from a linear regression
_, yhat, _ = plotter.fit_regression(grid=plotter.x)
plotter.y = plotter.y - yhat
# Set the regression option on the plotter
if lowess:
plotter.lowess = True
else:
plotter.fit_reg = False
# Plot a horizontal line at 0
ax.axhline(0, ls=":", c=".2")
# Draw the scatterplot
scatter_kws = {} if scatter_kws is None else scatter_kws.copy()
line_kws = {} if line_kws is None else line_kws.copy()
plotter.plot(ax, scatter_kws, line_kws)
return ax