"""Patch asyncio to allow nested event loops.""" import asyncio import asyncio.events as events import os import sys import threading from contextlib import contextmanager, suppress from heapq import heappop def apply(loop=None): """Patch asyncio to make its event loop reentrant.""" _patch_asyncio() _patch_policy() _patch_tornado() loop = loop or asyncio.get_event_loop() _patch_loop(loop) def _patch_asyncio(): """Patch asyncio module to use pure Python tasks and futures.""" def run(main, *, debug=False): loop = asyncio.get_event_loop() loop.set_debug(debug) task = asyncio.ensure_future(main) try: return loop.run_until_complete(task) finally: if not task.done(): task.cancel() with suppress(asyncio.CancelledError): loop.run_until_complete(task) def _get_event_loop(stacklevel=3): loop = events._get_running_loop() if loop is None: loop = events.get_event_loop_policy().get_event_loop() return loop # Use module level _current_tasks, all_tasks and patch run method. if hasattr(asyncio, '_nest_patched'): return if sys.version_info >= (3, 6, 0): asyncio.Task = asyncio.tasks._CTask = asyncio.tasks.Task = \ asyncio.tasks._PyTask asyncio.Future = asyncio.futures._CFuture = asyncio.futures.Future = \ asyncio.futures._PyFuture if sys.version_info < (3, 7, 0): asyncio.tasks._current_tasks = asyncio.tasks.Task._current_tasks asyncio.all_tasks = asyncio.tasks.Task.all_tasks if sys.version_info >= (3, 9, 0): events._get_event_loop = events.get_event_loop = \ asyncio.get_event_loop = _get_event_loop asyncio.run = run asyncio._nest_patched = True def _patch_policy(): """Patch the policy to always return a patched loop.""" def get_event_loop(self): if self._local._loop is None: loop = self.new_event_loop() _patch_loop(loop) self.set_event_loop(loop) return self._local._loop policy = events.get_event_loop_policy() policy.__class__.get_event_loop = get_event_loop def _patch_loop(loop): """Patch loop to make it reentrant.""" def run_forever(self): with manage_run(self), manage_asyncgens(self): while True: self._run_once() if self._stopping: break self._stopping = False def run_until_complete(self, future): with manage_run(self): f = asyncio.ensure_future(future, loop=self) if f is not future: f._log_destroy_pending = False while not f.done(): self._run_once() if self._stopping: break if not f.done(): raise RuntimeError( 'Event loop stopped before Future completed.') return f.result() def _run_once(self): """ Simplified re-implementation of asyncio's _run_once that runs handles as they become ready. """ ready = self._ready scheduled = self._scheduled while scheduled and scheduled[0]._cancelled: heappop(scheduled) timeout = ( 0 if ready or self._stopping else min(max( scheduled[0]._when - self.time(), 0), 86400) if scheduled else None) event_list = self._selector.select(timeout) self._process_events(event_list) end_time = self.time() + self._clock_resolution while scheduled and scheduled[0]._when < end_time: handle = heappop(scheduled) ready.append(handle) for _ in range(len(ready)): if not ready: break handle = ready.popleft() if not handle._cancelled: # preempt the current task so that that checks in # Task.__step do not raise curr_task = curr_tasks.pop(self, None) try: handle._run() finally: # restore the current task if curr_task is not None: curr_tasks[self] = curr_task handle = None @contextmanager def manage_run(self): """Set up the loop for running.""" self._check_closed() old_thread_id = self._thread_id old_running_loop = events._get_running_loop() try: self._thread_id = threading.get_ident() events._set_running_loop(self) self._num_runs_pending += 1 if self._is_proactorloop: if self._self_reading_future is None: self.call_soon(self._loop_self_reading) yield finally: self._thread_id = old_thread_id events._set_running_loop(old_running_loop) self._num_runs_pending -= 1 if self._is_proactorloop: if (self._num_runs_pending == 0 and self._self_reading_future is not None): ov = self._self_reading_future._ov self._self_reading_future.cancel() if ov is not None: self._proactor._unregister(ov) self._self_reading_future = None @contextmanager def manage_asyncgens(self): if not hasattr(sys, 'get_asyncgen_hooks'): # Python version is too old. return old_agen_hooks = sys.get_asyncgen_hooks() try: self._set_coroutine_origin_tracking(self._debug) if self._asyncgens is not None: sys.set_asyncgen_hooks( firstiter=self._asyncgen_firstiter_hook, finalizer=self._asyncgen_finalizer_hook) yield finally: self._set_coroutine_origin_tracking(False) if self._asyncgens is not None: sys.set_asyncgen_hooks(*old_agen_hooks) def _check_running(self): """Do not throw exception if loop is already running.""" pass if hasattr(loop, '_nest_patched'): return if not isinstance(loop, asyncio.BaseEventLoop): raise ValueError('Can\'t patch loop of type %s' % type(loop)) cls = loop.__class__ cls.run_forever = run_forever cls.run_until_complete = run_until_complete cls._run_once = _run_once cls._check_running = _check_running cls._check_runnung = _check_running # typo in Python 3.7 source cls._num_runs_pending = 1 if loop.is_running() else 0 cls._is_proactorloop = ( os.name == 'nt' and issubclass(cls, asyncio.ProactorEventLoop)) if sys.version_info < (3, 7, 0): cls._set_coroutine_origin_tracking = cls._set_coroutine_wrapper curr_tasks = asyncio.tasks._current_tasks \ if sys.version_info >= (3, 7, 0) else asyncio.Task._current_tasks cls._nest_patched = True def _patch_tornado(): """ If tornado is imported before nest_asyncio, make tornado aware of the pure-Python asyncio Future. """ if 'tornado' in sys.modules: import tornado.concurrent as tc # type: ignore tc.Future = asyncio.Future if asyncio.Future not in tc.FUTURES: tc.FUTURES += (asyncio.Future,)