Skip to content

Commit

Permalink
#15008: Move xattn cache generation to text prefill forward
Browse files Browse the repository at this point in the history
(cherry picked from commit d0f78cb)
  • Loading branch information
cglagovichTT committed Nov 14, 2024
1 parent 758f8c9 commit 5cde377
Show file tree
Hide file tree
Showing 8 changed files with 120 additions and 170 deletions.
55 changes: 26 additions & 29 deletions models/demos/llama3/tests/multimodal/test_llama_cross_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0)
pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache]

# Iterate over batch
# Preallocate K and V caches
tt_xattn_cache = [
ttnn.from_torch(
Expand All @@ -106,34 +105,6 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
)
for _ in range(2)
]
for b in range(batch):
tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill(
tt_xattn_tokens[b : b + 1],
force_replicated=True,
)
tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_tensor_xattn_tokens, tt_xattn_cache, user_id=b)
tt_xattn_cache_torch = [
ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view(
batch,
n_heads,
vision_seq_len,
head_dim,
)
for x in tt_xattn_cache
]

for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch):
passing, pcc_message = comp_pcc(pt, tt, pcc_required)

logger.info(comp_allclose(pt, tt))
logger.info(f"PCC: {pcc_message}")
if passing:
logger.info(f"compute_xattn_kv_cache Passed!")
else:
logger.warning(f"compute_xattn_kv_cache Failed!")
all_tests_pass = False

assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!"

