diff --git a/elasticsearch_serverless/_otel.py b/elasticsearch_serverless/_otel.py index 9264569..9d0863b 100644 --- a/elasticsearch_serverless/_otel.py +++ b/elasticsearch_serverless/_otel.py @@ -90,3 +90,25 @@ def span( endpoint_id=endpoint_id, body_strategy=self.body_strategy, ) + + @contextlib.contextmanager + def helpers_span(self, span_name: str) -> Generator[OpenTelemetrySpan, None, None]: + if not self.enabled or self.tracer is None: + yield OpenTelemetrySpan(None) + return + + with self.tracer.start_as_current_span(span_name) as otel_span: + otel_span.set_attribute("db.system", "elasticsearch") + otel_span.set_attribute("db.operation", span_name) + # Without a request method, Elastic APM does not display the traces + otel_span.set_attribute("http.request.method", "null") + yield OpenTelemetrySpan(otel_span) + + @contextlib.contextmanager + def use_span(self, span: OpenTelemetrySpan) -> Generator[None, None, None]: + if not self.enabled or self.tracer is None: + yield + return + + with trace.use_span(span.otel_span): + yield diff --git a/elasticsearch_serverless/helpers/actions.py b/elasticsearch_serverless/helpers/actions.py index 0e194b6..26ea929 100644 --- a/elasticsearch_serverless/helpers/actions.py +++ b/elasticsearch_serverless/helpers/actions.py @@ -34,6 +34,8 @@ Union, ) +from elastic_transport import OpenTelemetrySpan + from .. import Elasticsearch from ..compat import to_bytes from ..exceptions import ApiError, NotFoundError, TransportError @@ -322,6 +324,7 @@ def _process_bulk_chunk( Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY], ] ], + otel_span: OpenTelemetrySpan, raise_on_exception: bool = True, raise_on_error: bool = True, ignore_status: Union[int, Collection[int]] = (), @@ -331,28 +334,29 @@ def _process_bulk_chunk( """ Send a bulk request to elasticsearch and process the output. """ - if isinstance(ignore_status, int): - ignore_status = (ignore_status,) - - try: - # send the actual request - resp = client.bulk(*args, operations=bulk_actions, **kwargs) # type: ignore[arg-type] - except ApiError as e: - gen = _process_bulk_chunk_error( - error=e, - bulk_data=bulk_data, - ignore_status=ignore_status, - raise_on_exception=raise_on_exception, - raise_on_error=raise_on_error, - ) - else: - gen = _process_bulk_chunk_success( - resp=resp.body, - bulk_data=bulk_data, - ignore_status=ignore_status, - raise_on_error=raise_on_error, - ) - yield from gen + with client._otel.use_span(otel_span): + if isinstance(ignore_status, int): + ignore_status = (ignore_status,) + + try: + # send the actual request + resp = client.bulk(*args, operations=bulk_actions, **kwargs) # type: ignore[arg-type] + except ApiError as e: + gen = _process_bulk_chunk_error( + error=e, + bulk_data=bulk_data, + ignore_status=ignore_status, + raise_on_exception=raise_on_exception, + raise_on_error=raise_on_error, + ) + else: + gen = _process_bulk_chunk_success( + resp=resp.body, + bulk_data=bulk_data, + ignore_status=ignore_status, + raise_on_error=raise_on_error, + ) + yield from gen def streaming_bulk( @@ -370,6 +374,7 @@ def streaming_bulk( max_backoff: float = 600, yield_ok: bool = True, ignore_status: Union[int, Collection[int]] = (), + span_name: str = "helpers.streaming_bulk", *args: Any, **kwargs: Any, ) -> Iterable[Tuple[bool, Dict[str, Any]]]: @@ -406,73 +411,78 @@ def streaming_bulk( :arg yield_ok: if set to False will skip successful documents in the output :arg ignore_status: list of HTTP status code that you want to ignore """ - client = client.options() - client._client_meta = (("h", "bp"),) + with client._otel.helpers_span(span_name) as otel_span: + client = client.options() + client._client_meta = (("h", "bp"),) - serializer = client.transport.serializers.get_serializer("application/json") + serializer = client.transport.serializers.get_serializer("application/json") - bulk_data: List[ - Union[ - Tuple[_TYPE_BULK_ACTION_HEADER], - Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY], + bulk_data: List[ + Union[ + Tuple[_TYPE_BULK_ACTION_HEADER], + Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY], + ] ] - ] - bulk_actions: List[bytes] - for bulk_data, bulk_actions in _chunk_actions( - map(expand_action_callback, actions), chunk_size, max_chunk_bytes, serializer - ): - for attempt in range(max_retries + 1): - to_retry: List[bytes] = [] - to_retry_data: List[ - Union[ - Tuple[_TYPE_BULK_ACTION_HEADER], - Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY], - ] - ] = [] - if attempt: - time.sleep(min(max_backoff, initial_backoff * 2 ** (attempt - 1))) - - try: - for data, (ok, info) in zip( - bulk_data, - _process_bulk_chunk( - client, - bulk_actions, + bulk_actions: List[bytes] + for bulk_data, bulk_actions in _chunk_actions( + map(expand_action_callback, actions), + chunk_size, + max_chunk_bytes, + serializer, + ): + for attempt in range(max_retries + 1): + to_retry: List[bytes] = [] + to_retry_data: List[ + Union[ + Tuple[_TYPE_BULK_ACTION_HEADER], + Tuple[_TYPE_BULK_ACTION_HEADER, _TYPE_BULK_ACTION_BODY], + ] + ] = [] + if attempt: + time.sleep(min(max_backoff, initial_backoff * 2 ** (attempt - 1))) + + try: + for data, (ok, info) in zip( bulk_data, - raise_on_exception, - raise_on_error, - ignore_status, - *args, - **kwargs, - ), - ): - if not ok: - action, info = info.popitem() - # retry if retries enabled, we get 429, and we are not - # in the last attempt - if ( - max_retries - and info["status"] == 429 - and (attempt + 1) <= max_retries - ): - # _process_bulk_chunk expects bytes so we need to - # re-serialize the data - to_retry.extend(map(serializer.dumps, data)) - to_retry_data.append(data) - else: - yield ok, {action: info} - elif yield_ok: - yield ok, info - - except ApiError as e: - # suppress 429 errors since we will retry them - if attempt == max_retries or e.status_code != 429: - raise - else: - if not to_retry: - break - # retry only subset of documents that didn't succeed - bulk_actions, bulk_data = to_retry, to_retry_data + _process_bulk_chunk( + client, + bulk_actions, + bulk_data, + otel_span, + raise_on_exception, + raise_on_error, + ignore_status, + *args, + **kwargs, + ), + ): + if not ok: + action, info = info.popitem() + # retry if retries enabled, we get 429, and we are not + # in the last attempt + if ( + max_retries + and info["status"] == 429 + and (attempt + 1) <= max_retries + ): + # _process_bulk_chunk expects bytes so we need to + # re-serialize the data + to_retry.extend(map(serializer.dumps, data)) + to_retry_data.append(data) + else: + yield ok, {action: info} + elif yield_ok: + yield ok, info + + except ApiError as e: + # suppress 429 errors since we will retry them + if attempt == max_retries or e.status_code != 429: + raise + else: + if not to_retry: + break + # retry only subset of documents that didn't succeed + bulk_actions, bulk_data = to_retry, to_retry_data def bulk( @@ -519,7 +529,7 @@ def bulk( # make streaming_bulk yield successful results so we can count them kwargs["yield_ok"] = True for ok, item in streaming_bulk( - client, actions, ignore_status=ignore_status, *args, **kwargs # type: ignore[misc] + client, actions, ignore_status=ignore_status, span_name="helpers.bulk", *args, **kwargs # type: ignore[misc] ): # go through request-response pairs and detect failures if not ok: @@ -589,27 +599,31 @@ def _setup_queues(self) -> None: ] = Queue(max(queue_size, thread_count)) self._quick_put = self._inqueue.put - pool = BlockingPool(thread_count) + with client._otel.helpers_span("helpers.parallel_bulk") as otel_span: + pool = BlockingPool(thread_count) - try: - for result in pool.imap( - lambda bulk_chunk: list( - _process_bulk_chunk( - client, - bulk_chunk[1], - bulk_chunk[0], - ignore_status=ignore_status, # type: ignore[misc] - *args, - **kwargs, - ) - ), - _chunk_actions(expanded_actions, chunk_size, max_chunk_bytes, serializer), - ): - yield from result - - finally: - pool.close() - pool.join() + try: + for result in pool.imap( + lambda bulk_chunk: list( + _process_bulk_chunk( + client, + bulk_chunk[1], + bulk_chunk[0], + otel_span=otel_span, + ignore_status=ignore_status, # type: ignore[misc] + *args, + **kwargs, + ) + ), + _chunk_actions( + expanded_actions, chunk_size, max_chunk_bytes, serializer + ), + ): + yield from result + + finally: + pool.close() + pool.join() def scan( diff --git a/test_elasticsearch_serverless/test_otel.py b/test_elasticsearch_serverless/test_otel.py index 1177f38..42fda36 100644 --- a/test_elasticsearch_serverless/test_otel.py +++ b/test_elasticsearch_serverless/test_otel.py @@ -16,9 +16,12 @@ # under the License. import os +from unittest import mock import pytest +from elasticsearch_serverless import Elasticsearch, helpers + try: from opentelemetry.sdk.trace import TracerProvider, export from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( @@ -95,3 +98,25 @@ def test_detailed_span(): "db.elasticsearch.cluster.name": "e9106fc68e3044f0b1475b04bf4ffd5f", "db.elasticsearch.node.name": "instance-0000000001", } + + +@mock.patch("elasticsearch_serverless._otel.OpenTelemetry.use_span") +@mock.patch("elasticsearch_serverless._otel.OpenTelemetry.helpers_span") +@mock.patch("elasticsearch_serverless.helpers.actions._process_bulk_chunk_success") +@mock.patch("elasticsearch_serverless.Elasticsearch.bulk") +def test_forward_otel_context_to_subthreads( + _call_bulk_mock, + _process_bulk_success_mock, + _mock_otel_helpers_span, + _mock_otel_use_span, +): + tracer, memory_exporter = setup_tracing() + es_client = Elasticsearch("http://localhost:9200") + es_client._otel = OpenTelemetry(enabled=True, tracer=tracer) + + _call_bulk_mock.return_value = mock.Mock() + actions = ({"x": i} for i in range(100)) + list(helpers.parallel_bulk(es_client, actions, chunk_size=4)) + # Ensures that the OTEL context has been forwarded to all chunks + assert es_client._otel.helpers_span.call_count == 1 + assert es_client._otel.use_span.call_count == 25 diff --git a/test_elasticsearch_serverless/test_server/test_helpers.py b/test_elasticsearch_serverless/test_server/test_helpers.py index 2b83737..306d53d 100644 --- a/test_elasticsearch_serverless/test_server/test_helpers.py +++ b/test_elasticsearch_serverless/test_server/test_helpers.py @@ -41,6 +41,7 @@ def __init__( ), ): self.client = client + self._otel = client._otel self._called = 0 self._fail_at = fail_at self.transport = client.transport diff --git a/test_elasticsearch_serverless/test_server/test_otel.py b/test_elasticsearch_serverless/test_server/test_otel.py index e9ebcd1..96c2a86 100644 --- a/test_elasticsearch_serverless/test_server/test_otel.py +++ b/test_elasticsearch_serverless/test_server/test_otel.py @@ -19,14 +19,9 @@ import pytest -try: - from opentelemetry import trace - from opentelemetry.sdk.trace import TracerProvider, export - from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( - InMemorySpanExporter, - ) -except ModuleNotFoundError: - pass +import elasticsearch_serverless.helpers + +from ..test_otel import setup_tracing pytestmark = [ pytest.mark.skipif( @@ -36,13 +31,9 @@ ] -def test_otel_end_to_end(monkeypatch, sync_client): - # Sets the global default tracer provider - tracer_provider = TracerProvider() - memory_exporter = InMemorySpanExporter() - span_processor = export.SimpleSpanProcessor(memory_exporter) - tracer_provider.add_span_processor(span_processor) - trace.set_tracer_provider(tracer_provider) +def test_otel_end_to_end(sync_client): + tracer, memory_exporter = setup_tracing() + sync_client._otel.tracer = tracer resp = sync_client.search(index="logs-*", query={"match_all": {}}) assert resp.meta.status == 200 @@ -59,3 +50,59 @@ def test_otel_end_to_end(monkeypatch, sync_client): # Assert expected atttributes are here, but allow other attributes too # to make this test robust to elastic-transport changes assert expected_attributes.items() <= spans[0].attributes.items() + + +@pytest.mark.parametrize( + "bulk_helper_name", ["bulk", "streaming_bulk", "parallel_bulk"] +) +def test_otel_bulk(sync_client, elasticsearch_url, bulk_helper_name): + tracer, memory_exporter = setup_tracing() + + # Create a new client with our tracer + sync_client = sync_client.options() + sync_client._otel.tracer = tracer + # "Disable" options to keep our custom tracer + sync_client.options = lambda: sync_client + + docs = [{"answer": x, "helper": bulk_helper_name, "_id": x} for x in range(10)] + bulk_function = getattr(elasticsearch_serverless.helpers, bulk_helper_name) + if bulk_helper_name == "bulk": + success, failed = bulk_function( + sync_client, docs, index="test-index", chunk_size=2, refresh=True + ) + assert success, failed == (5, 0) + else: + for ok, resp in bulk_function( + sync_client, docs, index="test-index", chunk_size=2, refresh=True + ): + assert ok is True + + memory_exporter.shutdown() + + assert 10 == sync_client.count(index="test-index")["count"] + assert {"answer": 4, "helper": bulk_helper_name} == sync_client.get( + index="test-index", id=4 + )["_source"] + + spans = list(memory_exporter.get_finished_spans()) + parent_span = spans.pop() + assert parent_span.name == f"helpers.{bulk_helper_name}" + assert parent_span.attributes == { + "db.system": "elasticsearch", + "db.operation": f"helpers.{bulk_helper_name}", + "http.request.method": "null", + } + + assert len(spans) == 5 + for span in spans: + assert span.name == "bulk" + expected_attributes = { + "http.request.method": "PUT", + "db.system": "elasticsearch", + "db.operation": "bulk", + "db.elasticsearch.path_parts.index": "test-index", + } + # Assert expected atttributes are here, but allow other attributes too + # to make this test robust to elastic-transport changes + assert expected_attributes.items() <= spans[0].attributes.items() + assert span.parent.trace_id == parent_span.context.trace_id