Skip to content

Commit

Permalink
#8189: Llama prefill/decode supports bfp8 KV cache with multidevice t…
Browse files Browse the repository at this point in the history
…ensors

(cherry picked from commit 70f45db)
  • Loading branch information
cglagovichTT committed May 15, 2024
1 parent 19185f4 commit 1e6b24f
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions models/experimental/llama2_70b/tt/llama_attention_optimized.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def init_kv_cache(self):
mesh_mapper=ShardTensorToMesh(self.device_mesh, dim=0),
layout=ttnn.TILE_LAYOUT,
memory_config=self.model_config["DRAM_MEMCFG"],
dtype=ttnn.bfloat16,
dtype=ttnn.bfloat8_b,
cache_file_name=self.cache_path / f"empty_attn_cache{cache_k.shape}",
),
self.device_mesh,
Expand Down Expand Up @@ -672,19 +672,23 @@ def prefill_attn_mqa(
keys = self.layer_past[0]
# Fill cache expects batch in dim0
keys_reshaped = ttnn.reshape(keys, [self.max_batch_size, self.n_local_kv_heads, -1, self.head_dim])
# tt_lib.tensor.fill_cache(keys, tt_lib.tensor.typecast(key_layer, self.model_config["BFP8_DTYPE"]), user_id)
tt_lib.tensor.fill_cache(
keys_reshaped, key_layer, user_id
) # TODO: Set back to bfp8_b when typecast supports MD tensors
keys_reshaped, tt_lib.tensor.typecast(key_layer, self.model_config["BFP8_DTYPE"]), user_id
)
# tt_lib.tensor.fill_cache(
# keys_reshaped, key_layer, user_id
# )

# FILL V CACHE
values = self.layer_past[1]
# Fill cache expects batch in dim0
values_reshaped = ttnn.reshape(values, [self.max_batch_size, self.n_local_kv_heads, -1, self.head_dim])
# tt_lib.tensor.fill_cache(values, tt_lib.tensor.typecast(value_layer, self.model_config["BFP8_DTYPE"]), user_id)
tt_lib.tensor.fill_cache(
values_reshaped, value_layer, user_id
) # TODO: Set back to bfp8_b when typecast supports MD tensors
values_reshaped, tt_lib.tensor.typecast(value_layer, self.model_config["BFP8_DTYPE"]), user_id
)
# tt_lib.tensor.fill_cache(
# values_reshaped, value_layer, user_id
# )

# PRE-SOFTMAX MM

Expand Down

0 comments on commit 1e6b24f

Please sign in to comment.