Skip to content

Commit

Permalink
#5383: [Falcon7b] Tilize attention mask on device, repeat mask across…
Browse files Browse the repository at this point in the history
… heads on device for prefill non-optimized mode

Signed-off-by: Salar Hosseini <[email protected]>
  • Loading branch information
skhorasganiTT committed May 3, 2024
1 parent d569105 commit 62d7a12
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 34 deletions.
46 changes: 23 additions & 23 deletions models/demos/falcon7b/tests/test_perf_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -374,13 +374,13 @@ class TestParametrized:
"llm_mode, num_layers, batch, seq_len, kv_cache_len, model_config_str, expected_output_pcc, expected_k_cache_pcc, expected_v_cache_pcc, expected_inference_time",
(
("prefill", 32, 1, 128, 0, "BFLOAT16-DRAM", 0.85, 0.97, 0.86, 0.33),
("prefill", 32, 1, 128, 0, "BFLOAT16-L1", 0.85, 0.97, 0.86, 0.33),
("prefill", 32, 1, 256, 0, "BFLOAT16-DRAM", 0.90, 0.97, 0.87, 0.60),
("prefill", 32, 1, 256, 0, "BFLOAT16-L1", 0.90, 0.97, 0.87, 0.48),
("decode", 32, 32, 1, 128, "BFLOAT16-DRAM", 0.63, 0.80, 0.84, 0.27),
("decode", 32, 32, 1, 128, "BFLOAT16-L1", 0.63, 0.80, 0.84, 0.27),
("decode", 32, 32, 1, 1024, "BFLOAT16-DRAM", 0.56, 0.86, 0.88, 0.35),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1", 0.56, 0.86, 0.88, 0.35),
("prefill", 32, 1, 128, 0, "BFLOAT16-L1", 0.85, 0.97, 0.86, 0.31),
("prefill", 32, 1, 256, 0, "BFLOAT16-DRAM", 0.90, 0.97, 0.87, 0.48),
("prefill", 32, 1, 256, 0, "BFLOAT16-L1", 0.90, 0.97, 0.87, 0.39),
("decode", 32, 32, 1, 128, "BFLOAT16-DRAM", 0.63, 0.80, 0.84, 0.30),
("decode", 32, 32, 1, 128, "BFLOAT16-L1", 0.63, 0.80, 0.84, 0.30),
("decode", 32, 32, 1, 1024, "BFLOAT16-DRAM", 0.56, 0.86, 0.88, 0.40),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1", 0.56, 0.86, 0.88, 0.34),
("decode", 32, 32, 1, 2047, "BFLOAT16-DRAM", 0.55, 0.91, 0.89, 0.40),
("decode", 32, 32, 1, 2047, "BFLOAT16-L1", 0.55, 0.91, 0.89, 0.35),
),
Expand Down Expand Up @@ -504,18 +504,18 @@ def run_perf_wh_bare_metal(
@pytest.mark.parametrize(
"llm_mode, num_layers, batch, seq_len, kv_cache_len, model_config_str, expected_output_pcc, expected_k_cache_pcc, expected_v_cache_pcc, expected_inference_time",
(
("prefill", 32, 1, 128, 0, "BFLOAT16-DRAM", 0.97, 0.99, 0.96, 0.2),
("prefill", 32, 1, 128, 0, "BFLOAT16-L1", 0.97, 0.99, 0.96, 0.2),
("prefill", 32, 1, 256, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96, 0.4),
("prefill", 32, 1, 256, 0, "BFLOAT16-L1", 0.98, 0.99, 0.96, 0.4),
("decode", 32, 32, 1, 128, "BFLOAT16-DRAM", 0.91, 0.92, 0.93, 0.2),
("decode", 32, 32, 1, 128, "BFLOAT16-L1", 0.91, 0.92, 0.93, 0.2),
("decode", 32, 32, 1, 128, "BFLOAT16-L1_SHARDED", 0.92, 0.95, 0.95, 0.2),
("decode", 32, 32, 1, 1024, "BFLOAT16-DRAM", 0.86, 0.92, 0.92, 0.5),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1", 0.86, 0.92, 0.92, 0.5),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.85, 0.93, 0.94, 0.2),
("decode", 32, 32, 1, 2047, "BFLOAT16-DRAM", 0.88, 0.93, 0.93, 0.8),
("decode", 32, 32, 1, 2047, "BFLOAT16-L1", 0.88, 0.93, 0.93, 0.8),
("prefill", 32, 1, 128, 0, "BFLOAT16-DRAM", 0.97, 0.99, 0.96, 0.17),
("prefill", 32, 1, 128, 0, "BFLOAT16-L1", 0.97, 0.99, 0.96, 0.17),
("prefill", 32, 1, 256, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96, 0.2),
("prefill", 32, 1, 256, 0, "BFLOAT16-L1", 0.98, 0.99, 0.96, 0.2),
("decode", 32, 32, 1, 128, "BFLOAT16-DRAM", 0.91, 0.92, 0.93, 0.15),
("decode", 32, 32, 1, 128, "BFLOAT16-L1", 0.91, 0.92, 0.93, 0.15),
("decode", 32, 32, 1, 128, "BFLOAT16-L1_SHARDED", 0.92, 0.95, 0.95, 0.1),
("decode", 32, 32, 1, 1024, "BFLOAT16-DRAM", 0.86, 0.92, 0.92, 0.4),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1", 0.86, 0.92, 0.92, 0.35),
("decode", 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.85, 0.93, 0.94, 0.1),
("decode", 32, 32, 1, 2047, "BFLOAT16-DRAM", 0.88, 0.93, 0.93, 0.75),
("decode", 32, 32, 1, 2047, "BFLOAT16-L1", 0.88, 0.93, 0.93, 0.6),
),
ids=[
"prefill_seq128_bf16_dram",
Expand Down Expand Up @@ -585,10 +585,10 @@ def test_perf_wh_bare_metal(
@pytest.mark.parametrize(
"llm_mode, num_devices, num_layers, batch, seq_len, kv_cache_len, model_config_str, expected_output_pcc, expected_k_cache_pcc, expected_v_cache_pcc, expected_inference_time, async_mode",
(
("prefill", 4, 32, 1, 256, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96, 0.60, False),
("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.91, 0.91, 0.25, False),
("prefill", 4, 32, 1, 256, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96, 0.30, True),
("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.91, 0.91, 0.13, True),
("prefill", 4, 32, 1, 256, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96, 0.20, False),
("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.91, 0.91, 0.23, False),
("prefill", 4, 32, 1, 256, 0, "BFLOAT16-DRAM", 0.98, 0.99, 0.96, 0.18, True),
("decode", 4, 32, 32, 1, 1024, "BFLOAT16-L1_SHARDED", 0.87, 0.91, 0.91, 0.10, True),
),
ids=[
"prefill_seq256",
Expand Down
49 changes: 38 additions & 11 deletions models/demos/falcon7b/tt/falcon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,33 +123,53 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token
self.config.num_attention_heads,
num_input_tokens,
)

# Send attn masks to device
attn_masks_unordered = [
torch_tensors_to_tt_tensors(
[attention_mask_slice for _ in self.devices],
tt_lib.tensor.Layout.TILE,
self.model_config["ATTN_MASK_DTYPE"],
tt_lib.tensor.Layout.ROW_MAJOR,
tt_lib.tensor.DataType.BFLOAT16, # subsequent tilize op excepts bfloat16 inputs
self.model_config["ATTN_MASK_MEMCFG"],
self.devices,
)
for attention_mask_slice in attention_mask_
]
# Tilize attn masks
for tt_attention_mask_slice in attn_masks_unordered:
for i in range(self.num_devices):
tt_attention_mask_slice[i] = tt_lib.tensor.tilize(
tt_attention_mask_slice[i],
output_mem_config=self.model_config["ATTN_MASK_MEMCFG"],
output_dtype=self.model_config["ATTN_MASK_DTYPE"],
)
# Expected output attention_masks
# [dev0: [slice0, slice1, ...], dev1: [slice0, slice1, ...], ...]
tt_attention_mask = [list(x) for x in zip(*attn_masks_unordered)]
else:
attention_mask_ = (attention_mask_bool_padded * -1e3).expand(
-1, self.config.num_attention_heads, -1, -1
)
attention_mask_ = attention_mask_bool_padded * -1e3
attention_masks = [attention_mask_.clone() for _ in self.devices]
# Send attn masks to device
tt_attention_mask = torch_tensors_to_tt_tensors(
attention_masks,
tt_lib.tensor.Layout.TILE,
self.model_config["ATTN_MASK_DTYPE"],
tt_lib.tensor.Layout.ROW_MAJOR,
tt_lib.tensor.DataType.BFLOAT16, # subsequent tilize op excepts bfloat16 inputs
self.model_config["ATTN_MASK_MEMCFG"],
self.devices,
)
# Repeat attn masks for all heads
for i in range(self.num_devices):
tt_attention_mask[i] = tt_lib.tensor.repeat(
tt_attention_mask[i],
[1, self.config.num_attention_heads, 1, 1],
output_mem_config=self.model_config["ATTN_MASK_MEMCFG"],
)
# Tilize attn masks
for i in range(self.num_devices):
tt_attention_mask[i] = tt_lib.tensor.tilize(
tt_attention_mask[i],
output_mem_config=self.model_config["ATTN_MASK_MEMCFG"],
output_dtype=self.model_config["ATTN_MASK_DTYPE"],
)

tt_input_ids = []
for i, device in enumerate(self.devices):
Expand Down Expand Up @@ -192,14 +212,21 @@ def model_preprocessing(self, llm_mode, input_ids, kv_cache_len, num_input_token
-1, -1, nearest_32(self.config.num_attention_heads), -1
)
)

# Send attn masks to device
tt_attention_mask = torch_tensors_to_tt_tensors(
attention_masks,
tt_lib.tensor.Layout.TILE,
self.model_config["ATTN_MASK_DTYPE"],
tt_lib.tensor.Layout.ROW_MAJOR,
tt_lib.tensor.DataType.BFLOAT16, # subsequent tilize op excepts bfloat16 inputs
self.model_config["ATTN_MASK_MEMCFG"],
self.devices,
)
# Tilize attn masks
for i in range(self.num_devices):
tt_attention_mask[i] = tt_lib.tensor.tilize(
tt_attention_mask[i],
output_mem_config=self.model_config["ATTN_MASK_MEMCFG"],
output_dtype=self.model_config["ATTN_MASK_DTYPE"],
)

if self.model_config["l1_sharded"]:
for i, device in enumerate(self.devices):
Expand Down

0 comments on commit 62d7a12

Please sign in to comment.