Skip to content

Commit

Permalink
[shortfin] Zero out kv cache pages during allocation (#738)
Browse files Browse the repository at this point in the history
First decode sometimes resulted in bad decode values. This is likely
related to bad values in the KV cache. Zeroing should avoid nan / inf
corruption for uninitialized memory.

---------

Co-authored-by: Stephen Baione <[email protected]>
  • Loading branch information
rsuderman and stbaione authored Jan 16, 2025
1 parent 0da9f25 commit 13781bb
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig
page_table = sf.array.device_array.for_device(
device, page_table_shape, self.config.dtype
)
page_table_host = page_table.for_transfer()
with page_table_host.map(discard=True) as m:
m.fill(0)
page_table_host.copy_to(page_table)
self.page_tables.append(page_table)

def acquire_free_pages(self, count: int) -> list[PageInfo] | None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,21 @@ def __repr__(self):
def mock_device_array():
"""Create mock device array with proper interface implementation"""

class MockMapping:
def __enter__(self):
return self

def __exit__(
self,
exc_type: object | None,
exc_value: object | None,
exc_tb: object | None,
):
pass

def fill(self, value: int):
pass

class MockDeviceArray:
def __init__(self):
self.shape = None
Expand All @@ -67,6 +82,17 @@ def view(self, *args):
def copy_from(self, src):
pass

def copy_to(self, dst):
pass

def for_transfer(self):
return MockDeviceArray()

def map(
self, *, read: bool = False, write: bool = False, discard: bool = False
):
return MockMapping()

return MockDeviceArray()


Expand Down

0 comments on commit 13781bb

Please sign in to comment.