Skip to content

Commit

Permalink
Rework OpenTelemetryTracer
Browse files Browse the repository at this point in the history
  • Loading branch information
NiklasKoehneckeAA committed May 22, 2024
1 parent 4bcf987 commit 91381eb
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 114 deletions.
45 changes: 18 additions & 27 deletions src/intelligence_layer/core/tracer/open_telemetry_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@
from opentelemetry.trace import set_span_in_context

from intelligence_layer.core.tracer.tracer import (
Context,
ExportedSpan,
LogEntry,
PydanticSerializable,
Span,
TaskSpan,
Expand All @@ -21,60 +21,56 @@ class OpenTelemetryTracer(Tracer):
"""A `Tracer` that uses open telemetry."""

def __init__(self, tracer: OpenTTracer) -> None:
self.O_tracer = tracer
self._tracer = tracer

def span(
self,
name: str,
timestamp: Optional[datetime] = None,
trace_id: Optional[str] = None,
) -> "OpenTelemetrySpan":
trace_id = self.ensure_id(trace_id)
tracer_span = self._tracer.start_span(
name,
attributes={"trace_id": trace_id},
start_time=None if not timestamp else _open_telemetry_timestamp(timestamp),
)
token = attach(set_span_in_context(tracer_span))
self._tracer
return OpenTelemetrySpan(tracer_span, self._tracer, token, trace_id)
return OpenTelemetrySpan(tracer_span, self._tracer, token, self.context)

def task_span(
self,
task_name: str,
input: PydanticSerializable,
timestamp: Optional[datetime] = None,
trace_id: Optional[str] = None,
) -> "OpenTelemetryTaskSpan":
trace_id = self.ensure_id(trace_id)

tracer_span = self._tracer.start_span(
task_name,
attributes={"input": _serialize(input), "trace_id": trace_id},
attributes={"input": _serialize(input)},
start_time=None if not timestamp else _open_telemetry_timestamp(timestamp),
)
token = attach(set_span_in_context(tracer_span))
return OpenTelemetryTaskSpan(tracer_span, self._tracer, token, trace_id)
return OpenTelemetryTaskSpan(tracer_span, self._tracer, token, self.context)

def export_for_viewing(self) -> Sequence[ExportedSpan]:
raise NotImplementedError("The OpenTelemetryTracer does not support export for viewing, as it can not acces its own traces.")
raise NotImplementedError(
"The OpenTelemetryTracer does not support export for viewing, as it can not access its own traces."
)


class OpenTelemetrySpan(Span, OpenTelemetryTracer):
"""A `Span` created by `OpenTelemetryTracer.span`."""

end_timestamp: Optional[datetime] = None

def id(self) -> str:
return self._trace_id

def __init__(
self, span: OpenTSpan, tracer: OpenTTracer, token: object, trace_id: str
self,
span: OpenTSpan,
tracer: OpenTTracer,
token: object,
context: Optional[Context] = None,
) -> None:
super().__init__(tracer)
OpenTelemetryTracer.__init__(self, tracer)
Span.__init__(self, context=context)
self.open_ts_span = span
self._token = token
self._trace_id = trace_id

def log(
self,
Expand All @@ -84,28 +80,23 @@ def log(
) -> None:
self.open_ts_span.add_event(
message,
{"value": _serialize(value), "trace_id": self.id()},
{"value": _serialize(value)},
None if not timestamp else _open_telemetry_timestamp(timestamp),
)

def end(self, timestamp: Optional[datetime] = None) -> None:
super().end(timestamp)
detach(self._token)
self.open_ts_span.end(
_open_telemetry_timestamp(timestamp) if timestamp is not None else None
)
super().end(timestamp)


class OpenTelemetryTaskSpan(TaskSpan, OpenTelemetrySpan):
"""A `TaskSpan` created by `OpenTelemetryTracer.task_span`."""

output: Optional[PydanticSerializable] = None

def __init__(
self, span: OpenTSpan, tracer: OpenTTracer, token: object, trace_id: str
) -> None:
super().__init__(span, tracer, token, trace_id)

def record_output(self, output: PydanticSerializable) -> None:
self.open_ts_span.set_attribute("output", _serialize(output))

Expand Down
203 changes: 116 additions & 87 deletions tests/core/tracer/test_open_telemetry_tracer.py
Original file line number Diff line number Diff line change
@@ -1,120 +1,149 @@
import json
import time
from typing import Any, Optional
from typing import Any
from uuid import uuid4

import pytest
import requests
from aleph_alpha_client import Prompt
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.trace import ReadableSpan, TracerProvider
from opentelemetry.sdk.trace.export import (
SimpleSpanProcessor,
SpanExporter,
SpanExportResult,
)
from pytest import fixture

from intelligence_layer.core import (
CompleteInput,
CompleteOutput,
OpenTelemetryTracer,
Task,
)
from intelligence_layer.core import OpenTelemetryTracer, Task


class DummyExporter(SpanExporter):
def __init__(self) -> None:
self.spans: list[ReadableSpan] = []

def export(self, spans: trace.Sequence[ReadableSpan]) -> SpanExportResult:
self.spans.extend(spans)
return SpanExportResult.SUCCESS

def shutdown(self) -> None:
self.spans = []

def force_flush(self, timeout_millis: int = 30000) -> bool:
return True


@fixture
def open_telemetry_tracer() -> tuple[str, OpenTelemetryTracer]:
service_name = "test-service"
url = "http://localhost:16686/api/traces?service=" + service_name
def exporter() -> DummyExporter:
return DummyExporter()


@fixture
def service_name() -> str:
return "test-service"


@fixture
def test_opentelemetry_tracer(
exporter: DummyExporter, service_name: str
) -> OpenTelemetryTracer:
resource = Resource.create({SERVICE_NAME: service_name})
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
processor = BatchSpanProcessor(OTLPSpanExporter())
processor = SimpleSpanProcessor(exporter)
provider.add_span_processor(processor)
openTracer = OpenTelemetryTracer(trace.get_tracer("intelligence-layer"))
return (url, openTracer)


def _get_trace_by_id(tracing_service: str, wanted_trace_id: str) -> Optional[Any]:
request_timeout_in_seconds = 10
traces = _get_current_traces(tracing_service)
if traces:
for current_trace in traces:
trace_id = _get_trace_id_from_trace(current_trace)
if trace_id == wanted_trace_id:
return trace

request_start = time.time()
while time.time() - request_start < request_timeout_in_seconds:
traces = _get_current_traces(tracing_service)
if traces:
for current_trace in traces:
trace_id = _get_trace_id_from_trace(current_trace)
if trace_id == wanted_trace_id:
return current_trace
time.sleep(0.1)
return None


def _get_current_traces(tracing_service: str) -> Any:
response = requests.get(tracing_service)
response_text = json.loads(response.text)
return response_text["data"]
tracer = OpenTelemetryTracer(trace.get_tracer("intelligence-layer"))
return tracer


def _get_trace_id_from_trace(trace: Any) -> Optional[str]:
spans = trace["spans"]
if not spans:
return None
return _get_trace_id_from_span(spans[0])

@fixture
def jaeger_compatible_tracer(service_name: str) -> OpenTelemetryTracer:
resource = Resource.create({SERVICE_NAME: service_name})
provider = TracerProvider(resource=resource)
trace.set_tracer_provider(provider)
processor = SimpleSpanProcessor(OTLPSpanExporter())
provider.add_span_processor(processor)
tracer = OpenTelemetryTracer(trace.get_tracer("intelligence-layer"))

def _get_trace_id_from_span(span: Any) -> Optional[str]:
tags = span["tags"]
if not tags:
return None
trace_id_tag = next(tag for tag in tags if tag["key"] == "trace_id")
return str(trace_id_tag["value"])
return tracer


@pytest.mark.docker
def test_open_telemetry_tracer_check_consistency_in_trace_ids(
open_telemetry_tracer: tuple[str, OpenTelemetryTracer],
def test_open_telemetry_tracer_has_consistent_trace_id(
test_opentelemetry_tracer: OpenTelemetryTracer,
exporter: DummyExporter,
test_task: Task[str, str],
) -> None:
tracing_service, tracer = open_telemetry_tracer
expected_trace_id = tracer.ensure_id(None)
test_task.run("test-input", tracer, trace_id=expected_trace_id)
trace = _get_trace_by_id(tracing_service, expected_trace_id)

assert trace is not None
assert _get_trace_id_from_trace(trace) == expected_trace_id
spans = trace["spans"]
test_task.run("test-input", test_opentelemetry_tracer)
spans = exporter.spans
assert len(spans) == 4
for span in spans:
assert _get_trace_id_from_span(span) == expected_trace_id
assert len(set(span.context.trace_id for span in spans)) == 1


@pytest.mark.docker
def test_open_telemetry_tracer_loggs_input_and_output(
open_telemetry_tracer: tuple[str, OpenTelemetryTracer],
complete: Task[CompleteInput, CompleteOutput],
def test_open_telemetry_tracer_sets_attributes_correctly(
test_opentelemetry_tracer: OpenTelemetryTracer,
exporter: DummyExporter,
test_task: Task[str, str],
) -> None:
tracing_service, tracer = open_telemetry_tracer
input = CompleteInput(prompt=Prompt.from_text("test"))
trace_id = tracer.ensure_id(None)
complete.run(input, tracer, trace_id)
trace = _get_trace_by_id(tracing_service, trace_id)
test_task.run("test-input", test_opentelemetry_tracer)
spans = exporter.spans
assert len(spans) == 4
spans_sorted_by_start: list[ReadableSpan] = sorted(
spans, key=lambda span: span.start_time
)

assert spans_sorted_by_start[0].name == "TestTask"
assert spans_sorted_by_start[0].attributes["input"] == '"test-input"'
assert spans_sorted_by_start[0].attributes["output"] == '"output"'

assert spans_sorted_by_start[1].name == "span"
assert "input" not in spans_sorted_by_start[1].attributes.keys()

assert spans_sorted_by_start[2].name == "TestSubTask"
assert spans_sorted_by_start[2].attributes["input"] == "null"
assert spans_sorted_by_start[2].attributes["output"] == "null"

assert spans_sorted_by_start[3].name == "TestSubTask"
assert spans_sorted_by_start[3].attributes["input"] == "null"
assert spans_sorted_by_start[3].attributes["output"] == "null"

spans_sorted_by_end: list[ReadableSpan] = sorted(
spans_sorted_by_start, key=lambda span: span.end_time
)

assert spans_sorted_by_end[0] == spans_sorted_by_start[2]
assert spans_sorted_by_end[1] == spans_sorted_by_start[1]
assert spans_sorted_by_end[2] == spans_sorted_by_start[3]
assert spans_sorted_by_end[3] == spans_sorted_by_start[0]

assert trace is not None

spans = trace["spans"]
assert spans is not []
def has_span_with_input(trace: Any, input_value: str) -> bool:
return any(
tag["key"] == "input" and tag["value"] == f'"{input_value}"'
for span in trace["spans"]
for tag in span["tags"]
)

task_span = next((span for span in spans if span["references"] == []), None)
assert task_span is not None

tags = task_span["tags"]
open_tel_input_tag = [tag for tag in tags if tag["key"] == "input"]
assert len(open_tel_input_tag) == 1
def get_current_traces(tracing_service: str) -> Any:
response = requests.get(tracing_service)
response_text = json.loads(response.text)
return response_text["data"]


@pytest.mark.docker
def test_open_telemetry_tracer_works_with_jaeger(
jaeger_compatible_tracer: OpenTelemetryTracer,
test_task: Task[str, str],
service_name: str,
) -> None:
url = "http://localhost:16686/api/traces?service=" + service_name
input_value = str(uuid4())
test_task.run(input_value, jaeger_compatible_tracer)
# the processor needs time to submit the trace to jaeger
time.sleep(1)
res = get_current_traces(url)

test_res = [trace_ for trace_ in res if has_span_with_input(trace_, input_value)]

open_tel_output_tag = [tag for tag in tags if tag["key"] == "output"]
assert len(open_tel_output_tag) == 1
assert len(test_res) == 1

0 comments on commit 91381eb

Please sign in to comment.