Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Feature: fork kernel #410

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@ __pycache__

data_kernelspec
.pytest_cache
.idea
2 changes: 1 addition & 1 deletion ipykernel/heartbeat.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class Heartbeat(Thread):
def __init__(self, context, addr=None):
if addr is None:
addr = ('tcp', localhost(), 0)
Thread.__init__(self)
Thread.__init__(self, name="Heartbeat")
self.context = context
self.transport, self.ip, self.port = addr
self.original_port = self.port
Expand Down
2 changes: 1 addition & 1 deletion ipykernel/iostream.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def __init__(self, socket, pipe=False):
self._local = threading.local()
self._events = deque()
self._setup_event_pipe()
self.thread = threading.Thread(target=self._thread_main)
self.thread = threading.Thread(target=self._thread_main, name="IOPub")
self.thread.daemon = True

def _thread_main(self):
Expand Down
2 changes: 1 addition & 1 deletion ipykernel/ipkernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ def finish_metadata(self, parent, metadata, reply_content):
# This is required by ipyparallel < 5.0
metadata['status'] = reply_content['status']
if reply_content['status'] == 'error' and reply_content['ename'] == 'UnmetDependency':
metadata['dependencies_met'] = False
metadata['dependencies_met'] = False

return metadata

Expand Down
89 changes: 75 additions & 14 deletions ipykernel/kernelapp.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from ipython_genutils.importstring import import_item
from jupyter_core.paths import jupyter_runtime_dir
from jupyter_client import write_connection_file
from jupyter_client.connect import ConnectionFileMixin
from jupyter_client.connect import ConnectionFileMixin, port_names

# local imports
from .iostream import IOPubThread
Expand Down Expand Up @@ -436,16 +436,22 @@ def init_kernel(self):

kernel_factory = self.kernel_class.instance

kernel = kernel_factory(parent=self, session=self.session,
control_stream=control_stream,
shell_streams=[shell_stream, control_stream],
iopub_thread=self.iopub_thread,
iopub_socket=self.iopub_socket,
stdin_socket=self.stdin_socket,
log=self.log,
profile_dir=self.profile_dir,
user_ns=self.user_ns,
params = dict(
parent=self,
session=self.session,
control_stream=control_stream,
shell_streams=[shell_stream, control_stream],
iopub_thread=self.iopub_thread,
iopub_socket=self.iopub_socket,
stdin_socket=self.stdin_socket,
log=self.log,
profile_dir=self.profile_dir,
user_ns=self.user_ns,
)
kernel = kernel_factory(**params)
for k, v in params.items():
setattr(kernel, k, v)

kernel.record_ports({
name + '_port': port for name, port in self.ports.items()
})
Expand Down Expand Up @@ -559,10 +565,64 @@ def start(self):
self.poller.start()
self.kernel.start()
self.io_loop = ioloop.IOLoop.current()
try:
self.io_loop.start()
except KeyboardInterrupt:
pass
keep_running = True
while keep_running:
try:
self.io_loop.start()
except KeyboardInterrupt:
pass
if not getattr(self.io_loop, '_fork_requested', False):
keep_running = False
else:
self.fork()

def fork(self):
# Create a temporary connection file that will be inherited by the child process.
connection_file, conn = write_connection_file()

parent_pid = os.getpid()
pid = os.fork()
self.io_loop._fork_requested = False # reset for parent AND child
if pid == 0:
import asyncio
self.log.debug('Child kernel with pid ', os.getpid())

# close all sockets and ioloops
self.close()

self.io_loop.close(all_fds=True)
self.io_loop.clear_current()
ioloop.IOLoop.clear_current()
asyncio.new_event_loop()

import tornado.platform.asyncio as tasio
# explicitly create a new io loop that will also be the current
self.io_loop = tasio.AsyncIOMainLoop(make_current=True)
assert self.io_loop == ioloop.IOLoop.current()

# Reset all ports so they will be reinitialized with the ports from the connection file
for name in port_names:
setattr(self, name, 0)
self.connection_file = connection_file

# Reset the ZMQ context for it to be recreated in initialize()
self.context = None

# Make ParentPoller work correctly (the new process is a child of the previous kernel)
self.parent_handle = parent_pid

# Session have a protection to send messages from forked processes through the `check_pid` flag.
self.session.pid = os.getpid()
self.session.key = conn['key'].encode()

self.initialize(argv=['-f', self.abs_connection_file, '--debug'])
self.start()
else:
self.log.debug('Parent kernel will resume')
# keep a reference, since the will set this to None
post_fork_callback = self.io_loop._post_fork_callback
self.io_loop.add_callback(lambda: post_fork_callback(pid, conn))
self.io_loop._post_fork_callback = None


launch_new_instance = IPKernelApp.launch_instance
Expand All @@ -577,3 +637,4 @@ def main():

if __name__ == '__main__':
main()

22 changes: 21 additions & 1 deletion ipykernel/kernelbase.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,30 @@ def _default_ident(self):
'connect_request', 'shutdown_request',
'is_complete_request',
# deprecated:
'apply_request',
'apply_request', 'fork'
]
# add deprecated ipyparallel control messages
control_msg_types = msg_types + ['clear_request', 'abort_request']

def fork(self, stream, ident, parent):
# Forking in the (async)io loop is not supported.
# instead, we stop it, and use the io loop to pass
# information up the callstack
loop = ioloop.IOLoop.current()
loop._fork_requested = True