"""
Test forward, prefill and decode!
Expand Down Expand Up @@ -179,6 +150,10 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
if mode == "prefill":
outputs = []
for b in range(batch):
tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill(
tt_xattn_tokens[b : b + 1],
force_replicated=True,
)
tt_tensor_x = model_args.prepare_inputs_ttnn_prefill(
tt_x[b : b + 1],
force_replicated=True,
Expand Down Expand Up @@ -206,6 +181,7 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
xattn_cache=tt_xattn_cache,
mode=mode,
user_id=b,
vision_tokens=tt_tensor_xattn_tokens,
)

tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))
Expand Down Expand Up @@ -271,4 +247,25 @@ def test_llama_cross_attention_inference(text_seq_len, batch, mesh_device, reset
logger.info(f"PCC: {pcc_message}")
all_tests_pass = all_tests_pass and passing

if mode == "prefill":
tt_xattn_cache_torch = [
ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view(
batch,
n_heads,
vision_seq_len,
head_dim,
)
for x in tt_xattn_cache
]
for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch):
passing, pcc_message = comp_pcc(pt, tt, pcc_required)

logger.info(comp_allclose(pt, tt))
logger.info(f"PCC: {pcc_message}")
if passing:
logger.info(f"compute_xattn_kv_cache Passed!")
else:
logger.warning(f"compute_xattn_kv_cache Failed!")
all_tests_pass = False

assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!"
Original file line number Diff line number Diff line change
Expand Up @@ -109,46 +109,11 @@ def test_llama_cross_attention_transformer_text_inference(
# unstack k/v
pt_xattn_cache_chunks = [torch.chunk(x, 2, dim=1) for x in pt_xattn_cache_chunks]
pt_xattn_cache_chunks = [x for xx in pt_xattn_cache_chunks for x in xx]
# slice out replicated k/v heads
pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache_chunks]

# Iterate over batch
# Preallocate K and V caches
tt_xattn_cache = tt_model.setup_cache(max_batch_size=batch)
for b in range(batch):
tt_tensor_vision_tokens = model_args.prepare_inputs_ttnn_prefill(
tt_vision_tokens[b : b + 1],
force_replicated=True,
)

tt_xattn_cache = [
layer.compute_xattn_kv_cache(tt_tensor_vision_tokens, tt_xattn_cache[layer_num], user_id=b)
for layer_num, layer in enumerate(tt_model.cross_attention_layers)
]
tt_xattn_cache_torch = [
ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view(
batch,
n_heads,
vision_seq_len,
head_dim,
)
for kv_cache in tt_xattn_cache
for x in kv_cache
]

for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch):
passing, pcc_message = comp_pcc(pt, tt, prefill_pcc_required)

logger.info(comp_allclose(pt, tt))
logger.info(f"PCC: {pcc_message}")

if passing:
logger.info(f"compute_xattn_kv_cache Passed!")
else:
logger.warning(f"compute_xattn_kv_cache Failed!")
all_tests_pass = False

assert all_tests_pass

# Test forward pass of the model
n_iter = 10
Expand Down Expand Up @@ -214,6 +179,10 @@ def test_llama_cross_attention_transformer_text_inference(
if mode == "prefill":
outputs = []
for b in range(batch):
tt_tensor_vision_tokens = model_args.prepare_inputs_ttnn_prefill(
tt_vision_tokens[b : b + 1],
force_replicated=True,
)
tt_h = model_args.prepare_inputs_ttnn_prefill(
h[b : b + 1],
)
Expand Down Expand Up @@ -267,6 +236,7 @@ def test_llama_cross_attention_transformer_text_inference(
user_id=b,
mode=mode,
text_only_inference=TEXT_ONLY,
vision_tokens=tt_tensor_vision_tokens,
)

tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=0))
Expand Down Expand Up @@ -357,5 +327,31 @@ def test_llama_cross_attention_transformer_text_inference(
passing, pcc_message = comp_pcc(logits, tt_out, pcc_required)
logger.info(comp_allclose(logits, tt_out))
logger.info(f"PCC: {pcc_message}")
prev_pos = cur_pos
assert passing, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!"
prev_pos = cur_pos

if mode == "prefill":
tt_xattn_cache_torch = [
ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view(
batch,
n_heads,
vision_seq_len,
head_dim,
)
for kv_cache in tt_xattn_cache
for x in kv_cache
]

for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch):
passing, pcc_message = comp_pcc(pt, tt, prefill_pcc_required)

logger.info(comp_allclose(pt, tt))
logger.info(f"PCC: {pcc_message}")

if passing:
logger.info(f"compute_xattn_kv_cache Passed!")
else:
logger.warning(f"compute_xattn_kv_cache Failed!")
all_tests_pass = False

assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!"
56 changes: 27 additions & 29 deletions models/demos/llama3/tests/multimodal/test_llama_cross_block.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,6 @@ def test_llama_cross_attention_transformer_block_inference(
pt_xattn_cache_chunks = torch.chunk(pt_xattn_cache, 2, dim=0)
pt_xattn_cache_chunks = [x.view(batch, n_heads, vision_seq_len, head_dim) for x in pt_xattn_cache]

# Iterate over batch
# Preallocate K and V caches
tt_xattn_cache = [
ttnn.from_torch(
Expand All @@ -100,34 +99,6 @@ def test_llama_cross_attention_transformer_block_inference(
)
for _ in range(2)
]
for b in range(batch):
tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill(
tt_xattn_tokens[b : b + 1],
force_replicated=True,
)
tt_xattn_cache = tt_model.compute_xattn_kv_cache(tt_tensor_xattn_tokens, tt_xattn_cache, user_id=b)
tt_xattn_cache_torch = [
ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view(
batch,
n_heads,
vision_seq_len,
head_dim,
)
for x in tt_xattn_cache
]

for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch):
passing, pcc_message = comp_pcc(pt, tt, pcc_required)

logger.info(comp_allclose(pt, tt))
logger.info(f"PCC: {pcc_message}")
if passing:
logger.info(f"compute_xattn_kv_cache Passed!")
else:
logger.warning(f"compute_xattn_kv_cache Failed!")
all_tests_pass = False

assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!"

"""
Test forward, prefill and decode!
Expand Down Expand Up @@ -176,6 +147,10 @@ def test_llama_cross_attention_transformer_block_inference(
if mode == "prefill":
outputs = []
for b in range(batch):
tt_tensor_xattn_tokens = model_args.prepare_inputs_ttnn_prefill(
tt_xattn_tokens[b : b + 1],
force_replicated=True,
)
tt_tensor_x = model_args.prepare_inputs_ttnn_prefill(
tt_x[b : b + 1],
)
Expand Down Expand Up @@ -211,6 +186,7 @@ def test_llama_cross_attention_transformer_block_inference(
xattn_cache=tt_xattn_cache,
mode=mode,
user_id=b,
vision_tokens=tt_tensor_xattn_tokens,
)

tt_output_torch = ttnn.to_torch(tt_out, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=-1))
Expand Down Expand Up @@ -274,4 +250,26 @@ def test_llama_cross_attention_transformer_block_inference(
logger.info(f"PCC: {pcc_message}")
all_tests_pass = all_tests_pass and passing

if mode == "prefill":
tt_xattn_cache_torch = [
ttnn.to_torch(x, mesh_composer=ttnn.ConcatMeshToTensor(mesh_device, dim=1)).view(
batch,
n_heads,
vision_seq_len,
head_dim,
)
for x in tt_xattn_cache
]

for pt, tt in zip(pt_xattn_cache_chunks, tt_xattn_cache_torch):
passing, pcc_message = comp_pcc(pt, tt, pcc_required)

logger.info(comp_allclose(pt, tt))
logger.info(f"PCC: {pcc_message}")
if passing:
logger.info(f"compute_xattn_kv_cache Passed!")
else:
logger.warning(f"compute_xattn_kv_cache Failed!")
all_tests_pass = False

assert all_tests_pass, f"PCC value is lower than {pcc_required} for some of the outputs. Check Warnings!"
Loading

0 comments on commit 5cde377

Please sign in to comment.