Skip to content

Commit

Permalink
Handle SIGINT and SIGTERM signals when gRPC server is running
Browse files Browse the repository at this point in the history
  • Loading branch information
radovanZRasa committed Jun 15, 2024
1 parent b612b4b commit a82a231
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 13 deletions.
15 changes: 2 additions & 13 deletions rasa_sdk/__main__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import sys

import logging
import asyncio
import signal
Expand All @@ -10,17 +12,6 @@
logger = logging.getLogger(__name__)


def initialise_interrupts() -> None:
"""Initialise handlers for kernel signal interrupts."""

def handle_sigint(signum, frame):
logger.info("Received SIGINT, exiting")
asyncio.get_event_loop().stop()

signal.signal(signal.SIGINT, handle_sigint)
signal.signal(signal.SIGTERM, handle_sigint)


def main_from_args(args):
"""Run with arguments."""
logging.getLogger("matplotlib").setLevel(logging.WARN)
Expand All @@ -34,8 +25,6 @@ def main_from_args(args):
)
utils.update_sanic_log_level()

initialise_interrupts()

if args.grpc:
asyncio.run(
run_grpc(
Expand Down
30 changes: 30 additions & 0 deletions rasa_sdk/grpc_server.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
from __future__ import annotations

import sys

import signal

import asyncio

import grpc
import utils
import logging
Expand Down Expand Up @@ -109,6 +115,29 @@ def password_callback() -> bytes:
return password_callback


def get_signal_name(signal_number: int) -> str:
"""Return the signal name for the given signal number."""
return signal.Signals(signal_number).name


def initialise_interrupts(server: grpc.aio.Server) -> None:
"""Initialise handlers for kernel signal interrupts."""

async def handle_sigint(signal_received: int):
logger.info(
f"Received {get_signal_name(signal_received)} signal. Stopping gRPC server..."
)
await server.stop(0)
logger.info("gRPC server stopped.")
asyncio.get_event_loop().stop()

loop = asyncio.get_event_loop()
loop.add_signal_handler(signal.SIGINT,
lambda: asyncio.create_task(handle_sigint(signal.SIGINT)))
loop.add_signal_handler(signal.SIGTERM,
lambda: asyncio.create_task(handle_sigint(signal.SIGTERM)))


async def run_grpc(
action_package_name: Union[str, types.ModuleType],
port: int = DEFAULT_SERVER_PORT,
Expand All @@ -129,6 +158,7 @@ async def run_grpc(
"""
workers = utils.number_of_sanic_workers()
server = aio.server(futures.ThreadPoolExecutor(max_workers=workers))
initialise_interrupts(server)
executor = ActionExecutor()
executor.register_package(action_package_name)
# tracer_provider = get_tracer_provider(endpoints)
Expand Down

0 comments on commit a82a231

Please sign in to comment.