Skip to content

Commit

Permalink
#0: Update all_gather to work for multi_link. Update falcon-40b to us…
Browse files Browse the repository at this point in the history
…e 2 links for all gathers
  • Loading branch information
tt-aho committed Feb 9, 2024
1 parent 69d9860 commit 21954dc
Show file tree
Hide file tree
Showing 15 changed files with 407 additions and 318 deletions.
171 changes: 93 additions & 78 deletions models/demos/falcon40b/tests/test_perf_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def run_test_FalconCausalLM_end_to_end(
model_config_str,
tt_cache_path,
model_location_generator,
expected_compile_time,
expected_inference_time,
inference_iterations,
):
Expand Down Expand Up @@ -179,6 +180,10 @@ def run_test_FalconCausalLM_end_to_end(
input_ids=model_input, past_key_values=past_key_values, use_cache=use_cache
)
profiler.end("hugging_face_reference_model")
del past_key_values
del pytorch_layer_present
del pytorch_out
del pytorch_FalconCausalLM

# NOTE: Passing in pytorch tensor here instead of ll buda tensor
# since we don't yet have embedding support on device
Expand All @@ -200,17 +205,19 @@ def run_test_FalconCausalLM_end_to_end(
tt_lib.device.Synchronize(device)
profiler.end("TtFalcon_model_setup")

del state_dict

profiler.start("processing_of_input")
if llm_mode == "prefill":
model_inputs = torch.split(model_input, 1)
tt_embeddings, tt_attention_mask = zip(
tt_embeddings_host, tt_attention_mask_host = zip(
*[
tt_FalconCausalLM.model_preprocessing(llm_mode, m_i, kv_cache_len, num_input_tokens=seq_len)
for m_i in model_inputs
]
)
elif llm_mode == "decode":
tt_embeddings, tt_attention_mask = tt_FalconCausalLM.model_preprocessing(
tt_embeddings_host, tt_attention_mask_host = tt_FalconCausalLM.model_preprocessing(
llm_mode, model_input, kv_cache_len, num_input_tokens=kv_len
)
attention_mask_memconfig = model_config["ATTN_MASK_MEMCFG"]
Expand All @@ -228,9 +235,11 @@ def run_test_FalconCausalLM_end_to_end(
# Use force enable to only record this profiler call while others are disabled
profiler.start("first_model_run_with_compile", force_enable=True)
tt_embeddings = [
tt_embeddings[i].to(devices[i], model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) for i in range(len(devices))
tt_embeddings_host[i].to(devices[i], model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) for i in range(len(devices))
]
tt_attention_mask = [
tt_attention_mask_host[i].to(devices[i], attention_mask_memconfig) for i in range(len(devices))
]
tt_attention_mask = [tt_attention_mask[i].to(devices[i], attention_mask_memconfig) for i in range(len(devices))]
if llm_mode == "prefill":
tt_outs = []
for user_id in range(batch):
Expand Down Expand Up @@ -270,25 +279,63 @@ def run_test_FalconCausalLM_end_to_end(
profiler.enable()
enable_persistent_kernel_cache()

if llm_mode == "prefill":
model_inputs = torch.split(model_input, 1)
tt_embeddings, tt_attention_mask = zip(
*[
tt_FalconCausalLM.model_preprocessing(llm_mode, m_i, kv_cache_len, num_input_tokens=seq_len)
for m_i in model_inputs
]
)
elif llm_mode == "decode":
tt_embeddings, tt_attention_mask = tt_FalconCausalLM.model_preprocessing(
llm_mode, model_input, kv_cache_len, num_input_tokens=kv_len
)
def run_inference():
tt_embeddings = [
tt_embeddings_host[i].to(devices[i], model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
for i in range(len(devices))
]
tt_attention_mask = [
tt_attention_mask_host[i].to(devices[i], attention_mask_memconfig) for i in range(len(devices))
]
for _ in range(inference_iterations - 1):
if llm_mode == "prefill":
tt_outs = []
model_inputs = torch.split(model_input, 1)
tt_embeddings, tt_attention_mask = zip(
*[
tt_FalconCausalLM.model_preprocessing(llm_mode, m_i, kv_cache_len, num_input_tokens=seq_len)
for m_i in model_inputs
]
)
for user_id in range(batch):
tt_out, tt_layer_present = tt_FalconCausalLM(
input_embeddings=tt_embeddings[user_id],
llm_mode=llm_mode,
attention_mask=tt_attention_mask[user_id],
user_id=user_id,
layer_past=tt_layer_past,
layer_past_len=kv_cache_len,
use_cache=use_cache,
)
tt_outs.append(tt_out)
tt_embeddings = [
tt_embeddings[i].to(devices[i], model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
for i in range(len(devices))
]
tt_attention_mask = [
tt_attention_mask[i].to(devices[i], attention_mask_memconfig) for i in range(len(devices))
]
for i in range(len(tt_outs)):
tt_outs[i] = [tt_o.cpu() for tt_o in tt_outs[i]]

elif llm_mode == "decode":
tt_out, tt_layer_present = tt_FalconCausalLM(
input_embeddings=tt_embeddings,
llm_mode=llm_mode,
attention_mask=tt_attention_mask,
layer_past=tt_layer_past,
layer_past_len=kv_cache_len,
use_cache=use_cache,
)
tt_embeddings = [
tt_embeddings_host[i].to(devices[i], model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
for i in range(len(devices))
]
tt_attention_mask = [
tt_attention_mask_host[i].to(devices[i], attention_mask_memconfig) for i in range(len(devices))
]
tt_out = [tt_o.cpu() for tt_o in tt_out]

profiler.start(f"model_run_for_inference")
tt_embeddings = [
tt_embeddings[i].to(devices[i], model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"]) for i in range(len(devices))
]
tt_attention_mask = [tt_attention_mask[i].to(devices[i], attention_mask_memconfig) for i in range(len(devices))]
for _ in range(inference_iterations - 1):
if llm_mode == "prefill":
tt_outs = []
model_inputs = torch.split(model_input, 1)
Expand All @@ -309,13 +356,6 @@ def run_test_FalconCausalLM_end_to_end(
use_cache=use_cache,
)
tt_outs.append(tt_out)
tt_embeddings = [
tt_embeddings[i].to(devices[i], model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
for i in range(len(devices))
]
tt_attention_mask = [
tt_attention_mask[i].to(devices[i], attention_mask_memconfig) for i in range(len(devices))
]
for i in range(len(tt_outs)):
tt_outs[i] = [tt_o.cpu() for tt_o in tt_outs[i]]

Expand All @@ -328,49 +368,16 @@ def run_test_FalconCausalLM_end_to_end(
layer_past_len=kv_cache_len,
use_cache=use_cache,
)
tt_embeddings = [
tt_embeddings[i].to(devices[i], model_config["WORD_EMBEDDING_OUTPUT_MEMCFG"])
for i in range(len(devices))
]
tt_attention_mask = [
tt_attention_mask[i].to(devices[i], attention_mask_memconfig) for i in range(len(devices))
]
tt_out = [tt_o.cpu() for tt_o in tt_out]

if llm_mode == "prefill":
tt_outs = []
model_inputs = torch.split(model_input, 1)
tt_embeddings, tt_attention_mask = zip(
*[
tt_FalconCausalLM.model_preprocessing(llm_mode, m_i, kv_cache_len, num_input_tokens=seq_len)
for m_i in model_inputs
]
)
for user_id in range(batch):
tt_out, tt_layer_present = tt_FalconCausalLM(
input_embeddings=tt_embeddings[user_id],
llm_mode=llm_mode,
attention_mask=tt_attention_mask[user_id],
user_id=user_id,
layer_past=tt_layer_past,
layer_past_len=kv_cache_len,
use_cache=use_cache,
)
tt_outs.append(tt_out)
for i in range(len(tt_outs)):
tt_outs[i] = [tt_o.cpu() for tt_o in tt_outs[i]]

elif llm_mode == "decode":
tt_out, tt_layer_present = tt_FalconCausalLM(
input_embeddings=tt_embeddings,
llm_mode=llm_mode,
attention_mask=tt_attention_mask,
layer_past=tt_layer_past,
layer_past_len=kv_cache_len,
use_cache=use_cache,
)
tt_out = [tt_o.cpu() for tt_o in tt_out]
profiler.start(f"model_warmup_run_for_inference")
run_inference()
profiler.end(f"model_warmup_run_for_inference")
for device in devices:
tt_lib.device.Synchronize(device)

profiler.start(f"model_run_for_inference")
run_inference()
profiler.end(f"model_run_for_inference")
for device in devices:
tt_lib.device.Synchronize(device)
Expand All @@ -381,7 +388,6 @@ def run_test_FalconCausalLM_end_to_end(
cpu_time = profiler.get("hugging_face_reference_model")
first_iter_time = profiler.get("first_model_run_with_compile")
second_iter_time = profiler.get("model_run_for_inference") / inference_iterations
expected_compile_time = 30
prep_perf_report(
model_name=f"Falcon_{llm_mode}_{comment}",
batch_size=batch,
Expand All @@ -397,15 +403,26 @@ def run_test_FalconCausalLM_end_to_end(
logger.info(f"falcon {comment} inference time: {second_iter_time}")
logger.info(f"falcon {comment} compile time: {compile_time}")

tokens_per_s_per_user = 1 / second_iter_time
tokens_per_s_overall = tokens_per_s_per_user * batch
logger.info(f"{inference_iterations} Iterations inference time: {profiler.get('model_run_for_inference')}")
logger.info(f"Time per iteration: {second_iter_time}")
logger.info(f"Tokens per s per user: {tokens_per_s_per_user}")
logger.info(f"Tokens per s overall: {tokens_per_s_overall}")

# This script will assert since this is not a part of regular perf pipeline
# assert second_iter_time <= expected_inference_time
# assert compile_time <= expected_compile_time


@skip_for_grayskull("Requires eth connected devices to run")
@pytest.mark.models_performance_bare_metal
@pytest.mark.parametrize(
"llm_mode, batch, seq_len, kv_cache_len, expected_inference_time, inference_iterations",
"llm_mode, batch, seq_len, kv_cache_len, expected_compile_time, expected_inference_time, inference_iterations",
(
# ("prefill", 1, 128, 0, 0.30),
# ("prefill", 1, 256, 0, 0.44),
("decode", 32, 1, 128, 0.27, 10),
("decode", 32, 1, 128, 60, 0.22, 10),
# ("decode", 32, 1, 1024, 0.35, 10),
# ("decode", 32, 1, 2047, 0.48, 10),
),
Expand All @@ -419,12 +436,8 @@ def run_test_FalconCausalLM_end_to_end(
)
@pytest.mark.parametrize(
"num_layers",
(
1,
2,
60,
),
ids=["layers_1", "layers_2", "layers_60"],
(60,),
ids=["layers_60"],
)
@pytest.mark.parametrize(
"model_version",
Expand All @@ -439,6 +452,7 @@ def test_perf_bare_metal(
batch,
seq_len,
kv_cache_len,
expected_compile_time,
expected_inference_time,
inference_iterations,
num_layers,
Expand Down Expand Up @@ -475,6 +489,7 @@ def test_perf_bare_metal(
model_config_str,
tt_cache_path,
model_location_generator,
expected_compile_time,
expected_inference_time,
inference_iterations,
)
5 changes: 4 additions & 1 deletion models/demos/falcon40b/tt/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,7 +493,10 @@ def __call__(
)

attn_output = tt_lib.tensor.all_gather(
attn_output, dim=3, output_mem_config=self.model_config["DEFAULT_MEMCFG"]
attn_output,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
)

for i in range(len(attn_output)):
Expand Down
5 changes: 4 additions & 1 deletion models/demos/falcon40b/tt/falcon_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,10 @@ def __call__(
)
)
replicated_hidden_states = tt_lib.tensor.all_gather(
replicated_hidden_states, dim=3, output_mem_config=self.model_config["DEFAULT_MEMCFG"]
replicated_hidden_states,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
dim=3,
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
)
for i in range(len(replicated_hidden_states)):
replicated_hidden_states[i] = tt_lib.tensor.interleaved_to_sharded(
Expand Down
5 changes: 4 additions & 1 deletion models/demos/falcon40b/tt/falcon_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,10 @@ def __call__(self, x: List[tt_lib.tensor.Tensor]) -> List[tt_lib.tensor.Tensor]:
hidden_states[i], output_mem_config=self.model_config["DEFAULT_MEMCFG"]
)
hidden_states = tt_lib.tensor.all_gather(
hidden_states, dim=3, output_mem_config=self.model_config["DEFAULT_MEMCFG"]
hidden_states,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
)
for i in range(len(hidden_states)):
hidden_states[i] = tt_lib.tensor.interleaved_to_sharded(
Expand Down
5 changes: 4 additions & 1 deletion models/demos/falcon40b/tt/falcon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,10 @@ def __call__(
layer_output[i], output_mem_config=self.model_config["DEFAULT_MEMCFG"]
)
layer_output = tt_lib.tensor.all_gather(
layer_output, dim=3, output_mem_config=self.model_config["DEFAULT_MEMCFG"]
layer_output,
dim=3,
num_links=self.model_config["ALL_GATHER_NUM_LINKS"],
output_mem_config=self.model_config["DEFAULT_MEMCFG"],
)
for i in range(len(layer_output)):
layer_output[i] = tt_lib.tensor.interleaved_to_sharded(
Expand Down
1 change: 1 addition & 0 deletions models/demos/falcon40b/tt/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def get_model_config(model_config_str):
"MOVE_DECODER_OUTPUT_BOOL": False,
"NUM_DEVICES": 4,
"MAX_GRID_SIZE": (8, 4),
"ALL_GATHER_NUM_LINKS": 2,
"DEFAULT_CACHE_PATH": Path(f"models/demos/falcon40b/datasets/"),
}
model_config.update({f"{key}_MEMCFG": mem_config for key in OP_KEYS if key not in NO_MEMCFG})
Expand Down
13 changes: 10 additions & 3 deletions tests/tt_eager/python_api_testing/unit_testing/test_all_gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,21 @@
ttl.tensor.MemoryConfig(buffer_type=ttl.tensor.BufferType.L1),
],
)
def test_all_gather(pcie_devices, input_shape, dim, layout, mem_config, use_program_cache, function_level_defaults):
if layout == ttl.tensor.Layout.ROW_MAJOR and mem_config.buffer_type == ttl.tensor.BufferType.DRAM:
@pytest.mark.parametrize("num_links", [1, 2])
def test_all_gather(
pcie_devices, input_shape, dim, num_links, layout, mem_config, use_program_cache, function_level_defaults
):
if (
layout == ttl.tensor.Layout.ROW_MAJOR or num_links == 2
) and mem_config.buffer_type == ttl.tensor.BufferType.DRAM:
pytest.skip("All gather tests are hanging for RM in DRAM")
devices = pcie_devices
input_tensor = torch.rand(input_shape).bfloat16()
num_devices = len(devices)
if num_devices < 2:
pytest.skip("Requires multiple devices to run")
elif num_devices == 2 and num_links == 2:
pytest.skip("Not enough links to run")

if input_shape[dim] % num_devices != 0 or (dim == 3 and input_shape[dim] // num_devices % 32 != 0):
pytest.skip("Unsupported test case")
Expand All @@ -57,7 +64,7 @@ def test_all_gather(pcie_devices, input_shape, dim, layout, mem_config, use_prog
ttl.tensor.Tensor(t, ttl.tensor.DataType.BFLOAT16).to(layout).to(devices[i], mem_config)
)

tt_out_tensors = ttl.tensor.all_gather(tt_input_tensors, dim, mem_config)
tt_out_tensors = ttl.tensor.all_gather(tt_input_tensors, dim, num_links, output_mem_config=mem_config)

for i, t in enumerate(tt_out_tensors):
tt_output_tensor = t.cpu().to(ttl.tensor.Layout.ROW_MAJOR).to_torch()
Expand Down
Loading

0 comments on commit 21954dc

Please sign in to comment.