from __future__ import annotations import warnings from _pytest.recwarn import WarningsChecker from pytest import warns __all__ = ["pytest_warns"] class NoWarningsChecker: def __init__(self): self.cw = warnings.catch_warnings(record=True) self.rec = [] def __enter__(self): self.rec = self.cw.__enter__() def __exit__(self, type, value, traceback): if self.rec: warnings = [w.category.__name__ for w in self.rec] joined = "\\n".join(warnings) raise AssertionError( "Function is marked as not warning but the following " "warnings were found: \n" f"{joined}" ) def pytest_warns( warning: type[Warning] | tuple[type[Warning], ...] | None ) -> WarningsChecker | NoWarningsChecker: """ Parameters ---------- warning : {None, Warning, Tuple[Warning]} None if no warning is produced, or a single or multiple Warnings Returns ------- cm """ if warning is None: return NoWarningsChecker() else: assert warning is not None return warns(warning)