Skip to content

Commit

Permalink
Use partial import form utils
Browse files Browse the repository at this point in the history
  • Loading branch information
radovanZRasa committed Jun 17, 2024
1 parent 5f94cab commit ed05393
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions rasa_sdk/grpc_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import asyncio

import grpc
import utils
import logging
import ssl
import types
Expand Down Expand Up @@ -36,11 +35,14 @@
get_tracer_and_context,
TracerProvider,
)
from rasa_sdk.utils import check_version_compatibility, number_of_sanic_workers

logger = logging.getLogger(__name__)


class ActionServerWebhook(action_webhook_pb2_grpc.ActionServerWebhookServicer):
class GRPCActionServerWebhook(action_webhook_pb2_grpc.ActionServerWebhookServicer):
"""Runs webhook RPC which is served through gRPC server."""

def __init__(
self,
executor: ActionExecutor,
Expand All @@ -65,13 +67,15 @@ async def webhook(
Args:
request: The webhook request.
context: The context of the request.
Returns:
gRPC response.
"""
await asyncio.sleep(50)
tracer, tracer_context, span_name = get_tracer_and_context(
self.tracer_provider, request
)
with tracer.start_as_current_span(span_name, context=tracer_context):
utils.check_version_compatibility(request.version)
check_version_compatibility(request.version)
try:
action_call = MessageToDict(request, preserving_proto_field_name=True)
result = await self.executor.run(action_call)
Expand Down Expand Up @@ -121,6 +125,7 @@ def initialise_interrupts(server: grpc.aio.Server) -> None:
"""Initialise handlers for kernel signal interrupts."""

async def handle_sigint(signal_received: int):
"""Handle the received signal."""
logger.info(
f"Received {get_signal_name(signal_received)} signal."
"Stopping gRPC server..."
Expand Down Expand Up @@ -155,15 +160,15 @@ async def run_grpc(
ssl_password: Password for the SSL key file.
endpoints: Path to the endpoints file.
"""
workers = utils.number_of_sanic_workers()
workers = number_of_sanic_workers()
server = aio.server(futures.ThreadPoolExecutor(max_workers=workers))
initialise_interrupts(server)
executor = ActionExecutor()
executor.register_package(action_package_name)
# tracer_provider = get_tracer_provider(endpoints)
tracer_provider = None
action_webhook_pb2_grpc.add_ActionServerWebhookServicer_to_server(
ActionServerWebhook(executor, tracer_provider), server
GRPCActionServerWebhook(executor, tracer_provider), server
)
if ssl_certificate and ssl_keyfile:
# Use SSL/TLS if certificate and key are provided
Expand Down

0 comments on commit ed05393

Please sign in to comment.