From 546136312d5c8b144222ab86fdd225e1ccebccef Mon Sep 17 00:00:00 2001 From: Christophe Papazian <114495376+christophe-papazian@users.noreply.github.com> Date: Thu, 10 Oct 2024 17:36:35 +0200 Subject: [PATCH] chore(asm): refactor replacing old asm_request context by core.context (#10899) - Make core.ExecutionContext as a proper context. Add exception handler to suppress exception propagation and configure that mechanism for each context (using it now for BlockingException for threats). - Remove asm request context as a standalone context. Use core contexts instead with a specific class for appsec request data - move all appsec handlers to appsec directory - ensure appsec is loaded explicitely with all core listeners set, either when an Appsec Span Processor is created or when the IAST processor module is loaded (whichever comes first). - update gazillions of tests using the previous instrumentation with the new instrumentation - slightly modify appsec integrations by refactoring some common mechanism (set_blocked, get_blocked) ## Checklist - [x] PR author has checked that all the criteria below are met - The PR description includes an overview of the change - The PR description articulates the motivation for the change - The change includes tests OR the PR description describes a testing strategy - The PR description notes risks associated with the change, if any - Newly-added code is easy to change - The change follows the [library release note guidelines](https://ddtrace.readthedocs.io/en/stable/releasenotes.html) - The change includes or references documentation updates if necessary - Backport labels are set (if [applicable](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting)) ## Reviewer Checklist - [x] Reviewer has checked that all the criteria below are met - Title is accurate - All changes are related to the pull request's stated goal - Avoids breaking [API](https://ddtrace.readthedocs.io/en/stable/versioning.html#interfaces) changes - Testing strategy adequately addresses listed risks - Newly-added code is easy to change - Release note makes sense to a user of the library - If necessary, author has acknowledged and discussed the performance implications of this PR as reported in the benchmarks PR comment - Backport labels are set in a manner that is consistent with the [release branch maintenance policy](https://ddtrace.readthedocs.io/en/latest/contributing.html#backporting) --- ddtrace/_trace/trace_handlers.py | 165 +-------- ddtrace/appsec/__init__.py | 11 + ddtrace/appsec/_asm_request_context.py | 339 ++++++++---------- ddtrace/appsec/_common_module_patches.py | 16 +- ddtrace/appsec/_constants.py | 3 +- ddtrace/appsec/_handlers.py | 143 +++++++- ddtrace/appsec/_iast/processor.py | 4 + ddtrace/appsec/_processor.py | 27 +- ddtrace/appsec/_trace_utils.py | 18 +- ddtrace/contrib/internal/asgi/middleware.py | 25 +- ddtrace/contrib/internal/django/patch.py | 23 +- ddtrace/contrib/internal/flask/patch.py | 22 +- ddtrace/contrib/internal/starlette/patch.py | 8 +- ddtrace/contrib/internal/wsgi/wsgi.py | 9 +- ddtrace/internal/core/__init__.py | 41 ++- ddtrace/internal/utils/__init__.py | 18 + ddtrace/settings/asm.py | 11 +- .../appsec/appsec/test_appsec_trace_utils.py | 63 ++-- .../appsec/appsec/test_asm_request_context.py | 129 +++---- tests/appsec/appsec/test_processor.py | 320 +++++++---------- .../appsec/appsec/test_remoteconfiguration.py | 48 ++- tests/appsec/appsec/test_telemetry.py | 91 +++-- tests/appsec/contrib_appsec/utils.py | 11 +- tests/appsec/iast/test_telemetry.py | 4 +- tests/appsec/utils.py | 38 ++ .../django/test_django_appsec_snapshots.py | 14 +- tests/contrib/fastapi/test_fastapi_appsec.py | 5 +- tests/tracer/test_trace_utils.py | 3 +- tests/tracer/test_tracer.py | 6 +- 29 files changed, 765 insertions(+), 850 deletions(-) diff --git a/ddtrace/_trace/trace_handlers.py b/ddtrace/_trace/trace_handlers.py index 7702c6050d2..7ddffd12588 100644 --- a/ddtrace/_trace/trace_handlers.py +++ b/ddtrace/_trace/trace_handlers.py @@ -1,6 +1,6 @@ import functools -import re import sys +from typing import TYPE_CHECKING from typing import Any from typing import Callable from typing import Dict @@ -10,7 +10,6 @@ import wrapt from ddtrace._trace._span_pointer import _SpanPointerDescription -from ddtrace._trace.span import Span from ddtrace._trace.utils import extract_DD_context_from_messages from ddtrace._trace.utils_botocore.span_pointers import extract_span_pointers_from_successful_botocore_response from ddtrace._trace.utils_botocore.span_tags import ( @@ -22,7 +21,6 @@ from ddtrace.constants import SPAN_MEASURED_KEY from ddtrace.contrib import trace_utils from ddtrace.contrib.internal.botocore.constants import BOTOCORE_STEPFUNCTIONS_INPUT_KEY -from ddtrace.contrib.trace_utils import _get_request_header_user_agent from ddtrace.contrib.trace_utils import _set_url_tag from ddtrace.ext import SpanKind from ddtrace.ext import db @@ -34,14 +32,15 @@ from ddtrace.internal.constants import FLASK_ENDPOINT from ddtrace.internal.constants import FLASK_URL_RULE from ddtrace.internal.constants import FLASK_VIEW_ARGS -from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED -from ddtrace.internal.constants import RESPONSE_HEADERS from ddtrace.internal.logger import get_logger from ddtrace.internal.schema.span_attribute_schema import SpanDirection -from ddtrace.internal.utils import http as http_utils from ddtrace.propagation.http import HTTPPropagator +if TYPE_CHECKING: + from ddtrace import Span + + log = get_logger(__name__) @@ -106,7 +105,7 @@ def _get_parameters_for_new_span_directly_from_context(ctx: core.ExecutionContex return span_kwargs -def _start_span(ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -> Span: +def _start_span(ctx: core.ExecutionContext, call_trace: bool = True, **kwargs) -> "Span": span_kwargs = _get_parameters_for_new_span_directly_from_context(ctx) call_trace = ctx.get_item("call_trace", call_trace) tracer = (ctx.get_item("middleware") or ctx["pin"]).tracer @@ -160,126 +159,6 @@ def _maybe_start_http_response_span(ctx: core.ExecutionContext) -> None: ) -def _use_html(headers) -> bool: - """decide if the response should be html or json. - - Add support for quality values in the Accept header. - """ - ctype = headers.get("Accept", headers.get("accept", "")) - if not ctype: - return False - html_score = 0.0 - json_score = 0.0 - ctypes = ctype.split(",") - for ct in ctypes: - if len(ct) > 128: - # ignore long (and probably malicious) headers to avoid performances issues - continue - m = re.match(r"([^/;]+/[^/;]+)(?:;q=([01](?:\.\d*)?))?", ct.strip()) - if m: - if m.group(1) == "text/html": - html_score = max(html_score, min(1.0, float(1.0 if m.group(2) is None else m.group(2)))) - elif m.group(1) == "text/*": - html_score = max(html_score, min(1.0, float(0.2 if m.group(2) is None else m.group(2)))) - elif m.group(1) == "application/json": - json_score = max(json_score, min(1.0, float(1.0 if m.group(2) is None else m.group(2)))) - elif m.group(1) == "application/*": - json_score = max(json_score, min(1.0, float(0.2 if m.group(2) is None else m.group(2)))) - return html_score > json_score - - -def _ctype_from_headers(block_config, headers) -> str: - """compute MIME type of the blocked response.""" - desired_type = block_config.get("type", "auto") - if desired_type == "auto": - return "text/html" if _use_html(headers) else "application/json" - else: - return "text/html" if block_config["type"] == "html" else "application/json" - - -def _wsgi_make_block_content(ctx, construct_url): - middleware = ctx.get_item("middleware") - req_span = ctx.get_item("req_span") - headers = ctx.get_item("headers") - environ = ctx.get_item("environ") - if req_span is None: - raise ValueError("request span not found") - block_config = core.get_item(HTTP_REQUEST_BLOCKED, span=req_span) - desired_type = block_config.get("type", "auto") - ctype = None - if desired_type == "none": - content = "" - resp_headers = [("content-type", "text/plain; charset=utf-8"), ("location", block_config.get("location", ""))] - else: - ctype = _ctype_from_headers(block_config, headers) - content = http_utils._get_blocked_template(ctype).encode("UTF-8") - resp_headers = [("content-type", ctype)] - status = block_config.get("status_code", 403) - try: - req_span.set_tag_str(RESPONSE_HEADERS + ".content-length", str(len(content))) - if ctype is not None: - req_span.set_tag_str(RESPONSE_HEADERS + ".content-type", ctype) - req_span.set_tag_str(http.STATUS_CODE, str(status)) - url = construct_url(environ) - query_string = environ.get("QUERY_STRING") - _set_url_tag(middleware._config, req_span, url, query_string) - if query_string and middleware._config.trace_query_string: - req_span.set_tag_str(http.QUERY_STRING, query_string) - method = environ.get("REQUEST_METHOD") - if method: - req_span.set_tag_str(http.METHOD, method) - user_agent = _get_request_header_user_agent(headers, headers_are_case_sensitive=True) - if user_agent: - req_span.set_tag_str(http.USER_AGENT, user_agent) - except Exception as e: - log.warning("Could not set some span tags on blocked request: %s", str(e)) # noqa: G200 - resp_headers.append(("Content-Length", str(len(content)))) - return status, resp_headers, content - - -def _asgi_make_block_content(ctx, url): - middleware = ctx.get_item("middleware") - req_span = ctx.get_item("req_span") - headers = ctx.get_item("headers") - environ = ctx.get_item("environ") - if req_span is None: - raise ValueError("request span not found") - block_config = core.get_item(HTTP_REQUEST_BLOCKED, span=req_span) - desired_type = block_config.get("type", "auto") - ctype = None - if desired_type == "none": - content = "" - resp_headers = [ - (b"content-type", b"text/plain; charset=utf-8"), - (b"location", block_config.get("location", "").encode()), - ] - else: - ctype = _ctype_from_headers(block_config, headers) - content = http_utils._get_blocked_template(ctype).encode("UTF-8") - # ctype = f"{ctype}; charset=utf-8" can be considered at some point - resp_headers = [(b"content-type", ctype.encode())] - status = block_config.get("status_code", 403) - try: - req_span.set_tag_str(RESPONSE_HEADERS + ".content-length", str(len(content))) - if ctype is not None: - req_span.set_tag_str(RESPONSE_HEADERS + ".content-type", ctype) - req_span.set_tag_str(http.STATUS_CODE, str(status)) - query_string = environ.get("QUERY_STRING") - _set_url_tag(middleware.integration_config, req_span, url, query_string) - if query_string and middleware._config.trace_query_string: - req_span.set_tag_str(http.QUERY_STRING, query_string) - method = environ.get("REQUEST_METHOD") - if method: - req_span.set_tag_str(http.METHOD, method) - user_agent = _get_request_header_user_agent(headers, headers_are_case_sensitive=True) - if user_agent: - req_span.set_tag_str(http.USER_AGENT, user_agent) - except Exception as e: - log.warning("Could not set some span tags on blocked request: %s", str(e)) # noqa: G200 - resp_headers.append((b"Content-Length", str(len(content)).encode())) - return status, resp_headers, content - - def _on_request_prepare(ctx, start_response): middleware = ctx.get_item("middleware") req_span = ctx.get_item("req_span") @@ -438,25 +317,6 @@ def _cookies_from_response_headers(response_headers): return cookies -def _on_flask_blocked_request(span): - span.set_tag_str(http.STATUS_CODE, "403") - request = core.get_item("flask_request") - try: - base_url = getattr(request, "base_url", None) - query_string = getattr(request, "query_string", None) - if base_url and query_string: - _set_url_tag(core.get_item("flask_config"), span, base_url, query_string) - if query_string and core.get_item("flask_config").trace_query_string: - span.set_tag_str(http.QUERY_STRING, query_string) - if request.method is not None: - span.set_tag_str(http.METHOD, request.method) - user_agent = _get_request_header_user_agent(request.headers) - if user_agent: - span.set_tag_str(http.USER_AGENT, user_agent) - except Exception as e: - log.warning("Could not set some span tags on blocked request: %s", str(e)) # noqa: G200 - - def _on_flask_render(template, flask_config): span = core.get_item("current_span") if not span: @@ -508,10 +368,6 @@ def _on_request_span_modifier_post(ctx, flask_config, request, req_body): ) -def _on_start_response_blocked(ctx, flask_config, response_headers, status): - trace_utils.set_http_meta(ctx["req_span"], flask_config, status_code=status, response_headers=response_headers) - - def _on_traced_get_response_pre(_, ctx: core.ExecutionContext, request, before_request_tags): before_request_tags(ctx["pin"], ctx["call"], request) ctx["call"]._metrics[SPAN_MEASURED_KEY] = 1 @@ -568,7 +424,7 @@ def _on_django_block_request(ctx: core.ExecutionContext, metadata: Dict[str, str def _on_django_after_request_headers_post( request_headers, response_headers, - span: Span, + span: "Span", django_config, request, url, @@ -818,7 +674,7 @@ def _on_test_visibility_is_enabled() -> bool: return CIVisibility.enabled -def _set_span_pointer(span: Span, span_pointer_description: _SpanPointerDescription) -> None: +def _set_span_pointer(span: "Span", span_pointer_description: _SpanPointerDescription) -> None: span._add_span_pointer( pointer_kind=span_pointer_description.pointer_kind, pointer_direction=span_pointer_description.pointer_direction, @@ -828,8 +684,6 @@ def _set_span_pointer(span: Span, span_pointer_description: _SpanPointerDescript def listen(): - core.on("wsgi.block.started", _wsgi_make_block_content, "status_headers_content") - core.on("asgi.block.started", _asgi_make_block_content, "status_headers_content") core.on("wsgi.request.prepare", _on_request_prepare) core.on("wsgi.request.prepared", _on_request_prepared) core.on("wsgi.app.success", _on_app_success) @@ -837,11 +691,9 @@ def listen(): core.on("wsgi.request.complete", _on_request_complete, "traced_iterable") core.on("wsgi.response.prepared", _on_response_prepared) core.on("flask.start_response.pre", _on_start_response_pre) - core.on("flask.blocked_request_callable", _on_flask_blocked_request) core.on("flask.request_call_modifier", _on_request_span_modifier) core.on("flask.request_call_modifier.post", _on_request_span_modifier_post) core.on("flask.render", _on_flask_render) - core.on("flask.start_response.blocked", _on_start_response_blocked) core.on("context.started.wsgi.response", _maybe_start_http_response_span) core.on("context.started.flask._patched_request", _on_traced_request_context_started_flask) core.on("django.traced_get_response.pre", _on_traced_get_response_pre) @@ -891,6 +743,7 @@ def listen(): "flask.call", "flask.jsonify", "flask.render_template", + "asgi.__call__", "wsgi.__call__", "django.traced_get_response", "django.cache", diff --git a/ddtrace/appsec/__init__.py b/ddtrace/appsec/__init__.py index e69de29bb2d..2b5f0802c51 100644 --- a/ddtrace/appsec/__init__.py +++ b/ddtrace/appsec/__init__.py @@ -0,0 +1,11 @@ +_APPSEC_TO_BE_LOADED = True + + +def load_appsec(): + """Lazily load the appsec module listeners.""" + from ddtrace.appsec._asm_request_context import listen + + global _APPSEC_TO_BE_LOADED + if _APPSEC_TO_BE_LOADED: + listen() + _APPSEC_TO_BE_LOADED = False diff --git a/ddtrace/appsec/_asm_request_context.py b/ddtrace/appsec/_asm_request_context.py index 39229bfff6b..6d168f834a4 100644 --- a/ddtrace/appsec/_asm_request_context.py +++ b/ddtrace/appsec/_asm_request_context.py @@ -1,11 +1,10 @@ -import contextlib import functools import json +import re import sys from typing import Any from typing import Callable from typing import Dict -from typing import Generator from typing import List from typing import Optional from typing import Set @@ -13,7 +12,6 @@ from urllib import parse from ddtrace._trace.span import Span -from ddtrace.appsec import _handlers from ddtrace.appsec._constants import APPSEC from ddtrace.appsec._constants import EXPLOIT_PREVENTION from ddtrace.appsec._constants import SPAN_DATA_NAMES @@ -37,6 +35,7 @@ else: from typing_extensions import Literal # noqa:F401 +_ASM_CONTEXT: Literal["_asm_env"] = "_asm_env" _WAF_ADDRESSES: Literal["waf_addresses"] = "waf_addresses" _CALLBACKS: Literal["callbacks"] = "callbacks" _TELEMETRY: Literal["telemetry"] = "telemetry" @@ -46,7 +45,7 @@ _TELEMETRY_WAF_RESULTS: Literal["t_waf_results"] = "t_waf_results" -GLOBAL_CALLBACKS: Dict[str, List[Callable]] = {} +GLOBAL_CALLBACKS: Dict[str, List[Callable]] = {_CONTEXT_CALL: []} class ASM_Environment: @@ -56,64 +55,108 @@ class ASM_Environment: It is contained into a ContextVar. """ - def __init__(self, active: bool = False): - self.active: bool = active - self.span: Optional[Span] = None - self.span_asm_context: Optional[contextlib.AbstractContextManager] = None + def __init__(self, span: Optional[Span] = None): + self.root = not in_context() + if self.root: + core.add_suppress_exception(BlockingException) + if span is None: + self.span: Span = core.get_item(core.get_item("call_key")) + else: + self.span = span self.waf_addresses: Dict[str, Any] = {} - self.callbacks: Dict[str, Any] = {} - self.telemetry: Dict[str, Any] = {} + self.callbacks: Dict[str, Any] = {_CONTEXT_CALL: []} + self.telemetry: Dict[str, Any] = { + _TELEMETRY_WAF_RESULTS: { + "blocked": False, + "triggered": False, + "timeout": False, + "version": None, + "duration": 0.0, + "total_duration": 0.0, + "rasp": { + "sum_eval": 0, + "duration": 0.0, + "total_duration": 0.0, + "eval": {t: 0 for _, t in EXPLOIT_PREVENTION.TYPE}, + "match": {t: 0 for _, t in EXPLOIT_PREVENTION.TYPE}, + "timeout": {t: 0 for _, t in EXPLOIT_PREVENTION.TYPE}, + }, + } + } self.addresses_sent: Set[str] = set() - self.must_call_globals: bool = True self.waf_triggers: List[Dict[str, Any]] = [] + self.blocked: Optional[Dict[str, Any]] = None -def _get_asm_context() -> ASM_Environment: - env = core.get_item("asm_env") - if env is None: - env = ASM_Environment() - core.set_item("asm_env", env) - return env - - -def free_context_available() -> bool: - env = _get_asm_context() - return env.active and env.span is None +def _get_asm_context() -> Optional[ASM_Environment]: + return core.get_item(_ASM_CONTEXT) def in_context() -> bool: - env = _get_asm_context() - return env.active + return core.get_item(_ASM_CONTEXT) is not None def is_blocked() -> bool: - try: - env = _get_asm_context() - if not env.active or env.span is None: - return False - return bool(core.get_item(WAF_CONTEXT_NAMES.BLOCKED, span=env.span)) - except Exception: + env = _get_asm_context() + if env is None: return False + return env.blocked is not None -def register(span: Span, span_asm_context=None) -> None: +def get_blocked() -> Dict[str, Any]: env = _get_asm_context() - if not env.active: - log.debug("registering a span with no active asm context") - return - env.span = span - env.span_asm_context = span_asm_context + if env is None: + return {} + return env.blocked or {} + +def _use_html(headers) -> bool: + """decide if the response should be html or json. -def unregister(span: Span) -> None: + Add support for quality values in the Accept header. + """ + ctype = headers.get("Accept", headers.get("accept", "")) + if not ctype: + return False + html_score = 0.0 + json_score = 0.0 + ctypes = ctype.split(",") + for ct in ctypes: + if len(ct) > 128: + # ignore long (and probably malicious) headers to avoid performances issues + continue + m = re.match(r"([^/;]+/[^/;]+)(?:;q=([01](?:\.\d*)?))?", ct.strip()) + if m: + if m.group(1) == "text/html": + html_score = max(html_score, min(1.0, float(1.0 if m.group(2) is None else m.group(2)))) + elif m.group(1) == "text/*": + html_score = max(html_score, min(1.0, float(0.2 if m.group(2) is None else m.group(2)))) + elif m.group(1) == "application/json": + json_score = max(json_score, min(1.0, float(1.0 if m.group(2) is None else m.group(2)))) + elif m.group(1) == "application/*": + json_score = max(json_score, min(1.0, float(0.2 if m.group(2) is None else m.group(2)))) + return html_score > json_score + + +def _ctype_from_headers(block_config, headers) -> None: + """compute MIME type of the blocked response and store it in the block config""" + desired_type = block_config.get("type", "auto") + if desired_type == "auto": + block_config["content-type"] = "text/html" if _use_html(headers) else "application/json" + else: + block_config["content-type"] = "text/html" if block_config["type"] == "html" else "application/json" + + +def set_blocked(blocked: Dict[str, Any]) -> None: + blocked = blocked.copy() env = _get_asm_context() - if env.span_asm_context is not None and env.span is span: - env.span_asm_context.__exit__(None, None, None) - elif env.span is span and env.must_call_globals: - # needed for api security flushing information before end of the span - for function in GLOBAL_CALLBACKS.get(_CONTEXT_CALL, []): - function(env) - env.must_call_globals = False + if env is None: + log.debug("setting blocked with no active asm context") + return + _ctype_from_headers(blocked, get_headers()) + env.blocked = blocked + # DEV: legacy code, to be removed + core.set_item(WAF_CONTEXT_NAMES.BLOCKED, True, span=env.span) def update_span_metrics(span: Span, name: str, value: Union[float, int]) -> None: @@ -121,9 +164,16 @@ def update_span_metrics(span: Span, name: str, value: Union[float, int]) -> None def flush_waf_triggers(env: ASM_Environment) -> None: - if not env.span: - return - root_span = env.span._local_root or env.span + # Make sure we find a root span to attach the triggers to + if env.span is None: + from ddtrace import tracer + + current_span = tracer.current_span() + if current_span is None: + return + root_span = current_span._local_root or current_span + else: + root_span = env.span._local_root or env.span if env.waf_triggers: report_list = get_triggers(root_span) if report_list is not None: @@ -142,70 +192,29 @@ def flush_waf_triggers(env: ASM_Environment) -> None: root_span.set_tag_str(APPSEC.WAF_VERSION, DDWAF_VERSION) if telemetry_results["total_duration"]: update_span_metrics(root_span, APPSEC.WAF_DURATION, telemetry_results["duration"]) + telemetry_results["duration"] = 0.0 update_span_metrics(root_span, APPSEC.WAF_DURATION_EXT, telemetry_results["total_duration"]) + telemetry_results["total_duration"] = 0.0 if telemetry_results["rasp"]["sum_eval"]: update_span_metrics(root_span, APPSEC.RASP_DURATION, telemetry_results["rasp"]["duration"]) + telemetry_results["rasp"]["duration"] = 0.0 update_span_metrics(root_span, APPSEC.RASP_DURATION_EXT, telemetry_results["rasp"]["total_duration"]) + telemetry_results["rasp"]["total_duration"] = 0.0 update_span_metrics(root_span, APPSEC.RASP_RULE_EVAL, telemetry_results["rasp"]["sum_eval"]) + telemetry_results["rasp"]["sum_eval"] = 0 -GLOBAL_CALLBACKS[_CONTEXT_CALL] = [flush_waf_triggers] - - -class _DataHandler: - """ - An object of this class is created by each asm request context. - It handles the creation and destruction of ASM_Environment object. - It allows the ASM context to be reentrant. - """ - - main_id = 0 - - def __init__(self): - _DataHandler.main_id += 1 - env = ASM_Environment(True) - - self._id = _DataHandler.main_id - self._root = not in_context() - self.active = True - self.execution_context = core.ExecutionContext(__name__, **{"asm_env": env}) - - env.telemetry[_TELEMETRY_WAF_RESULTS] = { - "blocked": False, - "triggered": False, - "timeout": False, - "version": None, - "duration": 0.0, - "total_duration": 0.0, - "rasp": { - "sum_eval": 0, - "duration": 0.0, - "total_duration": 0.0, - "eval": {t: 0 for _, t in EXPLOIT_PREVENTION.TYPE}, - "match": {t: 0 for _, t in EXPLOIT_PREVENTION.TYPE}, - "timeout": {t: 0 for _, t in EXPLOIT_PREVENTION.TYPE}, - }, - } - env.callbacks[_CONTEXT_CALL] = [] - - def finalise(self): - if self.active: - self.active = False - env = self.execution_context.get_item("asm_env") - if env is not None: - callbacks = GLOBAL_CALLBACKS.get(_CONTEXT_CALL, []) if env.must_call_globals else [] - env.must_call_globals = False - if env.callbacks is not None and env.callbacks.get(_CONTEXT_CALL): - callbacks = callbacks + env.callbacks.get(_CONTEXT_CALL) - if callbacks: - for function in callbacks: - function(env) - self.execution_context.end() +def finalize_asm_env(env: ASM_Environment) -> None: + callbacks = GLOBAL_CALLBACKS[_CONTEXT_CALL] + env.callbacks[_CONTEXT_CALL] + for function in callbacks: + function(env) + flush_waf_triggers(env) + core.discard_local_item(_ASM_CONTEXT) def set_value(category: str, address: str, value: Any) -> None: env = _get_asm_context() - if not env.active: + if env is None: log.debug("setting %s address %s with no active asm context", category, address) return asm_context_attr = getattr(env, category, None) @@ -215,7 +224,7 @@ def set_value(category: str, address: str, value: Any) -> None: def set_headers_response(headers: Any) -> None: if headers is not None: - set_waf_address(SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES, headers, _get_asm_context().span) + set_waf_address(SPAN_DATA_NAMES.RESPONSE_HEADERS_NO_COOKIES, headers) def set_body_response(body_response): @@ -225,7 +234,7 @@ def set_body_response(body_response): set_waf_address(SPAN_DATA_NAMES.RESPONSE_BODY, lambda: parse_response_body(body_response)) -def set_waf_address(address: str, value: Any, span: Optional[Span] = None) -> None: +def set_waf_address(address: str, value: Any) -> None: if address == SPAN_DATA_NAMES.REQUEST_URI_RAW: parse_address = parse.urlparse(value) no_scheme = parse.ParseResult("", "", *parse_address[2:]) @@ -233,15 +242,15 @@ def set_waf_address(address: str, value: Any, span: Optional[Span] = None) -> No set_value(_WAF_ADDRESSES, address, waf_value) else: set_value(_WAF_ADDRESSES, address, value) - if span is None: - span = _get_asm_context().span - if span: - core.set_item(address, value, span=span) + env = _get_asm_context() + if env and env.span: + root = env.span._local_root or env.span + root._set_ctx_item(address, value) def get_value(category: str, address: str, default: Any = None) -> Any: env = _get_asm_context() - if not env.active: + if env is None: log.debug("getting %s address %s with no active asm context", category, address) return default asm_context_attr = getattr(env, category, None) @@ -254,11 +263,11 @@ def get_waf_address(address: str, default: Any = None) -> Any: return get_value(_WAF_ADDRESSES, address, default=default) -def get_waf_addresses(default: Any = None) -> Any: +def get_waf_addresses() -> Dict[str, Any]: env = _get_asm_context() - if not env.active: + if env is None: log.debug("getting WAF addresses with no active asm context") - return default + return {} return env.waf_addresses @@ -297,7 +306,7 @@ def call_waf_callback(custom_data: Optional[Dict[str, Any]] = None, **kwargs) -> def set_ip(ip: Optional[str]) -> None: if ip is not None: - set_waf_address(SPAN_DATA_NAMES.REQUEST_HTTP_IP, ip, _get_asm_context().span) + set_waf_address(SPAN_DATA_NAMES.REQUEST_HTTP_IP, ip) def get_ip() -> Optional[str]: @@ -311,7 +320,7 @@ def get_ip() -> Optional[str]: def set_headers(headers: Any) -> None: if headers is not None: - set_waf_address(SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES, headers, _get_asm_context().span) + set_waf_address(SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES, headers) def get_headers() -> Optional[Any]: @@ -319,7 +328,7 @@ def get_headers() -> Optional[Any]: def set_headers_case_sensitive(case_sensitive: bool) -> None: - set_waf_address(SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES_CASE, case_sensitive, _get_asm_context().span) + set_waf_address(SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES_CASE, case_sensitive) def get_headers_case_sensitive() -> bool: @@ -332,7 +341,7 @@ def set_block_request_callable(_callable: Optional[Callable], *_) -> None: the callable need any params, like headers, they should be curried with functools.partial. """ - if _callable: + if asm_config._asm_enabled and _callable: set_value(_CALLBACKS, _BLOCK_CALL, _callable) @@ -349,7 +358,7 @@ def block_request() -> None: def get_data_sent() -> Set[str]: env = _get_asm_context() - if not env.active: + if env is None: log.debug("getting addresses sent with no active asm context") return set() return env.addresses_sent @@ -405,82 +414,38 @@ def store_waf_results_data(data) -> None: if not data: return env = _get_asm_context() - if not env.active: + if env is None: log.debug("storing waf results data with no active asm context") return - if not env.span: - log.debug("storing waf results data with no active span") - return for d in data: d["span_id"] = env.span.span_id env.waf_triggers.extend(data) -@contextlib.contextmanager -def asm_request_context_manager( - remote_ip: Optional[str] = None, - headers: Any = None, - headers_case_sensitive: bool = False, - block_request_callable: Optional[Callable] = None, -) -> Generator[Optional[_DataHandler], None, None]: - """ - The ASM context manager - """ - resources = _start_context(remote_ip, headers, headers_case_sensitive, block_request_callable) - if resources is not None: - try: - yield resources - except BlockingException as e: - # ensure that the BlockingRequest that is never raised outside a context - # is also never propagated outside the context - core.set_item(WAF_CONTEXT_NAMES.BLOCKED, e.args[0]) - if not resources._root: - raise - finally: - _end_context(resources) - else: - yield None - - -def _start_context( - remote_ip: Optional[str], headers: Any, headers_case_sensitive: bool, block_request_callable: Optional[Callable] -) -> Optional[_DataHandler]: - if asm_config._asm_enabled or asm_config._iast_enabled: - resources = _DataHandler() - if asm_config._asm_enabled: - asm_request_context_set(remote_ip, headers, headers_case_sensitive, block_request_callable) - _handlers.listen() - listen_context_handlers() - return resources - return None - - -def _on_context_started(ctx): - resources = _start_context( - ctx.get_item("remote_addr"), - ctx.get_item("headers"), - ctx.get_item("headers_case_sensitive"), - ctx.get_item("block_request_callable"), - ) - ctx.set_item("resources", resources) +def start_context(span: Span): + if asm_config._asm_enabled: + # it should only be called at start of a core context, when ASM_Env is not set yet + core.set_item(_ASM_CONTEXT, ASM_Environment(span=span)) + asm_request_context_set( + core.get_local_item("remote_addr"), + core.get_local_item("headers"), + core.get_local_item("headers_case_sensitive"), + core.get_local_item("block_request_callable"), + ) + elif asm_config._iast_enabled: + core.set_item(_ASM_CONTEXT, ASM_Environment()) -def _end_context(resources): - resources.finalise() - core.set_item("asm_env", None) +def end_context(span: Span): + env = _get_asm_context() + if env is not None and env.span is span: + finalize_asm_env(env) def _on_context_ended(ctx): - resources = ctx.get_item("resources") - if resources is not None: - _end_context(resources) - - -core.on("context.started.wsgi.__call__", _on_context_started) -core.on("context.ended.wsgi.__call__", _on_context_ended) -core.on("context.started.django.traced_get_response", _on_context_started) -core.on("context.ended.django.traced_get_response", _on_context_ended) -core.on("django.traced_get_response.pre", set_block_request_callable) + env = ctx.get_local_item(_ASM_CONTEXT) + if env is not None: + finalize_asm_env(env) def _on_wrapped_view(kwargs): @@ -540,7 +505,7 @@ def _on_pre_tracedrequest(ctx): current_span = ctx["current_span"] if asm_config._asm_enabled: set_block_request_callable(functools.partial(block_request_callable, current_span)) - if core.get_item(WAF_CONTEXT_NAMES.BLOCKED): + if get_blocked(): block_request() @@ -591,7 +556,10 @@ def _get_headers_if_appsec(): return get_headers() -def listen_context_handlers(): +def listen(): + from ddtrace.appsec._handlers import listen + + listen() core.on("flask.finalize_request.post", _set_headers_and_response) core.on("flask.wrapped_view", _on_wrapped_view, "callback_and_args") core.on("flask._patched_request", _on_pre_tracedrequest) @@ -608,3 +576,12 @@ def listen_context_handlers(): core.on("asgi.start_request", _call_waf_first) core.on("asgi.start_response", _call_waf) core.on("asgi.finalize_response", _set_headers_and_response) + + core.on("asm.set_blocked", set_blocked) + core.on("asm.get_blocked", get_blocked, "block_config") + + core.on("context.ended.wsgi.__call__", _on_context_ended) + core.on("context.ended.asgi.__call__", _on_context_ended) + + core.on("context.ended.django.traced_get_response", _on_context_ended) + core.on("django.traced_get_response.pre", set_block_request_callable) diff --git a/ddtrace/appsec/_common_module_patches.py b/ddtrace/appsec/_common_module_patches.py index 5ff91867848..2660617dc78 100644 --- a/ddtrace/appsec/_common_module_patches.py +++ b/ddtrace/appsec/_common_module_patches.py @@ -12,8 +12,8 @@ from wrapt import resolve_path import ddtrace +from ddtrace.appsec._asm_request_context import get_blocked from ddtrace.appsec._constants import WAF_ACTIONS -from ddtrace.appsec._constants import WAF_CONTEXT_NAMES from ddtrace.appsec._iast._metrics import _set_metric_iast_instrumented_sink from ddtrace.appsec._iast.constants import VULN_PATH_TRAVERSAL from ddtrace.internal import core @@ -100,7 +100,7 @@ def wrapped_open_CFDDB7ABBA9081B6(original_open_callable, instance, args, kwargs rule_type=EXPLOIT_PREVENTION.TYPE.LFI, ) if res and _must_block(res.actions): - raise BlockingException(core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "lfi", filename) + raise BlockingException(get_blocked(), "exploit_prevention", "lfi", filename) try: return original_open_callable(*args, **kwargs) except Exception as e: @@ -144,7 +144,7 @@ def wrapped_open_ED4CF71136E15EBF(original_open_callable, instance, args, kwargs rule_type=EXPLOIT_PREVENTION.TYPE.SSRF, ) if res and _must_block(res.actions): - raise BlockingException(core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "ssrf", url) + raise BlockingException(get_blocked(), "exploit_prevention", "ssrf", url) return original_open_callable(*args, **kwargs) @@ -182,7 +182,7 @@ def wrapped_request_D8CB81E472AF98A2(original_request_callable, instance, args, rule_type=EXPLOIT_PREVENTION.TYPE.SSRF, ) if res and _must_block(res.actions): - raise BlockingException(core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "ssrf", url) + raise BlockingException(get_blocked(), "exploit_prevention", "ssrf", url) return original_request_callable(*args, **kwargs) @@ -218,9 +218,7 @@ def wrapped_system_5542593D237084A7(original_command_callable, instance, args, k rule_type=EXPLOIT_PREVENTION.TYPE.CMDI, ) if res and _must_block(res.actions): - raise BlockingException( - core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "cmdi", command - ) + raise BlockingException(get_blocked(), "exploit_prevention", "cmdi", command) try: return original_command_callable(*args, **kwargs) except Exception as e: @@ -274,9 +272,7 @@ def execute_4C9BAC8E228EB347(instrument_self, query, args, kwargs) -> None: rule_type=EXPLOIT_PREVENTION.TYPE.SQLI, ) if res and _must_block(res.actions): - raise BlockingException( - core.get_item(WAF_CONTEXT_NAMES.BLOCKED), "exploit_prevention", "sqli", query - ) + raise BlockingException(get_blocked(), "exploit_prevention", "sqli", query) def try_unwrap(module, name): diff --git a/ddtrace/appsec/_constants.py b/ddtrace/appsec/_constants.py index 9054dc41882..a585cc48411 100644 --- a/ddtrace/appsec/_constants.py +++ b/ddtrace/appsec/_constants.py @@ -13,6 +13,7 @@ from typing import Any from typing import Iterator +from typing import Tuple from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED from ddtrace.internal.constants import REQUEST_PATH_PARAMS @@ -32,7 +33,7 @@ class Constant_Class(type): def __setattr__(self, __name: str, __value: Any) -> None: raise TypeError("Constant class does not support item assignment: %s.%s" % (self.__name__, __name)) - def __iter__(self) -> Iterator[str]: + def __iter__(self) -> Iterator[Tuple[str, Any]]: def aux(): for t in self.__dict__.items(): if not t[0].startswith("_"): diff --git a/ddtrace/appsec/_handlers.py b/ddtrace/appsec/_handlers.py index 5b8051d0e5b..a815edaf360 100644 --- a/ddtrace/appsec/_handlers.py +++ b/ddtrace/appsec/_handlers.py @@ -7,15 +7,21 @@ from wrapt import wrap_function_wrapper as _w import xmltodict +from ddtrace.appsec._asm_request_context import get_blocked from ddtrace.appsec._constants import SPAN_DATA_NAMES from ddtrace.appsec._iast._patch import if_iast_taint_returned_object_for from ddtrace.appsec._iast._patch import if_iast_taint_yield_tuple_for from ddtrace.appsec._iast._utils import _is_iast_enabled from ddtrace.contrib import trace_utils +from ddtrace.contrib.trace_utils import _get_request_header_user_agent +from ddtrace.contrib.trace_utils import _set_url_tag from ddtrace.ext import SpanTypes +from ddtrace.ext import http from ddtrace.internal import core from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED +from ddtrace.internal.constants import RESPONSE_HEADERS from ddtrace.internal.logger import get_logger +from ddtrace.internal.utils import http as http_utils from ddtrace.internal.utils.http import parse_form_multipart from ddtrace.settings.asm import config as asm_config @@ -89,10 +95,7 @@ def _on_set_http_meta( ] for k, v in addresses: if v is not None: - set_waf_address(k, v, span) - - -core.on("set_http_meta_for_asm", _on_set_http_meta) + set_waf_address(k, v) # ASGI @@ -278,10 +281,6 @@ def _on_flask_patch(flask_version): _set_metric_iast_instrumented_source(OriginType.QUERY) -def _on_flask_blocked_request(_): - core.set_item(HTTP_REQUEST_BLOCKED, True) - - def _on_django_func_wrapped(fn_args, fn_kwargs, first_arg_expected_type, *_): # If IAST is enabled and we're wrapping a Django view call, taint the kwargs (view's # path parameters) @@ -482,19 +481,131 @@ def _on_grpc_server_data(headers, request_message, method, metadata): set_waf_address(SPAN_DATA_NAMES.GRPC_SERVER_REQUEST_METADATA, dict(metadata)) +def _wsgi_make_block_content(ctx, construct_url): + middleware = ctx.get_item("middleware") + req_span = ctx.get_item("req_span") + headers = ctx.get_item("headers") + environ = ctx.get_item("environ") + if req_span is None: + raise ValueError("request span not found") + block_config = get_blocked() + desired_type = block_config.get("type", "auto") + ctype = None + if desired_type == "none": + content = "" + resp_headers = [("content-type", "text/plain; charset=utf-8"), ("location", block_config.get("location", ""))] + else: + ctype = block_config.get("content-type", "application/json") + content = http_utils._get_blocked_template(ctype).encode("UTF-8") + resp_headers = [("content-type", ctype)] + status = block_config.get("status_code", 403) + try: + req_span.set_tag_str(RESPONSE_HEADERS + ".content-length", str(len(content))) + if ctype is not None: + req_span.set_tag_str(RESPONSE_HEADERS + ".content-type", ctype) + req_span.set_tag_str(http.STATUS_CODE, str(status)) + url = construct_url(environ) + query_string = environ.get("QUERY_STRING") + _set_url_tag(middleware._config, req_span, url, query_string) + if query_string and middleware._config.trace_query_string: + req_span.set_tag_str(http.QUERY_STRING, query_string) + method = environ.get("REQUEST_METHOD") + if method: + req_span.set_tag_str(http.METHOD, method) + user_agent = _get_request_header_user_agent(headers, headers_are_case_sensitive=True) + if user_agent: + req_span.set_tag_str(http.USER_AGENT, user_agent) + except Exception as e: + log.warning("Could not set some span tags on blocked request: %s", str(e)) # noqa: G200 + resp_headers.append(("Content-Length", str(len(content)))) + return status, resp_headers, content + + +def _asgi_make_block_content(ctx, url): + middleware = ctx.get_item("middleware") + req_span = ctx.get_item("req_span") + headers = ctx.get_item("headers") + environ = ctx.get_item("environ") + if req_span is None: + raise ValueError("request span not found") + block_config = get_blocked() + desired_type = block_config.get("type", "auto") + ctype = None + if desired_type == "none": + content = "" + resp_headers = [ + (b"content-type", b"text/plain; charset=utf-8"), + (b"location", block_config.get("location", "").encode()), + ] + else: + ctype = block_config.get("content-type", "application/json") + content = http_utils._get_blocked_template(ctype).encode("UTF-8") + # ctype = f"{ctype}; charset=utf-8" can be considered at some point + resp_headers = [(b"content-type", ctype.encode())] + status = block_config.get("status_code", 403) + try: + req_span.set_tag_str(RESPONSE_HEADERS + ".content-length", str(len(content))) + if ctype is not None: + req_span.set_tag_str(RESPONSE_HEADERS + ".content-type", ctype) + req_span.set_tag_str(http.STATUS_CODE, str(status)) + query_string = environ.get("QUERY_STRING") + _set_url_tag(middleware.integration_config, req_span, url, query_string) + if query_string and middleware._config.trace_query_string: + req_span.set_tag_str(http.QUERY_STRING, query_string) + method = environ.get("REQUEST_METHOD") + if method: + req_span.set_tag_str(http.METHOD, method) + user_agent = _get_request_header_user_agent(headers, headers_are_case_sensitive=True) + if user_agent: + req_span.set_tag_str(http.USER_AGENT, user_agent) + except Exception as e: + log.warning("Could not set some span tags on blocked request: %s", str(e)) # noqa: G200 + resp_headers.append((b"Content-Length", str(len(content)).encode())) + return status, resp_headers, content + + +def _on_flask_blocked_request(span): + core.set_item(HTTP_REQUEST_BLOCKED, True) + span.set_tag_str(http.STATUS_CODE, "403") + request = core.get_item("flask_request") + try: + base_url = getattr(request, "base_url", None) + query_string = getattr(request, "query_string", None) + if base_url and query_string: + _set_url_tag(core.get_item("flask_config"), span, base_url, query_string) + if query_string and core.get_item("flask_config").trace_query_string: + span.set_tag_str(http.QUERY_STRING, query_string) + if request.method is not None: + span.set_tag_str(http.METHOD, request.method) + user_agent = _get_request_header_user_agent(request.headers) + if user_agent: + span.set_tag_str(http.USER_AGENT, user_agent) + except Exception as e: + log.warning("Could not set some span tags on blocked request: %s", str(e)) # noqa: G200 + + +def _on_start_response_blocked(ctx, flask_config, response_headers, status): + trace_utils.set_http_meta(ctx["req_span"], flask_config, status_code=status, response_headers=response_headers) + + def listen(): + core.on("set_http_meta_for_asm", _on_set_http_meta) core.on("flask.request_call_modifier", _on_request_span_modifier, "request_body") core.on("flask.request_init", _on_request_init) core.on("flask.blocked_request_callable", _on_flask_blocked_request) + core.on("flask.start_response.blocked", _on_start_response_blocked) + + core.on("django.func.wrapped", _on_django_func_wrapped) + core.on("django.wsgi_environ", _on_wsgi_environ, "wrapped_result") + core.on("django.patch", _on_django_patch) + core.on("flask.patch", _on_flask_patch) -core.on("django.func.wrapped", _on_django_func_wrapped) -core.on("django.wsgi_environ", _on_wsgi_environ, "wrapped_result") -core.on("django.patch", _on_django_patch) -core.on("flask.patch", _on_flask_patch) + core.on("asgi.request.parse.body", _on_asgi_request_parse_body, "await_receive_and_body") -core.on("asgi.request.parse.body", _on_asgi_request_parse_body, "await_receive_and_body") + core.on("grpc.client.response.message", _on_grpc_response) + core.on("grpc.server.response.message", _on_grpc_server_response) + core.on("grpc.server.data", _on_grpc_server_data) -core.on("grpc.client.response.message", _on_grpc_response) -core.on("grpc.server.response.message", _on_grpc_server_response) -core.on("grpc.server.data", _on_grpc_server_data) + core.on("wsgi.block.started", _wsgi_make_block_content, "status_headers_content") + core.on("asgi.block.started", _asgi_make_block_content, "status_headers_content") diff --git a/ddtrace/appsec/_iast/processor.py b/ddtrace/appsec/_iast/processor.py index 92001f2f026..0a04e8b2be1 100644 --- a/ddtrace/appsec/_iast/processor.py +++ b/ddtrace/appsec/_iast/processor.py @@ -3,6 +3,7 @@ from ddtrace._trace.processor import SpanProcessor from ddtrace._trace.span import Span +from ddtrace.appsec import load_appsec from ddtrace.appsec._constants import APPSEC from ddtrace.appsec._constants import IAST from ddtrace.constants import ORIGIN_KEY @@ -84,3 +85,6 @@ def on_span_finish(self, span: Span): span.set_tag_str(ORIGIN_KEY, APPSEC.ORIGIN_VALUE) oce.release_request() + + +load_appsec() diff --git a/ddtrace/appsec/_processor.py b/ddtrace/appsec/_processor.py index b2e9e7b58f1..ab7a74ee99b 100644 --- a/ddtrace/appsec/_processor.py +++ b/ddtrace/appsec/_processor.py @@ -17,13 +17,13 @@ from ddtrace._trace.processor import SpanProcessor from ddtrace._trace.span import Span from ddtrace.appsec import _asm_request_context +from ddtrace.appsec import load_appsec from ddtrace.appsec._constants import APPSEC from ddtrace.appsec._constants import DEFAULT from ddtrace.appsec._constants import EXPLOIT_PREVENTION from ddtrace.appsec._constants import FINGERPRINTING from ddtrace.appsec._constants import SPAN_DATA_NAMES from ddtrace.appsec._constants import WAF_ACTIONS -from ddtrace.appsec._constants import WAF_CONTEXT_NAMES from ddtrace.appsec._constants import WAF_DATA_NAMES from ddtrace.appsec._ddwaf import DDWaf_result from ddtrace.appsec._ddwaf.ddwaf_types import ddwaf_context_capsule @@ -143,6 +143,7 @@ def enabled(self): def __post_init__(self) -> None: from ddtrace.appsec._ddwaf import DDWaf + load_appsec() self.obfuscation_parameter_key_regexp = asm_config._asm_obfuscation_parameter_key_regexp.encode() self.obfuscation_parameter_value_regexp = asm_config._asm_obfuscation_parameter_value_regexp.encode() self._rules = None @@ -236,12 +237,7 @@ def on_span_start(self, span: Span) -> None: if span.span_type not in {SpanTypes.WEB, SpanTypes.GRPC}: return - if _asm_request_context.free_context_available(): - _asm_request_context.register(span) - else: - new_asm_context = _asm_request_context.asm_request_context_manager() - new_asm_context.__enter__() - _asm_request_context.register(span, new_asm_context) + _asm_request_context.start_context(span) ctx = self._ddwaf._at_request_start() self._span_to_waf_ctx[span] = ctx @@ -258,19 +254,18 @@ def waf_callable(custom_data=None, **kwargs): _asm_request_context.set_waf_callback(waf_callable) _asm_request_context.add_context_callback(_set_waf_request_metrics) if headers is not None: - _asm_request_context.set_waf_address(SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES, headers, span) + _asm_request_context.set_waf_address(SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES, headers) _asm_request_context.set_waf_address( - SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES_CASE, headers_case_sensitive, span + SPAN_DATA_NAMES.REQUEST_HEADERS_NO_COOKIES_CASE, headers_case_sensitive ) if not peer_ip: return ip = trace_utils._get_request_header_client_ip(headers, peer_ip, headers_case_sensitive) # Save the IP and headers in the context so the retrieval can be skipped later - _asm_request_context.set_waf_address(SPAN_DATA_NAMES.REQUEST_HTTP_IP, ip, span) + _asm_request_context.set_waf_address(SPAN_DATA_NAMES.REQUEST_HTTP_IP, ip) if ip and self._is_needed(WAF_DATA_NAMES.REQUEST_HTTP_IP): log.debug("[DDAS-001-00] Executing ASM WAF for checking IP block") - # _asm_request_context.call_callback() _asm_request_context.call_waf_callback({"REQUEST_HTTP_IP": None}) def _waf_action( @@ -295,7 +290,7 @@ def _waf_action( if span.span_type not in (SpanTypes.WEB, SpanTypes.HTTP, SpanTypes.GRPC): return None - if core.get_item(WAF_CONTEXT_NAMES.BLOCKED, span=span) or core.get_item(WAF_CONTEXT_NAMES.BLOCKED): + if _asm_request_context.get_blocked(): # We still must run the waf if we need to extract schemas for API SECURITY if not custom_data or not custom_data.get("PROCESSOR_SETTINGS", {}).get("extract-schema", False): return None @@ -367,8 +362,7 @@ def _waf_action( waf_results.total_runtime, ) if blocked: - core.set_item(WAF_CONTEXT_NAMES.BLOCKED, blocked, span=span) - core.set_item(WAF_CONTEXT_NAMES.BLOCKED, blocked) + _asm_request_context.set_blocked(blocked) try: info = self._ddwaf.info @@ -441,9 +435,10 @@ def on_span_finish(self, span: Span) -> None: _asm_request_context.call_waf_callback() self._ddwaf._at_request_end() + _asm_request_context.end_context(span) finally: - # release asm context if it was created by the span - _asm_request_context.unregister(span) + # release asm context associated with that span if it was not already done + _asm_request_context.end_context(span) if span.span_type not in {SpanTypes.WEB, SpanTypes.GRPC}: return diff --git a/ddtrace/appsec/_trace_utils.py b/ddtrace/appsec/_trace_utils.py index 9d0db0ab40b..0c84d45bca1 100644 --- a/ddtrace/appsec/_trace_utils.py +++ b/ddtrace/appsec/_trace_utils.py @@ -4,9 +4,9 @@ from ddtrace import constants from ddtrace._trace.span import Span from ddtrace.appsec import _asm_request_context +from ddtrace.appsec._asm_request_context import get_blocked from ddtrace.appsec._constants import APPSEC from ddtrace.appsec._constants import LOGIN_EVENTS_MODE -from ddtrace.appsec._constants import WAF_CONTEXT_NAMES from ddtrace.appsec._utils import _hash_user_id from ddtrace.contrib.trace_utils import set_user from ddtrace.ext import SpanTypes @@ -246,22 +246,12 @@ def should_block_user(tracer: Tracer, userid: str) -> bool: ) return False - # Early check to avoid calling the WAF if the request is already blocked - span = tracer.current_root_span() - if not span: - log.warning( - "No root span in the current execution. should_block_user returning False" - "See https://docs.datadoghq.com/security_platform/application_security" - "/setup_and_configure/" - "?tab=set_user&code-lang=python for more information.", - ) - return False - - if core.get_item(WAF_CONTEXT_NAMES.BLOCKED, span=span): + # Early check to avoid calling the WAF if the request is already blockedxw + if get_blocked(): return True _asm_request_context.call_waf_callback(custom_data={"REQUEST_USER_ID": str(userid)}) - return bool(core.get_item(WAF_CONTEXT_NAMES.BLOCKED, span=span)) + return bool(get_blocked()) def block_request() -> None: diff --git a/ddtrace/contrib/internal/asgi/middleware.py b/ddtrace/contrib/internal/asgi/middleware.py index 765f50044d4..993f2500bd6 100644 --- a/ddtrace/contrib/internal/asgi/middleware.py +++ b/ddtrace/contrib/internal/asgi/middleware.py @@ -23,6 +23,7 @@ from ddtrace.internal.logger import get_logger from ddtrace.internal.schema import schematize_url_operation from ddtrace.internal.schema.span_attribute_schema import SpanDirection +from ddtrace.internal.utils import get_blocked log = get_logger(__name__) @@ -138,20 +139,19 @@ async def __call__(self, scope, receive, send): operation_name = schematize_url_operation(operation_name, direction=SpanDirection.INBOUND, protocol="http") pin = ddtrace.pin.Pin(service="asgi", tracer=self.tracer) - with pin.tracer.trace( - name=operation_name, - service=trace_utils.int_service(None, self.integration_config), - resource=resource, - span_type=SpanTypes.WEB, - ) as span, core.context_with_data( + with core.context_with_data( "asgi.__call__", remote_addr=scope.get("REMOTE_ADDR"), headers=headers, headers_case_sensitive=True, environ=scope, middleware=self, - span=span, - ) as ctx: + span_name=operation_name, + resource=resource, + span_type=SpanTypes.WEB, + service=trace_utils.int_service(None, self.integration_config), + pin=pin, + ) as ctx, ctx.get_item("call") as span: span.set_tag_str(COMPONENT, self.integration_config.integration_name) ctx.set_item("req_span", span) @@ -242,9 +242,9 @@ async def wrapped_send(message): ) core.dispatch("asgi.start_response", ("asgi",)) core.dispatch("asgi.finalize_response", (message.get("body"), response_headers)) - - if core.get_item(HTTP_REQUEST_BLOCKED): - raise trace_utils.InterruptException("wrapped_send") + blocked = get_blocked() + if blocked: + raise BlockingException(blocked) try: return await send(message) finally: @@ -287,12 +287,11 @@ async def wrapped_blocked_send(message): try: core.dispatch("asgi.start_request", ("asgi",)) + # Do not block right here. Wait for route to be resolved in starlette/patch.py return await self.app(scope, receive, wrapped_send) except BlockingException as e: core.set_item(HTTP_REQUEST_BLOCKED, e.args[0]) return await _blocked_asgi_app(scope, receive, wrapped_blocked_send) - except trace_utils.InterruptException: - return await _blocked_asgi_app(scope, receive, wrapped_blocked_send) except Exception as exc: (exc_type, exc_val, exc_tb) = sys.exc_info() span.set_exc_info(exc_type, exc_val, exc_tb) diff --git a/ddtrace/contrib/internal/django/patch.py b/ddtrace/contrib/internal/django/patch.py index 5fcaa325099..f0835d50945 100644 --- a/ddtrace/contrib/internal/django/patch.py +++ b/ddtrace/contrib/internal/django/patch.py @@ -19,7 +19,6 @@ from ddtrace import Pin from ddtrace import config -from ddtrace._trace.trace_handlers import _ctype_from_headers from ddtrace.appsec._utils import _UserInfoRetriever from ddtrace.constants import SPAN_KIND from ddtrace.contrib import dbapi @@ -36,15 +35,15 @@ from ddtrace.internal.compat import Iterable from ddtrace.internal.compat import maybe_stringify from ddtrace.internal.constants import COMPONENT -from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED -from ddtrace.internal.constants import STATUS_403_TYPE_AUTO from ddtrace.internal.core.event_hub import ResultType from ddtrace.internal.logger import get_logger from ddtrace.internal.schema import schematize_service_name from ddtrace.internal.schema import schematize_url_operation from ddtrace.internal.schema.span_attribute_schema import SpanDirection from ddtrace.internal.utils import get_argument_value +from ddtrace.internal.utils import get_blocked from ddtrace.internal.utils import http as http_utils +from ddtrace.internal.utils import set_blocked from ddtrace.internal.utils.formats import asbool from ddtrace.internal.utils.importlib import func_name from ddtrace.propagation._database_monitoring import _DBM_Propagator @@ -443,7 +442,7 @@ def _block_request_callable(request, request_headers, ctx: core.ExecutionContext # at any point so it's a callable stored in the ASM context. from django.core.exceptions import PermissionDenied - core.root.set_item(HTTP_REQUEST_BLOCKED, STATUS_403_TYPE_AUTO) + set_blocked() _gather_block_metadata(request, request_headers, ctx) raise PermissionDenied() @@ -496,7 +495,7 @@ def traced_get_response(django, pin, func, instance, args, kwargs): def blocked_response(): from django.http import HttpResponse - block_config = core.get_item(HTTP_REQUEST_BLOCKED) or {} + block_config = get_blocked() or {} desired_type = block_config.get("type", "auto") status = block_config.get("status_code", 403) if desired_type == "none": @@ -505,7 +504,7 @@ def blocked_response(): if location: response["location"] = location else: - ctype = _ctype_from_headers(block_config, request_headers) + ctype = block_config.get("content-type", "application/json") content = http_utils._get_blocked_template(ctype) response = HttpResponse(content, content_type=ctype, status=status) response.content = content @@ -514,7 +513,7 @@ def blocked_response(): return response try: - if core.get_item(HTTP_REQUEST_BLOCKED): + if get_blocked(): response = blocked_response() return response @@ -535,27 +534,27 @@ def blocked_response(): ) core.dispatch("django.start_response.post", ("Django",)) - if core.get_item(HTTP_REQUEST_BLOCKED): + if get_blocked(): response = blocked_response() return response try: response = func(*args, **kwargs) except BlockingException as e: - core.set_item(HTTP_REQUEST_BLOCKED, e.args[0]) + set_blocked(e.args[0]) response = blocked_response() return response - if core.get_item(HTTP_REQUEST_BLOCKED): + if get_blocked(): response = blocked_response() return response return response finally: core.dispatch("django.finalize_response.pre", (ctx, utils._after_request_tags, request, response)) - if not core.get_item(HTTP_REQUEST_BLOCKED): + if not get_blocked(): core.dispatch("django.finalize_response", ("Django",)) - if core.get_item(HTTP_REQUEST_BLOCKED): + if get_blocked(): response = blocked_response() return response # noqa: B012 diff --git a/ddtrace/contrib/internal/flask/patch.py b/ddtrace/contrib/internal/flask/patch.py index ba02d966af2..7b1f1f7c127 100644 --- a/ddtrace/contrib/internal/flask/patch.py +++ b/ddtrace/contrib/internal/flask/patch.py @@ -6,18 +6,17 @@ from werkzeug.exceptions import NotFound from werkzeug.exceptions import abort -from ddtrace._trace.trace_handlers import _ctype_from_headers from ddtrace.contrib import trace_utils from ddtrace.ext import SpanTypes from ddtrace.internal import core from ddtrace.internal.constants import COMPONENT -from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED -from ddtrace.internal.constants import STATUS_403_TYPE_AUTO from ddtrace.internal.packages import get_version_for_package from ddtrace.internal.schema import schematize_service_name from ddtrace.internal.schema import schematize_url_operation from ddtrace.internal.schema.span_attribute_schema import SpanDirection +from ddtrace.internal.utils import get_blocked from ddtrace.internal.utils import http as http_utils +from ddtrace.internal.utils import set_blocked # Not all versions of flask/werkzeug have this mixin @@ -110,25 +109,22 @@ class _FlaskWSGIMiddleware(_DDWSGIMiddlewareBase): def _wrapped_start_response(self, start_response, ctx, status_code, headers, exc_info=None): core.dispatch("flask.start_response.pre", (flask.request, ctx, config.flask, status_code, headers)) - if not core.get_item(HTTP_REQUEST_BLOCKED): - headers_from_context = "" - result_waf = core.dispatch_with_results("flask.start_response", ("Flask",)).waf - if result_waf: - headers_from_context = result_waf.value - if core.get_item(HTTP_REQUEST_BLOCKED): + if not get_blocked(): + core.dispatch("flask.start_response", ("Flask",)) + if get_blocked(): # response code must be set here, or it will be too late result_content = core.dispatch_with_results("flask.block.request.content", ()).block_requested if result_content: _, status, response_headers = result_content.value result = start_response(str(status), response_headers) else: - block_config = core.get_item(HTTP_REQUEST_BLOCKED) + block_config = get_blocked() desired_type = block_config.get("type", "auto") status = block_config.get("status_code", 403) if desired_type == "none": response_headers = [] else: - ctype = _ctype_from_headers(block_config, headers_from_context) + ctype = block_config.get("content-type", "application/json") response_headers = [("content-type", ctype)] result = start_response(str(status), response_headers) core.dispatch("flask.start_response.blocked", (ctx, config.flask, response_headers, status)) @@ -518,9 +514,9 @@ def _wrap(code_or_exception, f): def _block_request_callable(call): - core.set_item(HTTP_REQUEST_BLOCKED, STATUS_403_TYPE_AUTO) + set_blocked() core.dispatch("flask.blocked_request_callable", (call,)) - ctype = _ctype_from_headers(STATUS_403_TYPE_AUTO, flask.request.headers) + ctype = get_blocked().get("content-type", "application/json") abort(flask.Response(http_utils._get_blocked_template(ctype), content_type=ctype, status=403)) diff --git a/ddtrace/contrib/internal/starlette/patch.py b/ddtrace/contrib/internal/starlette/patch.py index 2a944a813d3..b872a77ecd7 100644 --- a/ddtrace/contrib/internal/starlette/patch.py +++ b/ddtrace/contrib/internal/starlette/patch.py @@ -21,10 +21,11 @@ from ddtrace.contrib.trace_utils import with_traced_module from ddtrace.ext import http from ddtrace.internal import core -from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED +from ddtrace.internal._exceptions import BlockingException from ddtrace.internal.logger import get_logger from ddtrace.internal.schema import schematize_service_name from ddtrace.internal.utils import get_argument_value +from ddtrace.internal.utils import get_blocked from ddtrace.internal.utils import set_argument_value from ddtrace.internal.utils.wrappers import unwrap as _u from ddtrace.vendor.packaging.version import parse as parse_version @@ -180,8 +181,9 @@ def traced_handler(wrapped, instance, args, kwargs): route=request_spans[0].get_tag(http.ROUTE), ) core.dispatch("asgi.start_request", ("starlette",)) - if core.get_item(HTTP_REQUEST_BLOCKED): - raise trace_utils.InterruptException("starlette") + blocked = get_blocked() + if blocked: + raise BlockingException(blocked) # https://github.com/encode/starlette/issues/1336 if _STARLETTE_VERSION <= parse_version("0.33.0") and len(request_spans) > 1: diff --git a/ddtrace/contrib/internal/wsgi/wsgi.py b/ddtrace/contrib/internal/wsgi/wsgi.py index 331adafbd25..4de10f5d503 100644 --- a/ddtrace/contrib/internal/wsgi/wsgi.py +++ b/ddtrace/contrib/internal/wsgi/wsgi.py @@ -29,9 +29,10 @@ from ddtrace.internal import core from ddtrace.internal._exceptions import BlockingException from ddtrace.internal.constants import COMPONENT -from ddtrace.internal.constants import HTTP_REQUEST_BLOCKED from ddtrace.internal.logger import get_logger from ddtrace.internal.schema import schematize_url_operation +from ddtrace.internal.utils import get_blocked +from ddtrace.internal.utils import set_blocked from ddtrace.propagation._utils import from_wsgi_header from ddtrace.propagation.http import HTTPPropagator @@ -119,7 +120,7 @@ def blocked_view(): status, headers, content = 403, [], "" return content, status, headers - if core.get_item(HTTP_REQUEST_BLOCKED): + if get_blocked(): content, status, headers = blocked_view() start_response(str(status), headers) closing_iterable = [content] @@ -133,7 +134,7 @@ def blocked_view(): try: closing_iterable = self.app(environ, ctx.get_item("intercept_start_response")) except BlockingException as e: - core.set_item(HTTP_REQUEST_BLOCKED, e.args[0]) + set_blocked(e.args[0]) content, status, headers = blocked_view() start_response(str(status), headers) closing_iterable = [content] @@ -153,7 +154,7 @@ def blocked_view(): core.dispatch("wsgi.app.exception", (ctx,)) raise else: - if core.get_item(HTTP_REQUEST_BLOCKED): + if get_blocked(): _, _, content = core.dispatch_with_results( "wsgi.block.started", (ctx, construct_url) ).status_headers_content.value or (None, None, "") diff --git a/ddtrace/internal/core/__init__.py b/ddtrace/internal/core/__init__.py index f31e7d0cb7d..672634f2e0d 100644 --- a/ddtrace/internal/core/__init__.py +++ b/ddtrace/internal/core/__init__.py @@ -101,7 +101,7 @@ def _on_jsonify_context_started_flask(ctx): The names of these events follow the pattern ``context.[started|ended].``. """ -from contextlib import contextmanager +from contextlib import AbstractContextManager import logging import sys from typing import TYPE_CHECKING # noqa:F401 @@ -163,23 +163,24 @@ def _deprecate_span_kwarg(span): ) -class ExecutionContext: - __slots__ = ["identifier", "_data", "_parents", "_span", "_token"] - +class ExecutionContext(AbstractContextManager): def __init__(self, identifier, parent=None, span=None, **kwargs): _deprecate_span_kwarg(span) self.identifier = identifier self._data = {} self._parents = [] self._span = span + self._suppress_exceptions = [] if parent is not None: self.addParent(parent) self._data.update(kwargs) + def __enter__(self): if self._span is None and "_CURRENT_CONTEXT" in globals(): self._token = _CURRENT_CONTEXT.set(self) dispatch("context.started.%s" % self.identifier, (self,)) dispatch("context.started.start_span.%s" % self.identifier, (self,)) + return self def __repr__(self): return self.__class__.__name__ + " '" + self.identifier + "' @ " + str(id(self)) @@ -192,8 +193,8 @@ def parents(self): def parent(self): return self._parents[0] if self._parents else None - def end(self): - dispatch_result = dispatch_with_results("context.ended.%s" % self.identifier, (self,)) + def __exit__(self, exc_type, exc_value, traceback): + dispatch("context.ended.%s" % self.identifier, (self,)) if self._span is None: try: _CURRENT_CONTEXT.reset(self._token) @@ -209,22 +210,18 @@ def end(self): ) if id(self) in DEPRECATION_MEMO: DEPRECATION_MEMO.remove(id(self)) - return dispatch_result + + return ( + True + if exc_type is None + else any(issubclass(exc_type, exc_type_) for exc_type_ in self._suppress_exceptions) + ) def addParent(self, context): if self.identifier == ROOT_CONTEXT_ID: raise ValueError("Cannot add parent to root context") self._parents.append(context) - @classmethod - @contextmanager - def context_with_data(cls, identifier, parent=None, span=None, **kwargs): - new_context = cls(identifier, parent=parent, span=span, **kwargs) - try: - yield new_context - finally: - new_context.end() - def get_item(current, data_key: str, default: Optional[Any] = None) -> Any: # NB mimic the behavior of `ddtrace.internal._context` by doing lazy inheritance while current is not None: @@ -294,15 +291,18 @@ def _reset_context(): def context_with_data(identifier, parent=None, **kwargs): - return _CONTEXT_CLASS.context_with_data(identifier, parent=(parent or _CURRENT_CONTEXT.get()), **kwargs) + return _CONTEXT_CLASS(identifier, parent=(parent or _CURRENT_CONTEXT.get()), **kwargs) + + +def add_suppress_exception(exc_type: type) -> None: + _CURRENT_CONTEXT.get()._suppress_exceptions.append(exc_type) def get_item(data_key: str, span: Optional["Span"] = None) -> Any: _deprecate_span_kwarg(span) if span is not None and span._local_root is not None: return span._local_root._get_ctx_item(data_key) - else: - return _CURRENT_CONTEXT.get().get_item(data_key) + return _CURRENT_CONTEXT.get().get_item(data_key) def get_local_item(data_key: str, span: Optional["Span"] = None) -> Any: @@ -313,8 +313,7 @@ def get_items(data_keys: List[str], span: Optional["Span"] = None) -> List[Optio _deprecate_span_kwarg(span) if span is not None and span._local_root is not None: return [span._local_root._get_ctx_item(key) for key in data_keys] - else: - return _CURRENT_CONTEXT.get().get_items(data_keys) + return _CURRENT_CONTEXT.get().get_items(data_keys) def set_safe(data_key: str, data_value: Optional[Any]) -> None: diff --git a/ddtrace/internal/utils/__init__.py b/ddtrace/internal/utils/__init__.py index 94631bbad60..294e99d1263 100644 --- a/ddtrace/internal/utils/__init__.py +++ b/ddtrace/internal/utils/__init__.py @@ -79,3 +79,21 @@ def _get_metas_to_propagate(context): if isinstance(k, str) and k.startswith("_dd.p."): metas_to_propagate.append((k, v)) return metas_to_propagate + + +def get_blocked() -> Optional[Dict[str, Any]]: + # local import to avoid circular dependency + from ddtrace.internal import core + + res = core.dispatch_with_results("asm.get_blocked") + if res and res.block_config: + return res.block_config.value + return None + + +def set_blocked(block_settings: Optional[Dict[str, Any]] = None) -> None: + # local imports to avoid circular dependency + from ddtrace.internal import core + from ddtrace.internal.constants import STATUS_403_TYPE_AUTO + + core.dispatch("asm.set_blocked", (block_settings or STATUS_403_TYPE_AUTO,)) diff --git a/ddtrace/settings/asm.py b/ddtrace/settings/asm.py index d3fb62b6c1e..b871c595a33 100644 --- a/ddtrace/settings/asm.py +++ b/ddtrace/settings/asm.py @@ -213,6 +213,11 @@ def __init__(self): # Only for deprecation phase if self._auto_user_instrumentation_local_mode == "": self._auto_user_instrumentation_local_mode = self._automatic_login_events_mode or LOGIN_EVENTS_MODE.IDENT + if not self._asm_libddwaf_available: + self._asm_enabled = False + self._asm_can_be_enabled = False + self._iast_enabled = False + self._api_security_enabled = False def reset(self): """For testing puposes, reset the configuration to its default values given current environment variables.""" @@ -233,9 +238,3 @@ def _user_event_mode(self) -> str: config = ASMConfig() _report_telemetry(config) - -if not config._asm_libddwaf_available: - config._asm_enabled = False - config._asm_can_be_enabled = False - config._iast_enabled = False - config._api_security_enabled = False diff --git a/tests/appsec/appsec/test_appsec_trace_utils.py b/tests/appsec/appsec/test_appsec_trace_utils.py index 7cdc89747ee..782f0e9869c 100644 --- a/tests/appsec/appsec/test_appsec_trace_utils.py +++ b/tests/appsec/appsec/test_appsec_trace_utils.py @@ -12,25 +12,27 @@ from ddtrace.appsec.trace_utils import track_user_login_success_event from ddtrace.appsec.trace_utils import track_user_signup_event from ddtrace.contrib.trace_utils import set_user -from ddtrace.ext import SpanTypes from ddtrace.ext import user from ddtrace.internal import core -from tests.appsec.appsec.test_processor import tracer_appsec # noqa: F401 import tests.appsec.rules as rules +from tests.appsec.utils import asm_context from tests.utils import TracerTestCase -from tests.utils import override_global_config + + +config_asm = {"_asm_enabled": True} +config_good_rules = {"_asm_static_rule_file": rules.RULES_GOOD_PATH, "_asm_enabled": True} class EventsSDKTestCase(TracerTestCase): _BLOCKED_USER = "123456" @pytest.fixture(autouse=True) - def inject_fixtures(self, tracer_appsec, caplog): # noqa: F811 + def inject_fixtures(self, tracer, caplog): # noqa: F811 self._caplog = caplog - self._tracer_appsec = tracer_appsec + self.tracer = tracer def test_track_user_login_event_success_without_metadata(self): - with self.trace("test_success1"): + with asm_context(tracer=self.tracer, span_name="test_success1", config=config_asm): track_user_login_success_event( self.tracer, "1234", @@ -59,7 +61,7 @@ def test_track_user_login_event_success_without_metadata(self): assert root_span.get_tag(user.SESSION_ID) == "test_session_id" def test_track_user_login_event_success_in_span_without_metadata(self): - with self.trace("test_success1") as parent_span: + with asm_context(tracer=self.tracer, span_name="test_success1", config=config_asm) as parent_span: user_span = self.trace("user_span") user_span.parent_id = parent_span.span_id track_user_login_success_event( @@ -93,7 +95,7 @@ def test_track_user_login_event_success_in_span_without_metadata(self): ) def test_track_user_login_event_success_auto_mode_safe(self): - with self.trace("test_success1"): + with asm_context(tracer=self.tracer, span_name="test_success1", config=config_asm): track_user_login_success_event( self.tracer, "1234", @@ -113,7 +115,7 @@ def test_track_user_login_event_success_auto_mode_safe(self): assert root_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_SUCCESS_MODE) == str(LOGIN_EVENTS_MODE.ANON) def test_track_user_login_event_success_auto_mode_extended(self): - with self.trace("test_success1"): + with asm_context(tracer=self.tracer, span_name="test_success1", config=config_asm): track_user_login_success_event( self.tracer, "1234", @@ -133,7 +135,7 @@ def test_track_user_login_event_success_auto_mode_extended(self): assert root_span.get_tag(APPSEC.AUTO_LOGIN_EVENTS_SUCCESS_MODE) == str(LOGIN_EVENTS_MODE.IDENT) def test_track_user_login_event_success_with_metadata(self): - with self.trace("test_success2"): + with asm_context(tracer=self.tracer, span_name="test_success2", config=config_asm): track_user_login_success_event(self.tracer, "1234", metadata={"foo": "bar"}) root_span = self.tracer.current_root_span() assert root_span.get_tag("appsec.events.users.login.success.track") == "true" @@ -150,7 +152,7 @@ def test_track_user_login_event_success_with_metadata(self): assert not root_span.get_tag(user.SESSION_ID) def test_track_user_login_event_failure_user_exists(self): - with self.trace("test_failure"): + with asm_context(tracer=self.tracer, span_name="test_failure", config=config_asm): track_user_login_failure_event( self.tracer, "1234", @@ -216,27 +218,24 @@ def test_custom_event(self): assert root_span.get_tag("%s.%s.track" % (APPSEC.CUSTOM_EVENT_PREFIX, event)) == "true" def test_set_user_blocked(self): - tracer = self._tracer_appsec - with override_global_config(dict(_asm_enabled="true", _asm_static_rule_file=rules.RULES_GOOD_PATH)): - tracer.configure(api_version="v0.4") - with tracer.trace("fake_span", span_type=SpanTypes.WEB) as span: - set_user( - self.tracer, - user_id=self._BLOCKED_USER, - email="usr.email", - name="usr.name", - session_id="usr.session_id", - role="usr.role", - scope="usr.scope", - ) - assert span.get_tag(user.ID) - assert span.get_tag(user.EMAIL) - assert span.get_tag(user.SESSION_ID) - assert span.get_tag(user.NAME) - assert span.get_tag(user.ROLE) - assert span.get_tag(user.SCOPE) - assert span.get_tag(user.SESSION_ID) - assert core.get_item("http.request.blocked", span=span) + with asm_context(tracer=self.tracer, span_name="fake_span", config=config_good_rules) as span: + set_user( + self.tracer, + user_id=self._BLOCKED_USER, + email="usr.email", + name="usr.name", + session_id="usr.session_id", + role="usr.role", + scope="usr.scope", + ) + assert span.get_tag(user.ID) + assert span.get_tag(user.EMAIL) + assert span.get_tag(user.SESSION_ID) + assert span.get_tag(user.NAME) + assert span.get_tag(user.ROLE) + assert span.get_tag(user.SCOPE) + assert span.get_tag(user.SESSION_ID) + assert core.get_item("http.request.blocked", span=span) def test_no_span_doesnt_raise(self): from ddtrace import tracer diff --git a/tests/appsec/appsec/test_asm_request_context.py b/tests/appsec/appsec/test_asm_request_context.py index 487401f00ed..ad7fb102cee 100644 --- a/tests/appsec/appsec/test_asm_request_context.py +++ b/tests/appsec/appsec/test_asm_request_context.py @@ -2,60 +2,67 @@ from ddtrace.appsec import _asm_request_context from ddtrace.internal._exceptions import BlockingException -from tests.utils import override_global_config +from tests.appsec.utils import asm_context _TEST_IP = "1.2.3.4" _TEST_HEADERS = {"foo": "bar"} +config_asm = {"_asm_enabled": True} + def test_context_set_and_reset(): - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(_TEST_IP, _TEST_HEADERS, True, lambda: True): - assert _asm_request_context.get_ip() == _TEST_IP - assert _asm_request_context.get_headers() == _TEST_HEADERS - assert _asm_request_context.get_headers_case_sensitive() - assert _asm_request_context.get_value("callbacks", "block") is not None - assert _asm_request_context.get_ip() is None - assert _asm_request_context.get_headers() == {} - assert _asm_request_context.get_value("callbacks", "block") is None + with asm_context( + ip_addr=_TEST_IP, + headers=_TEST_HEADERS, + headers_case_sensitive=True, + block_request_callable=(lambda: True), + config=config_asm, + ): + assert _asm_request_context.get_ip() == _TEST_IP + assert _asm_request_context.get_headers() == _TEST_HEADERS + assert _asm_request_context.get_headers_case_sensitive() + assert _asm_request_context.get_value("callbacks", "block") is not None + assert _asm_request_context.get_ip() is None + assert _asm_request_context.get_headers() == {} + assert _asm_request_context.get_value("callbacks", "block") is None + assert not _asm_request_context.get_headers_case_sensitive() + with asm_context( + ip_addr=_TEST_IP, + headers=_TEST_HEADERS, + config=config_asm, + ): assert not _asm_request_context.get_headers_case_sensitive() - with _asm_request_context.asm_request_context_manager(_TEST_IP, _TEST_HEADERS): - assert not _asm_request_context.get_headers_case_sensitive() - assert not _asm_request_context.block_request() + assert not _asm_request_context.block_request() def test_set_get_ip(): - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - _asm_request_context.set_ip(_TEST_IP) - assert _asm_request_context.get_ip() == _TEST_IP + with asm_context(config=config_asm): + _asm_request_context.set_ip(_TEST_IP) + assert _asm_request_context.get_ip() == _TEST_IP def test_set_get_headers(): - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - _asm_request_context.set_headers(_TEST_HEADERS) - assert _asm_request_context.get_headers() == _TEST_HEADERS + with asm_context(config=config_asm): + _asm_request_context.set_headers(_TEST_HEADERS) + assert _asm_request_context.get_headers() == _TEST_HEADERS def test_call_block_callable_none(): - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - _asm_request_context.set_block_request_callable(None) - assert not _asm_request_context.block_request() + with asm_context(config=config_asm): + _asm_request_context.set_block_request_callable(None) assert not _asm_request_context.block_request() + assert not _asm_request_context.block_request() def test_call_block_callable_noargs(): def _callable(): return 42 - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - _asm_request_context.set_block_request_callable(_callable) - assert _asm_request_context.get_value("callbacks", "block")() == 42 - assert not _asm_request_context.get_value("callbacks", "block") + with asm_context(config=config_asm): + _asm_request_context.set_block_request_callable(_callable) + assert _asm_request_context.get_value("callbacks", "block")() == 42 + assert not _asm_request_context.get_value("callbacks", "block") def test_call_block_callable_curried(): @@ -65,31 +72,34 @@ class TestException(Exception): def _callable(): raise TestException() - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - _asm_request_context.set_block_request_callable(_callable) - with pytest.raises(TestException): - assert _asm_request_context.block_request() + with asm_context(config=config_asm): + _asm_request_context.set_block_request_callable(_callable) + with pytest.raises(TestException): + assert _asm_request_context.block_request() def test_set_get_headers_case_sensitive(): # default reset value should be False assert not _asm_request_context.get_headers_case_sensitive() - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - _asm_request_context.set_headers_case_sensitive(True) - assert _asm_request_context.get_headers_case_sensitive() - _asm_request_context.set_headers_case_sensitive(False) - assert not _asm_request_context.get_headers_case_sensitive() + with asm_context(config=config_asm): + _asm_request_context.set_headers_case_sensitive(True) + assert _asm_request_context.get_headers_case_sensitive() + _asm_request_context.set_headers_case_sensitive(False) + assert not _asm_request_context.get_headers_case_sensitive() def test_asm_request_context_manager(): - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(_TEST_IP, _TEST_HEADERS, True, lambda: 42): - assert _asm_request_context.get_ip() == _TEST_IP - assert _asm_request_context.get_headers() == _TEST_HEADERS - assert _asm_request_context.get_headers_case_sensitive() - assert _asm_request_context.get_value("callbacks", "block")() == 42 + with asm_context( + ip_addr=_TEST_IP, + headers=_TEST_HEADERS, + headers_case_sensitive=True, + block_request_callable=(lambda: 42), + config=config_asm, + ): + assert _asm_request_context.get_ip() == _TEST_IP + assert _asm_request_context.get_headers() == _TEST_HEADERS + assert _asm_request_context.get_headers_case_sensitive() + assert _asm_request_context.get_value("callbacks", "block")() == 42 assert _asm_request_context.get_ip() is None assert _asm_request_context.get_headers() == {} @@ -98,16 +108,15 @@ def test_asm_request_context_manager(): def test_blocking_exception_correctly_propagated(): - with override_global_config({"_asm_enabled": True}): - with _asm_request_context.asm_request_context_manager(): - witness = 0 - with _asm_request_context.asm_request_context_manager(): - witness = 1 - raise BlockingException({}, "rule", "type", "value") - # should be skipped by exception - witness = 3 - # should be also skipped by exception - witness = 4 - # no more exception there - # ensure that the exception was raised and caught at the end of the last context manager - assert witness == 1 + with asm_context(config=config_asm): + witness = 0 + with asm_context(config=config_asm): + witness = 1 + raise BlockingException({}, "rule", "type", "value") + # should be skipped by exception + witness = 3 + # should be also skipped by exception + witness = 4 + # no more exception there + # ensure that the exception was raised and caught at the end of the last context manager + assert witness == 1 diff --git a/tests/appsec/appsec/test_processor.py b/tests/appsec/appsec/test_processor.py index 36c17d127d4..f2704d6b7dd 100644 --- a/tests/appsec/appsec/test_processor.py +++ b/tests/appsec/appsec/test_processor.py @@ -5,7 +5,6 @@ import mock import pytest -from ddtrace.appsec import _asm_request_context from ddtrace.appsec._constants import APPSEC from ddtrace.appsec._constants import DEFAULT from ddtrace.appsec._constants import FINGERPRINTING @@ -19,6 +18,7 @@ from ddtrace.ext import SpanTypes from ddtrace.internal import core import tests.appsec.rules as rules +from tests.appsec.utils import asm_context from tests.utils import override_env from tests.utils import override_global_config from tests.utils import snapshot @@ -32,19 +32,9 @@ APPSEC_JSON_TAG = f"meta.{APPSEC.JSON}" - - -@pytest.fixture -def tracer_appsec(tracer): - with override_global_config(dict(_asm_enabled=True)): - yield _enable_appsec(tracer) - - -def _enable_appsec(tracer): - tracer._asm_enabled = True - # Hack: need to pass an argument to configure so that the processors are recreated - tracer.configure(api_version="v0.4") - return tracer +config_asm = {"_asm_enabled": True} +config_good_rules = {"_asm_static_rule_file": rules.RULES_GOOD_PATH, "_asm_enabled": True} +config_bad_rules = {"_asm_static_rule_file": rules.RULES_BAD_PATH, "_asm_enabled": True, "_raise": True} def test_transform_headers(): @@ -64,10 +54,8 @@ def test_transform_headers(): assert set(transformed["foo"]) == {"bar1", "bar2", "bar3"} -def test_enable(tracer_appsec): - tracer = tracer_appsec - - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: +def test_enable(tracer): + with asm_context(tracer=tracer, config=config_asm) as span: set_http_meta(span, {}, raw_uri="http://example.com/.git", status_code="404") assert span.get_metric("_dd.appsec.enabled") == 1.0 @@ -81,68 +69,56 @@ def test_enable_custom_rules(): assert processor.rule_filename == rules.RULES_GOOD_PATH -def test_ddwaf_ctx(tracer_appsec): - tracer = tracer_appsec - - with override_global_config(dict(_asm_static_rule_file=rules.RULES_GOOD_PATH)): - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: - processor = AppSecSpanProcessor() - processor.on_span_start(span) - ctx = processor._span_to_waf_ctx.get(span) - assert ctx - processor.on_span_finish(span) - assert span not in processor._span_to_waf_ctx +def test_ddwaf_ctx(tracer): + with asm_context(tracer=tracer, config=config_good_rules) as span: + processor = AppSecSpanProcessor() + processor.on_span_start(span) + ctx = processor._span_to_waf_ctx.get(span) + assert ctx + processor.on_span_finish(span) + assert span not in processor._span_to_waf_ctx -@pytest.mark.parametrize("rule,exc", [(rules.RULES_MISSING_PATH, IOError), (rules.RULES_BAD_PATH, ValueError)]) -def test_enable_bad_rules(rule, exc, tracer): +@pytest.mark.parametrize("rule, _exc", [(rules.RULES_MISSING_PATH, IOError), (rules.RULES_BAD_PATH, ValueError)]) +def test_enable_bad_rules(rule, _exc, tracer): # by default enable must not crash but display errors in the logs - with override_env(dict(DD_APPSEC_RULES=rule)): - with override_global_config(dict(_raise=False)): - _enable_appsec(tracer) - + with asm_context(tracer=tracer, config=config_bad_rules) as span: + set_http_meta(span, {}, raw_uri="http://example.com/.git", status_code="404") -def test_retain_traces(tracer_appsec): - tracer = tracer_appsec - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: +def test_retain_traces(tracer): + with asm_context(tracer=tracer, config=config_asm) as span: + print(">>> set HTTP meta", flush=True) set_http_meta(span, {}, raw_uri="http://example.com/.git", status_code="404") assert span.context.sampling_priority == USER_KEEP -def test_valid_json(tracer_appsec): - tracer = tracer_appsec - - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: +def test_valid_json(tracer): + with asm_context(tracer=tracer, config=config_asm) as span: set_http_meta(span, {}, raw_uri="http://example.com/.git", status_code="404") assert get_triggers(span) -def test_header_attack(tracer_appsec): - tracer = tracer_appsec - - with override_global_config(dict(retrieve_client_ip=True, _asm_enabled=True)): - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta( - span, - rules.Config(), - request_headers={ - "User-Agent": "Arachni/v1", - "user-agent": "aa", - "x-forwarded-for": "8.8.8.8", - }, - ) - - assert get_triggers(span) - assert span.get_tag("actor.ip") == "8.8.8.8" +def test_header_attack(tracer): + with asm_context(tracer=tracer, config=dict(retrieve_client_ip=True, _asm_enabled=True)) as span: + set_http_meta( + span, + rules.Config(), + request_headers={ + "User-Agent": "Arachni/v1", + "user-agent": "aa", + "x-forwarded-for": "8.8.8.8", + }, + ) + assert get_triggers(span) + assert span.get_tag("actor.ip") == "8.8.8.8" -def test_headers_collection(tracer_appsec): - tracer = tracer_appsec - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: +def test_headers_collection(tracer): + with asm_context(tracer=tracer, config=config_asm) as span: set_http_meta( span, rules.Config(), @@ -181,18 +157,16 @@ def test_headers_collection(tracer_appsec): def test_appsec_cookies_no_collection_snapshot(tracer): # We use tracer instead of tracer_appsec because snapshot is looking for tracer fixture and not understands # other fixtures - with override_global_config(dict(_asm_enabled=True)): - _enable_appsec(tracer) - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta( - span, - {}, - raw_uri="http://example.com/.git", - status_code="404", - request_cookies={"cookie1": "im the cookie1"}, - ) + with asm_context(tracer=tracer, config=config_asm) as span: + set_http_meta( + span, + {}, + raw_uri="http://example.com/.git", + status_code="404", + request_cookies={"cookie1": "im the cookie1"}, + ) - assert get_triggers(span) + assert get_triggers(span) @snapshot( @@ -208,53 +182,43 @@ def test_appsec_cookies_no_collection_snapshot(tracer): ], ) def test_appsec_body_no_collection_snapshot(tracer): - with override_global_config(dict(_asm_enabled=True)): - _enable_appsec(tracer) - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta( - span, - {}, - raw_uri="http://example.com/.git", - status_code="404", - request_body={"somekey": "somekey value"}, - ) + with asm_context(tracer=tracer, config=config_asm) as span: + set_http_meta( + span, + {}, + raw_uri="http://example.com/.git", + status_code="404", + request_body={"somekey": "somekey value"}, + ) - assert get_triggers(span) + assert get_triggers(span) def test_ip_block(tracer): - with override_global_config(dict(_asm_enabled=True, _asm_static_rule_file=rules.RULES_GOOD_PATH)): - _enable_appsec(tracer) - with _asm_request_context.asm_request_context_manager(rules._IP.BLOCKED, {}): - with tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta( - span, - rules.Config(), - ) - - assert get_triggers(span) - assert core.get_item("http.request.remote_ip", span) == rules._IP.BLOCKED - assert core.get_item("http.request.blocked", span) + with asm_context(tracer=tracer, ip_addr=rules._IP.BLOCKED, config=config_good_rules) as span: + set_http_meta( + span, + rules.Config(), + ) + assert get_triggers(span) + assert core.get_item("http.request.remote_ip", span) == rules._IP.BLOCKED + assert core.get_item("http.request.blocked", span) @pytest.mark.parametrize("ip", [rules._IP.MONITORED, rules._IP.BYPASS, rules._IP.DEFAULT]) def test_ip_not_block(tracer, ip): - with override_global_config(dict(_asm_enabled=True, _asm_static_rule_file=rules.RULES_GOOD_PATH)): - _enable_appsec(tracer) - with _asm_request_context.asm_request_context_manager(ip, {}): - with tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta( - span, - rules.Config(), - ) + with asm_context(tracer=tracer, ip_addr=ip, config=config_good_rules) as span: + set_http_meta( + span, + rules.Config(), + ) - assert core.get_item("http.request.remote_ip", span) == ip - assert core.get_item("http.request.blocked", span) is None + assert core.get_item("http.request.remote_ip", span) == ip + assert core.get_item("http.request.blocked", span) is None def test_ip_update_rules_and_block(tracer): - with override_global_config(dict(_asm_enabled=True)): - _enable_appsec(tracer) + with asm_context(tracer=tracer, ip_addr=rules._IP.BLOCKED, config=config_asm): tracer._appsec_processor._update_rules( { "rules_data": [ @@ -268,20 +232,18 @@ def test_ip_update_rules_and_block(tracer): ] } ) - with _asm_request_context.asm_request_context_manager(rules._IP.BLOCKED, {}): - with tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta( - span, - rules.Config(), - ) + with tracer.trace("test", span_type=SpanTypes.WEB) as span: + set_http_meta( + span, + rules.Config(), + ) - assert core.get_item("http.request.remote_ip", span) == rules._IP.BLOCKED - assert core.get_item("http.request.blocked", span) + assert core.get_item("http.request.remote_ip", span) == rules._IP.BLOCKED + assert core.get_item("http.request.blocked", span) def test_ip_update_rules_expired_no_block(tracer): - with override_global_config(dict(_asm_enabled=True)): - _enable_appsec(tracer) + with asm_context(tracer=tracer, ip_addr=rules._IP.BLOCKED, config=config_asm): tracer._appsec_processor._update_rules( { "rules_data": [ @@ -295,15 +257,14 @@ def test_ip_update_rules_expired_no_block(tracer): ] } ) - with _asm_request_context.asm_request_context_manager(rules._IP.BLOCKED, {}): - with tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta( - span, - rules.Config(), - ) + with tracer.trace("test", span_type=SpanTypes.WEB) as span: + set_http_meta( + span, + rules.Config(), + ) - assert core.get_item("http.request.remote_ip", span) == rules._IP.BLOCKED - assert core.get_item("http.request.blocked", span) is None + assert core.get_item("http.request.remote_ip", span) == rules._IP.BLOCKED + assert core.get_item("http.request.blocked", span) is None @snapshot( @@ -319,15 +280,11 @@ def test_ip_update_rules_expired_no_block(tracer): ], ) def test_appsec_span_tags_snapshot(tracer): - with override_global_config(dict(_asm_enabled=True)): - _enable_appsec(tracer) - with _asm_request_context.asm_request_context_manager(), tracer.trace( - "test", service="test", span_type=SpanTypes.WEB - ) as span: - span.set_tag("http.url", "http://example.com/.git") - set_http_meta(span, {}, raw_uri="http://example.com/.git", status_code="404") + with asm_context(tracer=tracer, config=config_asm, service="test") as span: + span.set_tag("http.url", "http://example.com/.git") + set_http_meta(span, {}, raw_uri="http://example.com/.git", status_code="404") - assert get_triggers(span) + assert get_triggers(span) @snapshot( @@ -341,34 +298,28 @@ def test_appsec_span_tags_snapshot(tracer): ], ) def test_appsec_span_tags_snapshot_with_errors(tracer): - with override_global_config( - dict( - _asm_enabled=True, - _asm_static_rule_file=os.path.join(rules.ROOT_DIR, "rules-with-2-errors.json"), - _waf_timeout=50_000, - ) - ): - _enable_appsec(tracer) - with _asm_request_context.asm_request_context_manager(), tracer.trace( - "test", service="test", span_type=SpanTypes.WEB - ) as span: - span.set_tag("http.url", "http://example.com/.git") - set_http_meta(span, {}, raw_uri="http://example.com/.git", status_code="404") + config = dict( + _asm_enabled=True, + _asm_static_rule_file=os.path.join(rules.ROOT_DIR, "rules-with-2-errors.json"), + _waf_timeout=50_000, + ) + with asm_context(tracer=tracer, config=config, service="test") as span: + span.set_tag("http.url", "http://example.com/.git") + set_http_meta(span, {}, raw_uri="http://example.com/.git", status_code="404") assert get_triggers(span) is None def test_appsec_span_rate_limit(tracer): - with override_global_config(dict(_asm_enabled=True)), override_env(dict(DD_APPSEC_TRACE_RATE_LIMIT="1")): - _enable_appsec(tracer) - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span1: + with override_env(dict(DD_APPSEC_TRACE_RATE_LIMIT="1")): + with asm_context(tracer=tracer, config=config_asm) as span1: set_http_meta(span1, {}, raw_uri="http://example.com/.git", status_code="404") - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span2: + with asm_context(tracer=tracer, config={}) as span2: set_http_meta(span2, {}, raw_uri="http://example.com/.git", status_code="404") span2.start_ns = span1.start_ns + 1 - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span3: + with asm_context(tracer=tracer, config={}) as span3: set_http_meta(span3, {}, raw_uri="http://example.com/.git", status_code="404") span2.start_ns = span1.start_ns + 2 @@ -433,10 +384,8 @@ def test_obfuscation_parameter_key_and_value_invalid_regex(): assert processor.enabled -def test_obfuscation_parameter_value_unconfigured_not_matching(tracer_appsec): - tracer = tracer_appsec - - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: +def test_obfuscation_parameter_value_unconfigured_not_matching(tracer): + with asm_context(tracer=tracer, config=config_asm) as span: set_http_meta(span, rules.Config(), raw_uri="http://example.com/.git?hello=goodbye", status_code="404") triggers = get_triggers(span) @@ -453,10 +402,8 @@ def test_obfuscation_parameter_value_unconfigured_not_matching(tracer_appsec): @pytest.mark.parametrize("key", ["password", "public_key", "jsessionid", "jwt"]) -def test_obfuscation_parameter_value_unconfigured_matching(tracer_appsec, key): - tracer = tracer_appsec - - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: +def test_obfuscation_parameter_value_unconfigured_matching(tracer, key): + with asm_context(tracer=tracer, config=config_asm) as span: set_http_meta(span, rules.Config(), raw_uri=f"http://example.com/.git?{key}=goodbye", status_code="404") triggers = get_triggers(span) @@ -473,11 +420,9 @@ def test_obfuscation_parameter_value_unconfigured_matching(tracer_appsec, key): def test_obfuscation_parameter_value_configured_not_matching(tracer): - with override_global_config(dict(_asm_enabled=True, _asm_obfuscation_parameter_value_regexp="token")): - _enable_appsec(tracer) - - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta(span, rules.Config(), raw_uri="http://example.com/.git?password=goodbye", status_code="404") + config = dict(_asm_enabled=True, _asm_obfuscation_parameter_value_regexp="token") + with asm_context(tracer=tracer, config=config) as span: + set_http_meta(span, rules.Config(), raw_uri="http://example.com/.git?password=goodbye", status_code="404") triggers = get_triggers(span) assert triggers @@ -493,11 +438,9 @@ def test_obfuscation_parameter_value_configured_not_matching(tracer): def test_obfuscation_parameter_value_configured_matching(tracer): - with override_global_config(dict(_asm_enabled=True, _asm_obfuscation_parameter_value_regexp="token")): - _enable_appsec(tracer) - - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta(span, rules.Config(), raw_uri="http://example.com/.git?token=goodbye", status_code="404") + config = dict(_asm_enabled=True, _asm_obfuscation_parameter_value_regexp="token") + with asm_context(tracer=tracer, config=config) as span: + set_http_meta(span, rules.Config(), raw_uri="http://example.com/.git?token=goodbye", status_code="404") triggers = get_triggers(span) assert triggers @@ -587,15 +530,14 @@ def test_ddwaf_info_with_3_errors(): assert info.errors == {"missing key 'name'": ["crs-942-100", "crs-913-120"]} -def test_ddwaf_info_with_json_decode_errors(tracer_appsec, caplog): - tracer = tracer_appsec +def test_ddwaf_info_with_json_decode_errors(tracer, caplog): config = rules.Config() config.http_tag_query_string = True with caplog.at_level(logging.WARNING), mock.patch( "ddtrace.appsec._processor.json.dumps", side_effect=JSONDecodeError("error", "error", 0) ), mock.patch.object(DDWaf, "info"): - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: + with asm_context(tracer=tracer, config=config_asm) as span: set_http_meta( span, config, @@ -623,16 +565,14 @@ def test_ddwaf_info_with_json_decode_errors(tracer_appsec, caplog): assert "Error parsing data ASM In-App WAF metrics report" in caplog.text -def test_ddwaf_run_contained_typeerror(tracer_appsec, caplog): - tracer = tracer_appsec - +def test_ddwaf_run_contained_typeerror(tracer, caplog): config = rules.Config() config.http_tag_query_string = True with caplog.at_level(logging.DEBUG), mock.patch( "ddtrace.appsec._ddwaf.ddwaf_run", side_effect=TypeError("expected c_long instead of int") ): - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: + with asm_context(tracer=tracer, config=config_asm) as span: set_http_meta( span, config, @@ -661,16 +601,14 @@ def test_ddwaf_run_contained_typeerror(tracer_appsec, caplog): assert "TypeError: expected c_long instead of int" in caplog.text -def test_ddwaf_run_contained_oserror(tracer_appsec, caplog): - tracer = tracer_appsec - +def test_ddwaf_run_contained_oserror(tracer, caplog): config = rules.Config() config.http_tag_query_string = True with caplog.at_level(logging.DEBUG), mock.patch( "ddtrace.appsec._ddwaf.ddwaf_run", side_effect=OSError("ddwaf run failed") ): - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: + with asm_context(tracer=tracer, config=config_asm) as span: set_http_meta( span, config, @@ -699,19 +637,19 @@ def test_ddwaf_run_contained_oserror(tracer_appsec, caplog): assert "OSError: ddwaf run failed" in caplog.text -def test_asm_context_registration(tracer_appsec): - tracer = tracer_appsec +def test_asm_context_registration(tracer): + from ddtrace.appsec._asm_request_context import _ASM_CONTEXT # For a web type span, a context manager is added, but then removed - with tracer.trace("test", span_type=SpanTypes.WEB) as span: - assert core.get_item("asm_env") is not None - assert core.get_item("asm_env") is None + with asm_context(tracer=tracer, config=config_asm) as span: + assert core.get_item(_ASM_CONTEXT) is not None + assert core.get_item(_ASM_CONTEXT) is None # Regression test, if the span type changes after being created, we always removed - with tracer.trace("test", span_type=SpanTypes.WEB) as span: + with asm_context(tracer=tracer, config=config_asm) as span: span.span_type = SpanTypes.HTTP - assert core.get_item("asm_env") is not None - assert core.get_item("asm_env") is None + assert core.get_item(_ASM_CONTEXT) is not None + assert core.get_item(_ASM_CONTEXT) is None CUSTOM_RULE_METHOD = { @@ -779,9 +717,7 @@ def test_ephemeral_addresses(mock_run, persistent, ephemeral): processor = AppSecSpanProcessor() processor._update_rules(CUSTOM_RULE_METHOD) - with override_global_config( - dict(_asm_enabled=True) - ), _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: + with asm_context(tracer=tracer, config=config_asm) as span: # first call must send all data to the waf processor._waf_action(span, None, {persistent: {"key_1": "value_1"}, ephemeral: {"key_2": "value_2"}}) assert mock_run.call_args[0][1] == {WAF_DATA_NAMES[persistent]: {"key_1": "value_1"}} diff --git a/tests/appsec/appsec/test_remoteconfiguration.py b/tests/appsec/appsec/test_remoteconfiguration.py index 127b25b500e..82fe04f61e0 100644 --- a/tests/appsec/appsec/test_remoteconfiguration.py +++ b/tests/appsec/appsec/test_remoteconfiguration.py @@ -8,7 +8,6 @@ from mock.mock import ANY import pytest -from ddtrace.appsec import _asm_request_context from ddtrace.appsec._capabilities import _appsec_rc_capabilities from ddtrace.appsec._constants import APPSEC from ddtrace.appsec._constants import DEFAULT @@ -21,7 +20,6 @@ from ddtrace.appsec._remoteconfiguration import enable_appsec_rc from ddtrace.appsec._utils import get_triggers from ddtrace.contrib.trace_utils import set_http_meta -from ddtrace.ext import SpanTypes from ddtrace.internal import core from ddtrace.internal.remoteconfig.client import AgentPayload from ddtrace.internal.remoteconfig.client import ConfigMetadata @@ -32,12 +30,13 @@ from ddtrace.settings.asm import config as asm_config import tests.appsec.rules as rules from tests.appsec.utils import Either +from tests.appsec.utils import asm_context from tests.utils import override_env from tests.utils import override_global_config def _set_and_get_appsec_tags(tracer): - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: + with asm_context(tracer) as span: set_http_meta( span, {}, @@ -923,14 +922,13 @@ def test_rc_activation_ip_blocking_data(tracer, remote_config_worker): assert remoteconfig_poller.status == ServiceStatus.STOPPED _appsec_callback(rc_config, tracer) - with _asm_request_context.asm_request_context_manager("8.8.4.4", {}): - with tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta( - span, - rules.Config(), - ) - assert get_triggers(span) - assert core.get_item("http.request.remote_ip", span) == "8.8.4.4" + with asm_context(tracer, ip_addr="8.8.4.4") as span: + set_http_meta( + span, + rules.Config(), + ) + assert get_triggers(span) + assert core.get_item("http.request.remote_ip", span) == "8.8.4.4" def test_rc_activation_ip_blocking_data_expired(tracer, remote_config_worker): @@ -954,13 +952,12 @@ def test_rc_activation_ip_blocking_data_expired(tracer, remote_config_worker): _appsec_callback(rc_config, tracer) - with _asm_request_context.asm_request_context_manager("8.8.4.4", {}): - with tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta( - span, - rules.Config(), - ) - assert get_triggers(span) is None + with asm_context(tracer, ip_addr="8.8.4.4") as span: + set_http_meta( + span, + rules.Config(), + ) + assert get_triggers(span) is None def test_rc_activation_ip_blocking_data_not_expired(tracer, remote_config_worker): @@ -984,14 +981,13 @@ def test_rc_activation_ip_blocking_data_not_expired(tracer, remote_config_worker _appsec_callback(rc_config, tracer) - with _asm_request_context.asm_request_context_manager("8.8.4.4", {}): - with tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta( - span, - rules.Config(), - ) - assert get_triggers(span) - assert core.get_item("http.request.remote_ip", span) == "8.8.4.4" + with asm_context(tracer, ip_addr="8.8.4.4") as span: + set_http_meta( + span, + rules.Config(), + ) + assert get_triggers(span) + assert core.get_item("http.request.remote_ip", span) == "8.8.4.4" def test_rc_rules_data(tracer): diff --git a/tests/appsec/appsec/test_telemetry.py b/tests/appsec/appsec/test_telemetry.py index ddf3ce58bb4..bf260b60c9f 100644 --- a/tests/appsec/appsec/test_telemetry.py +++ b/tests/appsec/appsec/test_telemetry.py @@ -3,7 +3,6 @@ import pytest -from ddtrace.appsec import _asm_request_context from ddtrace.appsec._ddwaf import version from ddtrace.appsec._deduplications import deduplication from ddtrace.appsec._processor import AppSecSpanProcessor @@ -12,11 +11,15 @@ from ddtrace.internal.telemetry.constants import TELEMETRY_NAMESPACE_TAG_APPSEC from ddtrace.internal.telemetry.constants import TELEMETRY_TYPE_DISTRIBUTION from ddtrace.internal.telemetry.constants import TELEMETRY_TYPE_GENERATE_METRICS -from tests.appsec.appsec.test_processor import _enable_appsec import tests.appsec.rules as rules +from tests.appsec.utils import asm_context from tests.utils import override_global_config +config_asm = {"_asm_enabled": True} +config_good_rules = {"_asm_static_rule_file": rules.RULES_GOOD_PATH, "_asm_enabled": True} + + def _assert_generate_metrics(metrics_result, is_rule_triggered=False, is_blocked_request=False): generate_metrics = metrics_result[TELEMETRY_TYPE_GENERATE_METRICS][TELEMETRY_NAMESPACE_TAG_APPSEC] assert len(generate_metrics) == 2, "Expected 2 generate_metrics" @@ -65,37 +68,29 @@ def test_metrics_when_appsec_doesnt_runs(telemetry_writer, tracer): def test_metrics_when_appsec_runs(telemetry_writer, tracer): - with override_global_config(dict(_asm_enabled=True)): - telemetry_writer._namespace.flush() - _enable_appsec(tracer) - with tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta( - span, - rules.Config(), - ) + telemetry_writer._namespace.flush() + with asm_context(tracer=tracer, span_name="test", config=config_asm) as span: + set_http_meta( + span, + rules.Config(), + ) _assert_generate_metrics(telemetry_writer._namespace._metrics_data) def test_metrics_when_appsec_attack(telemetry_writer, tracer): - with override_global_config(dict(_asm_enabled=True, _asm_static_rule_file=rules.RULES_GOOD_PATH)): - telemetry_writer._namespace.flush() - _enable_appsec(tracer) - with tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta(span, rules.Config(), request_cookies={"attack": "1' or '1' = '1'"}) + telemetry_writer._namespace.flush() + with asm_context(tracer=tracer, span_name="test", config=config_good_rules) as span: + set_http_meta(span, rules.Config(), request_cookies={"attack": "1' or '1' = '1'"}) _assert_generate_metrics(telemetry_writer._namespace._metrics_data, is_rule_triggered=True) def test_metrics_when_appsec_block(telemetry_writer, tracer): - with override_global_config(dict(_asm_enabled=True, _asm_static_rule_file=rules.RULES_GOOD_PATH)): - telemetry_writer._namespace.flush() - _enable_appsec(tracer) - with _asm_request_context.asm_request_context_manager(rules._IP.BLOCKED, {}): - with tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta( - span, - rules.Config(), - ) - + telemetry_writer._namespace.flush() + with asm_context(tracer=tracer, ip_addr=rules._IP.BLOCKED, span_name="test", config=config_good_rules) as span: + set_http_meta( + span, + rules.Config(), + ) _assert_generate_metrics(telemetry_writer._namespace._metrics_data, is_rule_triggered=True, is_blocked_request=True) @@ -117,35 +112,31 @@ def test_log_metric_error_ddwaf_init(telemetry_writer): def test_log_metric_error_ddwaf_timeout(telemetry_writer, tracer): - with override_global_config( - dict( - _asm_enabled=True, - _waf_timeout=0.0, - _deduplication_enabled=False, - _asm_static_rule_file=rules.RULES_GOOD_PATH, + config = dict( + _asm_enabled=True, + _waf_timeout=0.0, + _deduplication_enabled=False, + _asm_static_rule_file=rules.RULES_GOOD_PATH, + ) + with asm_context(tracer=tracer, ip_addr=rules._IP.BLOCKED, span_name="test", config=config) as span: + set_http_meta( + span, + rules.Config(), ) - ): - _enable_appsec(tracer) - with _asm_request_context.asm_request_context_manager(rules._IP.BLOCKED, {}): - with tracer.trace("test", span_type=SpanTypes.WEB) as span: - set_http_meta( - span, - rules.Config(), - ) - list_metrics_logs = list(telemetry_writer._logs) - assert len(list_metrics_logs) == 0 + list_metrics_logs = list(telemetry_writer._logs) + assert len(list_metrics_logs) == 0 - generate_metrics = telemetry_writer._namespace._metrics_data[TELEMETRY_TYPE_GENERATE_METRICS][ - TELEMETRY_NAMESPACE_TAG_APPSEC - ] + generate_metrics = telemetry_writer._namespace._metrics_data[TELEMETRY_TYPE_GENERATE_METRICS][ + TELEMETRY_NAMESPACE_TAG_APPSEC + ] - timeout_found = False - for _metric_id, metric in generate_metrics.items(): - if metric.name == "waf.requests": - assert ("waf_timeout", "true") in metric._tags - timeout_found = True - assert timeout_found + timeout_found = False + for _metric_id, metric in generate_metrics.items(): + if metric.name == "waf.requests": + assert ("waf_timeout", "true") in metric._tags + timeout_found = True + assert timeout_found def test_log_metric_error_ddwaf_update(telemetry_writer): diff --git a/tests/appsec/contrib_appsec/utils.py b/tests/appsec/contrib_appsec/utils.py index 7d03ace98b1..e1f5e31d9dd 100644 --- a/tests/appsec/contrib_appsec/utils.py +++ b/tests/appsec/contrib_appsec/utils.py @@ -115,16 +115,17 @@ def test_healthcheck(self, interface: Interface, get_tag, asm_enabled: bool): assert get_tag("http.status_code") == "200" assert self.headers(response)["content-type"] == "text/html; charset=utf-8" - def test_simple_attack(self, interface: Interface, root_span): + def test_simple_attack(self, interface: Interface, root_span, get_tag): with override_global_config(dict(_asm_enabled=True)): self.update_tracer(interface) response = interface.client.get("/.git?q=1") assert response.status_code == 404 triggers = get_triggers(root_span()) assert triggers is not None, "no appsec struct in root span" - assert core.get_item("http.request.uri", span=root_span()) == "http://localhost:8000/.git?q=1" - assert core.get_item("http.request.headers", span=root_span()) is not None - query = dict(core.get_item("http.request.query", span=root_span())) + assert root_span()._get_ctx_item("http.request.uri") == "http://localhost:8000/.git?q=1" + assert root_span()._get_ctx_item("http.request.headers") is not None + assert root_span()._get_ctx_item("http.request.method") == "GET" + query = dict(root_span()._get_ctx_item("http.request.query")) assert query == {"q": "1"} or query == {"q": ["1"]} def test_querystrings(self, interface: Interface, root_span): @@ -1249,7 +1250,7 @@ def test_global_callback_list_length(self, interface): assert self.status(response) == 200 assert self.body(response) == "awesome_test" # only two global callbacks are expected for API Security and Nested Events - assert len(_asm_request_context.GLOBAL_CALLBACKS.get(_asm_request_context._CONTEXT_CALL, [])) == 2 + assert len(_asm_request_context.GLOBAL_CALLBACKS.get(_asm_request_context._CONTEXT_CALL, [])) == 1 @pytest.mark.parametrize("asm_enabled", [True, False]) @pytest.mark.parametrize("metastruct", [True, False]) diff --git a/tests/appsec/iast/test_telemetry.py b/tests/appsec/iast/test_telemetry.py index 8777a377d08..42470e61d5b 100644 --- a/tests/appsec/iast/test_telemetry.py +++ b/tests/appsec/iast/test_telemetry.py @@ -1,6 +1,5 @@ import pytest -from ddtrace.appsec import _asm_request_context from ddtrace.appsec._common_module_patches import patch_common_modules from ddtrace.appsec._common_module_patches import unpatch_common_modules from ddtrace.appsec._constants import IAST @@ -28,6 +27,7 @@ from ddtrace.internal.telemetry.constants import TELEMETRY_NAMESPACE_TAG_IAST from ddtrace.internal.telemetry.constants import TELEMETRY_TYPE_GENERATE_METRICS from tests.appsec.iast.aspects.conftest import _iast_patched_module +from tests.appsec.utils import asm_context from tests.utils import DummyTracer from tests.utils import override_env from tests.utils import override_global_config @@ -74,7 +74,7 @@ def test_metric_executed_sink(no_request_sampling, telemetry_writer): tracer = DummyTracer(iast_enabled=True) telemetry_writer._namespace.flush() - with _asm_request_context.asm_request_context_manager(), tracer.trace("test", span_type=SpanTypes.WEB) as span: + with asm_context(tracer=tracer) as span: import hashlib m = hashlib.new("md5") diff --git a/tests/appsec/utils.py b/tests/appsec/utils.py index d66f037816e..b4c010d1123 100644 --- a/tests/appsec/utils.py +++ b/tests/appsec/utils.py @@ -1,4 +1,12 @@ +import contextlib import sys +import typing + +from ddtrace import tracer as default_tracer +from ddtrace.ext import SpanTypes +import ddtrace.internal.core as core +from ddtrace.settings.asm import config as asm_config +from tests.utils import override_global_config class Either: @@ -10,3 +18,33 @@ def __eq__(self, other): print(f"Either: Expected {other} to be in {self.possibilities}", file=sys.stderr, flush=True) return False return True + + +@contextlib.contextmanager +def asm_context( + tracer=None, + span_name: str = "", + ip_addr: typing.Optional[str] = None, + headers_case_sensitive: bool = False, + headers: typing.Optional[typing.Dict[str, str]] = None, + block_request_callable: typing.Optional[typing.Callable[[], bool]] = None, + service: typing.Optional[str] = None, + config=None, +): + with override_global_config(config) if config else contextlib.nullcontext(): + if tracer is None: + tracer = default_tracer + if asm_config._asm_enabled: + tracer._asm_enabled = True + if config: + tracer.configure(api_version="v0.4") + + with core.context_with_data( + "test.asm", + remote_addr=ip_addr, + headers_case_sensitive=headers_case_sensitive, + headers=headers, + block_request_callable=block_request_callable, + service=service, + ), tracer.trace(span_name or "test", span_type=SpanTypes.WEB, service=service) as span: + yield span diff --git a/tests/contrib/django/test_django_appsec_snapshots.py b/tests/contrib/django/test_django_appsec_snapshots.py index d040417d50b..6a1f2eb28db 100644 --- a/tests/contrib/django/test_django_appsec_snapshots.py +++ b/tests/contrib/django/test_django_appsec_snapshots.py @@ -40,20 +40,14 @@ def daphne_client(django_asgi, additional_env=None): # ddtrace-run uses execl which replaces the process but the webserver process itself might spawn new processes. # Right now it doesn't but it's possible that it might in the future (ex. uwsgi). cmd = ["ddtrace-run", "daphne", "-p", str(SERVER_PORT), "tests.contrib.django.asgi:%s" % django_asgi] - proc = subprocess.Popen( - cmd, - stdout=subprocess.DEVNULL, - stderr=subprocess.DEVNULL, - close_fds=True, - env=env, - ) + proc = subprocess.Popen(cmd, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL, env=env) client = Client("http://localhost:%d" % SERVER_PORT) - # Wait for the server to start up - client.wait() - try: + # Wait for the server to start up + client.wait() + yield client finally: resp = client.get_ignored("/shutdown-tracer") diff --git a/tests/contrib/fastapi/test_fastapi_appsec.py b/tests/contrib/fastapi/test_fastapi_appsec.py index f9ce2931a14..69284807d09 100644 --- a/tests/contrib/fastapi/test_fastapi_appsec.py +++ b/tests/contrib/fastapi/test_fastapi_appsec.py @@ -36,14 +36,13 @@ async def test_route(request: Request): body = await request._receive() return PlainTextResponse(body["body"]) - # test if asgi middleware is ok without any callback registered - core.reset_listeners(event_id="asgi.request.parse.body") - payload, content_type = '{"attack": "yqrweytqwreasldhkuqwgervflnmlnli"}', "application/json" with override_global_config(dict(_asm_enabled=True, _asm_static_rule_file=rules.RULES_SRB)): # disable callback _aux_appsec_prepare_tracer(tracer, asm_enabled=True) + # test if asgi middleware is ok without any callback registered + core.reset_listeners(event_id="asgi.request.parse.body") resp = client.post( "/index.html?args=test", data=payload, diff --git a/tests/tracer/test_trace_utils.py b/tests/tracer/test_trace_utils.py index d9c5508df93..cdf6460aec6 100644 --- a/tests/tracer/test_trace_utils.py +++ b/tests/tracer/test_trace_utils.py @@ -30,6 +30,7 @@ from ddtrace.propagation.http import HTTP_HEADER_TRACE_ID from ddtrace.settings import Config from ddtrace.settings import IntegrationConfig +from tests.appsec.utils import asm_context from tests.utils import override_global_config @@ -407,7 +408,7 @@ def test_set_http_meta( int_config.http.trace_headers(["my-header"]) int_config.trace_query_string = True span.span_type = span_type - with override_global_config({"_asm_enabled": appsec_enabled}): + with asm_context(config={"_asm_enabled": appsec_enabled}): trace_utils.set_http_meta( span, int_config, diff --git a/tests/tracer/test_tracer.py b/tests/tracer/test_tracer.py index 59bcbf345cb..bb43e0e6e63 100644 --- a/tests/tracer/test_tracer.py +++ b/tests/tracer/test_tracer.py @@ -2,6 +2,7 @@ """ tests for Tracer and utilities. """ + import contextlib import gc import logging @@ -40,7 +41,6 @@ from ddtrace.internal.writer import AgentWriter from ddtrace.internal.writer import LogWriter from ddtrace.settings import Config -from tests.appsec.appsec.test_processor import tracer_appsec from tests.subprocesstest import run_in_subprocess from tests.utils import TracerTestCase from tests.utils import override_global_config @@ -50,9 +50,9 @@ class TracerTestCases(TracerTestCase): @pytest.fixture(autouse=True) - def inject_fixtures(self, caplog): + def inject_fixtures(self, tracer, caplog): self._caplog = caplog - self._tracer_appsec = tracer_appsec + self._tracer_appsec = tracer def test_tracer_vars(self): span = self.trace("a", service="s", resource="r", span_type="t")