Skip to content

Commit

Permalink
use custom clock
Browse files Browse the repository at this point in the history
  • Loading branch information
jj22ee committed Feb 9, 2024
1 parent f3c26cb commit 77042d9
Show file tree
Hide file tree
Showing 8 changed files with 61 additions and 41 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import datetime


class _Clock:
def __init__(self):
self.__datetime = datetime.datetime

def now(self):
return self.__datetime.now()

def from_timestamp(self, timestamp: float):
return self.__datetime.fromtimestamp(timestamp)

def time_delta(self, seconds):
return datetime.timedelta(seconds=seconds)
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

from amazon.opentelemetry.distro.sampler._clock import _Clock
from amazon.opentelemetry.distro.sampler._reservoir_sampler import _ReservoirSampler
from opentelemetry.sdk.trace.sampling import ALWAYS_ON, Decision, Sampler, SamplingResult, TraceIdRatioBased


class _FallbackSampler(Sampler):
def __init__(self):
self.__reservoir_sampler = _ReservoirSampler(1)
def __init__(self, clock: _Clock):
self.__reservoir_sampler = _ReservoirSampler(1, clock)
self.__fixed_rate_sampler = TraceIdRatioBased(0.05)

# pylint: disable=no-self-use
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time as _time
from typing import Optional, Sequence

from amazon.opentelemetry.distro.sampler._clock import _Clock
from amazon.opentelemetry.distro.sampler._reservoir_wallet import _ReservoirWallet
from opentelemetry.context import Context
from opentelemetry.sdk.trace.sampling import Decision, Sampler, SamplingResult
Expand All @@ -14,12 +15,10 @@
class _ReservoirSampler:
def __init__(
self,
quota: int
quota: int,
clock: _Clock
):
self.quota = quota
self.quota_balance = 0

self.reservoir = _ReservoirWallet(1, quota, _time)
self.reservoir = _ReservoirWallet(1, quota, clock)

