diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py index 770d175aa..6663c249f 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_aws_xray_sampling_client.py @@ -6,6 +6,7 @@ import requests from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule +from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTargetResponse _logger = getLogger(__name__) @@ -19,6 +20,7 @@ def __init__(self, endpoint: str = None, log_level: str = None): if endpoint is None: _logger.error("endpoint must be specified") self.__get_sampling_rules_endpoint = endpoint + "/GetSamplingRules" + self.__get_sampling_targets_endpoint = endpoint + "/SamplingTargets" def get_sampling_rules(self) -> [_SamplingRule]: sampling_rules = [] @@ -30,12 +32,11 @@ def get_sampling_rules(self) -> [_SamplingRule]: _logger.error("GetSamplingRules response is None") return [] sampling_rules_response = xray_response.json() - if "SamplingRuleRecords" not in sampling_rules_response: + if sampling_rules_response is None or "SamplingRuleRecords" not in sampling_rules_response: _logger.error( "SamplingRuleRecords is missing in getSamplingRules response: %s", sampling_rules_response ) return [] - sampling_rules_records = sampling_rules_response["SamplingRuleRecords"] for record in sampling_rules_records: if "SamplingRule" not in record: @@ -47,5 +48,43 @@ def get_sampling_rules(self) -> [_SamplingRule]: _logger.error("Request error occurred: %s", req_err) except json.JSONDecodeError as json_err: _logger.error("Error in decoding JSON response: %s", json_err) + # pylint: disable=broad-exception-caught + except Exception as err: + _logger.error("Error occurred when attempting to fetch rules: %s", err) return sampling_rules + + def get_sampling_targets(self, statistics: [dict]) -> _SamplingTargetResponse: + sampling_targets_response = _SamplingTargetResponse( + LastRuleModification=None, SamplingTargetDocuments=None, UnprocessedStatistics=None + ) + headers = {"content-type": "application/json"} + try: + xray_response = requests.post( + url=self.__get_sampling_targets_endpoint, + headers=headers, + timeout=20, + json={"SamplingStatisticsDocuments": statistics}, + ) + if xray_response is None: + _logger.debug("GetSamplingTargets response is None. Unable to update targets.") + return sampling_targets_response + xray_response_json = xray_response.json() + if ( + xray_response_json is None + or "SamplingTargetDocuments" not in xray_response_json + or "LastRuleModification" not in xray_response_json + ): + _logger.debug("getSamplingTargets response is invalid. Unable to update targets.") + return sampling_targets_response + + sampling_targets_response = _SamplingTargetResponse(**xray_response_json) + except requests.exceptions.RequestException as req_err: + _logger.debug("Request error occurred: %s", req_err) + except json.JSONDecodeError as json_err: + _logger.debug("Error in decoding JSON response: %s", json_err) + # pylint: disable=broad-exception-caught + except Exception as err: + _logger.debug("Error occurred when attempting to fetch targets: %s", err) + + return sampling_targets_response diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_clock.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_clock.py new file mode 100644 index 000000000..c521e96bd --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_clock.py @@ -0,0 +1,22 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import datetime + + +class _Clock: + def __init__(self): + self.__datetime = datetime.datetime + + def now(self) -> datetime.datetime: + return self.__datetime.now() + + # pylint: disable=no-self-use + def from_timestamp(self, timestamp: float) -> datetime.datetime: + return datetime.datetime.fromtimestamp(timestamp) + + def time_delta(self, seconds: float) -> datetime.timedelta: + return datetime.timedelta(seconds=seconds) + + def max(self) -> datetime.datetime: + return datetime.datetime.max diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_fallback_sampler.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_fallback_sampler.py index 986ee0f16..1a3af8239 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_fallback_sampler.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_fallback_sampler.py @@ -2,17 +2,18 @@ # SPDX-License-Identifier: Apache-2.0 from typing import Optional, Sequence +from amazon.opentelemetry.distro.sampler._clock import _Clock +from amazon.opentelemetry.distro.sampler._rate_limiting_sampler import _RateLimitingSampler from opentelemetry.context import Context -from opentelemetry.sdk.trace.sampling import ALWAYS_ON, Sampler, SamplingResult, TraceIdRatioBased +from opentelemetry.sdk.trace.sampling import Decision, Sampler, SamplingResult, TraceIdRatioBased from opentelemetry.trace import Link, SpanKind from opentelemetry.trace.span import TraceState from opentelemetry.util.types import Attributes class _FallbackSampler(Sampler): - def __init__(self): - # TODO: Add Reservoir sampler - # pylint: disable=unused-private-member + def __init__(self, clock: _Clock): + self.__rate_limiting_sampler = _RateLimitingSampler(1, clock) self.__fixed_rate_sampler = TraceIdRatioBased(0.05) # pylint: disable=no-self-use @@ -26,8 +27,12 @@ def should_sample( links: Sequence[Link] = None, trace_state: TraceState = None, ) -> SamplingResult: - # TODO: add reservoir + fixed rate sampling - return ALWAYS_ON.should_sample( + sampling_result = self.__rate_limiting_sampler.should_sample( + parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state + ) + if sampling_result.decision is not Decision.DROP: + return sampling_result + return self.__fixed_rate_sampler.should_sample( parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state ) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rate_limiter.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rate_limiter.py new file mode 100644 index 000000000..64ce3a109 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rate_limiter.py @@ -0,0 +1,42 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from decimal import Decimal +from threading import Lock + +from amazon.opentelemetry.distro.sampler._clock import _Clock + + +class _RateLimiter: + def __init__(self, max_balance_in_seconds: int, quota: int, clock: _Clock): + # max_balance_in_seconds is usually 1 + # pylint: disable=invalid-name + self.MAX_BALANCE_MILLIS = Decimal(max_balance_in_seconds * 1000.0) + self._clock = clock + + self._quota = Decimal(quota) + self.__wallet_floor_millis = Decimal(self._clock.now().timestamp() * 1000.0) + # current "wallet_balance" would be ceiling - floor + + self.__lock = Lock() + + def try_spend(self, cost: float) -> bool: + if self._quota == 0: + return False + + quota_per_millis = self._quota / Decimal(1000.0) + + # assume divide by zero not possible + cost_in_millis = Decimal(cost) / quota_per_millis + + with self.__lock: + wallet_ceiling_millis = Decimal(self._clock.now().timestamp() * 1000.0) + current_balance_millis = wallet_ceiling_millis - self.__wallet_floor_millis + if current_balance_millis > self.MAX_BALANCE_MILLIS: + current_balance_millis = self.MAX_BALANCE_MILLIS + + pending_remaining_balance_millis = current_balance_millis - cost_in_millis + if pending_remaining_balance_millis >= 0: + self.__wallet_floor_millis = wallet_ceiling_millis - pending_remaining_balance_millis + return True + # No changes to the wallet state + return False diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rate_limiting_sampler.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rate_limiting_sampler.py new file mode 100644 index 000000000..e5c4dc3a7 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rate_limiting_sampler.py @@ -0,0 +1,41 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from typing import Optional, Sequence + +from amazon.opentelemetry.distro.sampler._clock import _Clock +from amazon.opentelemetry.distro.sampler._rate_limiter import _RateLimiter +from opentelemetry.context import Context +from opentelemetry.sdk.trace.sampling import Decision, Sampler, SamplingResult +from opentelemetry.trace import Link, SpanKind +from opentelemetry.trace.span import TraceState +from opentelemetry.util.types import Attributes + + +class _RateLimitingSampler(Sampler): + def __init__(self, quota: int, clock: _Clock): + self.__quota = quota + self.__reservoir = _RateLimiter(1, quota, clock) + + # pylint: disable=no-self-use + def should_sample( + self, + parent_context: Optional[Context], + trace_id: int, + name: str, + kind: SpanKind = None, + attributes: Attributes = None, + links: Sequence[Link] = None, + trace_state: TraceState = None, + ) -> SamplingResult: + if self.__reservoir.try_spend(1): + return SamplingResult(decision=Decision.RECORD_AND_SAMPLE, attributes=attributes, trace_state=trace_state) + return SamplingResult(decision=Decision.DROP, attributes=attributes, trace_state=trace_state) + + # pylint: disable=no-self-use + def get_description(self) -> str: + description = ( + "RateLimitingSampler{rate limiting sampling with sampling config of " + + self.__quota + + " req/sec and 0% of additional requests}" + ) + return description diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rule_cache.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rule_cache.py index 98afd6448..e57f2bae3 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rule_cache.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_rule_cache.py @@ -1,13 +1,14 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import datetime from logging import getLogger from threading import Lock -from typing import Optional, Sequence +from typing import Dict, Optional, Sequence +from amazon.opentelemetry.distro.sampler._clock import _Clock from amazon.opentelemetry.distro.sampler._fallback_sampler import _FallbackSampler from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule from amazon.opentelemetry.distro.sampler._sampling_rule_applier import _SamplingRuleApplier +from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTarget, _SamplingTargetResponse from opentelemetry.context import Context from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace.sampling import SamplingResult @@ -18,16 +19,20 @@ _logger = getLogger(__name__) CACHE_TTL_SECONDS = 3600 +DEFAULT_TARGET_POLLING_INTERVAL_SECONDS = 10 class _RuleCache: - def __init__(self, resource: Resource, fallback_sampler: _FallbackSampler, date_time: datetime, lock: Lock): + def __init__( + self, resource: Resource, fallback_sampler: _FallbackSampler, client_id: str, clock: _Clock, lock: Lock + ): + self.__client_id = client_id self.__rule_appliers: [_SamplingRuleApplier] = [] self.__cache_lock = lock self.__resource = resource self._fallback_sampler = fallback_sampler - self._date_time = date_time - self._last_modified = self._date_time.datetime.now() + self._clock = clock + self._last_modified = self._clock.now() def should_sample( self, @@ -39,6 +44,7 @@ def should_sample( links: Sequence[Link] = None, trace_state: TraceState = None, ) -> SamplingResult: + rule_applier: _SamplingRuleApplier for rule_applier in self.__rule_appliers: if rule_applier.matches(self.__resource, attributes): return rule_applier.should_sample( @@ -51,6 +57,8 @@ def should_sample( trace_state=trace_state, ) + _logger.debug("No sampling rules were matched") + # Should not ever reach fallback sampler as default rule is able to match return self._fallback_sampler.should_sample( parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state ) @@ -65,14 +73,17 @@ def update_sampling_rules(self, new_sampling_rules: [_SamplingRule]) -> None: if sampling_rule.Version != 1: _logger.debug("sampling rule without Version 1 is not supported: RuleName: %s", sampling_rule.RuleName) continue - temp_rule_appliers.append(_SamplingRuleApplier(sampling_rule)) + temp_rule_appliers.append(_SamplingRuleApplier(sampling_rule, self.__client_id, self._clock)) self.__cache_lock.acquire() # map list of rule appliers by each applier's sampling_rule name - rule_applier_map = {rule.sampling_rule.RuleName: rule for rule in self.__rule_appliers} + rule_applier_map: Dict[str, _SamplingRuleApplier] = { + applier.sampling_rule.RuleName: applier for applier in self.__rule_appliers + } # If a sampling rule has not changed, keep its respective applier in the cache. + new_applier: _SamplingRuleApplier for index, new_applier in enumerate(temp_rule_appliers): rule_name_to_check = new_applier.sampling_rule.RuleName if rule_name_to_check in rule_applier_map: @@ -80,13 +91,52 @@ def update_sampling_rules(self, new_sampling_rules: [_SamplingRule]) -> None: if new_applier.sampling_rule == old_applier.sampling_rule: temp_rule_appliers[index] = old_applier self.__rule_appliers = temp_rule_appliers - self._last_modified = datetime.datetime.now() + self._last_modified = self._clock.now() self.__cache_lock.release() + def update_sampling_targets(self, sampling_targets_response: _SamplingTargetResponse) -> (bool, int): + targets: [_SamplingTarget] = sampling_targets_response.SamplingTargetDocuments + + with self.__cache_lock: + next_polling_interval = DEFAULT_TARGET_POLLING_INTERVAL_SECONDS + min_polling_interval = None + + target_map: Dict[str, _SamplingTarget] = {target.RuleName: target for target in targets} + + new_appliers = [] + applier: _SamplingRuleApplier + for applier in self.__rule_appliers: + if applier.sampling_rule.RuleName in target_map: + target = target_map[applier.sampling_rule.RuleName] + new_appliers.append(applier.with_target(target)) + + if target.Interval is not None: + if min_polling_interval is None or min_polling_interval > target.Interval: + min_polling_interval = target.Interval + else: + new_appliers.append(applier) + + self.__rule_appliers = new_appliers + + if min_polling_interval is not None: + next_polling_interval = min_polling_interval + + last_rule_modification = self._clock.from_timestamp(sampling_targets_response.LastRuleModification) + refresh_rules = last_rule_modification > self._last_modified + + return (refresh_rules, next_polling_interval) + + def get_all_statistics(self) -> [dict]: + all_statistics = [] + applier: _SamplingRuleApplier + for applier in self.__rule_appliers: + all_statistics.append(applier.get_then_reset_statistics()) + return all_statistics + def expired(self) -> bool: self.__cache_lock.acquire() try: - return datetime.datetime.now() > self._last_modified + datetime.timedelta(seconds=CACHE_TTL_SECONDS) + return self._clock.now() > self._last_modified + self._clock.time_delta(seconds=CACHE_TTL_SECONDS) finally: self.__cache_lock.release() diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py index 0f64f6b96..73dce1404 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule.py @@ -36,7 +36,7 @@ def __init__( self.URLPath = URLPath if URLPath is not None else "" self.Version = Version if Version is not None else 0 - def __lt__(self, other) -> bool: + def __lt__(self, other: "_SamplingRule") -> bool: if self.Priority == other.Priority: # String order priority example: # "A","Abc","a","ab","abc","abcdef" diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule_applier.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule_applier.py index 43c9bf59b..8284e7cc6 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule_applier.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_rule_applier.py @@ -1,13 +1,18 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +from threading import Lock from typing import Optional, Sequence from urllib.parse import urlparse +from amazon.opentelemetry.distro.sampler._clock import _Clock from amazon.opentelemetry.distro.sampler._matcher import _Matcher, cloud_platform_mapping +from amazon.opentelemetry.distro.sampler._rate_limiting_sampler import _RateLimitingSampler from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule +from amazon.opentelemetry.distro.sampler._sampling_statistics_document import _SamplingStatisticsDocument +from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTarget from opentelemetry.context import Context from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace.sampling import ALWAYS_ON, SamplingResult +from opentelemetry.sdk.trace.sampling import Decision, ParentBased, Sampler, SamplingResult, TraceIdRatioBased from opentelemetry.semconv.resource import CloudPlatformValues, ResourceAttributes from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.trace import Link, SpanKind @@ -16,14 +21,46 @@ class _SamplingRuleApplier: - def __init__(self, sampling_rule: _SamplingRule): + def __init__( + self, + sampling_rule: _SamplingRule, + client_id: str, + clock: _Clock, + statistics: _SamplingStatisticsDocument = None, + target: _SamplingTarget = None, + ): + self.__client_id = client_id + self._clock = clock self.sampling_rule = sampling_rule - # TODO add self.next_target_fetch_time from maybe time.process_time() or cache's datetime object - # TODO add statistics - # TODO change to rate limiter given rate, add fixed rate sampler - self.reservoir_sampler = ALWAYS_ON - # self.fixed_rate_sampler = None - # TODO add clientId + + if statistics is None: + self.__statistics = _SamplingStatisticsDocument(self.__client_id, self.sampling_rule.RuleName) + else: + self.__statistics = statistics + self.__statistics_lock = Lock() + + self.__borrowing = False + + if target is None: + self.__fixed_rate_sampler = ParentBased(TraceIdRatioBased(self.sampling_rule.FixedRate)) + # Until targets are fetched, initialize as borrowing=True if there will be a quota > 0 + if self.sampling_rule.ReservoirSize > 0: + self.__reservoir_sampler = self.__create_reservoir_sampler(quota=1) + self.__borrowing = True + else: + self.__reservoir_sampler = self.__create_reservoir_sampler(quota=0) + # No targets are present, borrow until the end of time if there is any quota + self.__reservoir_expiry = self._clock.max() + else: + new_quota = target.ReservoirQuota if target.ReservoirQuota is not None else 0 + new_fixed_rate = target.FixedRate if target.FixedRate is not None else 0 + self.__reservoir_sampler = self.__create_reservoir_sampler(quota=new_quota) + self.__fixed_rate_sampler = ParentBased(TraceIdRatioBased(new_fixed_rate)) + if target.ReservoirQuotaTTL is not None: + self.__reservoir_expiry = self._clock.from_timestamp(target.ReservoirQuotaTTL) + else: + # assume expired if no TTL + self.__reservoir_expiry = self._clock.now() def should_sample( self, @@ -35,9 +72,43 @@ def should_sample( links: Sequence[Link] = None, trace_state: TraceState = None, ) -> SamplingResult: - return self.reservoir_sampler.should_sample( - parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state - ) + has_borrowed = False + has_sampled = False + sampling_result = SamplingResult(decision=Decision.DROP, attributes=attributes, trace_state=trace_state) + + reservoir_expired: bool = self._clock.now() >= self.__reservoir_expiry + if not reservoir_expired: + sampling_result = self.__reservoir_sampler.should_sample( + parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state + ) + + if sampling_result.decision is not Decision.DROP: + has_borrowed = self.__borrowing + has_sampled = True + else: + sampling_result = self.__fixed_rate_sampler.should_sample( + parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state + ) + if sampling_result.decision is not Decision.DROP: + has_sampled = True + + with self.__statistics_lock: + self.__statistics.RequestCount += 1 + self.__statistics.BorrowCount += 1 if has_borrowed else 0 + self.__statistics.SampleCount += 1 if has_sampled else 0 + + return sampling_result + + def get_then_reset_statistics(self) -> dict: + with self.__statistics_lock: + old_stats = self.__statistics + self.__statistics = _SamplingStatisticsDocument(self.__client_id, self.sampling_rule.RuleName) + + return old_stats.snapshot(self._clock) + + def with_target(self, target: _SamplingTarget) -> "_SamplingRuleApplier": + new_applier = _SamplingRuleApplier(self.sampling_rule, self.__client_id, self._clock, self.__statistics, target) + return new_applier def matches(self, resource: Resource, attributes: Attributes) -> bool: url_path = None @@ -47,10 +118,16 @@ def matches(self, resource: Resource, attributes: Attributes) -> bool: service_name = None if attributes is not None: - url_path = attributes.get(SpanAttributes.URL_PATH, None) - url_full = attributes.get(SpanAttributes.URL_FULL, None) - http_request_method = attributes.get(SpanAttributes.HTTP_REQUEST_METHOD, None) - server_address = attributes.get(SpanAttributes.SERVER_ADDRESS, None) + # If `URL_PATH/URL_FULL/HTTP_REQUEST_METHOD/SERVER_ADDRESS` are not populated + # also check `HTTP_TARGET/HTTP_URL/HTTP_METHOD/HTTP_HOST` respectively as backup + url_path = attributes.get(SpanAttributes.URL_PATH, attributes.get(SpanAttributes.HTTP_TARGET, None)) + url_full = attributes.get(SpanAttributes.URL_FULL, attributes.get(SpanAttributes.HTTP_URL, None)) + http_request_method = attributes.get( + SpanAttributes.HTTP_REQUEST_METHOD, attributes.get(SpanAttributes.HTTP_METHOD, None) + ) + server_address = attributes.get( + SpanAttributes.SERVER_ADDRESS, attributes.get(SpanAttributes.HTTP_HOST, None) + ) # Resource shouldn't be none as it should default to empty resource if resource is not None: @@ -60,8 +137,8 @@ def matches(self, resource: Resource, attributes: Attributes) -> bool: if url_path is None and url_full is not None: scheme_end_index = url_full.find("://") # For network calls, URL usually has `scheme://host[:port][path][?query][#fragment]` format - # Per spec, url.full is always populated with scheme://host/target. - # If scheme doesn't match, assume it's bad instrumentation and ignore. + # Per spec, url.full is always populated with scheme:// + # If scheme is not present, assume it's bad instrumentation and ignore. if scheme_end_index > -1: # urlparse("scheme://netloc/path;parameters?query#fragment") url_path = urlparse(url_full).path @@ -81,6 +158,9 @@ def matches(self, resource: Resource, attributes: Attributes) -> bool: and _Matcher.wild_card_match(self.__get_arn(resource, attributes), self.sampling_rule.ResourceARN) ) + def __create_reservoir_sampler(self, quota: int) -> Sampler: + return ParentBased(_RateLimitingSampler(quota, self._clock)) + # pylint: disable=no-self-use def __get_service_type(self, resource: Resource) -> str: if resource is None: @@ -98,10 +178,25 @@ def __get_arn(self, resource: Resource, attributes: Attributes) -> str: arn = resource.attributes.get(ResourceAttributes.AWS_ECS_CONTAINER_ARN, None) if arn is not None: return arn - if attributes is not None and self.__get_service_type(resource=resource) == cloud_platform_mapping.get( - CloudPlatformValues.AWS_LAMBDA.value + if ( + resource is not None + and resource.attributes.get(ResourceAttributes.CLOUD_PLATFORM) == CloudPlatformValues.AWS_LAMBDA.value ): - arn = attributes.get(SpanAttributes.CLOUD_RESOURCE_ID, None) - if arn is not None: - return arn + return self.__get_lambda_arn(resource, attributes) + return "" + + def __get_lambda_arn(self, resource: Resource, attributes: Attributes) -> str: + arn = resource.attributes.get( + ResourceAttributes.CLOUD_RESOURCE_ID, resource.attributes.get(ResourceAttributes.FAAS_ID, None) + ) + if arn is not None: + return arn + + # Note from `SpanAttributes.CLOUD_RESOURCE_ID`: + # "On some cloud providers, it may not be possible to determine the full ID at startup, + # so it may be necessary to set cloud.resource_id as a span attribute instead." + arn = attributes.get(SpanAttributes.CLOUD_RESOURCE_ID, attributes.get("faas.id", None)) + if arn is not None: + return arn + return "" diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_statistics_document.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_statistics_document.py new file mode 100644 index 000000000..849f2c846 --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_statistics_document.py @@ -0,0 +1,27 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +from amazon.opentelemetry.distro.sampler._clock import _Clock + + +# Disable snake_case naming style so this class can match the statistics document response from X-Ray +# pylint: disable=invalid-name +class _SamplingStatisticsDocument: + def __init__(self, clientID: str, ruleName: str, RequestCount: int = 0, BorrowCount: int = 0, SampleCount: int = 0): + self.ClientID = clientID + self.RuleName = ruleName + self.Timestamp = None + + self.RequestCount = RequestCount + self.BorrowCount = BorrowCount + self.SampleCount = SampleCount + + def snapshot(self, clock: _Clock) -> dict: + return { + "ClientID": self.ClientID, + "RuleName": self.RuleName, + "Timestamp": clock.now().timestamp(), + "RequestCount": self.RequestCount, + "BorrowCount": self.BorrowCount, + "SampleCount": self.SampleCount, + } diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_target.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_target.py new file mode 100644 index 000000000..32c32f5ce --- /dev/null +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/_sampling_target.py @@ -0,0 +1,61 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from logging import getLogger + +_logger = getLogger(__name__) + + +# Disable snake_case naming style so this class can match the sampling rules response from X-Ray +# pylint: disable=invalid-name +class _SamplingTarget: + def __init__( + self, + FixedRate: float = None, + Interval: int = None, + ReservoirQuota: int = None, + ReservoirQuotaTTL: float = None, + RuleName: str = None, + ): + self.FixedRate = FixedRate if FixedRate is not None else 0.0 + self.Interval = Interval # can be None + self.ReservoirQuota = ReservoirQuota # can be None + self.ReservoirQuotaTTL = ReservoirQuotaTTL # can be None + self.RuleName = RuleName if RuleName is not None else "" + + +class _UnprocessedStatistics: + def __init__( + self, + ErrorCode: str = None, + Message: str = None, + RuleName: str = None, + ): + self.ErrorCode = ErrorCode if ErrorCode is not None else "" + self.Message = Message if ErrorCode is not None else "" + self.RuleName = RuleName if ErrorCode is not None else "" + + +class _SamplingTargetResponse: + def __init__( + self, + LastRuleModification: float, + SamplingTargetDocuments: [dict] = None, + UnprocessedStatistics: [dict] = None, + ): + self.LastRuleModification: float = LastRuleModification if LastRuleModification is not None else 0.0 + + self.SamplingTargetDocuments: [_SamplingTarget] = [] + if SamplingTargetDocuments is not None: + for document in SamplingTargetDocuments: + try: + self.SamplingTargetDocuments.append(_SamplingTarget(**document)) + except TypeError as e: + _logger.debug("TypeError occurred: %s", e) + + self.UnprocessedStatistics: [_UnprocessedStatistics] = [] + if UnprocessedStatistics is not None: + for unprocessed in UnprocessedStatistics: + try: + self.UnprocessedStatistics.append(_UnprocessedStatistics(**unprocessed)) + except TypeError as e: + _logger.debug("TypeError occurred: %s", e) diff --git a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py index 52ed58622..e945fc049 100644 --- a/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py +++ b/aws-opentelemetry-distro/src/amazon/opentelemetry/distro/sampler/aws_xray_remote_sampler.py @@ -1,6 +1,5 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -import datetime import random from logging import getLogger from threading import Lock, Timer @@ -9,11 +8,12 @@ from typing_extensions import override from amazon.opentelemetry.distro.sampler._aws_xray_sampling_client import _AwsXRaySamplingClient +from amazon.opentelemetry.distro.sampler._clock import _Clock from amazon.opentelemetry.distro.sampler._fallback_sampler import _FallbackSampler -from amazon.opentelemetry.distro.sampler._rule_cache import _RuleCache +from amazon.opentelemetry.distro.sampler._rule_cache import DEFAULT_TARGET_POLLING_INTERVAL_SECONDS, _RuleCache from opentelemetry.context import Context from opentelemetry.sdk.resources import Resource -from opentelemetry.sdk.trace.sampling import Sampler, SamplingResult +from opentelemetry.sdk.trace.sampling import ParentBased, Sampler, SamplingResult from opentelemetry.trace import Link, SpanKind from opentelemetry.trace.span import TraceState from opentelemetry.util.types import Attributes @@ -21,7 +21,6 @@ _logger = getLogger(__name__) DEFAULT_RULES_POLLING_INTERVAL_SECONDS = 300 -DEFAULT_TARGET_POLLING_INTERVAL_SECONDS = 10 DEFAULT_SAMPLING_PROXY_ENDPOINT = "http://127.0.0.1:2000" @@ -36,30 +35,36 @@ class AwsXRayRemoteSampler(Sampler): log_level: custom log level configuration for remote sampler (Optional) """ - __resource: Resource - __polling_interval: int - __xray_client: _AwsXRaySamplingClient - def __init__( self, resource: Resource, - endpoint=DEFAULT_SAMPLING_PROXY_ENDPOINT, - polling_interval=DEFAULT_RULES_POLLING_INTERVAL_SECONDS, + endpoint: str = None, + polling_interval: int = None, log_level=None, ): # Override default log level if log_level is not None: _logger.setLevel(log_level) - self.__date_time = datetime + if endpoint is None: + _logger.info("`endpoint` is `None`. Defaulting to %s", DEFAULT_SAMPLING_PROXY_ENDPOINT) + endpoint = DEFAULT_SAMPLING_PROXY_ENDPOINT + if polling_interval is None or polling_interval < 10: + _logger.info( + "`polling_interval` is `None` or too small. Defaulting to %s", DEFAULT_RULES_POLLING_INTERVAL_SECONDS + ) + polling_interval = DEFAULT_RULES_POLLING_INTERVAL_SECONDS + + self.__client_id = self.__generate_client_id() + self._clock = _Clock() self.__xray_client = _AwsXRaySamplingClient(endpoint, log_level=log_level) - self.__rule_polling_jitter = random.uniform(0.0, 5.0) - self.__polling_interval = polling_interval - self.__fallback_sampler = _FallbackSampler() + self.__fallback_sampler = ParentBased(_FallbackSampler(self._clock)) - # TODO add client id + self.__polling_interval = polling_interval + self.__target_polling_interval = DEFAULT_TARGET_POLLING_INTERVAL_SECONDS + self.__rule_polling_jitter = random.uniform(0.0, 5.0) + self.__target_polling_jitter = random.uniform(0.0, 0.1) - # pylint: disable=unused-private-member if resource is not None: self.__resource = resource else: @@ -68,14 +73,19 @@ def __init__( self.__rule_cache_lock = Lock() self.__rule_cache = _RuleCache( - self.__resource, self.__fallback_sampler, self.__date_time, self.__rule_cache_lock + self.__resource, self.__fallback_sampler, self.__client_id, self._clock, self.__rule_cache_lock ) # Schedule the next rule poll now # Python Timers only run once, so they need to be recreated for every poll - self._timer = Timer(0, self.__start_sampling_rule_poller) - self._timer.daemon = True # Ensures that when the main thread exits, the Timer threads are killed - self._timer.start() + self._rules_timer = Timer(0, self.__start_sampling_rule_poller) + self._rules_timer.daemon = True # Ensures that when the main thread exits, the Timer threads are killed + self._rules_timer.start() + + # set up the target poller to go off once after the default interval. Subsequent polls may use new intervals. + self._targets_timer = Timer(DEFAULT_TARGET_POLLING_INTERVAL_SECONDS, self.__start_sampling_target_poller) + self._targets_timer.daemon = True # Ensures that when the main thread exits, the Timer threads are killed + self._targets_timer.start() # pylint: disable=no-self-use @override @@ -89,7 +99,6 @@ def should_sample( links: Sequence[Link] = None, trace_state: TraceState = None, ) -> SamplingResult: - if self.__rule_cache.expired(): _logger.debug("Rule cache is expired so using fallback sampling strategy") return self.__fallback_sampler.should_sample( @@ -113,6 +122,33 @@ def __get_and_update_sampling_rules(self) -> None: def __start_sampling_rule_poller(self) -> None: self.__get_and_update_sampling_rules() # Schedule the next sampling rule poll - self._timer = Timer(self.__polling_interval + self.__rule_polling_jitter, self.__start_sampling_rule_poller) - self._timer.daemon = True - self._timer.start() + self._rules_timer = Timer( + self.__polling_interval + self.__rule_polling_jitter, self.__start_sampling_rule_poller + ) + self._rules_timer.daemon = True + self._rules_timer.start() + + def __get_and_update_sampling_targets(self) -> None: + all_statistics = self.__rule_cache.get_all_statistics() + sampling_targets_response = self.__xray_client.get_sampling_targets(all_statistics) + refresh_rules, min_polling_interval = self.__rule_cache.update_sampling_targets(sampling_targets_response) + if refresh_rules: + self.__get_and_update_sampling_rules() + if min_polling_interval is not None: + self.__target_polling_interval = min_polling_interval + + def __start_sampling_target_poller(self) -> None: + self.__get_and_update_sampling_targets() + # Schedule the next sampling targets poll + self._targets_timer = Timer( + self.__target_polling_interval + self.__target_polling_jitter, self.__start_sampling_target_poller + ) + self._targets_timer.daemon = True + self._targets_timer.start() + + def __generate_client_id(self) -> str: + hex_chars = ["0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f"] + client_id_array = [] + for _ in range(0, 24): + client_id_array.append(random.choice(hex_chars)) + return "".join(client_id_array) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/get-sampling-targets-response-sample.json b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/get-sampling-targets-response-sample.json new file mode 100644 index 000000000..498fe1505 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/get-sampling-targets-response-sample.json @@ -0,0 +1,20 @@ +{ + "LastRuleModification": 1707551387.0, + "SamplingTargetDocuments": [ + { + "FixedRate": 0.10, + "Interval": 10, + "ReservoirQuota": 30, + "ReservoirQuotaTTL": 1707764006.0, + "RuleName": "test" + }, + { + "FixedRate": 0.05, + "Interval": 10, + "ReservoirQuota": 0, + "ReservoirQuotaTTL": 1707764006.0, + "RuleName": "Default" + } + ], + "UnprocessedStatistics": [] +} \ No newline at end of file diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/test-remote-sampler_sampling-rules-response-sample.json b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/test-remote-sampler_sampling-rules-response-sample.json new file mode 100644 index 000000000..a5c0d2cb5 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/test-remote-sampler_sampling-rules-response-sample.json @@ -0,0 +1,45 @@ +{ + "NextToken": null, + "SamplingRuleRecords": [ + { + "CreatedAt": 1.676038494E9, + "ModifiedAt": 1.676038494E9, + "SamplingRule": { + "Attributes": {}, + "FixedRate": 1.0, + "HTTPMethod": "*", + "Host": "*", + "Priority": 10000, + "ReservoirSize": 0, + "ResourceARN": "*", + "RuleARN": "arn:aws:xray:us-east-1:999999999999:sampling-rule/Default", + "RuleName": "Default", + "ServiceName": "*", + "ServiceType": "*", + "URLPath": "*", + "Version": 1 + } + }, + { + "CreatedAt": 1.67799933E9, + "ModifiedAt": 1.67799933E9, + "SamplingRule": { + "Attributes": { + "abc": "1234" + }, + "FixedRate": 0, + "HTTPMethod": "*", + "Host": "*", + "Priority": 20, + "ReservoirSize": 0, + "ResourceARN": "*", + "RuleARN": "arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + "RuleName": "test", + "ServiceName": "*", + "ServiceType": "*", + "URLPath": "*", + "Version": 1 + } + } + ] +} \ No newline at end of file diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/test-remote-sampler_sampling-targets-response-sample.json b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/test-remote-sampler_sampling-targets-response-sample.json new file mode 100644 index 000000000..244bf0d06 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/data/test-remote-sampler_sampling-targets-response-sample.json @@ -0,0 +1,20 @@ +{ + "LastRuleModification": 1707551387.0, + "SamplingTargetDocuments": [ + { + "FixedRate": 0.0, + "Interval": 100000, + "ReservoirQuota": 100000, + "ReservoirQuotaTTL": 9999999999.0, + "RuleName": "test" + }, + { + "FixedRate": 0.0, + "Interval": 1000, + "ReservoirQuota": 100, + "ReservoirQuotaTTL": 9999999999.0, + "RuleName": "Default" + } + ], + "UnprocessedStatistics": [] +} \ No newline at end of file diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/mock_clock.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/mock_clock.py new file mode 100644 index 000000000..4e941c7d7 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/mock_clock.py @@ -0,0 +1,21 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +import datetime + +from amazon.opentelemetry.distro.sampler._clock import _Clock + + +class MockClock(_Clock): + def __init__(self, dt: datetime.datetime = datetime.datetime.now()): + self.time_now = dt + super() + + def now(self) -> datetime.datetime: + return self.time_now + + def add_time(self, seconds: float) -> None: + self.time_now += self.time_delta(seconds) + + def set_time(self, dt: datetime.datetime) -> None: + self.time_now = dt diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py index 17d0d5f97..b4fbd7eba 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_remote_sampler.py @@ -1,25 +1,67 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import json +import os +import threading +import time from logging import DEBUG from unittest import TestCase +from unittest.mock import patch + +from mock_clock import MockClock from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.sampling import Decision + +TEST_DIR = os.path.dirname(os.path.realpath(__file__)) +DATA_DIR = os.path.join(TEST_DIR, "data") + + +def create_spans(sampled_array, thread_id, span_attributes, remote_sampler, number_of_spans): + sampled = 0 + for _ in range(0, number_of_spans): + if remote_sampler.should_sample(None, 0, "name", attributes=span_attributes).decision != Decision.DROP: + sampled += 1 + sampled_array[thread_id] = sampled + + +def mocked_requests_get(*args, **kwargs): + class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data + + if kwargs["url"] == "http://127.0.0.1:2000/GetSamplingRules": + with open(f"{DATA_DIR}/test-remote-sampler_sampling-rules-response-sample.json", encoding="UTF-8") as file: + sample_response = json.load(file) + file.close() + return MockResponse(sample_response, 200) + if kwargs["url"] == "http://127.0.0.1:2000/SamplingTargets": + with open(f"{DATA_DIR}/test-remote-sampler_sampling-targets-response-sample.json", encoding="UTF-8") as file: + sample_response = json.load(file) + file.close() + return MockResponse(sample_response, 200) + return MockResponse(None, 404) class TestAwsXRayRemoteSampler(TestCase): def test_create_remote_sampler_with_empty_resource(self): rs = AwsXRayRemoteSampler(resource=Resource.get_empty()) - self.assertIsNotNone(rs._timer) + self.assertIsNotNone(rs._rules_timer) self.assertEqual(rs._AwsXRayRemoteSampler__polling_interval, 300) self.assertIsNotNone(rs._AwsXRayRemoteSampler__xray_client) self.assertIsNotNone(rs._AwsXRayRemoteSampler__resource) + self.assertTrue(len(rs._AwsXRayRemoteSampler__client_id), 24) def test_create_remote_sampler_with_populated_resource(self): rs = AwsXRayRemoteSampler( resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"}) ) - self.assertIsNotNone(rs._timer) + self.assertIsNotNone(rs._rules_timer) self.assertEqual(rs._AwsXRayRemoteSampler__polling_interval, 300) self.assertIsNotNone(rs._AwsXRayRemoteSampler__xray_client) self.assertIsNotNone(rs._AwsXRayRemoteSampler__resource) @@ -33,7 +75,7 @@ def test_create_remote_sampler_with_all_fields_populated(self): polling_interval=120, log_level=DEBUG, ) - self.assertIsNotNone(rs._timer) + self.assertIsNotNone(rs._rules_timer) self.assertEqual(rs._AwsXRayRemoteSampler__polling_interval, 120) self.assertIsNotNone(rs._AwsXRayRemoteSampler__xray_client) self.assertIsNotNone(rs._AwsXRayRemoteSampler__resource) @@ -43,3 +85,128 @@ def test_create_remote_sampler_with_all_fields_populated(self): ) self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["service.name"], "test-service-name") self.assertEqual(rs._AwsXRayRemoteSampler__resource.attributes["cloud.platform"], "test-cloud-platform") + + @patch("requests.post", side_effect=mocked_requests_get) + @patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS", 2) + def test_update_sampling_rules_and_targets_with_pollers_and_should_sample(self, mock_post=None): + rs = AwsXRayRemoteSampler( + resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"}) + ) + self.assertEqual(rs._AwsXRayRemoteSampler__target_polling_interval, 2) + + time.sleep(1.0) + self.assertEqual( + rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0].sampling_rule.RuleName, "test" + ) + self.assertEqual(rs.should_sample(None, 0, "name", attributes={"abc": "1234"}).decision, Decision.DROP) + + # wait 2 more seconds since targets polling was patched to 2 seconds (rather than 10s) + time.sleep(2.0) + self.assertEqual(rs._AwsXRayRemoteSampler__target_polling_interval, 1000) + self.assertEqual( + rs.should_sample(None, 0, "name", attributes={"abc": "1234"}).decision, Decision.RECORD_AND_SAMPLE + ) + self.assertEqual( + rs.should_sample(None, 0, "name", attributes={"abc": "1234"}).decision, Decision.RECORD_AND_SAMPLE + ) + self.assertEqual( + rs.should_sample(None, 0, "name", attributes={"abc": "1234"}).decision, Decision.RECORD_AND_SAMPLE + ) + + @patch("requests.post", side_effect=mocked_requests_get) + @patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS", 3) + def test_multithreading_with_large_reservoir_with_otel_sdk(self, mock_post=None): + rs = AwsXRayRemoteSampler( + resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"}) + ) + attributes = {"abc": "1234"} + + time.sleep(2.0) + self.assertEqual(rs.should_sample(None, 0, "name", attributes=attributes).decision, Decision.DROP) + + # wait 3 more seconds since targets polling was patched to 2 seconds (rather than 10s) + time.sleep(3.0) + + number_of_spans = 100 + thread_count = 1000 + sampled_array = [] + threads = [] + + for idx in range(0, thread_count): + sampled_array.append(0) + threads.append( + threading.Thread( + target=create_spans, + name="thread_" + str(idx), + daemon=True, + args=(sampled_array, idx, attributes, rs, number_of_spans), + ) + ) + threads[idx].start() + sum_sampled = 0 + + for idx in range(0, thread_count): + threads[idx].join() + sum_sampled += sampled_array[idx] + + test_rule_applier = rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[0] + self.assertEqual( + test_rule_applier._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, + 100000, + ) + self.assertEqual(sum_sampled, 100000) + + # pylint: disable=no-member + @patch("requests.post", side_effect=mocked_requests_get) + @patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler.DEFAULT_TARGET_POLLING_INTERVAL_SECONDS", 2) + @patch("amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler._Clock", MockClock) + def test_multithreading_with_some_reservoir_with_otel_sdk(self, mock_post=None): + rs = AwsXRayRemoteSampler( + resource=Resource.create({"service.name": "test-service-name", "cloud.platform": "test-cloud-platform"}) + ) + attributes = {"abc": "non-matching attribute value, use default rule"} + + # Using normal clock, finishing all thread jobs will take more than a second, + # which will eat up more than 1 second of reservoir. Using MockClock we can freeze time + # and pretend all thread jobs start and end at the exact same time, + # assume and test exactly 1 second of reservoir (100 quota) only + mock_clock: MockClock = rs._clock + + time.sleep(1.0) + mock_clock.add_time(1.0) + self.assertEqual(mock_clock.now(), rs._clock.now()) + self.assertEqual(rs.should_sample(None, 0, "name", attributes=attributes).decision, Decision.RECORD_AND_SAMPLE) + + # wait 2 more seconds since targets polling was patched to 2 seconds (rather than 10s) + time.sleep(2.0) + mock_clock.add_time(2.0) + self.assertEqual(mock_clock.now(), rs._clock.now()) + + number_of_spans = 100 + thread_count = 1000 + sampled_array = [] + threads = [] + + for idx in range(0, thread_count): + sampled_array.append(0) + threads.append( + threading.Thread( + target=create_spans, + name="thread_" + str(idx), + daemon=True, + args=(sampled_array, idx, attributes, rs, number_of_spans), + ) + ) + threads[idx].start() + + sum_sampled = 0 + for idx in range(0, thread_count): + threads[idx].join() + sum_sampled += sampled_array[idx] + + default_rule_applier = rs._AwsXRayRemoteSampler__rule_cache._RuleCache__rule_appliers[1] + self.assertEqual( + default_rule_applier._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, + 100, + ) + self.assertEqual(sum_sampled, 100) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py index 9affdc28d..bf1ecdde1 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_aws_xray_sampling_client.py @@ -103,3 +103,32 @@ def validate_match_sampling_rules_properties_with_records(self, sampling_rules, self.assertEqual(sampling_rule.URLPath, sampling_record["SamplingRule"]["URLPath"]) self.assertIsNotNone(sampling_rule.Version) self.assertEqual(sampling_rule.Version, sampling_record["SamplingRule"]["Version"]) + + @patch("requests.post") + def test_get_sampling_targets(self, mock_post=None): + with open(f"{DATA_DIR}/get-sampling-targets-response-sample.json", encoding="UTF-8") as file: + sample_response = json.load(file) + mock_post.return_value.configure_mock(**{"json.return_value": sample_response}) + file.close() + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + sampling_targets_response = client.get_sampling_targets(statistics=[]) + self.assertEqual(len(sampling_targets_response.SamplingTargetDocuments), 2) + self.assertEqual(len(sampling_targets_response.UnprocessedStatistics), 0) + self.assertEqual(sampling_targets_response.LastRuleModification, 1707551387.0) + + @patch("requests.post") + def test_get_invalid_sampling_targets(self, mock_post=None): + mock_post.return_value.configure_mock( + **{ + "json.return_value": { + "LastRuleModification": None, + "SamplingTargetDocuments": None, + "UnprocessedStatistics": None, + } + } + ) + client = _AwsXRaySamplingClient("http://127.0.0.1:2000") + sampling_targets_response = client.get_sampling_targets(statistics=[]) + self.assertEqual(sampling_targets_response.SamplingTargetDocuments, []) + self.assertEqual(sampling_targets_response.UnprocessedStatistics, []) + self.assertEqual(sampling_targets_response.LastRuleModification, 0.0) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_clock.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_clock.py new file mode 100644 index 000000000..fc7740799 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_clock.py @@ -0,0 +1,17 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from unittest import TestCase + +from amazon.opentelemetry.distro.sampler._clock import _Clock + + +class TestClock(TestCase): + def test_from_timestamp(self): + pass + + def test_time_delta(self): + clock = _Clock() + dt = clock.from_timestamp(1707551387.0) + delta = clock.time_delta(3600) + new_dt = dt + delta + self.assertTrue(new_dt.timestamp() - dt.timestamp() == 3600) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_fallback_sampler.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_fallback_sampler.py new file mode 100644 index 000000000..44c5d7891 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_fallback_sampler.py @@ -0,0 +1,86 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import datetime +from unittest import TestCase + +from mock_clock import MockClock + +from amazon.opentelemetry.distro.sampler._fallback_sampler import _FallbackSampler +from opentelemetry.sdk.trace.sampling import ALWAYS_OFF, Decision + + +class TestRateLimitingSampler(TestCase): + # pylint: disable=too-many-branches + def test_should_sample(self): + time_now = datetime.datetime.fromtimestamp(1707551387.0) + clock = MockClock(time_now) + sampler = _FallbackSampler(clock) + # Ignore testing TraceIdRatioBased + sampler._FallbackSampler__fixed_rate_sampler = ALWAYS_OFF + + sampler.should_sample(None, 1234, "name") + + # Essentially the same tests as test_rate_limiter.py + + # 0 seconds passed, 0 quota available + sampled = 0 + for _ in range(0, 30): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 0) + + # 0.4 seconds passed, 0.4 quota available + sampled = 0 + clock.add_time(0.4) + for _ in range(0, 30): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 0) + + # 0.8 seconds passed, 0.8 quota available + sampled = 0 + clock.add_time(0.4) + for _ in range(0, 30): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 0) + + # 1.2 seconds passed, 1 quota consumed, 0 quota available + sampled = 0 + clock.add_time(0.4) + for _ in range(0, 30): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 1) + + # 1.6 seconds passed, 0.4 quota available + sampled = 0 + clock.add_time(0.4) + for _ in range(0, 30): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 0) + + # 2.0 seconds passed, 0.8 quota available + sampled = 0 + clock.add_time(0.4) + for _ in range(0, 30): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 0) + + # 2.4 seconds passed, one more quota consumed, 0 quota available + sampled = 0 + clock.add_time(0.4) + for _ in range(0, 30): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 1) + + # 30 seconds passed, only one quota can be consumed + sampled = 0 + clock.add_time(100) + for _ in range(0, 30): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 1) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rate_limiter.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rate_limiter.py new file mode 100644 index 000000000..e146d53c9 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rate_limiter.py @@ -0,0 +1,35 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import datetime +from unittest import TestCase + +from mock_clock import MockClock + +from amazon.opentelemetry.distro.sampler._rate_limiter import _RateLimiter + + +class TestRateLimiter(TestCase): + def test_try_spend(self): + time_now = datetime.datetime.fromtimestamp(1707551387.0) + clock = MockClock(time_now) + rate_limiter = _RateLimiter(1, 30, clock) + + spent = 0 + for _ in range(0, 100): + if rate_limiter.try_spend(1): + spent += 1 + self.assertEqual(spent, 0) + + spent = 0 + clock.add_time(0.5) + for _ in range(0, 100): + if rate_limiter.try_spend(1): + spent += 1 + self.assertEqual(spent, 15) + + spent = 0 + clock.add_time(1000) + for _ in range(0, 100): + if rate_limiter.try_spend(1): + spent += 1 + self.assertEqual(spent, 30) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rate_limiting_sampler.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rate_limiting_sampler.py new file mode 100644 index 000000000..c77b2bcba --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rate_limiting_sampler.py @@ -0,0 +1,83 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import datetime +from unittest import TestCase + +from mock_clock import MockClock + +from amazon.opentelemetry.distro.sampler._rate_limiting_sampler import _RateLimitingSampler +from opentelemetry.sdk.trace.sampling import Decision + + +class TestRateLimitingSampler(TestCase): + def test_should_sample(self): + time_now = datetime.datetime.fromtimestamp(1707551387.0) + clock = MockClock(time_now) + sampler = _RateLimitingSampler(30, clock) + + # Essentially the same tests as test_rate_limiter.py + sampled = 0 + for _ in range(0, 100): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 0) + + sampled = 0 + clock.add_time(0.5) + for _ in range(0, 100): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 15) + + sampled = 0 + clock.add_time(1.0) + for _ in range(0, 100): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 30) + + sampled = 0 + clock.add_time(2.5) + for _ in range(0, 100): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 30) + + sampled = 0 + clock.add_time(1000) + for _ in range(0, 100): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 30) + + def test_should_sample_with_quota_of_one(self): + time_now = datetime.datetime.fromtimestamp(1707551387.0) + clock = MockClock(time_now) + sampler = _RateLimitingSampler(1, clock) + + sampled = 0 + for _ in range(0, 50): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 0) + + sampled = 0 + clock.add_time(0.5) + for _ in range(0, 50): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 0) + + sampled = 0 + clock.add_time(0.5) + for _ in range(0, 50): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 1) + + sampled = 0 + clock.add_time(1000) + for _ in range(0, 50): + if sampler.should_sample(None, 1234, "name").decision != Decision.DROP: + sampled += 1 + self.assertEqual(sampled, 1) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule_cache.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule_cache.py index 3adeb7efe..a70b3beee 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule_cache.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_rule_cache.py @@ -4,15 +4,23 @@ from threading import Lock from unittest import TestCase +from mock_clock import MockClock + +from amazon.opentelemetry.distro.sampler._clock import _Clock from amazon.opentelemetry.distro.sampler._rule_cache import CACHE_TTL_SECONDS, _RuleCache from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule +from amazon.opentelemetry.distro.sampler._sampling_rule_applier import _SamplingRuleApplier +from amazon.opentelemetry.distro.sampler._sampling_statistics_document import _SamplingStatisticsDocument +from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTargetResponse from opentelemetry.sdk.resources import Resource +CLIENT_ID = "12345678901234567890abcd" + # pylint: disable=no-member class TestRuleCache(TestCase): def test_cache_update_rules_and_sorts_rules(self): - cache = _RuleCache(None, None, datetime, Lock()) + cache = _RuleCache(None, None, CLIENT_ID, _Clock(), Lock()) self.assertTrue(len(cache._RuleCache__rule_appliers) == 0) rule1 = _SamplingRule(Priority=200, RuleName="only_one_rule", Version=1) @@ -39,7 +47,7 @@ def test_cache_update_rules_and_sorts_rules(self): def test_rule_cache_expiration_logic(self): dt = datetime - cache = _RuleCache(None, Resource.get_empty(), dt, Lock()) + cache = _RuleCache(None, Resource.get_empty(), CLIENT_ID, _Clock(), Lock()) self.assertFalse(cache.expired()) cache._last_modified = dt.datetime.now() - dt.timedelta(seconds=CACHE_TTL_SECONDS - 5) self.assertFalse(cache.expired()) @@ -47,8 +55,7 @@ def test_rule_cache_expiration_logic(self): self.assertTrue(cache.expired()) def test_update_cache_with_only_one_rule_changed(self): - dt = datetime - cache = _RuleCache(None, Resource.get_empty(), dt, Lock()) + cache = _RuleCache(None, Resource.get_empty(), CLIENT_ID, _Clock(), Lock()) rule1 = _SamplingRule(Priority=1, RuleName="abcdef", Version=1) rule2 = _SamplingRule(Priority=10, RuleName="ab", Version=1) rule3 = _SamplingRule(Priority=100, RuleName="Abc", Version=1) @@ -72,7 +79,7 @@ def test_update_cache_with_only_one_rule_changed(self): self.assertTrue(cache_rules_copy[2] is not cache._RuleCache__rule_appliers[1]) def test_update_rules_removes_older_rule(self): - cache = _RuleCache(None, None, datetime, Lock()) + cache = _RuleCache(None, None, CLIENT_ID, _Clock(), Lock()) self.assertTrue(len(cache._RuleCache__rule_appliers) == 0) rule1 = _SamplingRule(Priority=200, RuleName="first_rule", Version=1) @@ -86,3 +93,142 @@ def test_update_rules_removes_older_rule(self): cache.update_sampling_rules(rules) self.assertTrue(len(cache._RuleCache__rule_appliers) == 1) self.assertEqual(cache._RuleCache__rule_appliers[0].sampling_rule.RuleName, "second_rule") + + def test_update_sampling_targets(self): + sampling_rule_1 = _SamplingRule( + Attributes={}, + FixedRate=0.05, + HTTPMethod="*", + Host="*", + Priority=10000, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/default", + RuleName="default", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + sampling_rule_2 = _SamplingRule( + Attributes={}, + FixedRate=0.20, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=10, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + time_now = datetime.datetime.fromtimestamp(1707551387.0) + mock_clock = MockClock(time_now) + + rule_cache = _RuleCache(Resource.get_empty(), None, "", mock_clock, Lock()) + rule_cache.update_sampling_rules([sampling_rule_1, sampling_rule_2]) + + # quota should be 1 because of borrowing=true until targets are updated + rule_applier_0 = rule_cache._RuleCache__rule_appliers[0] + self.assertEqual( + rule_applier_0._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, 1 + ) + self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._root._rate, sampling_rule_2.FixedRate) + + rule_applier_1 = rule_cache._RuleCache__rule_appliers[1] + self.assertEqual( + rule_applier_1._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, 1 + ) + self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._root._rate, sampling_rule_1.FixedRate) + + target_1 = { + "FixedRate": 0.05, + "Interval": 15, + "ReservoirQuota": 1, + "ReservoirQuotaTTL": mock_clock.now().timestamp() + 10, + "RuleName": "default", + } + target_2 = { + "FixedRate": 0.15, + "Interval": 12, + "ReservoirQuota": 5, + "ReservoirQuotaTTL": mock_clock.now().timestamp() + 10, + "RuleName": "test", + } + target_3 = { + "FixedRate": 0.15, + "Interval": 3, + "ReservoirQuota": 5, + "ReservoirQuotaTTL": mock_clock.now().timestamp() + 10, + "RuleName": "associated rule does not exist", + } + target_response = _SamplingTargetResponse(mock_clock.now().timestamp() - 10, [target_1, target_2, target_3], []) + refresh_rules, min_polling_interval = rule_cache.update_sampling_targets(target_response) + self.assertFalse(refresh_rules) + # target_3 Interval is ignored since it's not associated with a Rule Applier + self.assertEqual(min_polling_interval, target_2["Interval"]) + + # still only 2 rule appliers should exist if for some reason 3 targets are obtained + self.assertEqual(len(rule_cache._RuleCache__rule_appliers), 2) + + # borrowing=false, use quota from targets + rule_applier_0 = rule_cache._RuleCache__rule_appliers[0] + self.assertEqual( + rule_applier_0._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, + target_2["ReservoirQuota"], + ) + self.assertEqual(rule_applier_0._SamplingRuleApplier__fixed_rate_sampler._root._rate, target_2["FixedRate"]) + + rule_applier_1 = rule_cache._RuleCache__rule_appliers[1] + self.assertEqual( + rule_applier_1._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, + target_1["ReservoirQuota"], + ) + self.assertEqual(rule_applier_1._SamplingRuleApplier__fixed_rate_sampler._root._rate, target_1["FixedRate"]) + + # Test target response modified after Rule cache's last modified date + target_response.LastRuleModification = mock_clock.now().timestamp() + 1 + refresh_rules, _ = rule_cache.update_sampling_targets(target_response) + self.assertTrue(refresh_rules) + + def test_get_all_statistics(self): + time_now = datetime.datetime.fromtimestamp(1707551387.0) + mock_clock = MockClock(time_now) + rule_applier_1 = _SamplingRuleApplier(_SamplingRule(RuleName="test"), CLIENT_ID, mock_clock) + rule_applier_2 = _SamplingRuleApplier(_SamplingRule(RuleName="default"), CLIENT_ID, mock_clock) + + rule_applier_1._SamplingRuleApplier__statistics = _SamplingStatisticsDocument(CLIENT_ID, "test", 4, 2, 2) + rule_applier_2._SamplingRuleApplier__statistics = _SamplingStatisticsDocument(CLIENT_ID, "default", 5, 5, 5) + + rule_cache = _RuleCache(Resource.get_empty(), None, "", mock_clock, Lock()) + rule_cache._RuleCache__rule_appliers = [rule_applier_1, rule_applier_2] + + mock_clock.add_time(10) + statistics = rule_cache.get_all_statistics() + + self.assertEqual( + statistics, + [ + { + "ClientID": CLIENT_ID, + "RuleName": "test", + "Timestamp": mock_clock.now().timestamp(), + "RequestCount": 4, + "BorrowCount": 2, + "SampleCount": 2, + }, + { + "ClientID": CLIENT_ID, + "RuleName": "default", + "Timestamp": mock_clock.now().timestamp(), + "RequestCount": 5, + "BorrowCount": 5, + "SampleCount": 5, + }, + ], + ) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_rule_applier.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_rule_applier.py index 5a4dc383e..85123843d 100644 --- a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_rule_applier.py +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_rule_applier.py @@ -1,12 +1,20 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 +import datetime import json import os from unittest import TestCase +from unittest.mock import patch +from mock_clock import MockClock + +from amazon.opentelemetry.distro.sampler._clock import _Clock +from amazon.opentelemetry.distro.sampler._rate_limiting_sampler import _RateLimitingSampler from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule from amazon.opentelemetry.distro.sampler._sampling_rule_applier import _SamplingRuleApplier +from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTarget from opentelemetry.sdk.resources import Resource +from opentelemetry.sdk.trace.sampling import Decision, SamplingResult, TraceIdRatioBased from opentelemetry.semconv.resource import ResourceAttributes from opentelemetry.semconv.trace import SpanAttributes from opentelemetry.util.types import Attributes @@ -14,7 +22,10 @@ TEST_DIR = os.path.dirname(os.path.realpath(__file__)) DATA_DIR = os.path.join(TEST_DIR, "data") +CLIENT_ID = "12345678901234567890abcd" + +# pylint: disable=no-member class TestSamplingRuleApplier(TestCase): def test_applier_attribute_matching_from_xray_response(self): default_rule = None @@ -40,7 +51,18 @@ def test_applier_attribute_matching_from_xray_response(self): "abc": "1234", } - rule_applier = _SamplingRuleApplier(default_rule) + rule_applier = _SamplingRuleApplier(default_rule, CLIENT_ID, _Clock()) + self.assertTrue(rule_applier.matches(res, attr)) + + # Test again using deprecated Span Attributes + attr: Attributes = { + SpanAttributes.HTTP_TARGET: "target", + SpanAttributes.HTTP_METHOD: "method", + SpanAttributes.HTTP_URL: "url", + SpanAttributes.HTTP_HOST: "host", + "foo": "bar", + "abc": "1234", + } self.assertTrue(rule_applier.matches(res, attr)) def test_applier_matches_with_all_attributes(self): @@ -51,13 +73,13 @@ def test_applier_matches_with_all_attributes(self): Host="localhost", Priority=20, ReservoirSize=1, - # ResourceARN can only be "*" + # Note that ResourceARN is usually only able to be "*" # See: https://docs.aws.amazon.com/xray/latest/devguide/xray-console-sampling.html#xray-console-sampling-options # noqa: E501 - ResourceARN="*", + ResourceARN="arn:aws:lambda:us-west-2:123456789012:function:my-function", RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", RuleName="test", ServiceName="myServiceName", - ServiceType="AWS::EKS::Container", + ServiceType="AWS::Lambda::Function", URLPath="/helloworld", Version=1, ) @@ -65,19 +87,37 @@ def test_applier_matches_with_all_attributes(self): attributes: Attributes = { "server.address": "localhost", SpanAttributes.HTTP_REQUEST_METHOD: "GET", + SpanAttributes.CLOUD_RESOURCE_ID: "arn:aws:lambda:us-west-2:123456789012:function:my-function", "url.full": "http://127.0.0.1:5000/helloworld", "abc": "123", "def": "456", "ghi": "789", + # Test that deprecated attributes are not used in matching when above new attributes are set + "http.host": "deprecated and will not be used in matching", + SpanAttributes.HTTP_METHOD: "deprecated and will not be used in matching", + "faas.id": "deprecated and will not be used in matching", + "http.url": "deprecated and will not be used in matching", } resource_attr: Resource = { ResourceAttributes.SERVICE_NAME: "myServiceName", - ResourceAttributes.CLOUD_PLATFORM: "aws_eks", + ResourceAttributes.CLOUD_PLATFORM: "aws_lambda", # CloudPlatformValues.AWS_LAMBDA.value } resource = Resource.create(attributes=resource_attr) - rule_applier = _SamplingRuleApplier(sampling_rule) + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) + self.assertTrue(rule_applier.matches(resource, attributes)) + + # Test using deprecated Span Attributes + attributes: Attributes = { + "http.host": "localhost", + SpanAttributes.HTTP_METHOD: "GET", + "faas.id": "arn:aws:lambda:us-west-2:123456789012:function:my-function", + "http.url": "http://127.0.0.1:5000/helloworld", + "abc": "123", + "def": "456", + "ghi": "789", + } self.assertTrue(rule_applier.matches(resource, attributes)) def test_applier_wild_card_attributes_matches_span_attributes(self): @@ -119,7 +159,7 @@ def test_applier_wild_card_attributes_matches_span_attributes(self): "attr9": "Bye.World", } - rule_applier = _SamplingRuleApplier(sampling_rule) + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) self.assertTrue(rule_applier.matches(Resource.get_empty(), attributes)) def test_applier_wild_card_attributes_matches_http_span_attributes(self): @@ -145,7 +185,16 @@ def test_applier_wild_card_attributes_matches_http_span_attributes(self): SpanAttributes.URL_FULL: "http://127.0.0.1:5000/helloworld", } - rule_applier = _SamplingRuleApplier(sampling_rule) + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) + self.assertTrue(rule_applier.matches(Resource.get_empty(), attributes)) + + # Test using deprecated Span Attributes + attributes: Attributes = { + SpanAttributes.HTTP_HOST: "localhost", + SpanAttributes.HTTP_METHOD: "GET", + SpanAttributes.HTTP_URL: "http://127.0.0.1:5000/helloworld", + } + self.assertTrue(rule_applier.matches(Resource.get_empty(), attributes)) def test_applier_wild_card_attributes_matches_with_empty_attributes(self): @@ -172,7 +221,7 @@ def test_applier_wild_card_attributes_matches_with_empty_attributes(self): } resource = Resource.create(attributes=resource_attr) - rule_applier = _SamplingRuleApplier(sampling_rule) + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) self.assertTrue(rule_applier.matches(resource, attributes)) self.assertTrue(rule_applier.matches(resource, None)) self.assertTrue(rule_applier.matches(Resource.get_empty(), attributes)) @@ -204,7 +253,7 @@ def test_applier_does_not_match_without_http_target(self): } resource = Resource.create(attributes=resource_attr) - rule_applier = _SamplingRuleApplier(sampling_rule) + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) self.assertFalse(rule_applier.matches(resource, attributes)) def test_applier_matches_with_http_target(self): @@ -231,7 +280,11 @@ def test_applier_matches_with_http_target(self): } resource = Resource.create(attributes=resource_attr) - rule_applier = _SamplingRuleApplier(sampling_rule) + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) + self.assertTrue(rule_applier.matches(resource, attributes)) + + # Test again using deprecated Span Attributes + attributes: Attributes = {SpanAttributes.HTTP_TARGET: "/helloworld"} self.assertTrue(rule_applier.matches(resource, attributes)) def test_applier_matches_with_span_attributes(self): @@ -252,7 +305,7 @@ def test_applier_matches_with_span_attributes(self): ) attributes: Attributes = { - "http.host": "localhost", + "server.address": "localhost", SpanAttributes.HTTP_REQUEST_METHOD: "GET", "url.full": "http://127.0.0.1:5000/helloworld", "abc": "123", @@ -266,7 +319,18 @@ def test_applier_matches_with_span_attributes(self): } resource = Resource.create(attributes=resource_attr) - rule_applier = _SamplingRuleApplier(sampling_rule) + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) + self.assertTrue(rule_applier.matches(resource, attributes)) + + # Test again using deprecated Span Attributes + attributes: Attributes = { + "http.host": "localhost", + SpanAttributes.HTTP_METHOD: "GET", + "http.url": "http://127.0.0.1:5000/helloworld", + "abc": "123", + "def": "456", + "ghi": "789", + } self.assertTrue(rule_applier.matches(resource, attributes)) def test_applier_does_not_match_with_less_span_attributes(self): @@ -299,5 +363,130 @@ def test_applier_does_not_match_with_less_span_attributes(self): } resource = Resource.create(attributes=resource_attr) - rule_applier = _SamplingRuleApplier(sampling_rule) + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, _Clock()) self.assertFalse(rule_applier.matches(resource, attributes)) + + def test_update_sampling_applier(self): + sampling_rule = _SamplingRule( + Attributes={}, + FixedRate=0.11, + HTTPMethod="*", + Host="*", + Priority=20, + ReservoirSize=1, + ResourceARN="*", + RuleARN="arn:aws:xray:us-east-1:999999999999:sampling-rule/test", + RuleName="test", + ServiceName="*", + ServiceType="*", + URLPath="*", + Version=1, + ) + + time_now = datetime.datetime.fromtimestamp(1707551387.0) + mock_clock = MockClock(time_now) + + rule_applier = _SamplingRuleApplier(sampling_rule, CLIENT_ID, mock_clock) + + self.assertEqual(rule_applier._SamplingRuleApplier__fixed_rate_sampler._root._rate, 0.11) + self.assertEqual( + rule_applier._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, 1 + ) + self.assertEqual(rule_applier._SamplingRuleApplier__reservoir_expiry, datetime.datetime.max) + + target = _SamplingTarget( + FixedRate=1.0, Interval=10, ReservoirQuota=30, ReservoirQuotaTTL=1707764006.0, RuleName="test" + ) + # Update rule applier + rule_applier = rule_applier.with_target(target) + + time_now = datetime.datetime.fromtimestamp(target.ReservoirQuotaTTL) + mock_clock.set_time(time_now) + + self.assertEqual(rule_applier._SamplingRuleApplier__fixed_rate_sampler._root._rate, 1.0) + self.assertEqual( + rule_applier._SamplingRuleApplier__reservoir_sampler._root._RateLimitingSampler__reservoir._quota, 30 + ) + self.assertEqual(rule_applier._SamplingRuleApplier__reservoir_expiry, mock_clock.now()) + + @staticmethod + def fake_reservoir_do_sample(*args, **kwargs): + return SamplingResult(decision=Decision.RECORD_AND_SAMPLE, attributes=None, trace_state=None) + + @staticmethod + def fake_ratio_do_sample(*args, **kwargs): + return SamplingResult(decision=Decision.RECORD_AND_SAMPLE, attributes=None, trace_state=None) + + @staticmethod + def fake_ratio_do_not_sample(*args, **kwargs): + return SamplingResult(decision=Decision.RECORD_AND_SAMPLE, attributes=None, trace_state=None) + + @patch.object(TraceIdRatioBased, "should_sample", fake_ratio_do_sample) + @patch.object(_RateLimitingSampler, "should_sample", fake_reservoir_do_sample) + def test_populate_and_get_then_reset_statistics(self): + mock_clock = MockClock() + rule_applier = _SamplingRuleApplier(_SamplingRule(RuleName="test", ReservoirSize=10), CLIENT_ID, mock_clock) + rule_applier.should_sample(None, 0, "name") + rule_applier.should_sample(None, 0, "name") + rule_applier.should_sample(None, 0, "name") + + statistics = rule_applier.get_then_reset_statistics() + + self.assertEqual(statistics["ClientID"], CLIENT_ID) + self.assertEqual(statistics["RuleName"], "test") + self.assertEqual(statistics["Timestamp"], mock_clock.now().timestamp()) + self.assertEqual(statistics["RequestCount"], 3) + self.assertEqual(statistics["BorrowCount"], 3) + self.assertEqual(statistics["SampleCount"], 3) + self.assertEqual(rule_applier._SamplingRuleApplier__statistics.RequestCount, 0) + self.assertEqual(rule_applier._SamplingRuleApplier__statistics.BorrowCount, 0) + self.assertEqual(rule_applier._SamplingRuleApplier__statistics.SampleCount, 0) + + def test_should_sample_logic_from_reservoir(self): + reservoir_size = 10 + time_now = datetime.datetime.fromtimestamp(1707551387.0) + mock_clock = MockClock(time_now) + rule_applier = _SamplingRuleApplier( + _SamplingRule(RuleName="test", ReservoirSize=reservoir_size, FixedRate=0.0), CLIENT_ID, mock_clock + ) + + mock_clock.add_time(seconds=2.0) + sampled_count = 0 + for _ in range(0, reservoir_size + 10): + if rule_applier.should_sample(None, 0, "name").decision != Decision.DROP: + sampled_count += 1 + self.assertEqual(sampled_count, 1) + # borrow means only 1 sampled + + target = _SamplingTarget( + FixedRate=0.0, + Interval=10, + ReservoirQuota=10, + ReservoirQuotaTTL=mock_clock.now().timestamp() + 10, + RuleName="test", + ) + rule_applier = rule_applier.with_target(target) + + # Use only 100% of quota (10 out of 10), even if 2 seconds have passed + mock_clock.add_time(seconds=2.0) + sampled_count = 0 + for _ in range(0, reservoir_size + 10): + if rule_applier.should_sample(None, 0, "name").decision != Decision.DROP: + sampled_count += 1 + self.assertEqual(sampled_count, reservoir_size) + + # Use only 50% of quota (5 out of 10) + mock_clock.add_time(seconds=0.5) + sampled_count = 0 + for _ in range(0, reservoir_size + 10): + if rule_applier.should_sample(None, 0, "name").decision != Decision.DROP: + sampled_count += 1 + self.assertEqual(sampled_count, 5) + + # Expired at 10s, do not sample + mock_clock.add_time(seconds=7.5) + sampled_count = 0 + for _ in range(0, reservoir_size + 10): + if rule_applier.should_sample(None, 0, "name").decision != Decision.DROP: + sampled_count += 1 + self.assertEqual(sampled_count, 0) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_statistics_document.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_statistics_document.py new file mode 100644 index 000000000..bf147ff1b --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_statistics_document.py @@ -0,0 +1,34 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +import datetime +from unittest import TestCase + +from mock_clock import MockClock + +from amazon.opentelemetry.distro.sampler._sampling_statistics_document import _SamplingStatisticsDocument + + +class TestSamplingStatisticsDocument(TestCase): + def test_sampling_statistics_document_inputs(self): + statistics = _SamplingStatisticsDocument("", "") + self.assertEqual(statistics.ClientID, "") + self.assertEqual(statistics.RuleName, "") + self.assertEqual(statistics.BorrowCount, 0) + self.assertEqual(statistics.SampleCount, 0) + self.assertEqual(statistics.RequestCount, 0) + + statistics = _SamplingStatisticsDocument("client_id", "rule_name", 1, 2, 3) + self.assertEqual(statistics.ClientID, "client_id") + self.assertEqual(statistics.RuleName, "rule_name") + self.assertEqual(statistics.RequestCount, 1) + self.assertEqual(statistics.BorrowCount, 2) + self.assertEqual(statistics.SampleCount, 3) + + clock = MockClock(datetime.datetime.fromtimestamp(1707551387.0)) + snapshot = statistics.snapshot(clock) + self.assertEqual(snapshot.get("ClientID"), "client_id") + self.assertEqual(snapshot.get("RuleName"), "rule_name") + self.assertEqual(snapshot.get("Timestamp"), 1707551387.0) + self.assertEqual(snapshot.get("RequestCount"), 1) + self.assertEqual(snapshot.get("BorrowCount"), 2) + self.assertEqual(snapshot.get("SampleCount"), 3) diff --git a/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_target.py b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_target.py new file mode 100644 index 000000000..e22b0e0a7 --- /dev/null +++ b/aws-opentelemetry-distro/tests/amazon/opentelemetry/distro/sampler/test_sampling_target.py @@ -0,0 +1,32 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 +from unittest import TestCase + +from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTargetResponse + + +class TestSamplingTarget(TestCase): + def test_sampling_target_response_with_none_inputs(self): + target_response = _SamplingTargetResponse(None, None, None) + self.assertEqual(target_response.LastRuleModification, 0.0) + self.assertEqual(target_response.SamplingTargetDocuments, []) + self.assertEqual(target_response.UnprocessedStatistics, []) + + def test_sampling_target_response_with_invalid_inputs(self): + target_response = _SamplingTargetResponse(1.0, [{}], [{}]) + self.assertEqual(target_response.LastRuleModification, 1.0) + self.assertEqual(len(target_response.SamplingTargetDocuments), 1) + self.assertEqual(target_response.SamplingTargetDocuments[0].FixedRate, 0) + self.assertEqual(target_response.SamplingTargetDocuments[0].Interval, None) + self.assertEqual(target_response.SamplingTargetDocuments[0].ReservoirQuota, None) + self.assertEqual(target_response.SamplingTargetDocuments[0].ReservoirQuotaTTL, None) + self.assertEqual(target_response.SamplingTargetDocuments[0].RuleName, "") + + self.assertEqual(len(target_response.UnprocessedStatistics), 1) + self.assertEqual(target_response.UnprocessedStatistics[0].ErrorCode, "") + self.assertEqual(target_response.UnprocessedStatistics[0].Message, "") + self.assertEqual(target_response.UnprocessedStatistics[0].RuleName, "") + + target_response = _SamplingTargetResponse(1.0, [{"foo": "bar"}], [{"dog": "cat"}]) + self.assertEqual(len(target_response.SamplingTargetDocuments), 0) + self.assertEqual(len(target_response.UnprocessedStatistics), 0)