Skip to content

Commit

Permalink
AWS X-Ray Remote Sampler Part 3 - rate limiter logic and get sampling…
Browse files Browse the repository at this point in the history
… targets (#55)

*Issue #, if available:*
This is the 3rd and final part of the X-Ray Remote Sampler
implementation for Python
[See Part
2](#47)

*Description of changes:*
- Added logic to fetch sampling targets for each sampling rule applier.
- The sampling targets are periodically fetched every 10 seconds by
making the [GetSamplingTargets API call to
X-Ray](https://docs.aws.amazon.com/xray/latest/api/API_GetSamplingTargets.html).
- The targets determine the reservoir quota and the rate at which a
sampling rule applier will sample the requests.
- Each rule applier keeps and updates a sampling statistics document
which is required in `GetSamplingTargets` call to determine the next
target
- Added the rate limiting and fixed rate samplers to be used in each
rule applier.
- Together these sampler determine how many requests to sample every
second and what percentage of additional requests to sample in that
second.
- The FallbackSampler is a combination of above samplers to sample 1
req/sec and 5% of additional requests in that second.

*Testing:*
Unit Tests and Remote Sampling Testbed


## **Testbed:**
1. Have XRay Daemon running or OTel collector with XRay Proxy Client
setup running. Ensure AWS credentials used has only default rule with 5%
sampling and 1 req/s

2. Checkout this PR Branch, and install `pip3 install
aws-opentelemetry-distro/`

3. Download this Python Sample App:
https://github.com/jj22ee/aws-otel-community/tree/python-sample/centralized-sampling-tests/sample-apps/python-flask
Install python3 requirements.txt
Replace the following:
```
###
### Set sampler HERE
###
```
with:
```
from amazon.opentelemetry.distro.sampler.aws_xray_remote_sampler import AwsXRayRemoteSampler
xray_sampler = AwsXRayRemoteSampler(resource, polling_interval=10)
trace.set_tracer_provider(TracerProvider(sampler=xray_sampler))
```
Run Python Sample app with `python3 app.py`

4. Download this repository:
https://github.com/aws-observability/aws-otel-community/
Within directory `centralized-sampling-tests/` run: `./gradlew
:integration-tests:run`
Check that the tests have passed.

TODO: wire in the remote sampler to the ADOT Python customizer in this
PR or a new PR

By submitting this pull request, I confirm that you can use, modify,
copy, and redistribute this contribution, under the terms of your
choice.
  • Loading branch information
jj22ee authored Feb 16, 2024
1 parent ec3011e commit 453a3d5
Show file tree
Hide file tree
Showing 25 changed files with 1,428 additions and 86 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import requests

from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule
from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTargetResponse

_logger = getLogger(__name__)

Expand All @@ -19,6 +20,7 @@ def __init__(self, endpoint: str = None, log_level: str = None):
if endpoint is None:
_logger.error("endpoint must be specified")
self.__get_sampling_rules_endpoint = endpoint + "/GetSamplingRules"
self.__get_sampling_targets_endpoint = endpoint + "/SamplingTargets"

def get_sampling_rules(self) -> [_SamplingRule]:
sampling_rules = []
Expand All @@ -30,12 +32,11 @@ def get_sampling_rules(self) -> [_SamplingRule]:
_logger.error("GetSamplingRules response is None")
return []
sampling_rules_response = xray_response.json()
if "SamplingRuleRecords" not in sampling_rules_response:
if sampling_rules_response is None or "SamplingRuleRecords" not in sampling_rules_response:
_logger.error(
"SamplingRuleRecords is missing in getSamplingRules response: %s", sampling_rules_response
)
return []

sampling_rules_records = sampling_rules_response["SamplingRuleRecords"]
for record in sampling_rules_records:
if "SamplingRule" not in record:
Expand All @@ -47,5 +48,43 @@ def get_sampling_rules(self) -> [_SamplingRule]:
_logger.error("Request error occurred: %s", req_err)
except json.JSONDecodeError as json_err:
_logger.error("Error in decoding JSON response: %s", json_err)
# pylint: disable=broad-exception-caught
except Exception as err:
_logger.error("Error occurred when attempting to fetch rules: %s", err)

return sampling_rules

def get_sampling_targets(self, statistics: [dict]) -> _SamplingTargetResponse:
sampling_targets_response = _SamplingTargetResponse(
LastRuleModification=None, SamplingTargetDocuments=None, UnprocessedStatistics=None
)
headers = {"content-type": "application/json"}
try:
xray_response = requests.post(
url=self.__get_sampling_targets_endpoint,
headers=headers,
timeout=20,
json={"SamplingStatisticsDocuments": statistics},
)
if xray_response is None:
_logger.debug("GetSamplingTargets response is None. Unable to update targets.")
return sampling_targets_response
xray_response_json = xray_response.json()
if (
xray_response_json is None
or "SamplingTargetDocuments" not in xray_response_json
or "LastRuleModification" not in xray_response_json
):
_logger.debug("getSamplingTargets response is invalid. Unable to update targets.")
return sampling_targets_response

sampling_targets_response = _SamplingTargetResponse(**xray_response_json)
except requests.exceptions.RequestException as req_err:
_logger.debug("Request error occurred: %s", req_err)
except json.JSONDecodeError as json_err:
_logger.debug("Error in decoding JSON response: %s", json_err)
# pylint: disable=broad-exception-caught
except Exception as err:
_logger.debug("Error occurred when attempting to fetch targets: %s", err)

return sampling_targets_response
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

import datetime


class _Clock:
def __init__(self):
self.__datetime = datetime.datetime

def now(self) -> datetime.datetime:
return self.__datetime.now()

# pylint: disable=no-self-use
def from_timestamp(self, timestamp: float) -> datetime.datetime:
return datetime.datetime.fromtimestamp(timestamp)

def time_delta(self, seconds: float) -> datetime.timedelta:
return datetime.timedelta(seconds=seconds)

def max(self) -> datetime.datetime:
return datetime.datetime.max
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,18 @@
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Sequence

from amazon.opentelemetry.distro.sampler._clock import _Clock
from amazon.opentelemetry.distro.sampler._rate_limiting_sampler import _RateLimitingSampler
from opentelemetry.context import Context
from opentelemetry.sdk.trace.sampling import ALWAYS_ON, Sampler, SamplingResult, TraceIdRatioBased
from opentelemetry.sdk.trace.sampling import Decision, Sampler, SamplingResult, TraceIdRatioBased
from opentelemetry.trace import Link, SpanKind
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes


class _FallbackSampler(Sampler):
def __init__(self):
# TODO: Add Reservoir sampler
# pylint: disable=unused-private-member
def __init__(self, clock: _Clock):
self.__rate_limiting_sampler = _RateLimitingSampler(1, clock)
self.__fixed_rate_sampler = TraceIdRatioBased(0.05)

# pylint: disable=no-self-use
Expand All @@ -26,8 +27,12 @@ def should_sample(
links: Sequence[Link] = None,
trace_state: TraceState = None,
) -> SamplingResult:
# TODO: add reservoir + fixed rate sampling
return ALWAYS_ON.should_sample(
sampling_result = self.__rate_limiting_sampler.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)
if sampling_result.decision is not Decision.DROP:
return sampling_result
return self.__fixed_rate_sampler.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from decimal import Decimal
from threading import Lock

from amazon.opentelemetry.distro.sampler._clock import _Clock


class _RateLimiter:
def __init__(self, max_balance_in_seconds: int, quota: int, clock: _Clock):
# max_balance_in_seconds is usually 1
# pylint: disable=invalid-name
self.MAX_BALANCE_MILLIS = Decimal(max_balance_in_seconds * 1000.0)
self._clock = clock

self._quota = Decimal(quota)
self.__wallet_floor_millis = Decimal(self._clock.now().timestamp() * 1000.0)
# current "wallet_balance" would be ceiling - floor

self.__lock = Lock()

def try_spend(self, cost: float) -> bool:
if self._quota == 0:
return False

quota_per_millis = self._quota / Decimal(1000.0)

# assume divide by zero not possible
cost_in_millis = Decimal(cost) / quota_per_millis

with self.__lock:
wallet_ceiling_millis = Decimal(self._clock.now().timestamp() * 1000.0)
current_balance_millis = wallet_ceiling_millis - self.__wallet_floor_millis
if current_balance_millis > self.MAX_BALANCE_MILLIS:
current_balance_millis = self.MAX_BALANCE_MILLIS

pending_remaining_balance_millis = current_balance_millis - cost_in_millis
if pending_remaining_balance_millis >= 0:
self.__wallet_floor_millis = wallet_ceiling_millis - pending_remaining_balance_millis
return True
# No changes to the wallet state
return False
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
from typing import Optional, Sequence

from amazon.opentelemetry.distro.sampler._clock import _Clock
from amazon.opentelemetry.distro.sampler._rate_limiter import _RateLimiter
from opentelemetry.context import Context
from opentelemetry.sdk.trace.sampling import Decision, Sampler, SamplingResult
from opentelemetry.trace import Link, SpanKind
from opentelemetry.trace.span import TraceState
from opentelemetry.util.types import Attributes


class _RateLimitingSampler(Sampler):
def __init__(self, quota: int, clock: _Clock):
self.__quota = quota
self.__reservoir = _RateLimiter(1, quota, clock)

# pylint: disable=no-self-use
def should_sample(
self,
parent_context: Optional[Context],
trace_id: int,
name: str,
kind: SpanKind = None,
attributes: Attributes = None,
links: Sequence[Link] = None,
trace_state: TraceState = None,
) -> SamplingResult:
if self.__reservoir.try_spend(1):
return SamplingResult(decision=Decision.RECORD_AND_SAMPLE, attributes=attributes, trace_state=trace_state)
return SamplingResult(decision=Decision.DROP, attributes=attributes, trace_state=trace_state)

# pylint: disable=no-self-use
def get_description(self) -> str:
description = (
"RateLimitingSampler{rate limiting sampling with sampling config of "
+ self.__quota
+ " req/sec and 0% of additional requests}"
)
return description
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import datetime
from logging import getLogger
from threading import Lock
from typing import Optional, Sequence
from typing import Dict, Optional, Sequence

from amazon.opentelemetry.distro.sampler._clock import _Clock
from amazon.opentelemetry.distro.sampler._fallback_sampler import _FallbackSampler
from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule
from amazon.opentelemetry.distro.sampler._sampling_rule_applier import _SamplingRuleApplier
from amazon.opentelemetry.distro.sampler._sampling_target import _SamplingTarget, _SamplingTargetResponse
from opentelemetry.context import Context
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace.sampling import SamplingResult
Expand All @@ -18,16 +19,20 @@
_logger = getLogger(__name__)

CACHE_TTL_SECONDS = 3600
DEFAULT_TARGET_POLLING_INTERVAL_SECONDS = 10


class _RuleCache:
def __init__(self, resource: Resource, fallback_sampler: _FallbackSampler, date_time: datetime, lock: Lock):
def __init__(
self, resource: Resource, fallback_sampler: _FallbackSampler, client_id: str, clock: _Clock, lock: Lock
):
self.__client_id = client_id
self.__rule_appliers: [_SamplingRuleApplier] = []
self.__cache_lock = lock
self.__resource = resource
self._fallback_sampler = fallback_sampler
self._date_time = date_time
self._last_modified = self._date_time.datetime.now()
self._clock = clock
self._last_modified = self._clock.now()

def should_sample(
self,
Expand All @@ -39,6 +44,7 @@ def should_sample(
links: Sequence[Link] = None,
trace_state: TraceState = None,
) -> SamplingResult:
rule_applier: _SamplingRuleApplier
for rule_applier in self.__rule_appliers:
if rule_applier.matches(self.__resource, attributes):
return rule_applier.should_sample(
Expand All @@ -51,6 +57,8 @@ def should_sample(
trace_state=trace_state,
)

_logger.debug("No sampling rules were matched")
# Should not ever reach fallback sampler as default rule is able to match
return self._fallback_sampler.should_sample(
parent_context, trace_id, name, kind=kind, attributes=attributes, links=links, trace_state=trace_state
)
Expand All @@ -65,28 +73,70 @@ def update_sampling_rules(self, new_sampling_rules: [_SamplingRule]) -> None:
if sampling_rule.Version != 1:
_logger.debug("sampling rule without Version 1 is not supported: RuleName: %s", sampling_rule.RuleName)
continue
temp_rule_appliers.append(_SamplingRuleApplier(sampling_rule))
temp_rule_appliers.append(_SamplingRuleApplier(sampling_rule, self.__client_id, self._clock))

self.__cache_lock.acquire()

# map list of rule appliers by each applier's sampling_rule name
rule_applier_map = {rule.sampling_rule.RuleName: rule for rule in self.__rule_appliers}
rule_applier_map: Dict[str, _SamplingRuleApplier] = {
applier.sampling_rule.RuleName: applier for applier in self.__rule_appliers
}

# If a sampling rule has not changed, keep its respective applier in the cache.
new_applier: _SamplingRuleApplier
for index, new_applier in enumerate(temp_rule_appliers):
rule_name_to_check = new_applier.sampling_rule.RuleName
if rule_name_to_check in rule_applier_map:
old_applier = rule_applier_map[rule_name_to_check]
if new_applier.sampling_rule == old_applier.sampling_rule:
temp_rule_appliers[index] = old_applier
self.__rule_appliers = temp_rule_appliers
self._last_modified = datetime.datetime.now()
self._last_modified = self._clock.now()

self.__cache_lock.release()

def update_sampling_targets(self, sampling_targets_response: _SamplingTargetResponse) -> (bool, int):
targets: [_SamplingTarget] = sampling_targets_response.SamplingTargetDocuments

with self.__cache_lock:
next_polling_interval = DEFAULT_TARGET_POLLING_INTERVAL_SECONDS
min_polling_interval = None

target_map: Dict[str, _SamplingTarget] = {target.RuleName: target for target in targets}

new_appliers = []
applier: _SamplingRuleApplier
for applier in self.__rule_appliers:
if applier.sampling_rule.RuleName in target_map:
target = target_map[applier.sampling_rule.RuleName]
new_appliers.append(applier.with_target(target))

if target.Interval is not None:
if min_polling_interval is None or min_polling_interval > target.Interval:
min_polling_interval = target.Interval
else:
new_appliers.append(applier)

self.__rule_appliers = new_appliers

if min_polling_interval is not None:
next_polling_interval = min_polling_interval

last_rule_modification = self._clock.from_timestamp(sampling_targets_response.LastRuleModification)
refresh_rules = last_rule_modification > self._last_modified

return (refresh_rules, next_polling_interval)

def get_all_statistics(self) -> [dict]:
all_statistics = []
applier: _SamplingRuleApplier
for applier in self.__rule_appliers:
all_statistics.append(applier.get_then_reset_statistics())
return all_statistics

def expired(self) -> bool:
self.__cache_lock.acquire()
try:
return datetime.datetime.now() > self._last_modified + datetime.timedelta(seconds=CACHE_TTL_SECONDS)
return self._clock.now() > self._last_modified + self._clock.time_delta(seconds=CACHE_TTL_SECONDS)
finally:
self.__cache_lock.release()
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def __init__(
self.URLPath = URLPath if URLPath is not None else ""
self.Version = Version if Version is not None else 0

def __lt__(self, other) -> bool:
def __lt__(self, other: "_SamplingRule") -> bool:
if self.Priority == other.Priority:
# String order priority example:
# "A","Abc","a","ab","abc","abcdef"
Expand Down
Loading

0 comments on commit 453a3d5

Please sign in to comment.