Skip to content

Commit

Permalink
#5337: Merge branch 'mixtral-async' of https://github.com/tenstorrent…
Browse files Browse the repository at this point in the history
…/tt-metal into mixtral-async
  • Loading branch information
mtairum committed May 20, 2024
2 parents e18700d + fcd5b26 commit a682bb7
Show file tree
Hide file tree
Showing 13 changed files with 205 additions and 150 deletions.
7 changes: 5 additions & 2 deletions models/experimental/llama2_70b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
get_model_config,
)
from models.utility_functions import get_devices_for_t3000
from models.experimental.llama2_70b.tt.llama_common import get_llama_path
from models.experimental.llama2_70b.tt.llama_common import get_llama_path, load_llama_state_dict


def main(args):
Expand Down Expand Up @@ -55,9 +55,12 @@ def build_generator(args):
n_layers=1 if args.implementation == "tt" else args.num_layers,
)

state_dict = load_llama_state_dict(args.ckpt_dir, n_layers=args.num_layers)

if args.implementation == "tt":
generator.model = TtLlamaModelForGeneration(
reference_model=generator.model,
configuration=generator.model.params,
state_dict=state_dict,
device_mesh=args.device_mesh,
n_devices=args.n_devices,
n_layers=args.num_layers,
Expand Down
1 change: 0 additions & 1 deletion models/experimental/llama2_70b/tests/test_llama_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def test_LlamaMLP_inference(
n_devices,
t3k_device_mesh,
emulated,
use_program_cache,
):
model_config = get_model_config(model_config_str="BFLOAT16-DRAM", num_devices=n_devices, seq_len=seq_len)

Expand Down
184 changes: 89 additions & 95 deletions models/experimental/llama2_70b/tests/test_llama_perf_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,16 @@
from torch import nn
import tt_lib
import ttnn
from ttnn import ShardTensorToMesh, ReplicateTensorToMesh, ConcatMeshToTensor, ListMeshToTensor


from models.experimental.llama2_70b.reference.llama.llama import Llama

from models.experimental.llama2_70b.tt.llama_model_optimized import TtLlamaModel_optimized
from models.experimental.llama2_70b.tt.model_config import (
get_model_config,
)
from models.experimental.llama2_70b.tt.llama_common import (
get_llama_path,
MAX_SEQ_LEN,
BASE_URL,
)
from models.experimental.llama2_70b.tt.llama_common import get_llama_path, MAX_SEQ_LEN, BASE_URL, load_llama_state_dict
from models.utility_functions import (
torch2tt_tensor,
tt2torch_tensor,
Expand All @@ -34,7 +33,11 @@
get_devices_for_t3000,
)
from models.perf.perf_utils import prep_perf_report
from tracy import signpost

if os.getenv("CI") == "true":
os.environ["LLAMA_CKPT_DIR"] = "/mnt/MLPerf/tt_dnn-models/llama-2/llama-2-70b-repacked/"
os.environ["LLAMA_TOKENIZER_PATH"] = "/mnt/MLPerf/tt_dnn-models/llama-2/tokenizer.model"
os.environ["LLAMA_CACHE_PATH"] = "/mnt/MLPerf/tt_dnn-models/llama-2/llama-data-cache/weights-cache-2"


def load_prompts_file(tokenizer, prefill_length, generation_length, gap=64):
Expand Down Expand Up @@ -106,8 +109,36 @@ def calculate_decode_times(profiler, generation_length):
return times, times[f"decode_time_{generation_length}"]


def run_inference(tt_model, tokenizer, tokens, device_mesh, configuration, total_len, input_text_mask):
start_pos = 0
prev_pos = 0
for cur_pos in range(start_pos + 1, total_len):
logger.info(f"Generating token: {cur_pos}")

tt_inp_emb, prev_pos, rot_mat, attn_mask = tt_model.prepare_inputs(tokens[:, prev_pos:cur_pos], prev_pos)

tt_logits = tt_model(
tt_inp_emb,
rot_mat,
prev_pos,
attn_mask,
)

del tt_inp_emb
del rot_mat
del attn_mask
logits = ttnn.to_torch(tt_logits, device=device_mesh, mesh_composer=ConcatMeshToTensor(device_mesh, dim=3))
logits = logits[..., : configuration.vocab_size].float() # [1, batch, vocab_size]
del tt_logits

next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)

tokens, eos_reached, prev_pos = prepare_next_input(tokenizer, tokens, input_text_mask, cur_pos, next_token)