def post_fork_callback(pid, conn):
reply_content = json_clean({'status': 'ok', 'pid': pid, 'conn': conn})
metadata = {}
metadata = self.finish_metadata(parent, metadata, reply_content)

self.session.send(stream, u'fork_reply',
reply_content, parent, metadata=metadata,
ident=ident)

loop._post_fork_callback = post_fork_callback
loop.stop()

def __init__(self, **kwargs):
super(Kernel, self).__init__(**kwargs)
# Build dict of handlers for message types
Expand Down Expand Up @@ -514,6 +533,7 @@ def finish_metadata(self, parent, metadata, reply_content):
def execute_request(self, stream, ident, parent):
"""handle an execute_request"""


try:
content = parent[u'content']
code = py3compat.cast_unicode_py2(content[u'code'])
Expand Down
43 changes: 42 additions & 1 deletion ipykernel/tests/test_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
from IPython.paths import locate_profile
from ipython_genutils.tempdir import TemporaryDirectory

from ipykernel.tests.test_message_spec import validate_message
from .utils import (
new_kernel, kernel, TIMEOUT, assemble_output, execute,
flush_channels, wait_for_idle,
)
connect_to_kernel)


def _check_master(kc, expected=True, stream="stdout"):
Expand Down Expand Up @@ -326,3 +327,43 @@ def test_shutdown():
else:
break
assert not km.is_alive()

def test_fork_metadata():
with kernel() as kc:
km = kc.parent
fork_msg_id = kc.fork()
fork_reply = kc.get_shell_msg(block=True, timeout=TIMEOUT)
validate_message(fork_reply, "fork_reply", fork_msg_id)
assert fork_msg_id == fork_reply['parent_header']['msg_id'] == fork_msg_id
assert fork_reply['content']['conn']['key'] != kc.session.key.decode()
fork_pid = fork_reply['content']['pid']
_check_status(fork_reply['content'])
wait_for_idle(kc)

assert fork_pid != km.kernel.pid
#TODO: Inspect if `fork_pid` is running? Might need to use `psutil` for this in order to be cross platform

with connect_to_kernel(fork_reply['content']['conn'], TIMEOUT) as kc_fork:
assert fork_reply['content']['conn']['key'] == kc_fork.session.key.decode()
kc_fork.shutdown()

def test_fork():
def execute_with_user_expression(kc, code, user_expression):
_, reply = execute(code, kc=kc, user_expressions={"my-user-expression": user_expression})
content = reply["user_expressions"]["my-user-expression"]["data"]["text/plain"]
wait_for_idle(kc)
return content

"""Kernel forks after fork_request"""
with kernel() as kc:
assert execute_with_user_expression(kc, u'a = 1', "a") == "1"
assert execute_with_user_expression(kc, u'b = 2', "b") == "2"
kc.fork()
fork_reply = kc.get_shell_msg(block=True, timeout=TIMEOUT)
wait_for_idle(kc)

with connect_to_kernel(fork_reply['content']['conn'], TIMEOUT) as kc_fork:
assert execute_with_user_expression(kc_fork, 'a = 11', "a, b") == str((11, 2))
assert execute_with_user_expression(kc_fork, 'b = 12', "a, b") == str((11, 12))
assert execute_with_user_expression(kc, 'z = 20', "a, b") == str((1, 2))
kc_fork.shutdown()
6 changes: 6 additions & 0 deletions ipykernel/tests/test_message_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,11 @@ class IsCompleteReplyIncomplete(Reference):
indent = Unicode()


class ForkReply(Reply):
pid = Integer()
conn = Dict()


# IOPub messages

class ExecuteInput(Reference):
Expand Down Expand Up @@ -240,6 +245,7 @@ class HistoryReply(Reply):
'stream' : Stream(),
'display_data' : DisplayData(),
'header' : RHeader(),
'fork_reply' : ForkReply(),
}
"""
Specifications of `content` part of the reply messages.
Expand Down
24 changes: 19 additions & 5 deletions ipykernel/tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def start_new_kernel(**kwargs):

Integrates with our output capturing for tests.
"""
kwargs['stderr'] = STDOUT
try:
stdout = nose.iptest_stdstreams_fileno()
kwargs['stdout'] = nose.iptest_stdstreams_fileno()
except AttributeError:
stdout = open(os.devnull)
kwargs.update(dict(stdout=stdout, stderr=STDOUT))
pass
return manager.start_new_kernel(startup_timeout=STARTUP_TIMEOUT, **kwargs)


Expand Down Expand Up @@ -131,8 +131,11 @@ def new_kernel(argv=None):
-------
kernel_client: connected KernelClient instance
"""
stdout = getattr(nose, 'iptest_stdstreams_fileno', open(os.devnull))
kwargs = dict(stdout=stdout, stderr=STDOUT)
kwargs = {'stderr': STDOUT}
try:
kwargs['stdout'] = nose.iptest_stdstreams_fileno()
except AttributeError:
pass
if argv is not None:
kwargs['extra_arguments'] = argv
return manager.run_kernel(**kwargs)
Expand Down Expand Up @@ -167,3 +170,14 @@ def wait_for_idle(kc):
content = msg['content']
if msg_type == 'status' and content['execution_state'] == 'idle':
break

@contextmanager
def connect_to_kernel(connection_info, timeout):
from jupyter_client import BlockingKernelClient
kc = BlockingKernelClient()
kc.log.setLevel('DEBUG')
kc.load_connection_info(connection_info)
kc.start_channels()
kc.wait_for_ready(timeout)
yield kc
kc.stop_channels()