From 1e6b24fda9189ee19a117f173d8bd5e6947c9598 Mon Sep 17 00:00:00 2001 From: Colman Glagovich Date: Wed, 15 May 2024 12:07:57 +0000 Subject: [PATCH] #8189: Llama prefill/decode supports bfp8 KV cache with multidevice tensors (cherry picked from commit 70f45db154612c69a656867ad87b0158161e3604) --- .../llama2_70b/tt/llama_attention_optimized.py | 18 +++++++++++------- 1 file changed, 11 insertions(+), 7 deletions(-) diff --git a/models/experimental/llama2_70b/tt/llama_attention_optimized.py b/models/experimental/llama2_70b/tt/llama_attention_optimized.py index addb9c3e36c..0620c8fbac4 100644 --- a/models/experimental/llama2_70b/tt/llama_attention_optimized.py +++ b/models/experimental/llama2_70b/tt/llama_attention_optimized.py @@ -114,7 +114,7 @@ def init_kv_cache(self): mesh_mapper=ShardTensorToMesh(self.device_mesh, dim=0), layout=ttnn.TILE_LAYOUT, memory_config=self.model_config["DRAM_MEMCFG"], - dtype=ttnn.bfloat16, + dtype=ttnn.bfloat8_b, cache_file_name=self.cache_path / f"empty_attn_cache{cache_k.shape}", ), self.device_mesh, @@ -672,19 +672,23 @@ def prefill_attn_mqa( keys = self.layer_past[0] # Fill cache expects batch in dim0 keys_reshaped = ttnn.reshape(keys, [self.max_batch_size, self.n_local_kv_heads, -1, self.head_dim]) - # tt_lib.tensor.fill_cache(keys, tt_lib.tensor.typecast(key_layer, self.model_config["BFP8_DTYPE"]), user_id) tt_lib.tensor.fill_cache( - keys_reshaped, key_layer, user_id - ) # TODO: Set back to bfp8_b when typecast supports MD tensors + keys_reshaped, tt_lib.tensor.typecast(key_layer, self.model_config["BFP8_DTYPE"]), user_id + ) + # tt_lib.tensor.fill_cache( + # keys_reshaped, key_layer, user_id + # ) # FILL V CACHE values = self.layer_past[1] # Fill cache expects batch in dim0 values_reshaped = ttnn.reshape(values, [self.max_batch_size, self.n_local_kv_heads, -1, self.head_dim]) - # tt_lib.tensor.fill_cache(values, tt_lib.tensor.typecast(value_layer, self.model_config["BFP8_DTYPE"]), user_id) tt_lib.tensor.fill_cache( - values_reshaped, value_layer, user_id - ) # TODO: Set back to bfp8_b when typecast supports MD tensors + values_reshaped, tt_lib.tensor.typecast(value_layer, self.model_config["BFP8_DTYPE"]), user_id + ) + # tt_lib.tensor.fill_cache( + # values_reshaped, value_layer, user_id + # ) # PRE-SOFTMAX MM