def run_test_LlamaModel_end_to_end(
devices,
device_mesh,
batch,
seq_len,
model_config,
Expand All @@ -117,22 +148,22 @@ def run_test_LlamaModel_end_to_end(
expected_compile_time,
expected_inference_time,
emulated,
num_users,
):
devices, ckpt_dir, tokenizer_path, cache_path = get_llama_path(devices, model_config, n_devices, emulated)
device_mesh, ckpt_dir, tokenizer_path, cache_path = get_llama_path(device_mesh, model_config, n_devices, emulated)
logger.info(f"Running num_layer: {n_layers}")

generator = Llama.build(
ckpt_dir,
tokenizer_path,
max_seq_len=MAX_SEQ_LEN,
max_batch_size=num_users,
max_batch_size=batch,
n_layers=1,
skip_model_load=False,
)
hugging_face_reference_model, tokenizer = generator.model, generator.tokenizer
hugging_face_reference_model.eval()
state_dict = hugging_face_reference_model.state_dict()
# state_dict = hugging_face_reference_model.state_dict()
state_dict = load_llama_state_dict(ckpt_dir, n_layers=n_layers)
configuration = hugging_face_reference_model.params

# Prepare input -----------------------------------------------------------------------
Expand All @@ -141,15 +172,15 @@ def run_test_LlamaModel_end_to_end(
prefill_ids, ground_truth_texts = load_prompts_file(
tokenizer, prefill_length=32 if generation_length > 32 else 20, generation_length=generation_length
)
tokens, input_text_mask = intialize_inputs(tokenizer, prefill_ids, num_users, total_len)
tokens, input_text_mask = intialize_inputs(tokenizer, prefill_ids, batch, total_len)
# Clear global profiler state before starting measurements
profiler.clear()

# Set up model -----------------------------------------------------------------------
logger.info("Moving weights to devices; might take some time...")
profiler.start("TT_llama_model_setup")
tt_model = TtLlamaModel_optimized(
devices,
device_mesh,
state_dict,
BASE_URL,
n_layers,
Expand All @@ -158,136 +189,100 @@ def run_test_LlamaModel_end_to_end(
batch,
emulated=emulated,
cache_path=cache_path,
read_cache=False,
)
for device in devices:

for i in device_mesh.get_device_ids():
device = device_mesh.get_device(i)
tt_lib.device.Synchronize(device)

profiler.end("TT_llama_model_setup")

del state_dict

logger.info("Running 1st run decode stage with compile...")

start_pos = 0
prev_pos = start_pos
enable_persistent_kernel_cache()
for cur_pos in range(start_pos + 1, total_len):
logger.info(f"Generating token: {cur_pos}")

# Initialize profiling based on specific intervals and generation lengths
should_profile = (cur_pos == start_pos + 1) or is_in_profiling_range(
cur_pos, generation_length, profiling_ranges
)

if should_profile:
profiler.start(f"processing_of_decode_input_{cur_pos}")
signpost(header="Prepare Inputs", message="Prepare Inputs")

tt_inp_emb, prev_pos, rot_mat, attn_mask = tt_model.prepare_inputs(tokens[:, prev_pos:cur_pos], prev_pos)

if should_profile:
signpost(header="End of prepare Inputs", message="End of prepare Inputs")
profiler.end(f"processing_of_decode_input_{cur_pos}")
profiler.start(f"model_run_for_inference_{cur_pos}")

tt_logits = tt_model(
tt_inp_emb,
rot_mat,
prev_pos,
attn_mask,
)

if should_profile:
profiler.end(f"model_run_for_inference_{cur_pos}")

del tt_inp_emb
del rot_mat
del attn_mask

for device in devices:
tt_lib.device.Synchronize(device)

logits = torch.cat([tt2torch_tensor(tt_o).squeeze(1) for tt_o in tt_logits], -1)
logits = logits[..., : configuration.vocab_size].float() # [1, batch, vocab_size]
del tt_logits

next_token = torch.argmax(logits, dim=-1)
next_token = next_token.reshape(-1)

tokens, eos_reached, prev_pos = prepare_next_input(tokenizer, tokens, input_text_mask, cur_pos, next_token)

for user_id in range(3):
text = tokenizer.decode(tokens[user_id, : cur_pos + 1].tolist())
logger.info(f"Loop {cur_pos} user {user_id}: {text}\n")

for user_id in range(3):
logger.info(f"Ground Truth Texts: User {user_id}: {ground_truth_texts[user_id]}")
profiler.start(f"end_to_end_inference_with_compile")
run_inference(tt_model, tokenizer, tokens, device_mesh, configuration, total_len, input_text_mask)
profiler.end(f"end_to_end_inference_with_compile")
profiler.print()
compile_and_loop_time = profiler.get("end_to_end_inference_with_compile")
compile_iter_time = compile_and_loop_time / total_len
logger.info(f"decode with compile time, single iter latency: {compile_iter_time}")

logger.info("Finished 1st run decode stage with compile!")
profiler.start(f"end_to_end_inference")
run_inference(tt_model, tokenizer, tokens, device_mesh, configuration, total_len, input_text_mask)
profiler.end(f"end_to_end_inference")
profiler.print()
loop_time = profiler.get("end_to_end_inference")
iter_time = loop_time / total_len
logger.info(f"decode cached, single iter latency: {iter_time}")

comment = f"num_layers={n_layers}L_n_devices={n_devices}"
compile_time = profiler.get("TT_llama_model_setup")
decode_compile_time = profiler.get(f"model_run_for_inference_{start_pos + 1}")

decode_times, decode_time = calculate_decode_times(profiler, generation_length)
logger.info(decode_times)

prep_perf_report(
model_name=f"llama2_70b_{comment}",
batch_size=batch,
inference_and_compile_time=decode_compile_time,
inference_time=decode_time,
inference_and_compile_time=compile_iter_time,
inference_time=iter_time,
expected_compile_time=expected_compile_time,
expected_inference_time=expected_inference_time,
comments=comment,
)

logger.info(f"llama2_70b_{comment} inference time: {decode_time}")
tokens_per_s_per_user = 1 / decode_time
tokens_per_s_overall = tokens_per_s_per_user * batch * seq_len
tokens_per_s_per_user = 1 / iter_time
tokens_per_s_overall = tokens_per_s_per_user * batch

logger.info(f"Time per iteration: {decode_time}")
logger.info(f"Time per iteration: {iter_time}")
logger.info(f"Tokens per s per user: {tokens_per_s_per_user}")
logger.info(f"Tokens per s overall: {tokens_per_s_overall}")

# assert compile_time <= expected_compile_time
assert decode_time <= expected_inference_time
assert iter_time <= expected_inference_time


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.timeout(240000)
@pytest.mark.models_performance_bare_metal
@pytest.mark.model_perf_t3000
@pytest.mark.parametrize(
"generation_length, expected_compile_time, expected_inference_time",
(
(32, 550, 1.2),
(128, 550, 1.4),
(2048, 550, 2.0),
(32, 10000, 0.139 + 0.02), # TODO: decrease expected compile time once as_tensor gets speedup
(128, 10000, 0.138 + 0.02), # Fudge delta
(2048, 10000, 0.153 + 0.02),
),
ids=["quick", "short", "long"],
ids=["gen32", "gen128", "gen2048"],
)
def test_Llama_perf_host(
generation_length,
expected_compile_time,
expected_inference_time,
all_devices,
use_program_cache,
n_layers=1,
t3k_device_mesh,
n_layers=80,
n_devices=8,
emulated=False,
num_users=32,
):
devices = get_devices_for_t3000(all_devices, num_devices=n_devices if not emulated else 1)
batch, seq_len = 32, 1
if generation_length == 2048:
pytest.skip("Skipping 2048 test for now. segfault issue #8637")
batch = 32
seq_len = 1
model_config = get_model_config(model_config_str="BFLOAT16-DRAM", num_devices=n_devices, seq_len=seq_len)
compute_grid_size = devices[0].compute_with_storage_grid_size()

if t3k_device_mesh.get_num_devices() < n_devices and not emulated:
pytest.skip(f"Requires at {n_devices} devices to run")

compute_grid_size = t3k_device_mesh.get_device(0).compute_with_storage_grid_size()
if compute_grid_size.x < model_config["MAX_GRID_SIZE"][0] or compute_grid_size.y < model_config["MAX_GRID_SIZE"][1]:
pytest.skip(f"Requires grid size of at least {model_config['MAX_GRID_SIZE']} to run")

for i in t3k_device_mesh.get_device_ids():
device = t3k_device_mesh.get_device(i)
device.enable_program_cache()
device.enable_async(True)
disable_compilation_reports()

run_test_LlamaModel_end_to_end(
devices,
t3k_device_mesh,
batch,
seq_len,
model_config,
Expand All @@ -297,5 +292,4 @@ def test_Llama_perf_host(
expected_compile_time,
expected_inference_time,
emulated,
num_users,
)
Loading

0 comments on commit a682bb7

Please sign in to comment.