190 lines
7.8 KiB
Raw Permalink Normal View History

2024-10-02 22:15:59 +04:00
# This file is part of Patsy
# Copyright (C) 2012-2013 Nathaniel Smith <njs@pobox.com>
# See file LICENSE.txt for license information.
from __future__ import print_function
import numpy as np
from patsy.state import Center, Standardize, center
from patsy.util import atleast_2d_column_default
def check_stateful(cls, accepts_multicolumn, input, output, *args, **kwargs):
input = np.asarray(input)
output = np.asarray(output)
test_cases = [
# List input, one chunk
([input], output),
# Scalar input, many chunks
(input, output),
# List input, many chunks:
([[n] for n in input], output),
# 0-d array input, many chunks:
([np.array(n) for n in input], output),
# 1-d array input, one chunk:
([np.array(input)], output),
# 1-d array input, many chunks:
([np.array([n]) for n in input], output),
# 2-d but 1 column input, one chunk:
([np.array(input)[:, None]], atleast_2d_column_default(output)),
# 2-d but 1 column input, many chunks:
([np.array([[n]]) for n in input], atleast_2d_column_default(output)),
if accepts_multicolumn:
# 2-d array input, one chunk:
test_cases += [
([np.column_stack((input, input[::-1]))],
np.column_stack((output, output[::-1]))),
# 2-d array input, many chunks:
([np.array([[input[i], input[-i-1]]]) for i in range(len(input))],
np.column_stack((output, output[::-1]))),
from patsy.util import have_pandas
if have_pandas:
import pandas
pandas_type = (pandas.Series, pandas.DataFrame)
pandas_index = np.linspace(0, 1, num=len(input))
# 1d and 2d here refer to the dimensionality of the input
if output.ndim == 1:
output_1d = pandas.Series(output, index=pandas_index)
output_1d = pandas.DataFrame(output, index=pandas_index)
test_cases += [
# Series input, one chunk
([pandas.Series(input, index=pandas_index)], output_1d),
# Series input, many chunks
([pandas.Series([x], index=[idx])
for (x, idx) in zip(input, pandas_index)],
if accepts_multicolumn:
input_2d_2col = np.column_stack((input, input[::-1]))
output_2d_2col = np.column_stack((output, output[::-1]))
output_2col_dataframe = pandas.DataFrame(output_2d_2col,
test_cases += [
# DataFrame input, one chunk
([pandas.DataFrame(input_2d_2col, index=pandas_index)],
# DataFrame input, many chunks
([pandas.DataFrame([input_2d_2col[i, :]],
for i in range(len(input))],
for input_obj, output_obj in test_cases:
t = cls()
for input_chunk in input_obj:
t.memorize_chunk(input_chunk, *args, **kwargs)
all_outputs = []
for input_chunk in input_obj:
output_chunk = t.transform(input_chunk, *args, **kwargs)
if input.ndim == output.ndim:
assert output_chunk.ndim == np.asarray(input_chunk).ndim
if have_pandas and isinstance(all_outputs[0], pandas_type):
all_output1 = pandas.concat(all_outputs)
assert np.array_equal(all_output1.index, pandas_index)
elif all_outputs[0].ndim == 0:
all_output1 = np.array(all_outputs)
elif all_outputs[0].ndim == 1:
all_output1 = np.concatenate(all_outputs)
all_output1 = np.vstack(all_outputs)
assert all_output1.shape[0] == len(input)
# output_obj_reshaped = np.asarray(output_obj).reshape(all_output1.shape)
# assert np.allclose(all_output1, output_obj_reshaped)
assert np.allclose(all_output1, output_obj)
if np.asarray(input_obj[0]).ndim == 0:
all_input = np.array(input_obj)
elif have_pandas and isinstance(input_obj[0], pandas_type):
# handles both Series and DataFrames
all_input = pandas.concat(input_obj)
elif np.asarray(input_obj[0]).ndim == 1:
# Don't use vstack, because that would turn this into a 1xn
# matrix:
all_input = np.concatenate(input_obj)
all_input = np.vstack(input_obj)
all_output2 = t.transform(all_input, *args, **kwargs)
if have_pandas and isinstance(input_obj[0], pandas_type):
assert np.array_equal(all_output2.index, pandas_index)
if input.ndim == output.ndim:
assert all_output2.ndim == all_input.ndim
assert np.allclose(all_output2, output_obj)
def test_Center():
check_stateful(Center, True, [1, 2, 3], [-1, 0, 1])
check_stateful(Center, True, [1, 2, 1, 2], [-0.5, 0.5, -0.5, 0.5])
check_stateful(Center, True,
[1.3, -10.1, 7.0, 12.0],
[-1.25, -12.65, 4.45, 9.45])
def test_stateful_transform_wrapper():
assert np.allclose(center([1, 2, 3]), [-1, 0, 1])
assert np.allclose(center([1, 2, 1, 2]), [-0.5, 0.5, -0.5, 0.5])
assert center([1.0, 2.0, 3.0]).dtype == np.dtype(float)
assert (center(np.array([1.0, 2.0, 3.0], dtype=np.float32)).dtype
== np.dtype(np.float32))
assert center([1, 2, 3]).dtype == np.dtype(float)
from patsy.util import have_pandas
if have_pandas:
import pandas
s = pandas.Series([1, 2, 3], index=["a", "b", "c"])
df = pandas.DataFrame([[1, 2], [2, 4], [3, 6]],
columns=["x1", "x2"],
index=[10, 20, 30])
s_c = center(s)
assert isinstance(s_c, pandas.Series)
assert np.array_equal(s_c.index, ["a", "b", "c"])
assert np.allclose(s_c, [-1, 0, 1])
df_c = center(df)
assert isinstance(df_c, pandas.DataFrame)
assert np.array_equal(df_c.index, [10, 20, 30])
assert np.array_equal(df_c.columns, ["x1", "x2"])
assert np.allclose(df_c, [[-1, -2], [0, 0], [1, 2]])
def test_Standardize():
check_stateful(Standardize, True, [1, -1], [1, -1])
check_stateful(Standardize, True, [12, 10], [1, -1])
check_stateful(Standardize, True,
[12, 11, 10],
[np.sqrt(3./2), 0, -np.sqrt(3./2)])
check_stateful(Standardize, True,
[12.0, 11.0, 10.0],
[np.sqrt(3./2), 0, -np.sqrt(3./2)])
# XX: see the comment in Standardize.transform about why this doesn't
# work:
# check_stateful(Standardize,
# [12.0+0j, 11.0+0j, 10.0],
# [np.sqrt(3./2)+0j, 0, -np.sqrt(3./2)])
r20 = list(range(20))
check_stateful(Standardize, True, [1, -1], [np.sqrt(2)/2, -np.sqrt(2)/2],
check_stateful(Standardize, True,
list((np.arange(20) - 9.5) / 5.7662812973353983),
check_stateful(Standardize, True,
list((np.arange(20) - 9.5) / 5.9160797830996161),
check_stateful(Standardize, True,
list((np.arange(20) - 9.5)),
rescale=False, ddof=1)
check_stateful(Standardize, True,
list(np.arange(20) / 5.9160797830996161),
center=False, ddof=1)
check_stateful(Standardize, True,
center=False, rescale=False, ddof=1)