# This file is part of Patsy # Copyright (C) 2011-2012 Nathaniel Smith # See file LICENSE.txt for license information. # This file defines the ModelDesc class, which describes a model at a high # level, as a list of interactions of factors. It also has the code to convert # a formula parse tree (from patsy.parse_formula) into a ModelDesc. from __future__ import print_function import six from patsy import PatsyError from patsy.parse_formula import ParseNode, Token, parse_formula from patsy.eval import EvalEnvironment, EvalFactor from patsy.util import uniqueify_list from patsy.util import repr_pretty_delegate, repr_pretty_impl from patsy.util import no_pickling, assert_no_pickling # These are made available in the patsy.* namespace __all__ = ["Term", "ModelDesc", "INTERCEPT"] # One might think it would make more sense for 'factors' to be a set, rather # than a tuple-with-guaranteed-unique-entries-that-compares-like-a-set. The # reason we do it this way is that it preserves the order that the user typed # and is expecting, which then ends up producing nicer names in our final # output, nicer column ordering, etc. (A similar comment applies to the # ordering of terms in ModelDesc objects as a whole.) class Term(object): """The interaction between a collection of factor objects. This is one of the basic types used in representing formulas, and corresponds to an expression like ``"a:b:c"`` in a formula string. For details, see :ref:`formulas` and :ref:`expert-model-specification`. Terms are hashable and compare by value. Attributes: .. attribute:: factors A tuple of factor objects. """ def __init__(self, factors): self.factors = tuple(uniqueify_list(factors)) def __eq__(self, other): return (isinstance(other, Term) and frozenset(other.factors) == frozenset(self.factors)) def __ne__(self, other): return not self == other def __hash__(self): return hash((Term, frozenset(self.factors))) __repr__ = repr_pretty_delegate def _repr_pretty_(self, p, cycle): assert not cycle repr_pretty_impl(p, self, [list(self.factors)]) def name(self): """Return a human-readable name for this term.""" if self.factors: return ":".join([f.name() for f in self.factors]) else: return "Intercept" __getstate__ = no_pickling INTERCEPT = Term([]) class _MockFactor(object): def __init__(self, name): self._name = name def name(self): return self._name def test_Term(): assert Term([1, 2, 1]).factors == (1, 2) assert Term([1, 2]) == Term([2, 1]) assert hash(Term([1, 2])) == hash(Term([2, 1])) f1 = _MockFactor("a") f2 = _MockFactor("b") assert Term([f1, f2]).name() == "a:b" assert Term([f2, f1]).name() == "b:a" assert Term([]).name() == "Intercept" assert_no_pickling(Term([])) class ModelDesc(object): """A simple container representing the termlists parsed from a formula. This is a simple container object which has exactly the same representational power as a formula string, but is a Python object instead. You can construct one by hand, and pass it to functions like :func:`dmatrix` or :func:`incr_dbuilder` that are expecting a formula string, but without having to do any messy string manipulation. For details see :ref:`expert-model-specification`. Attributes: .. attribute:: lhs_termlist rhs_termlist Two termlists representing the left- and right-hand sides of a formula, suitable for passing to :func:`design_matrix_builders`. """ def __init__(self, lhs_termlist, rhs_termlist): self.lhs_termlist = uniqueify_list(lhs_termlist) self.rhs_termlist = uniqueify_list(rhs_termlist) __repr__ = repr_pretty_delegate def _repr_pretty_(self, p, cycle): assert not cycle return repr_pretty_impl(p, self, [], [("lhs_termlist", self.lhs_termlist), ("rhs_termlist", self.rhs_termlist)]) def describe(self): """Returns a human-readable representation of this :class:`ModelDesc` in pseudo-formula notation. .. warning:: There is no guarantee that the strings returned by this function can be parsed as formulas. They are best-effort descriptions intended for human users. However, if this ModelDesc was created by parsing a formula, then it should work in practice. If you *really* have to. """ def term_code(term): if term == INTERCEPT: return "1" else: return term.name() result = " + ".join([term_code(term) for term in self.lhs_termlist]) if result: result += " ~ " else: result += "~ " if self.rhs_termlist == [INTERCEPT]: result += term_code(INTERCEPT) else: term_names = [] if INTERCEPT not in self.rhs_termlist: term_names.append("0") term_names += [term_code(term) for term in self.rhs_termlist if term != INTERCEPT] result += " + ".join(term_names) return result @classmethod def from_formula(cls, tree_or_string): """Construct a :class:`ModelDesc` from a formula string. :arg tree_or_string: A formula string. (Or an unevaluated formula parse tree, but the API for generating those isn't public yet. Shh, it can be our secret.) :returns: A new :class:`ModelDesc`. """ if isinstance(tree_or_string, ParseNode): tree = tree_or_string else: tree = parse_formula(tree_or_string) value = Evaluator().eval(tree, require_evalexpr=False) assert isinstance(value, cls) return value __getstate__ = no_pickling def test_ModelDesc(): f1 = _MockFactor("a") f2 = _MockFactor("b") m = ModelDesc([INTERCEPT, Term([f1])], [Term([f1]), Term([f1, f2])]) assert m.lhs_termlist == [INTERCEPT, Term([f1])] assert m.rhs_termlist == [Term([f1]), Term([f1, f2])] print(m.describe()) assert m.describe() == "1 + a ~ 0 + a + a:b" assert_no_pickling(m) assert ModelDesc([], []).describe() == "~ 0" assert ModelDesc([INTERCEPT], []).describe() == "1 ~ 0" assert ModelDesc([INTERCEPT], [INTERCEPT]).describe() == "1 ~ 1" assert (ModelDesc([INTERCEPT], [INTERCEPT, Term([f2])]).describe() == "1 ~ b") def test_ModelDesc_from_formula(): for input in ("y ~ x", parse_formula("y ~ x")): md = ModelDesc.from_formula(input) assert md.lhs_termlist == [Term([EvalFactor("y")]),] assert md.rhs_termlist == [INTERCEPT, Term([EvalFactor("x")])] class IntermediateExpr(object): "This class holds an intermediate result while we're evaluating a tree." def __init__(self, intercept, intercept_origin, intercept_removed, terms): self.intercept = intercept self.intercept_origin = intercept_origin self.intercept_removed =intercept_removed self.terms = tuple(uniqueify_list(terms)) if self.intercept: assert self.intercept_origin assert not (self.intercept and self.intercept_removed) __repr__ = repr_pretty_delegate def _pretty_repr_(self, p, cycle): # pragma: no cover assert not cycle return repr_pretty_impl(p, self, [self.intercept, self.intercept_origin, self.intercept_removed, self.terms]) __getstate__ = no_pickling def _maybe_add_intercept(doit, terms): if doit: return (INTERCEPT,) + terms else: return terms def _eval_any_tilde(evaluator, tree): exprs = [evaluator.eval(arg) for arg in tree.args] if len(exprs) == 1: # Formula was like: "~ foo" # We pretend that instead it was like: "0 ~ foo" exprs.insert(0, IntermediateExpr(False, None, True, [])) assert len(exprs) == 2 # Note that only the RHS gets an implicit intercept: return ModelDesc(_maybe_add_intercept(exprs[0].intercept, exprs[0].terms), _maybe_add_intercept(not exprs[1].intercept_removed, exprs[1].terms)) def _eval_binary_plus(evaluator, tree): left_expr = evaluator.eval(tree.args[0]) if tree.args[1].type == "ZERO": return IntermediateExpr(False, None, True, left_expr.terms) else: right_expr = evaluator.eval(tree.args[1]) if right_expr.intercept: return IntermediateExpr(True, right_expr.intercept_origin, False, left_expr.terms + right_expr.terms) else: return IntermediateExpr(left_expr.intercept, left_expr.intercept_origin, left_expr.intercept_removed, left_expr.terms + right_expr.terms) def _eval_binary_minus(evaluator, tree): left_expr = evaluator.eval(tree.args[0]) if tree.args[1].type == "ZERO": return IntermediateExpr(True, tree.args[1], False, left_expr.terms) elif tree.args[1].type == "ONE": return IntermediateExpr(False, None, True, left_expr.terms) else: right_expr = evaluator.eval(tree.args[1]) terms = [term for term in left_expr.terms if term not in right_expr.terms] if right_expr.intercept: return IntermediateExpr(False, None, True, terms) else: return IntermediateExpr(left_expr.intercept, left_expr.intercept_origin, left_expr.intercept_removed, terms) def _check_interactable(expr): if expr.intercept: raise PatsyError("intercept term cannot interact with " "anything else", expr.intercept_origin) def _interaction(left_expr, right_expr): for expr in (left_expr, right_expr): _check_interactable(expr) terms = [] for l_term in left_expr.terms: for r_term in right_expr.terms: terms.append(Term(l_term.factors + r_term.factors)) return IntermediateExpr(False, None, False, terms) def _eval_binary_prod(evaluator, tree): exprs = [evaluator.eval(arg) for arg in tree.args] return IntermediateExpr(False, None, False, exprs[0].terms + exprs[1].terms + _interaction(*exprs).terms) # Division (nesting) is right-ward distributive: # a / (b + c) -> a/b + a/c -> a + a:b + a:c # But left-ward, in S/R it has a quirky behavior: # (a + b)/c -> a + b + a:b:c # This is because it's meaningless for a factor to be "nested" under two # different factors. (This is documented in Chambers and Hastie (page 30) as a # "Slightly more subtle..." rule, with no further elaboration. Hopefully we # will do better.) def _eval_binary_div(evaluator, tree): left_expr = evaluator.eval(tree.args[0]) right_expr = evaluator.eval(tree.args[1]) terms = list(left_expr.terms) _check_interactable(left_expr) # Build a single giant combined term for everything on the left: left_factors = [] for term in left_expr.terms: left_factors += list(term.factors) left_combined_expr = IntermediateExpr(False, None, False, [Term(left_factors)]) # Then interact it with everything on the right: terms += list(_interaction(left_combined_expr, right_expr).terms) return IntermediateExpr(False, None, False, terms) def _eval_binary_interact(evaluator, tree): exprs = [evaluator.eval(arg) for arg in tree.args] return _interaction(*exprs) def _eval_binary_power(evaluator, tree): left_expr = evaluator.eval(tree.args[0]) _check_interactable(left_expr) power = -1 if tree.args[1].type in ("ONE", "NUMBER"): expr = tree.args[1].token.extra try: power = int(expr) except ValueError: pass if power < 1: raise PatsyError("'**' requires a positive integer", tree.args[1]) all_terms = left_expr.terms big_expr = left_expr # Small optimization: (a + b)**100 is just the same as (a + b)**2. power = min(len(left_expr.terms), power) for i in range(1, power): big_expr = _interaction(left_expr, big_expr) all_terms = all_terms + big_expr.terms return IntermediateExpr(False, None, False, all_terms) def _eval_unary_plus(evaluator, tree): return evaluator.eval(tree.args[0]) def _eval_unary_minus(evaluator, tree): if tree.args[0].type == "ZERO": return IntermediateExpr(True, tree.origin, False, []) elif tree.args[0].type == "ONE": return IntermediateExpr(False, None, True, []) else: raise PatsyError("Unary minus can only be applied to 1 or 0", tree) def _eval_zero(evaluator, tree): return IntermediateExpr(False, None, True, []) def _eval_one(evaluator, tree): return IntermediateExpr(True, tree.origin, False, []) def _eval_number(evaluator, tree): raise PatsyError("numbers besides '0' and '1' are " "only allowed with **", tree) def _eval_python_expr(evaluator, tree): factor = EvalFactor(tree.token.extra, origin=tree.origin) return IntermediateExpr(False, None, False, [Term([factor])]) class Evaluator(object): def __init__(self): self._evaluators = {} self.add_op("~", 2, _eval_any_tilde) self.add_op("~", 1, _eval_any_tilde) self.add_op("+", 2, _eval_binary_plus) self.add_op("-", 2, _eval_binary_minus) self.add_op("*", 2, _eval_binary_prod) self.add_op("/", 2, _eval_binary_div) self.add_op(":", 2, _eval_binary_interact) self.add_op("**", 2, _eval_binary_power) self.add_op("+", 1, _eval_unary_plus) self.add_op("-", 1, _eval_unary_minus) self.add_op("ZERO", 0, _eval_zero) self.add_op("ONE", 0, _eval_one) self.add_op("NUMBER", 0, _eval_number) self.add_op("PYTHON_EXPR", 0, _eval_python_expr) # Not used by Patsy -- provided for the convenience of eventual # user-defined operators. self.stash = {} # This should not be considered a public API yet (to use for actually # adding new operator semantics) because I wrote in some of the relevant # code sort of speculatively, but it isn't actually tested. def add_op(self, op, arity, evaluator): self._evaluators[op, arity] = evaluator def eval(self, tree, require_evalexpr=True): result = None assert isinstance(tree, ParseNode) key = (tree.type, len(tree.args)) if key not in self._evaluators: raise PatsyError("I don't know how to evaluate this " "'%s' operator" % (tree.type,), tree.token) result = self._evaluators[key](self, tree) if require_evalexpr and not isinstance(result, IntermediateExpr): if isinstance(result, ModelDesc): raise PatsyError("~ can only be used once, and " "only at the top level", tree) else: raise PatsyError("custom operator returned an " "object that I don't know how to " "handle", tree) return result ############# _eval_tests = { "": (True, []), " ": (True, []), " \n ": (True, []), "a": (True, ["a"]), "1": (True, []), "0": (False, []), "- 1": (False, []), "- 0": (True, []), "+ 1": (True, []), "+ 0": (False, []), "0 + 1": (True, []), "1 + 0": (False, []), "1 - 0": (True, []), "0 - 1": (False, []), "1 + a": (True, ["a"]), "0 + a": (False, ["a"]), "a - 1": (False, ["a"]), "a - 0": (True, ["a"]), "1 - a": (True, []), "a + b": (True, ["a", "b"]), "(a + b)": (True, ["a", "b"]), "a + ((((b))))": (True, ["a", "b"]), "a + ((((+b))))": (True, ["a", "b"]), "a + ((((b - a))))": (True, ["a", "b"]), "a + a + a": (True, ["a"]), "a + (b - a)": (True, ["a", "b"]), "a + np.log(a, base=10)": (True, ["a", "np.log(a, base=10)"]), # Note different spacing: "a + np.log(a, base=10) - np . log(a , base = 10)": (True, ["a"]), "a + (I(b) + c)": (True, ["a", "I(b)", "c"]), "a + I(b + c)": (True, ["a", "I(b + c)"]), "a:b": (True, [("a", "b")]), "a:b:a": (True, [("a", "b")]), "a:(b + c)": (True, [("a", "b"), ("a", "c")]), "(a + b):c": (True, [("a", "c"), ("b", "c")]), "a:(b - c)": (True, [("a", "b")]), "c + a:c + a:(b - c)": (True, ["c", ("a", "c"), ("a", "b")]), "(a - b):c": (True, [("a", "c")]), "b + b:c + (a - b):c": (True, ["b", ("b", "c"), ("a", "c")]), "a:b - a:b": (True, []), "a:b - b:a": (True, []), "1 - (a + b)": (True, []), "a + b - (a + b)": (True, []), "a * b": (True, ["a", "b", ("a", "b")]), "a * b * a": (True, ["a", "b", ("a", "b")]), "a * (b + c)": (True, ["a", "b", "c", ("a", "b"), ("a", "c")]), "(a + b) * c": (True, ["a", "b", "c", ("a", "c"), ("b", "c")]), "a * (b - c)": (True, ["a", "b", ("a", "b")]), "c + a:c + a * (b - c)": (True, ["c", ("a", "c"), "a", "b", ("a", "b")]), "(a - b) * c": (True, ["a", "c", ("a", "c")]), "b + b:c + (a - b) * c": (True, ["b", ("b", "c"), "a", "c", ("a", "c")]), "a/b": (True, ["a", ("a", "b")]), "(a + b)/c": (True, ["a", "b", ("a", "b", "c")]), "b + b:c + (a - b)/c": (True, ["b", ("b", "c"), "a", ("a", "c")]), "a/(b + c)": (True, ["a", ("a", "b"), ("a", "c")]), "a ** 2": (True, ["a"]), "(a + b + c + d) ** 2": (True, ["a", "b", "c", "d", ("a", "b"), ("a", "c"), ("a", "d"), ("b", "c"), ("b", "d"), ("c", "d")]), "(a + b + c + d) ** 3": (True, ["a", "b", "c", "d", ("a", "b"), ("a", "c"), ("a", "d"), ("b", "c"), ("b", "d"), ("c", "d"), ("a", "b", "c"), ("a", "b", "d"), ("a", "c", "d"), ("b", "c", "d")]), "a + +a": (True, ["a"]), "~ a + b": (True, ["a", "b"]), "~ a*b": (True, ["a", "b", ("a", "b")]), "~ a*b + 0": (False, ["a", "b", ("a", "b")]), "~ -1": (False, []), "0 ~ a + b": (True, ["a", "b"]), "1 ~ a + b": (True, [], True, ["a", "b"]), "y ~ a + b": (False, ["y"], True, ["a", "b"]), "0 + y ~ a + b": (False, ["y"], True, ["a", "b"]), "0 + y * z ~ a + b": (False, ["y", "z", ("y", "z")], True, ["a", "b"]), "-1 ~ 1": (False, [], True, []), "1 + y ~ a + b": (True, ["y"], True, ["a", "b"]), # Check precedence: "a + b * c": (True, ["a", "b", "c", ("b", "c")]), "a * b + c": (True, ["a", "b", ("a", "b"), "c"]), "a * b - a": (True, ["b", ("a", "b")]), "a + b / c": (True, ["a", "b", ("b", "c")]), "a / b + c": (True, ["a", ("a", "b"), "c"]), "a*b:c": (True, ["a", ("b", "c"), ("a", "b", "c")]), "a:b*c": (True, [("a", "b"), "c", ("a", "b", "c")]), # Intercept handling: "~ 1 + 1 + 0 + 1": (True, []), "~ 0 + 1 + 0": (False, []), "~ 0 - 1 - 1 + 0 + 1": (True, []), "~ 1 - 1": (False, []), "~ 0 + a + 1": (True, ["a"]), "~ 1 + (a + 0)": (True, ["a"]), # This is correct, but perhaps surprising! "~ 0 + (a + 1)": (True, ["a"]), # Also correct! "~ 1 - (a + 1)": (False, []), } # <> mark off where the error should be reported: _eval_error_tests = [ "a <+>", "a + <(>", "b + <(-a)>", "a:<1>", "(a + <1>)*b", "a + <2>", "a + <1.0>", # eh, catching this is a hassle, we'll just leave the user some rope if # they really want it: #"a + <0x1>", "a ** ", "a ** <(1 + 1)>", "a ** <1.5>", "a + b <# asdf>", "<)>", "a + <)>", "<*> a", "a + <*>", "a + ", "a + ", "a + ", "a + <[bar>", "a + <{bar>", "a + <{bar[]>", "a + foo<]>bar", "a + foo[]<]>bar", "a + foo{}<}>bar", "a + foo<)>bar", "a + b<)>", "(a) <.>", "<(>a + b", " ~ b", "y ~ <(a ~ b)>", "<~ a> ~ b", "~ <(a ~ b)>", "1 + <-(a + b)>", "<- a>", "a + <-a**2>", ] def _assert_terms_match(terms, expected_intercept, expecteds): # pragma: no cover if expected_intercept: expecteds = [()] + expecteds assert len(terms) == len(expecteds) for term, expected in zip(terms, expecteds): if isinstance(term, Term): if isinstance(expected, str): expected = (expected,) assert term.factors == tuple([EvalFactor(s) for s in expected]) else: assert term == expected def _do_eval_formula_tests(tests): # pragma: no cover for code, result in six.iteritems(tests): if len(result) == 2: result = (False, []) + result model_desc = ModelDesc.from_formula(code) print(repr(code)) print(result) print(model_desc) lhs_intercept, lhs_termlist, rhs_intercept, rhs_termlist = result _assert_terms_match(model_desc.lhs_termlist, lhs_intercept, lhs_termlist) _assert_terms_match(model_desc.rhs_termlist, rhs_intercept, rhs_termlist) def test_eval_formula(): _do_eval_formula_tests(_eval_tests) def test_eval_formula_error_reporting(): from patsy.parse_formula import _parsing_error_test parse_fn = lambda formula: ModelDesc.from_formula(formula) _parsing_error_test(parse_fn, _eval_error_tests) def test_formula_factor_origin(): from patsy.origin import Origin desc = ModelDesc.from_formula("a + b") assert (desc.rhs_termlist[1].factors[0].origin == Origin("a + b", 0, 1)) assert (desc.rhs_termlist[2].factors[0].origin == Origin("a + b", 4, 5))