diff --git a/src/intelligence_layer/core/tracer/open_telemetry_tracer.py b/src/intelligence_layer/core/tracer/open_telemetry_tracer.py index cc7b6149d..de0386b17 100644 --- a/src/intelligence_layer/core/tracer/open_telemetry_tracer.py +++ b/src/intelligence_layer/core/tracer/open_telemetry_tracer.py @@ -7,8 +7,8 @@ from opentelemetry.trace import set_span_in_context from intelligence_layer.core.tracer.tracer import ( + Context, ExportedSpan, - LogEntry, PydanticSerializable, Span, TaskSpan, @@ -21,43 +21,38 @@ class OpenTelemetryTracer(Tracer): """A `Tracer` that uses open telemetry.""" def __init__(self, tracer: OpenTTracer) -> None: - self.O_tracer = tracer + self._tracer = tracer def span( self, name: str, timestamp: Optional[datetime] = None, - trace_id: Optional[str] = None, ) -> "OpenTelemetrySpan": - trace_id = self.ensure_id(trace_id) tracer_span = self._tracer.start_span( name, - attributes={"trace_id": trace_id}, start_time=None if not timestamp else _open_telemetry_timestamp(timestamp), ) token = attach(set_span_in_context(tracer_span)) - self._tracer - return OpenTelemetrySpan(tracer_span, self._tracer, token, trace_id) + return OpenTelemetrySpan(tracer_span, self._tracer, token, self.context) def task_span( self, task_name: str, input: PydanticSerializable, timestamp: Optional[datetime] = None, - trace_id: Optional[str] = None, ) -> "OpenTelemetryTaskSpan": - trace_id = self.ensure_id(trace_id) - tracer_span = self._tracer.start_span( task_name, - attributes={"input": _serialize(input), "trace_id": trace_id}, + attributes={"input": _serialize(input)}, start_time=None if not timestamp else _open_telemetry_timestamp(timestamp), ) token = attach(set_span_in_context(tracer_span)) - return OpenTelemetryTaskSpan(tracer_span, self._tracer, token, trace_id) - + return OpenTelemetryTaskSpan(tracer_span, self._tracer, token, self.context) + def export_for_viewing(self) -> Sequence[ExportedSpan]: - raise NotImplementedError("The OpenTelemetryTracer does not support export for viewing, as it can not acces its own traces.") + raise NotImplementedError( + "The OpenTelemetryTracer does not support export for viewing, as it can not access its own traces." + ) class OpenTelemetrySpan(Span, OpenTelemetryTracer): @@ -65,16 +60,17 @@ class OpenTelemetrySpan(Span, OpenTelemetryTracer): end_timestamp: Optional[datetime] = None - def id(self) -> str: - return self._trace_id - def __init__( - self, span: OpenTSpan, tracer: OpenTTracer, token: object, trace_id: str + self, + span: OpenTSpan, + tracer: OpenTTracer, + token: object, + context: Optional[Context] = None, ) -> None: - super().__init__(tracer) + OpenTelemetryTracer.__init__(self, tracer) + Span.__init__(self, context=context) self.open_ts_span = span self._token = token - self._trace_id = trace_id def log( self, @@ -84,16 +80,16 @@ def log( ) -> None: self.open_ts_span.add_event( message, - {"value": _serialize(value), "trace_id": self.id()}, + {"value": _serialize(value)}, None if not timestamp else _open_telemetry_timestamp(timestamp), ) def end(self, timestamp: Optional[datetime] = None) -> None: + super().end(timestamp) detach(self._token) self.open_ts_span.end( _open_telemetry_timestamp(timestamp) if timestamp is not None else None ) - super().end(timestamp) class OpenTelemetryTaskSpan(TaskSpan, OpenTelemetrySpan): @@ -101,11 +97,6 @@ class OpenTelemetryTaskSpan(TaskSpan, OpenTelemetrySpan): output: Optional[PydanticSerializable] = None - def __init__( - self, span: OpenTSpan, tracer: OpenTTracer, token: object, trace_id: str - ) -> None: - super().__init__(span, tracer, token, trace_id) - def record_output(self, output: PydanticSerializable) -> None: self.open_ts_span.set_attribute("output", _serialize(output)) diff --git a/tests/core/tracer/test_open_telemetry_tracer.py b/tests/core/tracer/test_open_telemetry_tracer.py index 04628dda7..3b3889f47 100644 --- a/tests/core/tracer/test_open_telemetry_tracer.py +++ b/tests/core/tracer/test_open_telemetry_tracer.py @@ -1,120 +1,149 @@ import json import time -from typing import Any, Optional +from typing import Any +from uuid import uuid4 import pytest import requests -from aleph_alpha_client import Prompt from opentelemetry import trace from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter from opentelemetry.sdk.resources import SERVICE_NAME, Resource -from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import BatchSpanProcessor +from opentelemetry.sdk.trace import ReadableSpan, TracerProvider +from opentelemetry.sdk.trace.export import ( + SimpleSpanProcessor, + SpanExporter, + SpanExportResult, +) from pytest import fixture -from intelligence_layer.core import ( - CompleteInput, - CompleteOutput, - OpenTelemetryTracer, - Task, -) +from intelligence_layer.core import OpenTelemetryTracer, Task + + +class DummyExporter(SpanExporter): + def __init__(self) -> None: + self.spans: list[ReadableSpan] = [] + + def export(self, spans: trace.Sequence[ReadableSpan]) -> SpanExportResult: + self.spans.extend(spans) + return SpanExportResult.SUCCESS + + def shutdown(self) -> None: + self.spans = [] + + def force_flush(self, timeout_millis: int = 30000) -> bool: + return True @fixture -def open_telemetry_tracer() -> tuple[str, OpenTelemetryTracer]: - service_name = "test-service" - url = "http://localhost:16686/api/traces?service=" + service_name +def exporter() -> DummyExporter: + return DummyExporter() + + +@fixture +def service_name() -> str: + return "test-service" + + +@fixture +def test_opentelemetry_tracer( + exporter: DummyExporter, service_name: str +) -> OpenTelemetryTracer: resource = Resource.create({SERVICE_NAME: service_name}) provider = TracerProvider(resource=resource) trace.set_tracer_provider(provider) - processor = BatchSpanProcessor(OTLPSpanExporter()) + processor = SimpleSpanProcessor(exporter) provider.add_span_processor(processor) - openTracer = OpenTelemetryTracer(trace.get_tracer("intelligence-layer")) - return (url, openTracer) - - -def _get_trace_by_id(tracing_service: str, wanted_trace_id: str) -> Optional[Any]: - request_timeout_in_seconds = 10 - traces = _get_current_traces(tracing_service) - if traces: - for current_trace in traces: - trace_id = _get_trace_id_from_trace(current_trace) - if trace_id == wanted_trace_id: - return trace - - request_start = time.time() - while time.time() - request_start < request_timeout_in_seconds: - traces = _get_current_traces(tracing_service) - if traces: - for current_trace in traces: - trace_id = _get_trace_id_from_trace(current_trace) - if trace_id == wanted_trace_id: - return current_trace - time.sleep(0.1) - return None - - -def _get_current_traces(tracing_service: str) -> Any: - response = requests.get(tracing_service) - response_text = json.loads(response.text) - return response_text["data"] + tracer = OpenTelemetryTracer(trace.get_tracer("intelligence-layer")) + return tracer -def _get_trace_id_from_trace(trace: Any) -> Optional[str]: - spans = trace["spans"] - if not spans: - return None - return _get_trace_id_from_span(spans[0]) - +@fixture +def jaeger_compatible_tracer(service_name: str) -> OpenTelemetryTracer: + resource = Resource.create({SERVICE_NAME: service_name}) + provider = TracerProvider(resource=resource) + trace.set_tracer_provider(provider) + processor = SimpleSpanProcessor(OTLPSpanExporter()) + provider.add_span_processor(processor) + tracer = OpenTelemetryTracer(trace.get_tracer("intelligence-layer")) -def _get_trace_id_from_span(span: Any) -> Optional[str]: - tags = span["tags"] - if not tags: - return None - trace_id_tag = next(tag for tag in tags if tag["key"] == "trace_id") - return str(trace_id_tag["value"]) + return tracer -@pytest.mark.docker -def test_open_telemetry_tracer_check_consistency_in_trace_ids( - open_telemetry_tracer: tuple[str, OpenTelemetryTracer], +def test_open_telemetry_tracer_has_consistent_trace_id( + test_opentelemetry_tracer: OpenTelemetryTracer, + exporter: DummyExporter, test_task: Task[str, str], ) -> None: - tracing_service, tracer = open_telemetry_tracer - expected_trace_id = tracer.ensure_id(None) - test_task.run("test-input", tracer, trace_id=expected_trace_id) - trace = _get_trace_by_id(tracing_service, expected_trace_id) - - assert trace is not None - assert _get_trace_id_from_trace(trace) == expected_trace_id - spans = trace["spans"] + test_task.run("test-input", test_opentelemetry_tracer) + spans = exporter.spans assert len(spans) == 4 - for span in spans: - assert _get_trace_id_from_span(span) == expected_trace_id + assert len(set(span.context.trace_id for span in spans)) == 1 -@pytest.mark.docker -def test_open_telemetry_tracer_loggs_input_and_output( - open_telemetry_tracer: tuple[str, OpenTelemetryTracer], - complete: Task[CompleteInput, CompleteOutput], +def test_open_telemetry_tracer_sets_attributes_correctly( + test_opentelemetry_tracer: OpenTelemetryTracer, + exporter: DummyExporter, + test_task: Task[str, str], ) -> None: - tracing_service, tracer = open_telemetry_tracer - input = CompleteInput(prompt=Prompt.from_text("test")) - trace_id = tracer.ensure_id(None) - complete.run(input, tracer, trace_id) - trace = _get_trace_by_id(tracing_service, trace_id) + test_task.run("test-input", test_opentelemetry_tracer) + spans = exporter.spans + assert len(spans) == 4 + spans_sorted_by_start: list[ReadableSpan] = sorted( + spans, key=lambda span: span.start_time + ) + + assert spans_sorted_by_start[0].name == "TestTask" + assert spans_sorted_by_start[0].attributes["input"] == '"test-input"' + assert spans_sorted_by_start[0].attributes["output"] == '"output"' + + assert spans_sorted_by_start[1].name == "span" + assert "input" not in spans_sorted_by_start[1].attributes.keys() + + assert spans_sorted_by_start[2].name == "TestSubTask" + assert spans_sorted_by_start[2].attributes["input"] == "null" + assert spans_sorted_by_start[2].attributes["output"] == "null" + + assert spans_sorted_by_start[3].name == "TestSubTask" + assert spans_sorted_by_start[3].attributes["input"] == "null" + assert spans_sorted_by_start[3].attributes["output"] == "null" + + spans_sorted_by_end: list[ReadableSpan] = sorted( + spans_sorted_by_start, key=lambda span: span.end_time + ) + + assert spans_sorted_by_end[0] == spans_sorted_by_start[2] + assert spans_sorted_by_end[1] == spans_sorted_by_start[1] + assert spans_sorted_by_end[2] == spans_sorted_by_start[3] + assert spans_sorted_by_end[3] == spans_sorted_by_start[0] - assert trace is not None - spans = trace["spans"] - assert spans is not [] +def has_span_with_input(trace: Any, input_value: str) -> bool: + return any( + tag["key"] == "input" and tag["value"] == f'"{input_value}"' + for span in trace["spans"] + for tag in span["tags"] + ) - task_span = next((span for span in spans if span["references"] == []), None) - assert task_span is not None - tags = task_span["tags"] - open_tel_input_tag = [tag for tag in tags if tag["key"] == "input"] - assert len(open_tel_input_tag) == 1 +def get_current_traces(tracing_service: str) -> Any: + response = requests.get(tracing_service) + response_text = json.loads(response.text) + return response_text["data"] + + +@pytest.mark.docker +def test_open_telemetry_tracer_works_with_jaeger( + jaeger_compatible_tracer: OpenTelemetryTracer, + test_task: Task[str, str], + service_name: str, +) -> None: + url = "http://localhost:16686/api/traces?service=" + service_name + input_value = str(uuid4()) + test_task.run(input_value, jaeger_compatible_tracer) + # the processor needs time to submit the trace to jaeger + time.sleep(1) + res = get_current_traces(url) + + test_res = [trace_ for trace_ in res if has_span_with_input(trace_, input_value)] - open_tel_output_tag = [tag for tag in tags if tag["key"] == "output"] - assert len(open_tel_output_tag) == 1 + assert len(test_res) == 1