Skip to content

Commit

Permalink
#8349: Move LM head configuration to model config
Browse files Browse the repository at this point in the history
  • Loading branch information
s-jovic committed May 29, 2024
1 parent d0000da commit 61f2230
Show file tree
Hide file tree
Showing 5 changed files with 65 additions and 217 deletions.

This file was deleted.

46 changes: 29 additions & 17 deletions models/demos/falcon7b/tt/falcon_causallm.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import torch
import ttnn
from models.demos.falcon7b.tt.falcon_lm_head import falcon_lm_head_matmul_2d
from models.demos.falcon7b.tt.falcon_model import TtFalconModelShared
from models.demos.falcon7b.tt.model_utils import get_falcon_default_core_grid, get_weights_cached
from models.utility_functions import torch_tensors_to_tt_tensors
Expand Down Expand Up @@ -47,22 +46,22 @@ def __init__(

if self.model_config["PREFILL_OPTIMIZED_MODE"] and self.seq_len > 512:
# Optimization for lm_head matmul
self.num_slices = 4 if self.seq_len <= 1024 else 8
num_slices = self.model_config["LM_HEAD_NUM_SLICES"][seq_len]
if lm_head_weight is not None:
PADDING = torch.zeros([64, lm_head_weight.shape[1] // self.num_slices])
lm_head_weights = torch.chunk(lm_head_weight, self.num_slices, dim=-1)
PADDING = torch.zeros([64, lm_head_weight.shape[1] // num_slices])
lm_head_weights = torch.chunk(lm_head_weight, num_slices, dim=-1)
lm_head_weights_padded = [torch.cat([weight, PADDING], 0) for weight in lm_head_weights]
# Cache sliced weights for lm_head with different seq_len
self.lm_head_sliced_weights = [
get_weights_cached(
devices,
model_config,
tt_cache_path,
f"lm_head.weight_slice_{i}_of_{self.num_slices}",
f"lm_head.weight_slice_{i}_of_{num_slices}",
weight_config_str="LM_HEAD_MM_WEIGHTS",
weights_to_cache=lm_head_weights_padded[i] if lm_head_weight is not None else None,
)
for i in range(self.num_slices)
for i in range(num_slices)
]
# Generate padding for lm_head > 512
padding = torch.zeros([1, 1, seq_len, 64])
Expand Down Expand Up @@ -106,18 +105,31 @@ def forward(
)

if llm_mode == "prefill":
if self.model_config["PREFILL_OPTIMIZED_MODE"] and hidden_states[0].get_legacy_shape()[-2] > 512:
lm_logits = [
falcon_lm_head_matmul_2d(
hidden_states[device_id],
[weights[device_id] for weights in self.lm_head_sliced_weights],
self.num_slices,
lm_head_padding=self.lm_head_padding[device_id],
out_mem_config=self.model_config["LM_HEAD_MM_OUTPUT_MEMCFG"],
out_dtype=self.model_config["LM_HEAD_MM_OUTPUT_DTYPE"],
if self.model_config["PREFILL_OPTIMIZED_MODE"] and self.seq_len > 512:
lm_logits = []
for device_id in range(self.num_devices):
hidden_states[device_id] = ttnn.experimental.tensor.concat(
[hidden_states[device_id], self.lm_head_padding[device_id]], -1
)
for device_id in range(self.num_devices)
]

out_slices = []
for slice_id in range(self.model_config["LM_HEAD_NUM_SLICES"][self.seq_len]):
out_slices.append(
ttnn.experimental.operations.primary.matmul(
hidden_states[device_id],
self.lm_head_sliced_weights[slice_id][device_id],
program_config=self.model_config["LM_HEAD_PROGCFG"][self.seq_len],
output_mem_config=self.model_config["LM_HEAD_MM_OUTPUT_MEMCFG"],
output_dtype=self.model_config["LM_HEAD_MM_OUTPUT_DTYPE"],
compute_kernel_config=self.model_config["LM_HEAD_KERNEL_CONFIG"],
)
)

out = ttnn.experimental.tensor.concat(out_slices, -1)
lm_logits.append(out)

for slice_id in range(self.model_config["LM_HEAD_NUM_SLICES"][self.seq_len]):
out_slices[slice_id].deallocate(True)
else:
lm_logits = [
ttnn.matmul(
Expand Down
92 changes: 0 additions & 92 deletions models/demos/falcon7b/tt/falcon_lm_head.py

This file was deleted.

39 changes: 35 additions & 4 deletions models/demos/falcon7b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import ttnn
from loguru import logger
from pathlib import Path
from models.utility_functions import is_grayskull, is_wormhole_b0
from models.utility_functions import is_grayskull, is_wormhole_b0, nearest_y

OP_KEYS = (
# Inputs
Expand Down Expand Up @@ -205,9 +205,6 @@ def get_model_config(model_config_str, prefill_seq_len=0):

def set_prefill_config(model_config, seq_len, dram_memcfg):
model_config["PREFILL_OPTIMIZED_MODE"] = not is_grayskull()
model_config["MLP_SEQ_LEN"] = seq_len
model_config["MLP_PADDING_VALUE"] = 4608
model_config["MLP_GRID_SIZE"] = (8, 8)

if is_wormhole_b0():
default_kernel_config = ttnn.experimental.tensor.WormholeComputeKernelConfig(
Expand All @@ -221,6 +218,11 @@ def set_prefill_config(model_config, seq_len, dram_memcfg):
math_fidelity=ttnn.experimental.tensor.MathFidelity.LoFi,
math_approx_mode=True,
)

# ---- mlp config
model_config["MLP_SEQ_LEN"] = seq_len
model_config["MLP_PADDING_VALUE"] = 4608
model_config["MLP_GRID_SIZE"] = (8, 8)
model_config["MLP_KERNEL_CONFIG"] = default_kernel_config

mm_h_to_4h_prog_cfg = ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig(
Expand Down Expand Up @@ -248,6 +250,7 @@ def set_prefill_config(model_config, seq_len, dram_memcfg):
model_config["DENSE_4H_TO_H_MM_PROGCFG"] = mm_4h_to_h_prog_cfg
model_config["MLP_INTERLEAVED_TO_SHARDED_MEM_CFG"] = dram_memcfg

# ---- attention config
model_config["FUSED_QKV_MM_OPTIMIZED_MEMCFG"] = dram_memcfg
model_config[
"FUSED_QKV_MM_OPTIMIZED_PROGCFG"
Expand Down Expand Up @@ -323,7 +326,35 @@ def set_prefill_config(model_config, seq_len, dram_memcfg):
mcast_in0=False,
)

# ---- lm head config
model_config["LM_HEAD_KERNEL_CONFIG"] = default_kernel_config
model_config["LM_HEAD_NUM_SLICES"] = {1024: 4, 2048: 8}

model_config["LM_HEAD_PROGCFG"] = {}
for seq_len in [1024, 2048]:
grid = (8, 8)
activations_m_in_tiles = seq_len // 32
weights_n_in_tiles = 65024 // model_config["LM_HEAD_NUM_SLICES"][seq_len] // 32

# calculate parameters for the given sequence length
out_subblock_h = 2
out_subblock_w = 4
per_core_M = nearest_y(activations_m_in_tiles / grid[0], out_subblock_h)
per_core_N = nearest_y(weights_n_in_tiles / grid[1], out_subblock_w)
in0_block_w = 4 if seq_len <= 1024 else 8

model_config["LM_HEAD_PROGCFG"][
seq_len
] = ttnn.experimental.operations.primary.MatmulMultiCoreReuseMultiCastProgramConfig(
compute_with_storage_grid_size=grid,
in0_block_w=in0_block_w,
out_subblock_h=out_subblock_h,
out_subblock_w=out_subblock_w,
per_core_M=per_core_M,
per_core_N=per_core_N,
transpose_mcast=False,
fused_activation=None,
)


model_config_entries = {
Expand Down
2 changes: 1 addition & 1 deletion models/demos/wormhole/falcon7b/demo_wormhole.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
@pytest.mark.parametrize(
"perf_mode, expected_perf_prefill_decode, greedy_sampling, expected_greedy_output_path",
(
(True, [1100, 335], False, None),
(True, [2000, 335], False, None),
(True, None, False, None),
(False, None, True, "models/demos/wormhole/falcon7b/expected_greedy_output.json"),
(False, None, True, None),
Expand Down

0 comments on commit 61f2230

Please sign in to comment.