148 lines
3.9 KiB
Python
148 lines
3.9 KiB
Python
"""Helper functions for graphics with Matplotlib."""
|
|
from statsmodels.compat.python import lrange
|
|
|
|
__all__ = ['create_mpl_ax', 'create_mpl_fig']
|
|
|
|
|
|
def _import_mpl():
|
|
"""This function is not needed outside this utils module."""
|
|
try:
|
|
import matplotlib.pyplot as plt
|
|
except:
|
|
raise ImportError("Matplotlib is not found.")
|
|
|
|
return plt
|
|
|
|
|
|
def create_mpl_ax(ax=None):
|
|
"""Helper function for when a single plot axis is needed.
|
|
|
|
Parameters
|
|
----------
|
|
ax : AxesSubplot, optional
|
|
If given, this subplot is used to plot in instead of a new figure being
|
|
created.
|
|
|
|
Returns
|
|
-------
|
|
fig : Figure
|
|
If `ax` is None, the created figure. Otherwise the figure to which
|
|
`ax` is connected.
|
|
ax : AxesSubplot
|
|
The created axis if `ax` is None, otherwise the axis that was passed
|
|
in.
|
|
|
|
Notes
|
|
-----
|
|
This function imports `matplotlib.pyplot`, which should only be done to
|
|
create (a) figure(s) with ``plt.figure``. All other functionality exposed
|
|
by the pyplot module can and should be imported directly from its
|
|
Matplotlib module.
|
|
|
|
See Also
|
|
--------
|
|
create_mpl_fig
|
|
|
|
Examples
|
|
--------
|
|
A plotting function has a keyword ``ax=None``. Then calls:
|
|
|
|
>>> from statsmodels.graphics import utils
|
|
>>> fig, ax = utils.create_mpl_ax(ax)
|
|
"""
|
|
if ax is None:
|
|
plt = _import_mpl()
|
|
fig = plt.figure()
|
|
ax = fig.add_subplot(111)
|
|
else:
|
|
fig = ax.figure
|
|
|
|
return fig, ax
|
|
|
|
|
|
def create_mpl_fig(fig=None, figsize=None):
|
|
"""Helper function for when multiple plot axes are needed.
|
|
|
|
Those axes should be created in the functions they are used in, with
|
|
``fig.add_subplot()``.
|
|
|
|
Parameters
|
|
----------
|
|
fig : Figure, optional
|
|
If given, this figure is simply returned. Otherwise a new figure is
|
|
created.
|
|
|
|
Returns
|
|
-------
|
|
Figure
|
|
If `fig` is None, the created figure. Otherwise the input `fig` is
|
|
returned.
|
|
|
|
See Also
|
|
--------
|
|
create_mpl_ax
|
|
"""
|
|
if fig is None:
|
|
plt = _import_mpl()
|
|
fig = plt.figure(figsize=figsize)
|
|
|
|
return fig
|
|
|
|
|
|
def maybe_name_or_idx(idx, model):
|
|
"""
|
|
Give a name or an integer and return the name and integer location of the
|
|
column in a design matrix.
|
|
"""
|
|
if idx is None:
|
|
idx = lrange(model.exog.shape[1])
|
|
if isinstance(idx, int):
|
|
exog_name = model.exog_names[idx]
|
|
exog_idx = idx
|
|
# anticipate index as list and recurse
|
|
elif isinstance(idx, (tuple, list)):
|
|
exog_name = []
|
|
exog_idx = []
|
|
for item in idx:
|
|
exog_name_item, exog_idx_item = maybe_name_or_idx(item, model)
|
|
exog_name.append(exog_name_item)
|
|
exog_idx.append(exog_idx_item)
|
|
else: # assume we've got a string variable
|
|
exog_name = idx
|
|
exog_idx = model.exog_names.index(idx)
|
|
|
|
return exog_name, exog_idx
|
|
|
|
|
|
def get_data_names(series_or_dataframe):
|
|
"""
|
|
Input can be an array or pandas-like. Will handle 1d array-like but not
|
|
2d. Returns a str for 1d data or a list of strings for 2d data.
|
|
"""
|
|
names = getattr(series_or_dataframe, 'name', None)
|
|
if not names:
|
|
names = getattr(series_or_dataframe, 'columns', None)
|
|
if not names:
|
|
shape = getattr(series_or_dataframe, 'shape', [1])
|
|
nvars = 1 if len(shape) == 1 else series_or_dataframe.shape[1]
|
|
names = ["X%d" for _ in range(nvars)]
|
|
if nvars == 1:
|
|
names = names[0]
|
|
else:
|
|
names = names.tolist()
|
|
return names
|
|
|
|
|
|
def annotate_axes(index, labels, points, offset_points, size, ax, **kwargs):
|
|
"""
|
|
Annotate Axes with labels, points, offset_points according to the
|
|
given index.
|
|
"""
|
|
for i in index:
|
|
label = labels[i]
|
|
point = points[i]
|
|
offset = offset_points[i]
|
|
ax.annotate(label, point, xytext=offset, textcoords="offset points",
|
|
size=size, **kwargs)
|
|
return ax
|