Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#8174: [WIP] Replace ttlib custom Falcon matmuls with ttnn matmuls #8273

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions models/demos/falcon7b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from models.demos.falcon7b.reference.hf_modeling_falcon import FalconConfig, FalconForCausalLM
from models.demos.falcon7b.tt.falcon_causallm import TtFalconCausalLM
from models.demos.falcon7b.tt.model_config import get_model_config, model_config_entries
from models.demos.falcon7b.tt.model_utils import get_falcon_default_core_grid
from models.utility_functions import (
disable_compilation_reports,
disable_persistent_kernel_cache,
Expand Down Expand Up @@ -104,8 +105,8 @@ def print_output_prompts(generated_ids, tokenizer, batch_size, num_users_to_disp
logger.info(f"Output for user {user_id}:\n{output_prompt}")


def update_model_config(model, model_config_str, prefill_seq_len=0):
model.model_config.update(get_model_config(model_config_str, prefill_seq_len))
def update_model_config(model, model_config_str, core_grid, prefill_seq_len=0):
model.model_config.update(get_model_config(model_config_str, core_grid, prefill_seq_len))


def top_pk_logits(logits, p=0.9, k=10, temperature=1.0, return_probs=False):
Expand Down Expand Up @@ -184,7 +185,10 @@ def run_falcon_demo_kv(
)
profiler.end(f"tokenizing_inputs")

model_config = get_model_config(model_config_strs_prefill_decode[0], nearest_32(num_input_tokens))
default_core_grid = get_falcon_default_core_grid(devices[0])
model_config = get_model_config(
model_config_strs_prefill_decode[0], default_core_grid, nearest_32(num_input_tokens)
)
tt_cache_path = get_tt_cache_path(
model_version, model_subdir="Falcon", default_dir=model_config["DEFAULT_CACHE_PATH"]
)
Expand Down Expand Up @@ -280,7 +284,7 @@ def run_falcon_demo_kv(
logger.info("Running 1st run decode stage with compile...")

# Update model config
update_model_config(tt_FalconCausalLM_singlelayer, model_config_strs_prefill_decode[1])
update_model_config(tt_FalconCausalLM_singlelayer, model_config_strs_prefill_decode[1], default_core_grid)

decode_ids = torch.randint(low=0, high=configuration.vocab_size - 1, size=(global_batch, 1), dtype=torch.int64)

Expand Down Expand Up @@ -333,7 +337,7 @@ def run_falcon_demo_kv(
num_layers,
configuration,
max_seq_len,
get_model_config(model_config_strs_prefill_decode[0], nearest_32(num_input_tokens)),
get_model_config(model_config_strs_prefill_decode[0], default_core_grid, nearest_32(num_input_tokens)),
tt_cache_path,
nearest_32(num_input_tokens),
)
Expand Down Expand Up @@ -408,7 +412,7 @@ def run_falcon_demo_kv(
logger.info("Running inference decode stage...")

# Update model config
update_model_config(tt_FalconCausalLM, model_config_strs_prefill_decode[1])
update_model_config(tt_FalconCausalLM, model_config_strs_prefill_decode[1], default_core_grid)

decode_ids = torch.zeros(global_batch, 1, dtype=torch.int64)
for user_id, output_id in enumerate(output_ids):
Expand Down
4 changes: 3 additions & 1 deletion models/demos/falcon7b/tests/test_falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from models.demos.falcon7b.tt.falcon_attention import TtFalconAttentionDecode, TtFalconAttentionPrefill
from models.demos.falcon7b.tt.model_config import get_model_config
from models.demos.falcon7b.tt.model_utils import get_falcon_default_core_grid
from models.demos.falcon7b.tests.test_utils import get_rand_falcon_inputs, concat_device_outputs
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_allclose,
Expand Down Expand Up @@ -181,7 +182,8 @@ def test_FalconAttention_inference(
):
devices = get_devices_for_t3000(all_devices, num_devices)

model_config = get_model_config(model_config_str, seq_len)
default_core_grid = get_falcon_default_core_grid(devices[0])
model_config = get_model_config(model_config_str, default_core_grid, seq_len)
tt_cache_path = get_tt_cache_path(
model_version, model_subdir="Falcon", default_dir=model_config["DEFAULT_CACHE_PATH"]
)
Expand Down
8 changes: 4 additions & 4 deletions models/demos/falcon7b/tests/test_falcon_causallm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
)
from models.demos.falcon7b.tt.falcon_causallm import TtFalconCausalLM

from models.demos.falcon7b.tt.model_config import (
get_model_config,
)
from models.demos.falcon7b.tt.model_config import get_model_config
from models.demos.falcon7b.tt.model_utils import get_falcon_default_core_grid
from models.demos.falcon7b.tests.test_utils import get_rand_falcon_inputs, concat_device_out_layer_present
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_allclose,
Expand Down Expand Up @@ -230,7 +229,8 @@ def test_FalconCausalLM_inference(
):
devices = get_devices_for_t3000(all_devices, num_devices)

model_config = get_model_config(model_config_str, seq_len)
default_core_grid = get_falcon_default_core_grid(devices[0])
model_config = get_model_config(model_config_str, default_core_grid, seq_len)
tt_cache_path = get_tt_cache_path(
model_version, model_subdir="Falcon", default_dir=model_config["DEFAULT_CACHE_PATH"]
)
Expand Down
4 changes: 3 additions & 1 deletion models/demos/falcon7b/tests/test_falcon_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from models.demos.falcon7b.tt.falcon_decoder import TtFalconDecoderLayer
from models.demos.falcon7b.tt.model_config import get_model_config
from models.demos.falcon7b.tt.model_utils import get_falcon_default_core_grid
from models.demos.falcon7b.tests.test_utils import get_rand_falcon_inputs, concat_device_outputs
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_pcc,
Expand Down Expand Up @@ -172,7 +173,8 @@ def test_FalconDecoder_inference(
):
devices = get_devices_for_t3000(all_devices, num_devices)

model_config = get_model_config(model_config_str, seq_len)
default_core_grid = get_falcon_default_core_grid(devices[0])
model_config = get_model_config(model_config_str, default_core_grid, seq_len)
tt_cache_path = get_tt_cache_path(
model_version, model_subdir="Falcon", default_dir=model_config["DEFAULT_CACHE_PATH"]
)
Expand Down
4 changes: 3 additions & 1 deletion models/demos/falcon7b/tests/test_falcon_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
# TODO: Remove this?
from models.demos.falcon7b.tt.falcon_common import PytorchFalconCausalLM
from models.demos.falcon7b.tt.model_config import get_model_config
from models.demos.falcon7b.tt.model_utils import get_falcon_default_core_grid
from models.utility_functions import (
disable_compilation_reports,
disable_persistent_kernel_cache,
Expand Down Expand Up @@ -369,7 +370,8 @@ def test_FalconCausalLM_end_to_end_with_program_cache(
):
pytest.skip("#7933: Out of DRAM space error for tensor")

model_config = get_model_config(model_config_str, seq_len)
default_core_grid = get_falcon_default_core_grid(device)
model_config = get_model_config(model_config_str, default_core_grid, seq_len)
tt_cache_path = get_tt_cache_path(
model_version, model_subdir="Falcon", default_dir=model_config["DEFAULT_CACHE_PATH"]
)
Expand Down
4 changes: 3 additions & 1 deletion models/demos/falcon7b/tests/test_falcon_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from models.demos.falcon7b.reference.hf_modeling_falcon import FalconForCausalLM
from models.demos.falcon7b.tt.falcon_mlp import TtFalconMLPDecode, TtFalconMLPPrefill
from models.demos.falcon7b.tt.model_config import get_model_config
from models.demos.falcon7b.tt.model_utils import get_falcon_default_core_grid
from models.utility_functions import get_devices_for_t3000, torch2tt_tensor, tt2torch_tensor
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import comp_allclose, comp_pcc

Expand Down Expand Up @@ -147,7 +148,8 @@ def test_FalconMLP_inference(
):
devices = get_devices_for_t3000(all_devices, num_devices)

model_config = get_model_config(model_config_str, seq_len)
default_core_grid = get_falcon_default_core_grid(devices[0])
model_config = get_model_config(model_config_str, default_core_grid, seq_len)
tt_cache_path = get_tt_cache_path(
model_version, model_subdir="Falcon", default_dir=model_config["DEFAULT_CACHE_PATH"]
)
Expand Down
8 changes: 4 additions & 4 deletions models/demos/falcon7b/tests/test_falcon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
FalconForCausalLM,
)
from models.demos.falcon7b.tt.falcon_model import TtFalconModel
from models.demos.falcon7b.tt.model_config import (
get_model_config,
)
from models.demos.falcon7b.tt.model_config import get_model_config
from models.demos.falcon7b.tt.model_utils import get_falcon_default_core_grid
from models.demos.falcon7b.tests.test_utils import get_rand_falcon_inputs, concat_device_out_layer_present
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_allclose,
Expand Down Expand Up @@ -220,7 +219,8 @@ def test_FalconModel_inference(
):
devices = get_devices_for_t3000(all_devices, num_devices)

model_config = get_model_config(model_config_str)
default_core_grid = get_falcon_default_core_grid(devices[0])
model_config = get_model_config(model_config_str, default_core_grid)
tt_cache_path = get_tt_cache_path(
model_version, model_subdir="Falcon", default_dir=model_config["DEFAULT_CACHE_PATH"]
)
Expand Down
8 changes: 4 additions & 4 deletions models/demos/falcon7b/tests/test_falcon_prefill_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
)
from models.demos.falcon7b.tt.falcon_causallm import TtFalconCausalLM

from models.demos.falcon7b.tt.model_config import (
get_model_config,
)
from models.demos.falcon7b.tt.model_config import get_model_config
from models.demos.falcon7b.tt.model_utils import get_falcon_default_core_grid

from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
comp_allclose,
Expand Down Expand Up @@ -205,7 +204,8 @@ def test_FalconCausalLM_inference(
get_tt_cache_path,
device,
):
model_config = get_model_config(model_config_str)
default_core_grid = get_falcon_default_core_grid(device)
model_config = get_model_config(model_config_str, default_core_grid)
tt_cache_path = get_tt_cache_path(
model_version, model_subdir="Falcon", default_dir=model_config["DEFAULT_CACHE_PATH"]
)
Expand Down
29 changes: 15 additions & 14 deletions models/demos/falcon7b/tests/test_perf_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,8 @@
PytorchFalconCausalLM,
)

from models.demos.falcon7b.tt.model_config import (
get_model_config,
)
from models.demos.falcon7b.tt.model_config import get_model_config
from models.demos.falcon7b.tt.model_utils import get_falcon_default_core_grid
from models.demos.falcon7b.tests.test_utils import get_rand_falcon_inputs, concat_device_out_layer_present
from tests.tt_eager.python_api_testing.sweep_tests.comparison_funcs import (
get_atol_rtol_pcc,
Expand Down Expand Up @@ -426,7 +425,8 @@ def test_perf_gs_bare_metal(
if model_config_str == "BFLOAT16-L1_SHARDED":
pytest.skip("Sharded config is not supported on GS")

model_config = get_model_config(model_config_str)
default_core_grid = get_falcon_default_core_grid(device)
model_config = get_model_config(model_config_str, default_core_grid)
tt_cache_path = get_tt_cache_path(
model_version, model_subdir="Falcon", default_dir=model_config["DEFAULT_CACHE_PATH"]
)
Expand Down Expand Up @@ -478,7 +478,8 @@ def run_perf_wh_bare_metal(
# Enable Async Mode
for device in devices:
device.enable_async(async_mode)
model_config = get_model_config(model_config_str)
default_core_grid = get_falcon_default_core_grid(device)
model_config = get_model_config(model_config_str, default_core_grid)
tt_cache_path = get_tt_cache_path(
model_version, model_subdir="Falcon", default_dir=model_config["DEFAULT_CACHE_PATH"]
)
Expand Down Expand Up @@ -511,14 +512,14 @@ def run_perf_wh_bare_metal(
("prefill", 32, 1, 128, 0, "BFLOAT16-L1", 0.97, 0.99, 0.96, 0.1),
("prefill", 32, 1, 256, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96, 0.18),
("prefill", 32, 1, 256, 0, "BFLOAT16-L1", 0.98, 0.99, 0.96, 0.18),
("decode", 32, 32, 1, 128, "BFLOAT16-DRAM", 0.91, 0.92, 0.93, 0.15),
("decode", 32, 32, 1, 128, "BFLOAT16-L1", 0.91, 0.92, 0.93, 0.15),
("decode", 32, 32, 1, 128, "BFLOAT16-DRAM", 0.92, 0.94, 0.94, 0.15),
("decode", 32, 32, 1, 128, "BFLOAT16-L1", 0.92, 0.94, 0.94, 0.15),
("decode", 32, 32, 1, 128, "BFLOAT16-L1_SHARDED", 0.92, 0.95, 0.95, 0.1),
("decode", 32, 32, 1, 1024, "BFLOAT16-DRAM", 0.86, 0.92, 0.92, 0.4),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1", 0.86, 0.92, 0.92, 0.35),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.85, 0.93, 0.94, 0.1),
("decode", 32, 32, 1, 2047, "BFLOAT16-DRAM", 0.88, 0.93, 0.93, 0.75),
("decode", 32, 32, 1, 2047, "BFLOAT16-L1", 0.88, 0.93, 0.93, 0.6),
("decode", 32, 32, 1, 1024, "BFLOAT16-DRAM", 0.90, 0.94, 0.94, 0.4),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1", 0.90, 0.94, 0.94, 0.35),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.89, 0.95, 0.95, 0.1),
("decode", 32, 32, 1, 2047, "BFLOAT16-DRAM", 0.89, 0.92, 0.93, 0.75),
("decode", 32, 32, 1, 2047, "BFLOAT16-L1", 0.89, 0.92, 0.93, 0.6),
),
ids=[
"prefill_seq128_bf16_dram",
Expand Down Expand Up @@ -589,9 +590,9 @@ def test_perf_wh_bare_metal(
"llm_mode, num_devices, num_layers, batch, seq_len, kv_cache_len, model_config_str, expected_output_pcc, expected_k_cache_pcc, expected_v_cache_pcc, expected_inference_time, async_mode",
(
("prefill", 4, 32, 1, 256, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96, 0.18, False), # Issue 7816 Inference time
("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.91, 0.91, 0.21, False),
("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.89, 0.90, 0.21, False),
("prefill", 4, 32, 1, 256, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96, 0.18, True),
("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.91, 0.91, 0.09, True),
("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.89, 0.90, 0.09, True),
),
ids=[
"prefill_seq256",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,12 @@
import pytest
from loguru import logger

import ttnn
import tt_lib as ttl
from models.utility_functions import comp_pcc, tt2torch_tensor, torch2tt_tensor, skip_for_wormhole_b0
from models.demos.falcon7b.tt.model_utils import get_falcon_default_core_grid
from models.demos.falcon7b.tt.falcon_mlp import falcon_dense_4h_to_h_matmul, falcon_dense_h_to_4h_matmul
from models.demos.falcon7b.tt.falcon_causallm import falcon_lm_head_matmul
import torch
import math

Expand All @@ -26,15 +30,15 @@ def run_falcon_matmul_test(
if out_dtype == ttl.tensor.DataType.BFLOAT8_B:
pcc = 0.98

if falcon_op == ttl.tensor.falcon_fused_qkv_matmul:
if falcon_op == "falcon_fused_qkv_matmul":
a_shape = [1, 1, seq_len, 4544]
b_shape = [1, 1, 4544, 4672]
expected_output_shape = [1, 1, seq_len, 4672]
elif falcon_op == ttl.tensor.falcon_selfout_matmul:
elif falcon_op == "falcon_selfout_matmul":
a_shape = [1, 1, seq_len, 4544]
b_shape = [1, 1, 4544, 4544]
expected_output_shape = [1, 1, seq_len, 4544]
elif falcon_op == ttl.tensor.falcon_dense_4h_to_h_matmul:
elif falcon_op == "falcon_dense_4h_to_h_matmul":
a_shape = [1, 1, seq_len, 18176]
b_shape = [1, 1, 18176, 4544]
expected_output_shape = [1, 1, seq_len, 4544]
Expand All @@ -59,7 +63,7 @@ def run_falcon_matmul_test(
out_mem_config = ttl.tensor.MemoryConfig(
ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM
)
elif falcon_op == ttl.tensor.falcon_dense_h_to_4h_matmul:
elif falcon_op == "falcon_dense_h_to_4h_matmul":
a_shape = [1, 1, seq_len, 4544]
b_shape = [1, 1, 4544, 18176]
expected_output_shape = [1, 1, seq_len, 18176]
Expand All @@ -77,7 +81,7 @@ def run_falcon_matmul_test(
out_mem_config = ttl.tensor.MemoryConfig(
ttl.tensor.TensorMemoryLayout.INTERLEAVED, ttl.tensor.BufferType.DRAM
)
elif falcon_op == ttl.tensor.falcon_lm_head_matmul:
elif falcon_op == "falcon_lm_head_matmul":
a_shape = [1, 1, seq_len, 4544]
b_shape = [1, 1, 4544, 65024]
expected_output_shape = [1, 1, seq_len, 65024]
Expand Down Expand Up @@ -116,7 +120,40 @@ def run_falcon_matmul_test(
b_t = ttl.tensor.Tensor(B, in1_dtype).to(ttl.tensor.Layout.TILE).to(device, in1_mem_config)
bias_t = None

out = falcon_op(a_t, b_t, bias_t, output_mem_config=out_mem_config, output_dtype=out_dtype)
default_core_grid = get_falcon_default_core_grid(device)
if falcon_op in ("falcon_fused_qkv_matmul", "falcon_selfout_matmul"):
out = ttnn.matmul(
a_t,
b_t,
memory_config=out_mem_config,
dtype=out_dtype,
core_grid=default_core_grid,
use_1d_systolic_array=True,
)
elif falcon_op == "falcon_dense_4h_to_h_matmul":
out = falcon_dense_4h_to_h_matmul(
a_t,
b_t,
core_grid=default_core_grid,
output_mem_config=out_mem_config,
output_dtype=out_dtype,
packer_l1_acc=True,
)
elif falcon_op == "falcon_dense_h_to_4h_matmul":
out = falcon_dense_h_to_4h_matmul(
a_t,
b_t,
core_grid=default_core_grid,
fused_activation=None,
output_mem_config=out_mem_config,
output_dtype=out_dtype,
)
elif falcon_op == "falcon_lm_head_matmul":
out = falcon_lm_head_matmul(
a_t, b_t, core_grid=default_core_grid, output_mem_config=out_mem_config, output_dtype=out_dtype
)
else:
raise NotImplementedError(f"falcon matmul op is undefined!")

# Check memory and dtype of inputs and outputs
assert a_t.memory_config().buffer_type == in0_mem_config.buffer_type
Expand Down Expand Up @@ -172,11 +209,11 @@ def run_falcon_matmul_test(
@pytest.mark.parametrize(
"falcon_op",
(
ttl.tensor.falcon_fused_qkv_matmul,
ttl.tensor.falcon_selfout_matmul,
ttl.tensor.falcon_dense_4h_to_h_matmul,
ttl.tensor.falcon_dense_h_to_4h_matmul,
ttl.tensor.falcon_lm_head_matmul,
"falcon_fused_qkv_matmul",
"falcon_selfout_matmul",
"falcon_dense_4h_to_h_matmul",
"falcon_dense_h_to_4h_matmul",
"falcon_lm_head_matmul",
),
ids=["fused_qkv", "selfout", "dense_4h_to_h", "dense_h_to_4h", "lm_head"],
)
Expand All @@ -199,7 +236,7 @@ def test_falcon_matmul(
):
compute_grid_size = device.compute_with_storage_grid_size()
is_e75_grid_size = (compute_grid_size.x * compute_grid_size.y) == 88
if is_e75_grid_size and (seq_len == 512) and (falcon_op == ttl.tensor.falcon_lm_head_matmul):
if is_e75_grid_size and (seq_len == 512) and (falcon_op == "falcon_lm_head_matmul"):
pytest.skip(f"LM Head does not work on E75 grid size {compute_grid_size}")

run_falcon_matmul_test(
Expand Down
Loading
Loading