Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix async pace #1654

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
210 changes: 210 additions & 0 deletions examples/experimental/pace_example.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# This notebook demonstrates how to use the Pace utility\n",
"\n",
"This utility can be used to limit the rate of API requests to external endpoints. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"import random\n",
"import threading\n",
"import time\n",
"\n",
"from IPython import display\n",
"from trulens.core.utils.pace import Pace"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a pace instance with 2 second per period and 20 marks per second. The\n",
"# average of 20 marks per second will be maintained across any 2 second period\n",
"# but this makes it possible for an initial burst of 20 marks immediately. This\n",
"# is due to the assumption that there were no marks before the process started.\n",
"\n",
"# If seconds_per_period is increased, a larger burst of marks will be possible\n",
"# before the average marks per second since the start of the process stabalizes.\n",
"# A larger burst also means there will be a delay until the next period before\n",
"# marks can return again. A \"burstiness\" warning is issue the first time a delay\n",
"# longer than half of the seconds_per_period is encountered.\n",
"\n",
"p = Pace(seconds_per_period=2, marks_per_second=20)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# start time and counter\n",
"st = time.time()\n",
"count = 0\n",
"\n",
"while True:\n",
" # Mark and increment counter. Calls to mark will block to maintain pace.\n",
" p.mark()\n",
" count += 1\n",
"\n",
" et = time.time()\n",
" display.clear_output(wait=True)\n",
"\n",
" # Show stats of the marks rate since the start of this cell.\n",
" print(f\"\"\"\n",
"Elapsed time: {et - st}\n",
"Marks count: {count}\n",
"Marks per second: {count / (et - st)}\n",
"\"\"\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pace across Threads\n",
"\n",
"The pacing should be maintained even if a single Pace instance is used across\n",
"multiple threads."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"num_threads = 10\n",
"count = 0\n",
"\n",
"\n",
"# Create a function to run in each thread and update the count for each mark:\n",
"def marker():\n",
" global count\n",
"\n",
" while True:\n",
" # Mark and increment counter. Calls to mark will block to maintain pace.\n",
" p.mark()\n",
" count += 1\n",
"\n",
" # Add a bit of sleep to simulate some work.\n",
" time.sleep(random.random() / 100.0)\n",
"\n",
"\n",
"# Start time.\n",
"st = time.time()\n",
"\n",
"# Start the threads.\n",
"for i in range(num_threads):\n",
" t = threading.Thread(target=marker)\n",
" t.start()\n",
"\n",
"while True:\n",
" # Report count stats every second.\n",
" time.sleep(1)\n",
"\n",
" display.clear_output(wait=True)\n",
"\n",
" et = time.time()\n",
"\n",
" # Show stats of the marks rate since the start of this cell.\n",
" print(f\"\"\"\n",
"Elapsed time: {et - st}\n",
"Marks count: {count}\n",
"Marks per second: {count / (et - st)}\n",
"\"\"\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Pace in Async Tasks\n",
"\n",
"Pace can also be maintained when using asynchronous tasks. For this, the `amark`\n",
"method must be used and awaited."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"num_tasks = 10\n",
"count = 0\n",
"\n",
"\n",
"# Create a function to run in each task and update the count for each mark:\n",
"async def async_marker():\n",
" global count\n",
"\n",
" while True:\n",
" # Mark and increment counter. Calls to amark will block to maintain pace.\n",
" await p.amark()\n",
" count += 1\n",
"\n",
" # Add a bit of sleep to simulate some work.\n",
" await asyncio.sleep(random.random() / 100.0)\n",
"\n",
"\n",
"# Start time.\n",
"st = time.time()\n",
"\n",
"loop = asyncio.get_event_loop()\n",
"\n",
"# Start the threads.\n",
"for i in range(num_tasks):\n",
" task = loop.create_task(async_marker())\n",
"\n",
"while True:\n",
" # Report count stats every second.\n",
"\n",
" await asyncio.sleep(1)\n",
"\n",
" display.clear_output(wait=True)\n",
"\n",
" et = time.time()\n",
"\n",
" # Show stats of the marks rate since the start of this cell.\n",
" print(f\"\"\"\n",
"Elapsed time: {et - st}\n",
"Marks count: {count}\n",
"Marks per second: {count / (et - st)}\n",
"\"\"\")"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "trulens-9bG3yHQd-py3.11",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
64 changes: 44 additions & 20 deletions src/core/trulens/core/utils/pace.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from _thread import LockType
import asyncio
from collections import deque
from datetime import datetime
Expand All @@ -10,6 +9,7 @@

from pydantic import BaseModel
from pydantic import Field
from pydantic import PrivateAttr

logger = logging.getLogger(__name__)

Expand All @@ -21,13 +21,18 @@ class Pace(BaseModel):
constraint: the number of returns in the given period of time cannot exceed
`marks_per_second * seconds_per_period`. This means the average number of
returns in that period is bounded above exactly by `marks_per_second`.