# pylint: disable=no-self-use
def should_sample(
Expand Down
Original file line number Diff line number Diff line change
@@ -1,19 +1,19 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import datetime
import time
from threading import Lock

from amazon.opentelemetry.distro.sampler._clock import _Clock


class _ReservoirWallet:
def __init__(self, max_balance_in_seconds: int, quota: int, _time: time):
def __init__(self, max_balance_in_seconds: int, quota: int, clock: _Clock):
# max_balance_in_seconds is usually 1
self.MAX_BALANCE_MILLIS = max_balance_in_seconds * 1000
self.clock = clock

# if income_rate is 5 samples per second, treat multiplier as 5 ms to spend per ms
self._quota = quota
self._time = _time

self.wallet_floor_millis = datetime.datetime.now().timestamp() * 1000
self.wallet_floor_millis = self.clock.now().timestamp() * 1000
# "wallet_ceiling_millis" would be current time.process_time_ns
# current "wallet_balance" would be ceiling - floor

Expand All @@ -27,7 +27,7 @@ def try_spend(self, cost: int, borrow: bool):

self.__lock.acquire()
try:
wallet_ceiling_millis = datetime.datetime.now().timestamp() * 1000
wallet_ceiling_millis = self.clock.now().timestamp() * 1000
current_balance_millis = wallet_ceiling_millis - self.wallet_floor_millis
if current_balance_millis > self.MAX_BALANCE_MILLIS:
current_balance_millis = self.MAX_BALANCE_MILLIS
Expand Down
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import datetime
from threading import Lock
from typing import Optional, Sequence

from amazon.opentelemetry.distro.sampler._clock import _Clock
from amazon.opentelemetry.distro.sampler._matcher import _Matcher, cloud_platform_mapping
from amazon.opentelemetry.distro.sampler._reservoir_sampler import _ReservoirSampler
from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule
Expand All @@ -20,9 +20,9 @@
DEFAULT_TARGET_POLLING_INTERVAL_SECONDS = 10

class _Rule:
def __init__(self, sampling_rule: _SamplingRule, client_id: str, date_time: datetime):
def __init__(self, sampling_rule: _SamplingRule, client_id: str, clock: _Clock):
self.client_id = client_id
self.date_time = date_time
self.clock = clock
self.sampling_rule = sampling_rule

self.statistics = _SamplingStatisticsDocument(self.client_id, self.sampling_rule.RuleName)
Expand All @@ -32,14 +32,13 @@ def __init__(self, sampling_rule: _SamplingRule, client_id: str, date_time: date
# TODO rename to rate limiter
# Initialize as borrowing if there will be a quota > 0
if self.sampling_rule.ReservoirSize > 0:
self.reservoir_sampler = _ReservoirSampler(1)
self.reservoir_sampler = _ReservoirSampler(1, self.clock)
self.borrowing = True
else:
self.reservoir_sampler = _ReservoirSampler(0)
self.reservoir_sampler = _ReservoirSampler(0, self.clock)
self.borrowing = False

# TODO add self.next_target_fetch_time from maybe time.process_time() or cache's datetime object
self.reservoir_expiry = self.date_time.datetime.now()
self.reservoir_expiry = self.clock.now()
self.polling_interval = DEFAULT_TARGET_POLLING_INTERVAL_SECONDS

def should_sample(
Expand All @@ -56,7 +55,7 @@ def should_sample(

# print("self.reservoir_expiry: %s", self.reservoir_expiry)
# print(type(self.reservoir_expiry))
reservoir_expired: bool = self.date_time.datetime.now() > self.reservoir_expiry
reservoir_expired: bool = self.clock.now() > self.reservoir_expiry
sampling_result = SamplingResult(decision=Decision.DROP, attributes=attributes, trace_state=trace_state)
if reservoir_expired:
self.borrowing = True
Expand Down Expand Up @@ -96,14 +95,18 @@ def should_sample(

def update_target(self, target):
print("%s ..... %s", target["RuleName"], target["ReservoirQuota"])
self.reservoir_sampler = _ReservoirSampler(target["ReservoirQuota"] if target["ReservoirQuota"] is not None else 0)
self.fixed_rate_sampler = TraceIdRatioBased(target["FixedRate"] if target["FixedRate"] is not None else 0)

new_quota = target["ReservoirQuota"] if target["ReservoirQuota"] is not None else 0
new_fixed_rate = target["FixedRate"] if target["FixedRate"] is not None else 0
self.reservoir_sampler = _ReservoirSampler(new_quota, self.clock)
self.fixed_rate_sampler = TraceIdRatioBased(new_fixed_rate)

if target["ReservoirQuotaTTL"] is not None:
self.reservoir_expiry = self.date_time.datetime.fromtimestamp(target["ReservoirQuotaTTL"])
self.reservoir_expiry = self.clock.from_timestamp(target["ReservoirQuotaTTL"])
# ^^^ test minus 4.5 minutes
else:
# Still Expired
self.reservoir_expiry = self.date_time.datetime.now()
self.reservoir_expiry = self.clock.now()

self.polling_interval = target["Interval"]
print("%s ..... %s ..... %s", self.reservoir_expiry, 10, 12)
Expand All @@ -117,7 +120,7 @@ def get_then_reset_statistics(self) -> _SamplingStatisticsDocument:
self.statistics = _SamplingStatisticsDocument(self.client_id, self.sampling_rule.RuleName)
self.statistics_lock.release()

return old_stats.snapshot()
return old_stats.snapshot(self.clock)

def matches(self, resource: Resource, attributes: Attributes) -> bool:
http_target = None
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import copy
import datetime
from logging import getLogger
from operator import itemgetter
from threading import Lock
from typing import Optional, Sequence

from amazon.opentelemetry.distro.sampler._clock import _Clock
from amazon.opentelemetry.distro.sampler._fallback_sampler import _FallbackSampler
from amazon.opentelemetry.distro.sampler._rule import DEFAULT_TARGET_POLLING_INTERVAL_SECONDS, _Rule
from amazon.opentelemetry.distro.sampler._sampling_rule import _SamplingRule
Expand All @@ -25,13 +25,13 @@
class _RuleCache:
rules: [_Rule] = []

def __init__(self, resource: Resource, fallback_sampler: _FallbackSampler, client_id: str, date_time: datetime, lock: Lock):
def __init__(self, resource: Resource, fallback_sampler: _FallbackSampler, client_id: str, clock: _Clock, lock: Lock):
self.client_id = client_id
self.__cache_lock = lock
self.__resource = resource
self._fallback_sampler = fallback_sampler
self._date_time = date_time
self._last_modified = self._date_time.datetime.now()
self._clock = clock
self._last_modified = self._clock.now()

def should_sample(
self,
Expand Down Expand Up @@ -74,7 +74,7 @@ def update_sampling_rules(self, new_sampling_rules: [_SamplingRule]) -> None:
if sampling_rule.Version != 1:
_logger.info("sampling rule without Version 1 is not supported: RuleName: %s", sampling_rule.RuleName)
continue
temp_rules.append(_Rule(copy.deepcopy(sampling_rule), self.client_id, self._date_time))
temp_rules.append(_Rule(copy.deepcopy(sampling_rule), self.client_id, self._clock))

self.__cache_lock.acquire()

Expand All @@ -90,7 +90,7 @@ def update_sampling_rules(self, new_sampling_rules: [_SamplingRule]) -> None:
if new_rule.sampling_rule == previous_rule.sampling_rule:
temp_rules[index] = previous_rule
self.rules = temp_rules
self._last_modified = datetime.datetime.now()
self._last_modified = self._clock.now()

self.__cache_lock.release()

Expand All @@ -112,7 +112,7 @@ def update_sampling_targets(self, sampling_targets_response) -> (bool, int):

self.__cache_lock.release()

last_rule_modification = self._date_time.datetime.fromtimestamp(sampling_targets_response["LastRuleModification"])
last_rule_modification = self._clock.from_timestamp(sampling_targets_response["LastRuleModification"])
if last_rule_modification > self._last_modified:
return (True, min_polling_interval)
return (False, min_polling_interval)
Expand All @@ -133,6 +133,6 @@ def get_all_statistics(self):
def expired(self) -> bool:
self.__cache_lock.acquire()
try:
return datetime.datetime.now() > self._last_modified + datetime.timedelta(seconds=CACHE_TTL_SECONDS)
return self._clock.now() > self._last_modified + self._clock.time_delta(seconds=CACHE_TTL_SECONDS)
finally:
self.__cache_lock.release()
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
import datetime
import json

from amazon.opentelemetry.distro.sampler._clock import _Clock


# Disable snake_case naming style so this class can match the statistics document response from X-Ray
# pylint: disable=invalid-name
Expand All @@ -18,11 +20,11 @@ def __init__(self, clientID, ruleName, RequestCount=0, BorrowCount=0, SampleCoun
self.BorrowCount = BorrowCount
self.SampleCount = SampleCount

def snapshot(self) -> dict:
def snapshot(self, clock: _Clock) -> dict:
return {
"ClientID": self.ClientID,
"RuleName": self.RuleName,
"Timestamp": datetime.datetime.now().timestamp(),
"Timestamp": clock.now().timestamp(),
"RequestCount": self.RequestCount,
"BorrowCount": self.BorrowCount,
"SampleCount": self.SampleCount
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import datetime
import random
from logging import getLogger
from threading import Lock, Timer
Expand All @@ -9,6 +8,7 @@
from typing_extensions import override

from amazon.opentelemetry.distro.sampler._aws_xray_sampling_client import _AwsXRaySamplingClient
from amazon.opentelemetry.distro.sampler._clock import _Clock
from amazon.opentelemetry.distro.sampler._fallback_sampler import _FallbackSampler
from amazon.opentelemetry.distro.sampler._rule import DEFAULT_TARGET_POLLING_INTERVAL_SECONDS
from amazon.opentelemetry.distro.sampler._rule_cache import _RuleCache
Expand Down Expand Up @@ -49,9 +49,9 @@ def __init__(

self.__client_id = self.__generate_client_id()
print("client ID: %s", self.__client_id)
self.__date_time = datetime
self.__clock = _Clock()
self.__xray_client = _AwsXRaySamplingClient(endpoint, log_level=log_level)
self.__fallback_sampler = _FallbackSampler()
self.__fallback_sampler = _FallbackSampler(self.__clock)

self.__polling_interval = polling_interval
self.__target_polling_interval = DEFAULT_TARGET_POLLING_INTERVAL_SECONDS
Expand All @@ -66,7 +66,7 @@ def __init__(

self.__rule_cache_lock = Lock()
self.__rule_cache = _RuleCache(
self.__resource, self.__fallback_sampler, self.__client_id, self.__date_time, self.__rule_cache_lock
self.__resource, self.__fallback_sampler, self.__client_id, self.__clock, self.__rule_cache_lock
)

# Schedule the next rule poll now
Expand Down

0 comments on commit 77042d9

Please sign in to comment.