Skip to content

Commit

Permalink
add locks to trie_attention_cache.py in attempt to adress trie-relate…
Browse files Browse the repository at this point in the history
…d token corruption
  • Loading branch information
renxida committed Jan 17, 2025
1 parent 1f50538 commit 1ebbe4e
Showing 1 changed file with 78 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import time
import math
import heapq
import threading
from .page_pool import PagePool, PageInfo
from .base_attention_cache import (
BasePagedAttentionCache,
Expand All @@ -14,21 +15,28 @@
@dataclass
class RefCount:
"""
A reference counter to replace simple int.
A thread-safe reference counter.
"""

count: int = 0
_lock: threading.Lock = None

def __post_init__(self):
self._lock = threading.Lock()

def increment(self) -> int:
self.count += 1
return self.count
with self._lock:
self.count += 1
return self.count

def decrement(self) -> int:
self.count -= 1
return self.count
with self._lock:
self.count -= 1
return self.count

def is_empty(self) -> bool:
return self.count <= 0
with self._lock:
return self.count <= 0


@dataclass
Expand All @@ -46,6 +54,7 @@ class TrieNode:
parent: Parent node in the trie (None for root)
ref_count: Number of active references to this node
access_time: Last access timestamp for LRU eviction
_lock: Lock for synchronizing node modifications
"""

tokens: Tuple[int, ...]
Expand All @@ -54,13 +63,15 @@ class TrieNode:
parent: Optional["TrieNode"] = None
ref_count: RefCount = None
access_time: float = 0.0
_lock: threading.Lock = None

def __post_init__(self) -> None:
"""Initialize children dict and access time if not provided."""
if self.children is None:
self.children = {}
self.access_time = time.monotonic()
self.ref_count = RefCount()
self._lock = threading.Lock()

def create_child(self, tokens: Tuple[int, ...], page: PageInfo) -> "TrieNode":
"""Create a new child node with the given tokens and page.
Expand All @@ -72,15 +83,21 @@ def create_child(self, tokens: Tuple[int, ...], page: PageInfo) -> "TrieNode":
Returns:
The newly created child node
"""
new_node = TrieNode(tokens=tokens, page=page, parent=self)
self.children[tokens] = new_node
return new_node
with self._lock:
new_node = TrieNode(tokens=tokens, page=page, parent=self)
self.children[tokens] = new_node
return new_node

def unlink(self) -> None:
"""Remove this node from its parent's children."""
"""Remove this node from its parent's children.
Thread-safe unlinking of nodes.
"""
if self.parent is not None:
del self.parent.children[self.tokens]
self.parent = None
with self.parent._lock:
if self.tokens in self.parent.children:
del self.parent.children[self.tokens]
self.parent = None

def __hash__(self) -> int:
"""Nodes are uniquely identified by their memory address."""
Expand Down Expand Up @@ -139,8 +156,6 @@ def publish_pages_for_tokens(
Raises:
ValueError: If tokens don't match allocation or exceed available pages
"""
# If we have more tokens, publish pages up to the incoming tokens.
# If incoming has more tokens, replace our tokens with incoming tokens and publish pages up to the incoming tokens.

def has_common_prefix(tokens1, tokens2):
for t1, t2 in zip(tokens1, tokens2):
Expand Down Expand Up @@ -181,19 +196,22 @@ def has_common_prefix(tokens1, tokens2):
raise NotImplementedError(
"Additional work needed here to support publishing incomplete pages to ensure that we finish up a page before attaching child nodes to it."
)
cur_node = self.last_cached_node
for token_block, page in zip(unpublished_tokens, unpublished_pages):
new_node = cur_node.create_child(token_block, page)
cur_node = new_node

if cur_node is not self.cache.root:
self.cache.leaves.add(cur_node)
# Acquire cache lock for trie modifications
with self.cache._trie_lock: # prevent matching and eviction until we are done
cur_node = self.last_cached_node
for token_block, page in zip(unpublished_tokens, unpublished_pages):
new_node = cur_node.create_child(token_block, page)
cur_node = new_node

# Update reference counts
if unpublished_tokens:
cur_node.ref_count.increment()
self.last_cached_node.ref_count.decrement()
self.last_cached_node = cur_node
if cur_node is not self.cache.root:
self.cache.leaves.add(cur_node)

# Update reference counts
if unpublished_tokens:
cur_node.ref_count.increment()
self.last_cached_node.ref_count.decrement()
self.last_cached_node = cur_node

self.number_of_published_pages = number_of_pages_to_publish

Expand Down Expand Up @@ -276,6 +294,7 @@ class TriePagedAttentionCache(BasePagedAttentionCache):
leaves: Set of leaf nodes for efficient eviction
page_pool: Pool providing page allocations
tokens_per_page: Number of tokens that fit in each page
_trie_lock: Lock for synchronizing trie structure modifications
"""

def __init__(self, page_pool: PagePool, tokens_per_page: int):
Expand All @@ -300,6 +319,7 @@ def __init__(self, page_pool: PagePool, tokens_per_page: int):
)
self.root = TrieNode(tokens=tuple(), page=dummy_page)
self.leaves: Set[TrieNode] = set()
self._trie_lock = threading.Lock()

def _match(self, tokens: List[int]) -> Tuple[TrieNode, List[PageInfo]]:
"""
Expand All @@ -318,14 +338,15 @@ def _match(self, tokens: List[int]) -> Tuple[TrieNode, List[PageInfo]]:
matched_pages = []
cur = self.root

for i in range(0, len(tokens), self.tokens_per_page):
token_block = tokens[i : i + self.tokens_per_page]
with self._trie_lock:
for i in range(0, len(tokens), self.tokens_per_page):
token_block = tokens[i : i + self.tokens_per_page]

if token_block not in cur.children:
break
cur = cur.children[token_block]
cur.access_time = time.monotonic()
matched_pages.append(cur.page)
if token_block not in cur.children:
break
cur = cur.children[token_block]
cur.access_time = time.monotonic()
matched_pages.append(cur.page)

return cur, matched_pages

Expand Down Expand Up @@ -400,29 +421,30 @@ def _evict_pages(self, max_pages: int) -> int:
"""
pages_to_evict = []

# Initialize heap with unreferenced leaves
unused_leaf_heap = [
(leaf.access_time, leaf)
for leaf in self.leaves
if leaf.ref_count.is_empty()
]
heapq.heapify(unused_leaf_heap)

# Evict least recently used nodes
while unused_leaf_heap and len(pages_to_evict) < max_pages:
_, leaf = heapq.heappop(unused_leaf_heap)
pages_to_evict.append(leaf.page)
parent = leaf.parent
leaf.unlink()
self.leaves.remove(leaf)

# If parent becomes childless, it becomes a leaf
if parent is not self.root and not parent.children:
self.leaves.add(parent)
if parent.ref_count.is_empty():
heapq.heappush(unused_leaf_heap, (parent.access_time, parent))

if pages_to_evict:
self.page_pool.free_pages(pages_to_evict)
with self._trie_lock:
# Initialize heap with unreferenced leaves
unused_leaf_heap = [
(leaf.access_time, leaf)
for leaf in self.leaves
if leaf.ref_count.is_empty()
]
heapq.heapify(unused_leaf_heap)

# Evict least recently used nodes
while unused_leaf_heap and len(pages_to_evict) < max_pages:
_, leaf = heapq.heappop(unused_leaf_heap)
pages_to_evict.append(leaf.page)
parent = leaf.parent
leaf.unlink()
self.leaves.remove(leaf)

# If parent becomes childless, it becomes a leaf
if parent is not self.root and not parent.children:
self.leaves.add(parent)
if parent.ref_count.is_empty():
heapq.heappush(unused_leaf_heap, (parent.access_time, parent))

if pages_to_evict:
self.page_pool.free_pages(pages_to_evict)

return len(pages_to_evict)

0 comments on commit 1ebbe4e

Please sign in to comment.