Skip to content

Commit

Permalink
Rebased
Browse files Browse the repository at this point in the history
Signed-off-by: Ann <[email protected]>
  • Loading branch information
quic-akuruvil committed Nov 13, 2024
2 parents 20346cc + fff4f94 commit 61ca907
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 68 deletions.
33 changes: 1 addition & 32 deletions QEfficient/transformers/pytorch_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,38 +267,7 @@ class KVCacheTransform(ModuleMappingTransform):
CohereModel: QEffCohereModel,
CohereRotaryEmbedding: QEffCohereRotaryEmbedding,
CohereDecoderLayer: QEffCohereDecoderLayer,
# Gemma
GemmaAttention: QEffGemmaAttention,
GemmaDecoderLayer: QEffGemmaDecoderLayer,
GemmaModel: QEffGemmaModel,
GemmaForCausalLM: QEffGemmaForCausalLM,
# Gemma2
Gemma2Attention: QEffGemma2Attention,
Gemma2DecoderLayer: QEffGemma2DecoderLayer,
Gemma2Model: QEffGemma2Model,
Gemma2ForCausalLM: QEffGemma2ForCausalLM,
# Cohere
CohereForCausalLM: QEffCohereForCausalLM,
CohereAttention: QEffCohereAttention,
CohereModel: QEffCohereModel,
CohereRotaryEmbedding: QEffCohereRotaryEmbedding,
CohereDecoderLayer: QEffCohereDecoderLayer,
# Gemma
GemmaAttention: QEffGemmaAttention,
GemmaDecoderLayer: QEffGemmaDecoderLayer,
GemmaModel: QEffGemmaModel,
GemmaForCausalLM: QEffGemmaForCausalLM,
# Gemma2
Gemma2Attention: QEffGemma2Attention,
Gemma2DecoderLayer: QEffGemma2DecoderLayer,
Gemma2Model: QEffGemma2Model,
Gemma2ForCausalLM: QEffGemma2ForCausalLM,
# Cohere
CohereForCausalLM: QEffCohereForCausalLM,
CohereAttention: QEffCohereAttention,
CohereModel: QEffCohereModel,
CohereRotaryEmbedding: QEffCohereRotaryEmbedding,
CohereDecoderLayer: QEffCohereDecoderLayer,

# Mistral
MistralAttention: QEffMistralAttention,
MistralDecoderLayer: QEffMistralDecoderLayer,
Expand Down
16 changes: 16 additions & 0 deletions QEfficient/utils/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,22 @@ def get_embeddings(
return model_kv.model.get_input_embeddings(), model_kv.model.config


def get_embeddings(
model_name: str,
hf_token: Optional[str] = None,
cache_dir: Optional[str] = None,
local_model_dir: Optional[str] = None,
):
from QEfficient.base.common import QEFFCommonLoader

model_kv = QEFFCommonLoader.from_pretrained(
pretrained_model_name_or_path=(local_model_dir if local_model_dir else model_name),
token=hf_token,
cache_dir=cache_dir,
)
return model_kv.model.get_input_embeddings(), model_kv.model.config


def get_qpc_dir_path(
model_card_name, num_cores, mos, batch_size, prompt_len, ctx_len, mxfp6, mxint8, device_group, full_batch_size
):
Expand Down
1 change: 1 addition & 0 deletions QEfficient/utils/generate_inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ def prepare_pytorch_inputs(self):
padding=True,
)
input_ids = inputs["input_ids"]

batch_size, input_len = input_ids.shape
inputs.pop("attention_mask")
inputs.pop("token_type_ids", None)
Expand Down
37 changes: 1 addition & 36 deletions tests/transformers/models/test_causal_lm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,42 +40,7 @@
"ibm-granite/granite-20b-code-base",

"CohereForAI/c4ai-command-r-v01",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"gpt2",
"Salesforce/codegen-350M-mono",
"microsoft/phi-2",
"microsoft/Phi-3-mini-4k-instruct",
"tiiuae/falcon-7b",
"Qwen/Qwen2-0.5B",
"bigcode/starcoder2-3b",
"Felladrin/Minueza-32M-Base",
"wtang06/mpt-125m-c4",
"hakurei/gpt-j-random-tinier",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"unsloth/gemma-2b",
"unsloth/gemma-2-2b",
"TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", # AWQ model
"TheBloke/Llama-2-7B-GPTQ", # GPTQ model
"ibm-granite/granite-20b-code-base",
"CohereForAI/c4ai-command-r-v01",
"TinyLlama/TinyLlama-1.1B-Chat-v1.0",
"gpt2",
"Salesforce/codegen-350M-mono",
"microsoft/phi-2",
"microsoft/Phi-3-mini-4k-instruct",
"tiiuae/falcon-7b",
"Qwen/Qwen2-0.5B",
"bigcode/starcoder2-3b",
"Felladrin/Minueza-32M-Base",
"wtang06/mpt-125m-c4",
"hakurei/gpt-j-random-tinier",
"mistralai/Mixtral-8x7B-Instruct-v0.1",
"unsloth/gemma-2b",
"unsloth/gemma-2-2b",
"TheBloke/TinyLlama-1.1B-Chat-v0.3-AWQ", # AWQ model
"TheBloke/Llama-2-7B-GPTQ", # GPTQ model
"ibm-granite/granite-20b-code-base",
"CohereForAI/c4ai-command-r-v01",

]


Expand Down

0 comments on commit 61ca907

Please sign in to comment.