Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
Tawakalt committed Jul 28, 2023
1 parent ea745a9 commit 933637c
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 35 deletions.
3 changes: 2 additions & 1 deletion rasa_sdk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from rasa_sdk import utils
from rasa_sdk.endpoint import create_argument_parser, run
from rasa_sdk.constants import APPLICATION_ROOT_LOGGER_NAME
from rasa_sdk.tracing.utils import get_tracer_provider


def main_from_args(args):
Expand All @@ -17,7 +18,7 @@ def main_from_args(args):
args.logging_config_file,
)
utils.update_sanic_log_level()
tracer_provider = utils.get_tracer_provider(args)
tracer_provider = get_tracer_provider(args)

run(
args.actions,
Expand Down
22 changes: 3 additions & 19 deletions rasa_sdk/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@
import types
import zlib
import json
from opentelemetry import trace
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator
from typing import List, Text, Union, Optional
from ssl import SSLContext

Expand All @@ -21,6 +19,7 @@
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.interfaces import ActionExecutionRejection, ActionNotFoundException
from rasa_sdk.plugin import plugin_manager
from rasa_sdk.tracing.utils import get_tracer_and_context, set_span_attributes

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -97,13 +96,7 @@ async def health(_) -> HTTPResponse:
@app.post("/webhook")
async def webhook(request: Request) -> HTTPResponse:
"""Webhook to retrieve action calls."""
span_name = "rasa_sdk.create_app.webhook"
if tracer_provider is None:
tracer = trace.get_tracer(span_name)
context = None
else:
tracer = tracer_provider.get_tracer(span_name)
context = TraceContextTextMapPropagator().extract(request.headers)
tracer, context, span_name = get_tracer_and_context(tracer_provider, request)

with tracer.start_as_current_span(span_name, context=context) as span:
if request.headers.get("Content-Encoding") == "deflate":
Expand Down Expand Up @@ -132,16 +125,7 @@ async def webhook(request: Request) -> HTTPResponse:
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=404)

if span.is_recording():
span.set_attribute("http.method", "POST")
span.set_attribute("http.route", "/webhook")
span.set_attribute("next_action", action_call.get("next_action"))
span.set_attribute("version", action_call.get("version"))
span.set_attribute("sender_id", action_call.get("tracker")["sender_id"])
span.set_attribute(
"message_id",
action_call.get("tracker")["latest_message"]["message_id"],
)
set_span_attributes(span, action_call)

return response.json(result, status=200)

Expand Down
56 changes: 56 additions & 0 deletions rasa_sdk/tracing/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import argparse
from rasa_sdk.tracing import config
from opentelemetry import trace
from opentelemetry.trace.propagation.tracecontext import TraceContextTextMapPropagator

from opentelemetry.sdk.trace import TracerProvider
from sanic.request import Request

from typing import Optional, Tuple, Any, Text


def get_tracer_provider(
cmdline_arguments: argparse.Namespace,
) -> Optional[TracerProvider]:
"""Gets the tracer provider from the command line arguments."""
tracer_provider = None
if "endpoints" in cmdline_arguments:
endpoints_file = cmdline_arguments.endpoints

if endpoints_file is not None:
tracer_provider = config.get_tracer_provider(endpoints_file)
return tracer_provider


def get_tracer_and_context(
tracer_provider: Optional[TracerProvider], request: Request
) -> Tuple[Any, Any, Text]:
"""Gets tracer and context"""
span_name = "rasa_sdk.create_app.webhook"
if tracer_provider is None:
tracer = trace.get_tracer(span_name)
context = None
else:
tracer = tracer_provider.get_tracer(span_name)
context = TraceContextTextMapPropagator().extract(request.headers)
return (tracer, context, span_name)


def set_span_attributes(span: Any, action_call: dict) -> None:
"""Sets span attributes"""
set_span_attributes = {
"http.method": "POST",
"http.route": "/webhook",
"next_action": action_call.get("next_action"),
"version": action_call.get("version"),
"sender_id": action_call.get("tracker")["sender_id"], # type: ignore
"message_id": action_call.get("tracker")["latest_message"][
"message_id"
], # type: ignore
}

if span.is_recording():
for key, value in set_span_attributes.items():
span.set_attribute(key, value)

return None
15 changes: 0 additions & 15 deletions rasa_sdk/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import argparse
import asyncio
import inspect
import logging
Expand All @@ -13,8 +12,6 @@
from typing import AbstractSet, Any, Dict, List, Text, Optional, Coroutine, Union

import rasa_sdk
from rasa_sdk.tracing import config
from opentelemetry.sdk.trace import TracerProvider

from rasa_sdk.constants import (
DEFAULT_ENCODING,
Expand Down Expand Up @@ -367,15 +364,3 @@ def read_yaml_file(filename: Union[Text, Path]) -> Dict[Text, Any]:
return read_yaml(read_file(filename, DEFAULT_ENCODING))
except (YAMLError, DuplicateKeyError) as e:
raise YamlSyntaxException(filename, e)


def get_tracer_provider(
cmdline_arguments: argparse.Namespace,
) -> Optional[TracerProvider]:
tracer_provider = None
if "endpoints" in cmdline_arguments:
endpoints_file = cmdline_arguments.endpoints

if endpoints_file is not None:
tracer_provider = config.get_tracer_provider(endpoints_file)
return tracer_provider

0 comments on commit 933637c

Please sign in to comment.