Skip to content

Commit

Permalink
updated llama model for python interface
Browse files Browse the repository at this point in the history
  • Loading branch information
yingchen21 committed Aug 8, 2024
1 parent 4acab6c commit 4207293
Showing 1 changed file with 20 additions and 4 deletions.
24 changes: 20 additions & 4 deletions python/flexflow/serve/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,9 +128,17 @@ def build_model(self, max_tokens_per_batch):
name=f"layers.{i}.input_layernorm",
)

qkv_proj = ffmodel.dense(
attn_norm,
3 * self.llama_config.hidden_size,
ActiMode.AC_MODE_NONE,
False,
name=f"layers.{i}.self_attn.qkv_proj",
)

if self.mode == InferenceMode.BEAM_SEARCH_MODE:
mha = ffmodel.spec_inc_multiquery_self_attention(
attn_norm,
qkv_proj,
self.llama_config.hidden_size,
self.llama_config.num_attention_heads,
self.llama_config.num_key_value_heads,
Expand All @@ -149,7 +157,7 @@ def build_model(self, max_tokens_per_batch):
)
elif self.mode == InferenceMode.TREE_VERIFY_MODE:
mha = ffmodel.inc_multiquery_self_attention_verify(
attn_norm,
qkv_proj,
self.llama_config.hidden_size,
self.llama_config.num_attention_heads,
self.llama_config.num_key_value_heads,
Expand All @@ -168,7 +176,7 @@ def build_model(self, max_tokens_per_batch):
)
elif self.mode == InferenceMode.INC_DECODING_MODE:
mha = ffmodel.inc_multiquery_self_attention(
attn_norm,
qkv_proj,
self.llama_config.hidden_size,
self.llama_config.num_attention_heads,
self.llama_config.num_key_value_heads,
Expand All @@ -188,9 +196,17 @@ def build_model(self, max_tokens_per_batch):
else:
assert False

o_proj = ffmodel.dense(
mha,
self.llama_config.hidden_size,
ActiMode.AC_MODE_NONE,
False,
name=f"layers.{i}.self_attn.o_proj"
)

token, ff_norm = ffmodel.residual_rms_norm(
token,
mha,
o_proj,
self.llama_config.rms_norm_eps,
self.llama_config.hidden_size,
name=f"layers.{i}.post_attention_layernorm",
Expand Down

0 comments on commit 4207293

Please sign in to comment.