Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement tracing in action server #1016

Merged
merged 8 commits into from
Aug 2, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions changelog/1016.feature.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Added tracing functionality to the Rasa SDK, bringing enhanced monitoring, execution profiling and debugging capabilities to the Rasa Actions Server.
See [Rasa Documentation on Tracing](https://rasa.com/docs/rasa/monitoring/tracing/#configuring-a-tracing-backend-or-collector) to know more about configuring a tracing backend or collector.
462 changes: 453 additions & 9 deletions poetry.lock

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,10 @@ prompt-toolkit = "^3.0,<3.0.29"
"ruamel.yaml" = ">=0.16.5,<0.18.0"
websockets = ">=10.0,<12.0"
pluggy = "^1.0.0"
opentelemetry-api = "~1.15.0"
opentelemetry-sdk = "~1.15.0"
opentelemetry-exporter-jaeger = "~1.15.0"
opentelemetry-exporter-otlp = "~1.15.0"
Tawakalt marked this conversation as resolved.
Show resolved Hide resolved

[tool.poetry.dev-dependencies]
pytest-cov = "^4.1.0"
Expand Down
3 changes: 3 additions & 0 deletions rasa_sdk/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from rasa_sdk import utils
from rasa_sdk.endpoint import create_argument_parser, run
from rasa_sdk.constants import APPLICATION_ROOT_LOGGER_NAME
from rasa_sdk.tracing.utils import get_tracer_provider


def main_from_args(args):
Expand All @@ -17,6 +18,7 @@ def main_from_args(args):
args.logging_config_file,
)
utils.update_sanic_log_level()
tracer_provider = get_tracer_provider(args)

run(
args.actions,
Expand All @@ -26,6 +28,7 @@ def main_from_args(args):
args.ssl_keyfile,
args.ssl_password,
args.auto_reload,
tracer_provider,
)


Expand Down
70 changes: 41 additions & 29 deletions rasa_sdk/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import types
import zlib
import json
from opentelemetry.sdk.trace import TracerProvider
from typing import List, Text, Union, Optional
from ssl import SSLContext

Expand All @@ -18,6 +19,7 @@
from rasa_sdk.executor import ActionExecutor
from rasa_sdk.interfaces import ActionExecutionRejection, ActionNotFoundException
from rasa_sdk.plugin import plugin_manager
from rasa_sdk.tracing.utils import get_tracer_and_context, set_span_attributes

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -65,6 +67,7 @@ def create_app(
action_package_name: Union[Text, types.ModuleType],
cors_origins: Union[Text, List[Text], None] = "*",
auto_reload: bool = False,
tracer_provider: Optional[TracerProvider] = None,
) -> Sanic:
"""Create a Sanic application and return it.

Expand All @@ -73,6 +76,7 @@ def create_app(
from.
cors_origins: CORS origins to allow.
auto_reload: When `True`, auto-reloading of actions is enabled.
tracer_provider: Tracer provider to use for tracing.

Returns:
A new Sanic application ready to be run.
Expand All @@ -93,34 +97,38 @@ async def health(_) -> HTTPResponse:
@app.post("/webhook")
async def webhook(request: Request) -> HTTPResponse:
"""Webhook to retrieve action calls."""
if request.headers.get("Content-Encoding") == "deflate":
# Decompress the request data using zlib
decompressed_data = zlib.decompress(request.body)
# Load the JSON data from the decompressed request data
action_call = json.loads(decompressed_data)
else:
action_call = request.json
if action_call is None:
body = {"error": "Invalid body request"}
return response.json(body, status=400)

utils.check_version_compatibility(action_call.get("version"))

if auto_reload:
executor.reload()

try:
result = await executor.run(action_call)
except ActionExecutionRejection as e:
logger.debug(e)
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=400)
except ActionNotFoundException as e:
logger.error(e)
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=404)

return response.json(result, status=200)
tracer, context, span_name = get_tracer_and_context(tracer_provider, request)

with tracer.start_as_current_span(span_name, context=context) as span:
if request.headers.get("Content-Encoding") == "deflate":
# Decompress the request data using zlib
decompressed_data = zlib.decompress(request.body)
# Load the JSON data from the decompressed request data
action_call = json.loads(decompressed_data)
else:
action_call = request.json
if action_call is None:
body = {"error": "Invalid body request"}
return response.json(body, status=400)

utils.check_version_compatibility(action_call.get("version"))

if auto_reload:
executor.reload()
try:
result = await executor.run(action_call)
except ActionExecutionRejection as e:
logger.debug(e)
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=400)
except ActionNotFoundException as e:
logger.error(e)
body = {"error": e.message, "action_name": e.action_name}
return response.json(body, status=404)

set_span_attributes(span, action_call)

return response.json(result, status=200)

@app.get("/actions")
async def actions(_) -> HTTPResponse:
Expand Down Expand Up @@ -151,11 +159,15 @@ def run(
ssl_keyfile: Optional[Text] = None,
ssl_password: Optional[Text] = None,
auto_reload: bool = False,
tracer_provider: Optional[TracerProvider] = None,
) -> None:
"""Starts the action endpoint server with given config values."""
logger.info("Starting action endpoint server...")
app = create_app(
action_package_name, cors_origins=cors_origins, auto_reload=auto_reload
action_package_name,
cors_origins=cors_origins,
auto_reload=auto_reload,
tracer_provider=tracer_provider,
)
## Attach additional sanic extensions: listeners, middleware and routing
logger.info("Starting plugins...")
Expand Down
166 changes: 166 additions & 0 deletions rasa_sdk/tracing/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
from __future__ import annotations

import abc
import logging
import os
from typing import Any, Dict, Optional, Text

import grpc
from opentelemetry.exporter.jaeger.thrift import JaegerExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from rasa_sdk.tracing.endpoints import EndpointConfig, read_endpoint_config


TRACING_SERVICE_NAME = os.environ.get("TRACING_SERVICE_NAME", "rasa_sdk")

ENDPOINTS_TRACING_KEY = "tracing"

logger = logging.getLogger(__name__)


def get_tracer_provider(endpoints_file: Text) -> Optional[TracerProvider]:
"""Configure tracing backend.

