From e2aa89e02b82eb0a93e360a99a5b2d7a17e1b7e4 Mon Sep 17 00:00:00 2001 From: kunal-vaishnavi <115581922+kunal-vaishnavi@users.noreply.github.com> Date: Thu, 9 May 2024 15:02:14 -0700 Subject: [PATCH] Add GroupQueryAttention on CPU in model builder (#420) ### Description This PR adds `GroupQueryAttention` to ONNX models generated for CPU. ### Motivation and Context This PR is a follow up to [this PR](https://github.com/microsoft/onnxruntime-genai/pull/270). --- src/python/py/models/builder.py | 43 ++++++++++++++++++--------------- 1 file changed, 24 insertions(+), 19 deletions(-) diff --git a/src/python/py/models/builder.py b/src/python/py/models/builder.py index 940f76e55..46049f1d2 100644 --- a/src/python/py/models/builder.py +++ b/src/python/py/models/builder.py @@ -167,17 +167,21 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): "use_rotemb_in_attn": False, # Use rotary embeddings within attention op (instead of a separate RotaryEmbedding op) "use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V) } - enable_GQA_on_CPU = True if "enable_GQA_on_CPU" in extra_options and extra_options["enable_GQA_on_CPU"] == "1" else False - if (self.ep in {"cuda", "dml"} and self.io_dtype == TensorProto.FLOAT16) or (enable_GQA_on_CPU and self.ep == "cpu" and self.io_dtype == TensorProto.FLOAT): + valid_gqa_configurations = [ + ("cpu", TensorProto.FLOAT), + ("cuda", TensorProto.FLOAT16), + ("dml", TensorProto.FLOAT16), + ] + if (self.ep, self.io_dtype) in valid_gqa_configurations: # Change model settings for GroupQueryAttention self.attention_attrs["op_type"] = "GroupQueryAttention" print("GroupQueryAttention (GQA) is used in this model.") # DML doesn't support packed Q/K/V for GQA yet - self.attention_attrs["use_packed_matmul"] = self.ep != "dml" and self.num_attn_heads == self.num_kv_heads + self.attention_attrs["use_packed_matmul"] = self.ep != "dml" # GQA + Rot.Emb. does not require `position ids` as input - if self.ep in {"cuda", "cpu"}: + if self.ep != "dml": self.attention_attrs["use_rotemb_in_attn"] = True self.input_names.remove("position_ids") @@ -199,7 +203,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options): } def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): - config = GenerationConfig.from_pretrained(model_name_or_path, **extra_kwargs) + config = GenerationConfig.from_pretrained(model_name_or_path, use_auth_token=True, trust_remote_code=True, **extra_kwargs) inputs = dict(zip(self.input_names, self.input_names)) inputs.update({ "past_key_names": "past_key_values.%d.key", @@ -212,7 +216,7 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): "decoder": { "session_options" : { "log_id": "onnxruntime-genai", - "provider_options" : [] + "provider_options" : [], }, "filename": self.filename, "head_size": self.head_size, @@ -259,7 +263,7 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir): json.dump(genai_config, f, indent=4) def save_processing(self, model_name_or_path, extra_kwargs, out_dir): - tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **extra_kwargs) + tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_auth_token=True, trust_remote_code=True, **extra_kwargs) print(f"Saving processing files in {out_dir} for GenAI") tokenizer.save_pretrained(out_dir) @@ -563,11 +567,13 @@ def make_matmul_fp16_or_fp32(self, matmul, name, root_input, **kwargs): # self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size]) def make_packed_matmul(self, q_matmul, k_matmul, v_matmul, name, root_input, **kwargs): - # N = num_heads * head_size, H = hidden_size - # Combine 3 Matmuls of shape NxH into 1 packed MatMul of shape 3NxH - # Note: Packed MatMul is of shape 3NxH instead of Hx3N because `make_matmul` will apply a transpose before saving - N, H = q_matmul.shape - matmul = np.stack((q_matmul.transpose(), k_matmul.transpose(), v_matmul.transpose()), axis=1).reshape(H, 3*N).transpose() + # N_q = num_attention_heads * head_size, N_kv = num_key_value_heads * head_size, H = hidden_size + # Combine 3 MatMuls of shape N_q x H, N_kv x H, N_kv x H into 1 packed MatMul of shape (N_q+N_kv+N_kv)xH + # Note: Packed MatMul is of shape (N_q+N_kv+N_kv)xH instead of Hx(N_q+N_kv+N_kv) because `make_matmul` will + # apply a transpose before saving + N_q, H = q_matmul.shape + N_kv, _ = k_matmul.shape + matmul = np.concatenate([q_matmul, k_matmul, v_matmul], axis=0).reshape(N_q + N_kv + N_kv, H) self.make_matmul(matmul, name, root_input, **kwargs) def make_add_bias(self, add, name, root_input, **kwargs): @@ -1001,7 +1007,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs): self.make_rotary_embedding(attention.rotary_emb, k_rotary_name, root_input=k_input_to_attention, position_ids=kwargs.get("position_ids", "position_ids")) k_input_to_attention = f"{k_rotary_name}/output_0" - # Make repeat KV nodes (TODO: remove once ORT supports GQA for CPU) + # Make repeat KV nodes (Note: `repeat_kv` needs to be kept since GroupQueryAttention isn't supported for FP32 CUDA) past_k = f"past_key_values.{layer_id}.key" past_v = f"past_key_values.{layer_id}.value" present_k = f"present.{layer_id}.key" @@ -1191,8 +1197,8 @@ def make_model(self, input_path): self.layernorm_attrs["add_offset"] = 0 # add offset already done for GGUF models else: # Load PyTorch model - extra_kwargs = {"trust_remote_code": True} if os.path.exists(self.model_name_or_path) else {"num_hidden_layers": self.num_layers, "trust_remote_code": True} if "num_hidden_layers" in self.extra_options else {"cache_dir": self.cache_dir, "use_auth_token": True, "trust_remote_code": True} - model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, **extra_kwargs) + extra_kwargs = {} if os.path.exists(self.model_name_or_path) else {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {"cache_dir": self.cache_dir} + model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, use_auth_token=True, trust_remote_code=True, **extra_kwargs) # Loop through model and map each module to ONNX/ORT ops self.layer_id = 0 @@ -1858,9 +1864,9 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid os.makedirs(cache_dir, exist_ok=True) # Load model config - extra_kwargs = {"trust_remote_code": True} if os.path.isdir(input_path) else {"cache_dir": cache_dir, "use_auth_token": True, "trust_remote_code": True} + extra_kwargs = {} if os.path.isdir(input_path) else {"cache_dir": cache_dir} hf_name = input_path if os.path.isdir(input_path) else model_name - config = AutoConfig.from_pretrained(hf_name, **extra_kwargs) + config = AutoConfig.from_pretrained(hf_name, use_auth_token=True, trust_remote_code=True, **extra_kwargs) # Set input/output precision of ONNX model io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") else TensorProto.FLOAT16 @@ -1980,12 +1986,11 @@ def get_args(): enable_cuda_graph = 1 : The model can use CUDA graph capture for CUDA execution provider. If enabled, all nodes being placed on the CUDA EP is the prerequisite for the CUDA graph to be used correctly. It is not guaranteed that cuda graph be enabled as it depends on the model and the graph structure. - enable_GQA_on_CPU = Enalbe G(Group)Query(Q)Attention(A) on CPU. """), ) args = parser.parse_args() - print("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, INT4 CPU, INT4 CUDA") + print("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, INT4 CPU, INT4 CUDA, INT4 DML") return args if __name__ == '__main__':