diff --git a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py index 407555216..0acb7dc95 100644 --- a/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py +++ b/shortfin/python/shortfin_apps/llm/components/kvcache/page_pool.py @@ -104,6 +104,10 @@ def __init__(self, *, devices: Sequence[sf.ScopedDevice], config: PagePoolConfig page_table = sf.array.device_array.for_device( device, page_table_shape, self.config.dtype ) + page_table_host = page_table.for_transfer() + with page_table_host.map(discard=True) as m: + m.fill(0) + page_table_host.copy_to(page_table) self.page_tables.append(page_table) def acquire_free_pages(self, count: int) -> list[PageInfo] | None: diff --git a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py index a4e1f2284..bd97bde02 100644 --- a/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py +++ b/shortfin/tests/apps/llm/components/kvcache/trie_attention_cache_test.py @@ -56,6 +56,21 @@ def __repr__(self): def mock_device_array(): """Create mock device array with proper interface implementation""" + class MockMapping: + def __enter__(self): + return self + + def __exit__( + self, + exc_type: object | None, + exc_value: object | None, + exc_tb: object | None, + ): + pass + + def fill(self, value: int): + pass + class MockDeviceArray: def __init__(self): self.shape = None @@ -67,6 +82,17 @@ def view(self, *args): def copy_from(self, src): pass + def copy_to(self, dst): + pass + + def for_transfer(self): + return MockDeviceArray() + + def map( + self, *, read: bool = False, write: bool = False, discard: bool = False + ): + return MockMapping() + return MockDeviceArray()