When a known tracing backend is defined in the endpoints file, this
function will configure the tracing infrastructure. When no or an unknown
tracing backend is defined, this function does nothing.

:param endpoints_file: The configuration file containing information about the
tracing backend.
:return: The `TracingProvider` to be used for all subsequent tracing.
"""
cfg = read_endpoint_config(endpoints_file, ENDPOINTS_TRACING_KEY)

if not cfg:
logger.info(
f"No endpoint for tracing type available in {endpoints_file},"
f"tracing will not be configured."
)
return None
if cfg.type == "jaeger":
tracer_provider = JaegerTracerConfigurer.configure_from_endpoint_config(cfg)
elif cfg.type == "otlp":
tracer_provider = OTLPCollectorConfigurer.configure_from_endpoint_config(cfg)
else:
logger.warning(
f"Unknown tracing type {cfg.type} read from {endpoints_file}, ignoring."
)
return None

return tracer_provider


class TracerConfigurer(abc.ABC):
"""Abstract superclass for tracing configuration.

`TracerConfigurer` is the abstract superclass from which all configurers
for different supported backends should inherit.
"""

@classmethod
@abc.abstractmethod
def configure_from_endpoint_config(cls, cfg: EndpointConfig) -> TracerProvider:
"""Configure tracing.

This abstract method should be implemented by all concrete `TracerConfigurer`s.
It shall read the configuration from the supplied argument, configure all
necessary infrastructure for tracing, and return the `TracerProvider` to be
used for tracing purposes.

:param cfg: The configuration to be read for configuring tracing.
:return: The configured `TracerProvider`.
"""


class JaegerTracerConfigurer(TracerConfigurer):
"""The `TracerConfigurer` for a Jaeger backend."""

@classmethod
def configure_from_endpoint_config(cls, cfg: EndpointConfig) -> TracerProvider:
"""Configure tracing for Jaeger.

This will read the Jaeger-specific configuration from the `EndpointConfig` and
create a corresponding `TracerProvider` that exports to the given Jaeger
backend.

