329 lines
11 KiB
Python
329 lines
11 KiB
Python
|
"""Base classes to manage a Client's interaction with a running kernel"""
|
||
|
# Copyright (c) Jupyter Development Team.
|
||
|
# Distributed under the terms of the Modified BSD License.
|
||
|
import asyncio
|
||
|
import atexit
|
||
|
import time
|
||
|
import typing as t
|
||
|
from queue import Empty
|
||
|
from threading import Event, Thread
|
||
|
|
||
|
import zmq.asyncio
|
||
|
from jupyter_core.utils import ensure_async
|
||
|
|
||
|
from ._version import protocol_version_info
|
||
|
from .channelsabc import HBChannelABC
|
||
|
from .session import Session
|
||
|
|
||
|
# import ZMQError in top-level namespace, to avoid ugly attribute-error messages
|
||
|
# during garbage collection of threads at exit
|
||
|
|
||
|
# -----------------------------------------------------------------------------
|
||
|
# Constants and exceptions
|
||
|
# -----------------------------------------------------------------------------
|
||
|
|
||
|
major_protocol_version = protocol_version_info[0]
|
||
|
|
||
|
|
||
|
class InvalidPortNumber(Exception): # noqa
|
||
|
"""An exception raised for an invalid port number."""
|
||
|
|
||
|
pass
|
||
|
|
||
|
|
||
|
class HBChannel(Thread):
|
||
|
"""The heartbeat channel which monitors the kernel heartbeat.
|
||
|
|
||
|
Note that the heartbeat channel is paused by default. As long as you start
|
||
|
this channel, the kernel manager will ensure that it is paused and un-paused
|
||
|
as appropriate.
|
||
|
"""
|
||
|
|
||
|
session = None
|
||
|
socket = None
|
||
|
address = None
|
||
|
_exiting = False
|
||
|
|
||
|
time_to_dead: float = 1.0
|
||
|
_running = None
|
||
|
_pause = None
|
||
|
_beating = None
|
||
|
|
||
|
def __init__(
|
||
|
self,
|
||
|
context: t.Optional[zmq.Context] = None,
|
||
|
session: t.Optional[Session] = None,
|
||
|
address: t.Union[t.Tuple[str, int], str] = "",
|
||
|
) -> None:
|
||
|
"""Create the heartbeat monitor thread.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
context : :class:`zmq.Context`
|
||
|
The ZMQ context to use.
|
||
|
session : :class:`session.Session`
|
||
|
The session to use.
|
||
|
address : zmq url
|
||
|
Standard (ip, port) tuple that the kernel is listening on.
|
||
|
"""
|
||
|
super().__init__()
|
||
|
self.daemon = True
|
||
|
|
||
|
self.context = context
|
||
|
self.session = session
|
||
|
if isinstance(address, tuple):
|
||
|
if address[1] == 0:
|
||
|
message = "The port number for a channel cannot be 0."
|
||
|
raise InvalidPortNumber(message)
|
||
|
address_str = "tcp://%s:%i" % address
|
||
|
else:
|
||
|
address_str = address
|
||
|
self.address = address_str
|
||
|
|
||
|
# running is False until `.start()` is called
|
||
|
self._running = False
|
||
|
self._exit = Event()
|
||
|
# don't start paused
|
||
|
self._pause = False
|
||
|
self.poller = zmq.Poller()
|
||
|
|
||
|
@staticmethod
|
||
|
@atexit.register
|
||
|
def _notice_exit() -> None:
|
||
|
# Class definitions can be torn down during interpreter shutdown.
|
||
|
# We only need to set _exiting flag if this hasn't happened.
|
||
|
if HBChannel is not None:
|
||
|
HBChannel._exiting = True
|
||
|
|
||
|
def _create_socket(self) -> None:
|
||
|
if self.socket is not None:
|
||
|
# close previous socket, before opening a new one
|
||
|
self.poller.unregister(self.socket) # type:ignore[unreachable]
|
||
|
self.socket.close()
|
||
|
assert self.context is not None
|
||
|
self.socket = self.context.socket(zmq.REQ)
|
||
|
self.socket.linger = 1000
|
||
|
assert self.address is not None
|
||
|
self.socket.connect(self.address)
|
||
|
|
||
|
self.poller.register(self.socket, zmq.POLLIN)
|
||
|
|
||
|
async def _async_run(self) -> None:
|
||
|
"""The thread's main activity. Call start() instead."""
|
||
|
self._create_socket()
|
||
|
self._running = True
|
||
|
self._beating = True
|
||
|
assert self.socket is not None
|
||
|
|
||
|
while self._running:
|
||
|
if self._pause:
|
||
|
# just sleep, and skip the rest of the loop
|
||
|
self._exit.wait(self.time_to_dead)
|
||
|
continue
|
||
|
|
||
|
since_last_heartbeat = 0.0
|
||
|
# no need to catch EFSM here, because the previous event was
|
||
|
# either a recv or connect, which cannot be followed by EFSM)
|
||
|
await ensure_async(self.socket.send(b"ping"))
|
||
|
request_time = time.time()
|
||
|
# Wait until timeout
|
||
|
self._exit.wait(self.time_to_dead)
|
||
|
# poll(0) means return immediately (see http://api.zeromq.org/2-1:zmq-poll)
|
||
|
self._beating = bool(self.poller.poll(0))
|
||
|
if self._beating:
|
||
|
# the poll above guarantees we have something to recv
|
||
|
await ensure_async(self.socket.recv())
|
||
|
continue
|
||
|
elif self._running:
|
||
|
# nothing was received within the time limit, signal heart failure
|
||
|
since_last_heartbeat = time.time() - request_time
|
||
|
self.call_handlers(since_last_heartbeat)
|
||
|
# and close/reopen the socket, because the REQ/REP cycle has been broken
|
||
|
self._create_socket()
|
||
|
continue
|
||
|
|
||
|
def run(self) -> None:
|
||
|
"""Run the heartbeat thread."""
|
||
|
loop = asyncio.new_event_loop()
|
||
|
asyncio.set_event_loop(loop)
|
||
|
loop.run_until_complete(self._async_run())
|
||
|
loop.close()
|
||
|
|
||
|
def pause(self) -> None:
|
||
|
"""Pause the heartbeat."""
|
||
|
self._pause = True
|
||
|
|
||
|
def unpause(self) -> None:
|
||
|
"""Unpause the heartbeat."""
|
||
|
self._pause = False
|
||
|
|
||
|
def is_beating(self) -> bool:
|
||
|
"""Is the heartbeat running and responsive (and not paused)."""
|
||
|
if self.is_alive() and not self._pause and self._beating: # noqa
|
||
|
return True
|
||
|
else:
|
||
|
return False
|
||
|
|
||
|
def stop(self) -> None:
|
||
|
"""Stop the channel's event loop and join its thread."""
|
||
|
self._running = False
|
||
|
self._exit.set()
|
||
|
self.join()
|
||
|
self.close()
|
||
|
|
||
|
def close(self) -> None:
|
||
|
"""Close the heartbeat thread."""
|
||
|
if self.socket is not None:
|
||
|
try:
|
||
|
self.socket.close(linger=0)
|
||
|
except Exception:
|
||
|
pass
|
||
|
self.socket = None
|
||
|
|
||
|
def call_handlers(self, since_last_heartbeat: float) -> None:
|
||
|
"""This method is called in the ioloop thread when a message arrives.
|
||
|
|
||
|
Subclasses should override this method to handle incoming messages.
|
||
|
It is important to remember that this method is called in the thread
|
||
|
so that some logic must be done to ensure that the application level
|
||
|
handlers are called in the application thread.
|
||
|
"""
|
||
|
pass
|
||
|
|
||
|
|
||
|
HBChannelABC.register(HBChannel)
|
||
|
|
||
|
|
||
|
class ZMQSocketChannel:
|
||
|
"""A ZMQ socket wrapper"""
|
||
|
|
||
|
def __init__(self, socket: zmq.Socket, session: Session, loop: t.Any = None) -> None:
|
||
|
"""Create a channel.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
socket : :class:`zmq.Socket`
|
||
|
The ZMQ socket to use.
|
||
|
session : :class:`session.Session`
|
||
|
The session to use.
|
||
|
loop
|
||
|
Unused here, for other implementations
|
||
|
"""
|
||
|
super().__init__()
|
||
|
|
||
|
self.socket: t.Optional[zmq.Socket] = socket
|
||
|
self.session = session
|
||
|
|
||
|
def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]:
|
||
|
assert self.socket is not None
|
||
|
msg = self.socket.recv_multipart(**kwargs)
|
||
|
ident, smsg = self.session.feed_identities(msg)
|
||
|
return self.session.deserialize(smsg)
|
||
|
|
||
|
def get_msg(self, timeout: t.Optional[float] = None) -> t.Dict[str, t.Any]:
|
||
|
"""Gets a message if there is one that is ready."""
|
||
|
assert self.socket is not None
|
||
|
timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms
|
||
|
ready = self.socket.poll(timeout_ms)
|
||
|
if ready:
|
||
|
res = self._recv()
|
||
|
return res
|
||
|
else:
|
||
|
raise Empty
|
||
|
|
||
|
def get_msgs(self) -> t.List[t.Dict[str, t.Any]]:
|
||
|
"""Get all messages that are currently ready."""
|
||
|
msgs = []
|
||
|
while True:
|
||
|
try:
|
||
|
msgs.append(self.get_msg())
|
||
|
except Empty:
|
||
|
break
|
||
|
return msgs
|
||
|
|
||
|
def msg_ready(self) -> bool:
|
||
|
"""Is there a message that has been received?"""
|
||
|
assert self.socket is not None
|
||
|
return bool(self.socket.poll(timeout=0))
|
||
|
|
||
|
def close(self) -> None:
|
||
|
"""Close the socket channel."""
|
||
|
if self.socket is not None:
|
||
|
try:
|
||
|
self.socket.close(linger=0)
|
||
|
except Exception:
|
||
|
pass
|
||
|
self.socket = None
|
||
|
|
||
|
stop = close
|
||
|
|
||
|
def is_alive(self) -> bool:
|
||
|
"""Test whether the channel is alive."""
|
||
|
return self.socket is not None
|
||
|
|
||
|
def send(self, msg: t.Dict[str, t.Any]) -> None:
|
||
|
"""Pass a message to the ZMQ socket to send"""
|
||
|
assert self.socket is not None
|
||
|
self.session.send(self.socket, msg)
|
||
|
|
||
|
def start(self) -> None:
|
||
|
"""Start the socket channel."""
|
||
|
pass
|
||
|
|
||
|
|
||
|
class AsyncZMQSocketChannel(ZMQSocketChannel):
|
||
|
"""A ZMQ socket in an async API"""
|
||
|
|
||
|
socket: zmq.asyncio.Socket
|
||
|
|
||
|
def __init__(self, socket: zmq.asyncio.Socket, session: Session, loop: t.Any = None) -> None:
|
||
|
"""Create a channel.
|
||
|
|
||
|
Parameters
|
||
|
----------
|
||
|
socket : :class:`zmq.asyncio.Socket`
|
||
|
The ZMQ socket to use.
|
||
|
session : :class:`session.Session`
|
||
|
The session to use.
|
||
|
loop
|
||
|
Unused here, for other implementations
|
||
|
"""
|
||
|
if not isinstance(socket, zmq.asyncio.Socket):
|
||
|
msg = "Socket must be asyncio" # type:ignore[unreachable]
|
||
|
raise ValueError(msg)
|
||
|
super().__init__(socket, session)
|
||
|
|
||
|
async def _recv(self, **kwargs: t.Any) -> t.Dict[str, t.Any]: # type:ignore[override]
|
||
|
assert self.socket is not None
|
||
|
msg = await self.socket.recv_multipart(**kwargs)
|
||
|
_, smsg = self.session.feed_identities(msg)
|
||
|
return self.session.deserialize(smsg)
|
||
|
|
||
|
async def get_msg( # type:ignore[override]
|
||
|
self, timeout: t.Optional[float] = None
|
||
|
) -> t.Dict[str, t.Any]:
|
||
|
"""Gets a message if there is one that is ready."""
|
||
|
assert self.socket is not None
|
||
|
timeout_ms = None if timeout is None else int(timeout * 1000) # seconds to ms
|
||
|
ready = await self.socket.poll(timeout_ms)
|
||
|
if ready:
|
||
|
res = await self._recv()
|
||
|
return res
|
||
|
else:
|
||
|
raise Empty
|
||
|
|
||
|
async def get_msgs(self) -> t.List[t.Dict[str, t.Any]]: # type:ignore[override]
|
||
|
"""Get all messages that are currently ready."""
|
||
|
msgs = []
|
||
|
while True:
|
||
|
try:
|
||
|
msgs.append(await self.get_msg())
|
||
|
except Empty:
|
||
|
break
|
||
|
return msgs
|
||
|
|
||
|
async def msg_ready(self) -> bool: # type:ignore[override]
|
||
|
"""Is there a message that has been received?"""
|
||
|
assert self.socket is not None
|
||
|
return bool(await self.socket.poll(timeout=0))
|