Skip to content

Commit

Permalink
#8049: Move kvcache initialization to dedicated function
Browse files Browse the repository at this point in the history
  • Loading branch information
johanna-rock-tt committed May 16, 2024
1 parent 3f17525 commit d04684f
Show file tree
Hide file tree
Showing 6 changed files with 74 additions and 96 deletions.
44 changes: 6 additions & 38 deletions models/demos/t3000/falcon40b/demo/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,42 +85,6 @@ def preprocess_and_validate_inputs(input_prompts, tokenizer, max_seq_len, perf_m
return prefill_ids, num_users, num_input_tokens


def initialize_kv_cache(model_config, configuration, num_layers, batch_size, max_seq_len, devices):
head_dim = configuration.hidden_size // configuration.num_attention_heads
num_kv_heads = configuration.num_kv_heads

tt_kv_cache = ()
tt_k_cache_host = torch.zeros(batch_size, num_kv_heads, max_seq_len, head_dim)
tt_v_cache_host = torch.zeros(batch_size, num_kv_heads, max_seq_len, head_dim)
tt_k_cache_host = torch.chunk(tt_k_cache_host, len(devices), 1)
tt_v_cache_host = torch.chunk(tt_v_cache_host, len(devices), 1)

for _ in tqdm(range(num_layers), desc="Initializing kv cache on devices for each layer"):
tt_k_cache = []
tt_v_cache = []
for j in range(len(devices)):
tt_k_cache.append(
torch2tt_tensor(
tt_k_cache_host[j],
devices[j],
tt_lib.tensor.Layout.TILE,
model_config["KV_CACHE_MEMCFG"],
model_config["KV_CACHE_DTYPE"],
)
)
tt_v_cache.append(
torch2tt_tensor(
tt_v_cache_host[j],
devices[j],
tt_lib.tensor.Layout.TILE,
model_config["KV_CACHE_MEMCFG"],
model_config["KV_CACHE_DTYPE"],
)
)
tt_kv_cache += ((tt_k_cache, tt_v_cache),)
return tt_kv_cache


# TODO: Remove once we have prefill on device
def initialize_and_fill_kv_cache(
pytorch_FalconCausalLM, model_config, configuration, prefill_ids, num_layers, batch_size, max_seq_len, devices
Expand Down Expand Up @@ -286,7 +250,7 @@ def run_falcon_demo_kv(

synchronize_devices(devices)

kv_cache_singlelayer = tt_FalconCausalLM_singlelayer.get_kv_cache() # only used for compile
kv_cache_singlelayer = tt_FalconCausalLM_singlelayer.initialize_kv_cache() # only used for compile

enable_persistent_kernel_cache()

Expand Down Expand Up @@ -388,7 +352,9 @@ def run_falcon_demo_kv(
logger.info("Moved weights (all layers) to device!")
profiler.end(f"moving_to_device")

kv_cache = tt_FalconCausalLM.get_kv_cache() # Initialized kv cache for all layers
profiler.start(f"initializing_KV_cache")
kv_cache = tt_FalconCausalLM.initialize_kv_cache() # Initialized kv cache for all layers
profiler.end(f"initializing_KV_cache")

### Second prefill run without compile ###
enable_persistent_kernel_cache()
Expand Down Expand Up @@ -549,6 +515,7 @@ def run_falcon_demo_kv(
"preprocessing": profiler.get("tokenizing_inputs"),
"loading_weights": profiler.get("loading_weights"),
"moving_to_device": profiler.get("moving_to_device"),
"initializing_KV_cache": profiler.get("initializing_KV_cache"),
"compile_prefill": time_prefill_compile if not prefill_on_host else None,
"compile_decode": time_decode_compile,
"compile_total": time_prefill_compile + time_decode_compile,
Expand All @@ -566,6 +533,7 @@ def run_falcon_demo_kv(
logger.info(
f"conversion to TT (if downloaded) and moving weights to device: {round(measurements['moving_to_device'], 5)} s"
)
logger.info(f"initializing KV cache: {round(measurements['initializing_KV_cache'], 5)} s")
if not prefill_on_host:
logger.info(f"prefill compile time: {round(measurements['compile_prefill'],5)} s")
logger.info(f"decode compile time: {round(measurements['compile_decode'], 5)} s")
Expand Down
26 changes: 14 additions & 12 deletions models/demos/t3000/falcon40b/tests/test_falcon_end_to_end.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,25 +87,27 @@ def run_test_FalconCausalLM_end_to_end(
assert q_len % 32 == 0, "For prefill, seq_len must be multiple of 32!"
assert kv_cache_len == 0, "For prefill, no kv_cache is passed in!"

past_key_values = None

elif llm_mode == "decode":
q_len, kv_len = seq_len, kv_cache_len + 1
assert batch % 32 == 0, "For decode, batch must be multiple of 32!"
assert q_len == 1, "For decode, q_len must be 1!"

past_key_values = ()
for i in range(num_layers):
k_cache = torch.zeros(batch, num_kv_heads, kv_cache_len, head_dim)
v_cache = torch.zeros(batch, num_kv_heads, kv_cache_len, head_dim)
past_key_values += (
(
torch.repeat_interleave(k_cache, num_attention_heads // num_kv_heads, 1),
(torch.repeat_interleave(v_cache, num_attention_heads // num_kv_heads, 1)),
),
)

else:
raise NotImplementedError(f"Llm mode {llm_mode} is not supported! Must be one of prefill or decode.")

past_key_values = ()
for i in range(num_layers):
k_cache = torch.zeros(batch, num_kv_heads, kv_cache_len, head_dim)
v_cache = torch.zeros(batch, num_kv_heads, kv_cache_len, head_dim)
past_key_values += (
(
torch.repeat_interleave(k_cache, num_attention_heads // num_kv_heads, 1),
(torch.repeat_interleave(v_cache, num_attention_heads // num_kv_heads, 1)),
),
)

# Prepare output -----------------------------------------------------------------------
logger.info("Running HF reference model")
profiler.start("hugging_face_reference_model")
Expand Down Expand Up @@ -137,7 +139,7 @@ def run_test_FalconCausalLM_end_to_end(
logger.info("Done loading TT Falcon Model")

# Initialize past layer values
tt_layer_past = tt_FalconCausalLM.get_kv_cache()
tt_layer_past = tt_FalconCausalLM.initialize_kv_cache()

profiler.start("processing_of_input")
if llm_mode == "prefill":
Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/falcon40b/tests/test_perf_e2e_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ def run_test_FalconCausalLM_end_to_end(
del state_dict

# Initialize past layer values
tt_layer_past = tt_FalconCausalLM.get_kv_cache()
tt_layer_past = tt_FalconCausalLM.initialize_kv_cache()

if llm_mode == "prefill":
model_inputs = torch.split(model_input, 1)
Expand Down
2 changes: 1 addition & 1 deletion models/demos/t3000/falcon40b/tests/test_perf_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def run_test_FalconCausalLM_end_to_end(
del state_dict

# Initialize past layer values
tt_layer_past = tt_FalconCausalLM.get_kv_cache()
tt_layer_past = tt_FalconCausalLM.initialize_kv_cache()

profiler.start("processing_of_input")
if llm_mode == "prefill":
Expand Down
92 changes: 50 additions & 42 deletions models/demos/t3000/falcon40b/tt/falcon_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ def __init__(
self.state_dict = state_dict
self.model_config = model_config
self.num_heads_per_device = self.num_heads // len(devices)
self.tt_cache_path = tt_cache_path
self.max_batch_size = 32

if (self.head_dim * self.num_heads) != self.hidden_size:
Expand Down Expand Up @@ -250,50 +251,57 @@ def __init__(
# self.scalar = pad_by_zero(torch.Tensor([1 / math.sqrt(self.head_dim)]), self.device)[0]
self.scalar = 1 / math.sqrt(self.head_dim)

# Preloading the kvcache
attn_cache_shape = (
self.max_batch_size,
self.num_kv_heads // len(devices),
self.max_position_embeddings,
# 2048 + 128, # Meets benchmarking spec needs
self.head_dim,
)
kvcache_path = tt_cache_path / f"empty_attn_cache{attn_cache_shape}.bin"
k_cache = []
v_cache = []
if (kvcache_path).exists():
for i in range(len(self.devices)):
k_cache.append(
tt_lib.tensor.load_tensor(str(kvcache_path)).to(devices[i], self.model_config["DRAM_MEMCFG"])
)
v_cache.append(
tt_lib.tensor.load_tensor(str(kvcache_path)).to(devices[i], self.model_config["DRAM_MEMCFG"])
self.preprocessing(self.model_config["LLM_MODE"], self.model_config["BATCH_SIZE"], self.model_config["SEQ_LEN"])
self.layer_past = None

def initialize_kvcache(self):
if self.layer_past is None:
# Preloading the kvcache
attn_cache_shape = (
self.max_batch_size,
self.num_kv_heads // len(self.devices),
self.max_position_embeddings,
self.head_dim,
)
kvcache_path = self.tt_cache_path / f"empty_attn_cache{attn_cache_shape}.bin"
k_cache = []
v_cache = []
if (kvcache_path).exists():
for i in range(len(self.devices)):
k_cache.append(
tt_lib.tensor.load_tensor(str(kvcache_path)).to(
self.devices[i], self.model_config["DRAM_MEMCFG"]
)
)
v_cache.append(
tt_lib.tensor.load_tensor(str(kvcache_path)).to(
self.devices[i], self.model_config["DRAM_MEMCFG"]
)
)
else:
attn_cache = torch.zeros(attn_cache_shape)
tt_attn_cache = torch2tt_tensor(
attn_cache,
None,
tt_memory_config=self.model_config["DRAM_MEMCFG"],
tt_dtype=ttnn.bfloat8_b,
)
else:
attn_cache = torch.zeros(attn_cache_shape)
tt_attn_cache = torch2tt_tensor(
attn_cache,
None,
tt_memory_config=self.model_config["DRAM_MEMCFG"],
tt_dtype=ttnn.bfloat8_b,
)
for i in range(len(self.devices)):
k_cache.append(tt_attn_cache.to(devices[i], self.model_config["DRAM_MEMCFG"]))
for i in range(len(self.devices)):
v_cache.append(tt_attn_cache.to(devices[i], self.model_config["DRAM_MEMCFG"]))

tt_lib.tensor.dump_tensor(
str(kvcache_path),
tt_attn_cache,
)
self.layer_past = (
(
k_cache,
v_cache,
),
)
for i in range(len(self.devices)):
k_cache.append(tt_attn_cache.to(self.devices[i], self.model_config["DRAM_MEMCFG"]))
for i in range(len(self.devices)):
v_cache.append(tt_attn_cache.to(self.devices[i], self.model_config["DRAM_MEMCFG"]))

self.preprocessing(self.model_config["LLM_MODE"], self.model_config["BATCH_SIZE"], self.model_config["SEQ_LEN"])
tt_lib.tensor.dump_tensor(
str(kvcache_path),
tt_attn_cache,
)
self.layer_past = (
(
k_cache,
v_cache,
),
)
return self.layer_past

def set_model_config(self, model_config):
self.model_config = model_config
Expand Down
4 changes: 2 additions & 2 deletions models/demos/t3000/falcon40b/tt/falcon_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,10 +133,10 @@ def __init__(

self.layernorm_eps = config.layer_norm_epsilon

def get_kv_cache(self):
def initialize_kv_cache(self):
layer_past = ()
for layer_num in range(self.num_layers):
layer_past += self.layers[layer_num].self_attn.layer_past
layer_past += self.layers[layer_num].self_attn.initialize_kvcache()
return layer_past

def set_model_config(self, model_config):
Expand Down

0 comments on commit d04684f

Please sign in to comment.