:param cfg: The configuration to be read for configuring tracing.
:return: The configured `TracerProvider`.
"""
provider = TracerProvider(
resource=Resource.create(
{SERVICE_NAME: cfg.kwargs.get("service_name", TRACING_SERVICE_NAME)}
)
)

jaeger_exporter = JaegerExporter(
**cls._extract_config(cfg), udp_split_oversized_batches=True
)
logger.info(
f"Registered {cfg.type} endpoint for tracing. Traces will be exported to"
f" {jaeger_exporter.agent_host_name}:{jaeger_exporter.agent_port}"
)
provider.add_span_processor(BatchSpanProcessor(jaeger_exporter))

return provider

@classmethod
def _extract_config(cls, cfg: EndpointConfig) -> Dict[str, Any]:
return {
"agent_host_name": (cfg.kwargs.get("host", "localhost")),
"agent_port": (cfg.kwargs.get("port", 6831)),
"username": cfg.kwargs.get("username"),
"password": cfg.kwargs.get("password"),
}


class OTLPCollectorConfigurer(TracerConfigurer):
"""The `TracerConfigurer` for an OTLP collector backend."""

@classmethod
def configure_from_endpoint_config(cls, cfg: EndpointConfig) -> TracerProvider:
"""Configure tracing for Jaeger.

This will read the OTLP collector-specific configuration from the
`EndpointConfig` and create a corresponding `TracerProvider` that exports to
the given OTLP collector.
Currently, this only supports insecure connections via gRPC.

:param cfg: The configuration to be read for configuring tracing.
:return: The configured `TracerProvider`.
"""
provider = TracerProvider(
resource=Resource.create(
{SERVICE_NAME: cfg.kwargs.get("service_name", TRACING_SERVICE_NAME)}
)
)

insecure = cfg.kwargs.get("insecure")

credentials = cls._get_credentials(cfg, insecure) # type: ignore

otlp_exporter = OTLPSpanExporter(
endpoint=cfg.kwargs["endpoint"],
insecure=insecure,
credentials=credentials,
)
logger.info(
f"Registered {cfg.type} endpoint for tracing."
f"Traces will be exported to {cfg.kwargs['endpoint']}"
)
provider.add_span_processor(BatchSpanProcessor(otlp_exporter))

return provider

@classmethod
def _get_credentials(
cls, cfg: EndpointConfig, insecure: bool
) -> Optional[grpc.ChannelCredentials]:
credentials = None
if not insecure and "root_certificates" in cfg.kwargs:
with open(cfg.kwargs.get("root_certificates"), "rb") as f: # type: ignore
root_cert = f.read()
credentials = grpc.ssl_channel_credentials(root_certificates=root_cert)
return credentials
64 changes: 64 additions & 0 deletions rasa_sdk/tracing/endpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
import logging

import os
from typing import Any, Dict, Optional, Text
import rasa_sdk.utils


logger = logging.getLogger(__name__)
DEFAULT_ENCODING = "utf-8"


def read_endpoint_config(
filename: Text, endpoint_type: Text
) -> Optional["EndpointConfig"]:
"""Read an endpoint configuration file from disk and extract one

config."""
if not filename:
return None

try:
content = rasa_sdk.utils.read_file(filename)
content = rasa_sdk.utils.read_yaml(content)

if content.get(endpoint_type) is None:
return None

return EndpointConfig.from_dict(content[endpoint_type])
except FileNotFoundError:
logger.error(
"Failed to read endpoint configuration "
"from {}. No such file.".format(os.path.abspath(filename))
)
return None


class EndpointConfig:
"""Configuration for an external HTTP endpoint."""

def __init__(
self,
url: Optional[Text] = None,
params: Optional[Dict[Text, Any]] = None,
headers: Optional[Dict[Text, Any]] = None,
basic_auth: Optional[Dict[Text, Text]] = None,
token: Optional[Text] = None,
token_name: Text = "token",
cafile: Optional[Text] = None,
**kwargs: Any,
) -> None:
"""Creates an `EndpointConfig` instance."""
self.url = url
self.params = params or {}
self.headers = headers or {}
self.basic_auth = basic_auth or {}
self.token = token
self.token_name = token_name
self.type = kwargs.pop("store_type", kwargs.pop("type", None))
self.cafile = cafile
self.kwargs = kwargs

@classmethod
def from_dict(cls, data: Dict[Text, Any]) -> "EndpointConfig":
return EndpointConfig(**data)
Loading