205 lines
7.2 KiB
Python
205 lines
7.2 KiB
Python
"""
|
|
Authors: Josef Perktold, Skipper Seabold, Denis A. Engemann
|
|
"""
|
|
from statsmodels.compat.python import lrange
|
|
|
|
import numpy as np
|
|
|
|
from statsmodels.graphics.plottools import rainbow
|
|
import statsmodels.graphics.utils as utils
|
|
|
|
|
|
def interaction_plot(x, trace, response, func="mean", ax=None, plottype='b',
|
|
xlabel=None, ylabel=None, colors=None, markers=None,
|
|
linestyles=None, legendloc='best', legendtitle=None,
|
|
**kwargs):
|
|
"""
|
|
Interaction plot for factor level statistics.
|
|
|
|
Note. If categorial factors are supplied levels will be internally
|
|
recoded to integers. This ensures matplotlib compatibility. Uses
|
|
a DataFrame to calculate an `aggregate` statistic for each level of the
|
|
factor or group given by `trace`.
|
|
|
|
Parameters
|
|
----------
|
|
x : array_like
|
|
The `x` factor levels constitute the x-axis. If a `pandas.Series` is
|
|
given its name will be used in `xlabel` if `xlabel` is None.
|
|
trace : array_like
|
|
The `trace` factor levels will be drawn as lines in the plot.
|
|
If `trace` is a `pandas.Series` its name will be used as the
|
|
`legendtitle` if `legendtitle` is None.
|
|
response : array_like
|
|
The reponse or dependent variable. If a `pandas.Series` is given
|
|
its name will be used in `ylabel` if `ylabel` is None.
|
|
func : function
|
|
Anything accepted by `pandas.DataFrame.aggregate`. This is applied to
|
|
the response variable grouped by the trace levels.
|
|
ax : axes, optional
|
|
Matplotlib axes instance
|
|
plottype : str {'line', 'scatter', 'both'}, optional
|
|
The type of plot to return. Can be 'l', 's', or 'b'
|
|
xlabel : str, optional
|
|
Label to use for `x`. Default is 'X'. If `x` is a `pandas.Series` it
|
|
will use the series names.
|
|
ylabel : str, optional
|
|
Label to use for `response`. Default is 'func of response'. If
|
|
`response` is a `pandas.Series` it will use the series names.
|
|
colors : list, optional
|
|
If given, must have length == number of levels in trace.
|
|
markers : list, optional
|
|
If given, must have length == number of levels in trace
|
|
linestyles : list, optional
|
|
If given, must have length == number of levels in trace.
|
|
legendloc : {None, str, int}
|
|
Location passed to the legend command.
|
|
legendtitle : {None, str}
|
|
Title of the legend.
|
|
**kwargs
|
|
These will be passed to the plot command used either plot or scatter.
|
|
If you want to control the overall plotting options, use kwargs.
|
|
|
|
Returns
|
|
-------
|
|
Figure
|
|
The figure given by `ax.figure` or a new instance.
|
|
|
|
Examples
|
|
--------
|
|
>>> import numpy as np
|
|
>>> np.random.seed(12345)
|
|
>>> weight = np.random.randint(1,4,size=60)
|
|
>>> duration = np.random.randint(1,3,size=60)
|
|
>>> days = np.log(np.random.randint(1,30, size=60))
|
|
>>> fig = interaction_plot(weight, duration, days,
|
|
... colors=['red','blue'], markers=['D','^'], ms=10)
|
|
>>> import matplotlib.pyplot as plt
|
|
>>> plt.show()
|
|
|
|
.. plot::
|
|
|
|
import numpy as np
|
|
from statsmodels.graphics.factorplots import interaction_plot
|
|
np.random.seed(12345)
|
|
weight = np.random.randint(1,4,size=60)
|
|
duration = np.random.randint(1,3,size=60)
|
|
days = np.log(np.random.randint(1,30, size=60))
|
|
fig = interaction_plot(weight, duration, days,
|
|
colors=['red','blue'], markers=['D','^'], ms=10)
|
|
import matplotlib.pyplot as plt
|
|
#plt.show()
|
|
"""
|
|
|
|
from pandas import DataFrame
|
|
fig, ax = utils.create_mpl_ax(ax)
|
|
|
|
response_name = ylabel or getattr(response, 'name', 'response')
|
|
func_name = getattr(func, "__name__", str(func))
|
|
ylabel = f'{func_name} of {response_name}'
|
|
xlabel = xlabel or getattr(x, 'name', 'X')
|
|
legendtitle = legendtitle or getattr(trace, 'name', 'Trace')
|
|
|
|
ax.set_ylabel(ylabel)
|
|
ax.set_xlabel(xlabel)
|
|
|
|
x_values = x_levels = None
|
|
if isinstance(x[0], str):
|
|
x_levels = [l for l in np.unique(x)]
|
|
x_values = lrange(len(x_levels))
|
|
x = _recode(x, dict(zip(x_levels, x_values)))
|
|
|
|
data = DataFrame(dict(x=x, trace=trace, response=response))
|
|
plot_data = data.groupby(['trace', 'x']).aggregate(func).reset_index()
|
|
|
|
# return data
|
|
# check plot args
|
|
n_trace = len(plot_data['trace'].unique())
|
|
|
|
linestyles = ['-'] * n_trace if linestyles is None else linestyles
|
|
markers = ['.'] * n_trace if markers is None else markers
|
|
colors = rainbow(n_trace) if colors is None else colors
|
|
|
|
if len(linestyles) != n_trace:
|
|
raise ValueError("Must be a linestyle for each trace level")
|
|
if len(markers) != n_trace:
|
|
raise ValueError("Must be a marker for each trace level")
|
|
if len(colors) != n_trace:
|
|
raise ValueError("Must be a color for each trace level")
|
|
|
|
if plottype == 'both' or plottype == 'b':
|
|
for i, (values, group) in enumerate(plot_data.groupby('trace')):
|
|
# trace label
|
|
label = str(group['trace'].values[0])
|
|
ax.plot(group['x'], group['response'], color=colors[i],
|
|
marker=markers[i], label=label,
|
|
linestyle=linestyles[i], **kwargs)
|
|
elif plottype == 'line' or plottype == 'l':
|
|
for i, (values, group) in enumerate(plot_data.groupby('trace')):
|
|
# trace label
|
|
label = str(group['trace'].values[0])
|
|
ax.plot(group['x'], group['response'], color=colors[i],
|
|
label=label, linestyle=linestyles[i], **kwargs)
|
|
elif plottype == 'scatter' or plottype == 's':
|
|
for i, (values, group) in enumerate(plot_data.groupby('trace')):
|
|
# trace label
|
|
label = str(group['trace'].values[0])
|
|
ax.scatter(group['x'], group['response'], color=colors[i],
|
|
label=label, marker=markers[i], **kwargs)
|
|
|
|
else:
|
|
raise ValueError("Plot type %s not understood" % plottype)
|
|
ax.legend(loc=legendloc, title=legendtitle)
|
|
ax.margins(.1)
|
|
|
|
if all([x_levels, x_values]):
|
|
ax.set_xticks(x_values)
|
|
ax.set_xticklabels(x_levels)
|
|
return fig
|
|
|
|
|
|
def _recode(x, levels):
|
|
""" Recode categorial data to int factor.
|
|
|
|
Parameters
|
|
----------
|
|
x : array_like
|
|
array like object supporting with numpy array methods of categorially
|
|
coded data.
|
|
levels : dict
|
|
mapping of labels to integer-codings
|
|
|
|
Returns
|
|
-------
|
|
out : instance numpy.ndarray
|
|
"""
|
|
from pandas import Series
|
|
name = None
|
|
index = None
|
|
|
|
if isinstance(x, Series):
|
|
name = x.name
|
|
index = x.index
|
|
x = x.values
|
|
|
|
if x.dtype.type not in [np.str_, np.object_]:
|
|
raise ValueError('This is not a categorial factor.'
|
|
' Array of str type required.')
|
|
|
|
elif not isinstance(levels, dict):
|
|
raise ValueError('This is not a valid value for levels.'
|
|
' Dict required.')
|
|
|
|
elif not (np.unique(x) == np.unique(list(levels.keys()))).all():
|
|
raise ValueError('The levels do not match the array values.')
|
|
|
|
else:
|
|
out = np.empty(x.shape[0], dtype=int)
|
|
for level, coding in levels.items():
|
|
out[x == level] = coding
|
|
|
|
if name:
|
|
out = Series(out, name=name, index=index)
|
|
|
|
return out
|