!!! Warning:
The asynchronous and synchronous methods `amark` and `mark` should not be
used at the same time. That is, use either the synchronous interface or the
asynchronous one, but not both.
"""

marks_per_second: float = 1.0
"""The pace in number of mark returns per second."""

seconds_per_period: float = 60.0
"""Evaluate pace as overage over this period.
"""Evaluate pace as the average over this period.

Assumes that prior to construction of this Pace instance, the period did not
have any marks called. The longer this period is, the bigger burst of marks
Expand All @@ -42,7 +47,7 @@ class Pace(BaseModel):
mark_expirations: Deque[datetime] = Field(default_factory=deque)
"""Keep track of returns that happened in the last `period` seconds.

Store the datetime at which they expire (they become longer than `period`
Store the datetime at which they expire (they become older than `period`
seconds old).
"""

Expand All @@ -57,11 +62,20 @@ class Pace(BaseModel):
last_mark: datetime = Field(default_factory=datetime.now)
"""Time of the last mark return."""

lock: LockType = Field(default_factory=Lock)
_lock: Lock = PrivateAttr(default_factory=Lock)
"""Thread Lock to ensure mark method details run only one at a time."""

_alock: asyncio.Lock = PrivateAttr(default_factory=asyncio.Lock)
"""Asyncio Lock to ensure amark method details run only one at a time."""

model_config: ClassVar[dict] = dict(arbitrary_types_allowed=True)

_warned: bool = False
"""Whether the long delay warning has already been issued.

This is to not repeatedly give it.
"""

def __init__(
self,
seconds_per_period: float,
Expand Down Expand Up @@ -103,23 +117,28 @@ async def amark(self) -> float:
seconds since last mark returned.
"""

async with self.lock:
async with self._alock:
while len(self.mark_expirations) >= self.max_marks:
delay = (
self.mark_expirations[0] - datetime.now()
).total_seconds()

if delay >= self.seconds_per_period * 0.5:
logger.warning(
f"""
Pace has a long delay of {delay} seconds. There might have been a burst of
if not self._warned:
self._warned = True
logger.warning(
"""
Pace has a long delay of %s seconds. There might have been a burst of
requests which may become a problem for the receiver of whatever is being paced.
Consider reducing the `seconds_per_period` (currently {self.seconds_per_period} [seconds]) over which to
Consider reducing the `seconds_per_period` (currently %s [seconds]) over which to
maintain pace to reduce burstiness. " Alternatively reduce `marks_per_second`
(currently {self.marks_per_second} [1/second]) to reduce the number of marks
(currently %s [1/second]) to reduce the number of marks
per second in that period.
"""
)
""",
delay,
self.seconds_per_period,
self.marks_per_second,
)

if delay > 0.0:
await asyncio.sleep(delay)
Expand All @@ -145,23 +164,28 @@ def mark(self) -> float:
seconds since last mark returned.
"""

with self.lock:
with self._lock:
while len(self.mark_expirations) >= self.max_marks:
delay = (
self.mark_expirations[0] - datetime.now()
).total_seconds()

if delay >= self.seconds_per_period * 0.5:
logger.warning(
f"""
Pace has a long delay of {delay} seconds. There might have been a burst of
if not self._warned:
self._warned = True
logger.warning(
"""
Pace has a long delay of %s seconds. There might have been a burst of
requests which may become a problem for the receiver of whatever is being paced.
Consider reducing the `seconds_per_period` (currently {self.seconds_per_period} [seconds]) over which to
Consider reducing the `seconds_per_period` (currently %s [seconds]) over which to
maintain pace to reduce burstiness. " Alternatively reduce `marks_per_second`
(currently {self.marks_per_second} [1/second]) to reduce the number of marks
(currently %s [1/second]) to reduce the number of marks
per second in that period.
"""
)
""",
delay,
self.seconds_per_period,
self.marks_per_second,
)

if delay > 0.0:
time.sleep(delay)
Expand Down
1 change: 0 additions & 1 deletion tests/unit/static/golden/api.trulens.3.11.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -1238,7 +1238,6 @@ trulens.core.utils.pace.Pace:
attributes:
amark: builtins.function
last_mark: datetime.datetime
lock: _thread.lock
mark: builtins.function
mark_expirations: typing.Deque[datetime.datetime]
marks_per_second: builtins.float
Expand Down
1 change: 0 additions & 1 deletion tests/unit/static/golden/api.trulens_eval.3.11.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7954,7 +7954,6 @@ trulens_eval.utils.pace.Pace:
__class__: pydantic._internal._model_construction.ModelMetaclass
attributes:
last_mark: datetime.datetime
lock: _thread.lock
mark: builtins.function
mark_expirations: typing.Deque[datetime.datetime]
marks_per_second: builtins.float
Expand Down