Skip to content

Commit

Permalink
Add GroupQueryAttention on CPU in model builder (#420)
Browse files Browse the repository at this point in the history
### Description

This PR adds `GroupQueryAttention` to ONNX models generated for CPU.

### Motivation and Context

This PR is a follow up to [this
PR](#270).
  • Loading branch information
kunal-vaishnavi authored May 9, 2024
1 parent 1c3a4be commit e2aa89e
Showing 1 changed file with 24 additions and 19 deletions.
43 changes: 24 additions & 19 deletions src/python/py/models/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,21 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
"use_rotemb_in_attn": False, # Use rotary embeddings within attention op (instead of a separate RotaryEmbedding op)
"use_packed_matmul": False, # Use packed MatMul (instead of 3 separate MatMuls for Q/K/V)
}
enable_GQA_on_CPU = True if "enable_GQA_on_CPU" in extra_options and extra_options["enable_GQA_on_CPU"] == "1" else False
if (self.ep in {"cuda", "dml"} and self.io_dtype == TensorProto.FLOAT16) or (enable_GQA_on_CPU and self.ep == "cpu" and self.io_dtype == TensorProto.FLOAT):
valid_gqa_configurations = [
("cpu", TensorProto.FLOAT),
("cuda", TensorProto.FLOAT16),
("dml", TensorProto.FLOAT16),
]
if (self.ep, self.io_dtype) in valid_gqa_configurations:
# Change model settings for GroupQueryAttention
self.attention_attrs["op_type"] = "GroupQueryAttention"
print("GroupQueryAttention (GQA) is used in this model.")

# DML doesn't support packed Q/K/V for GQA yet
self.attention_attrs["use_packed_matmul"] = self.ep != "dml" and self.num_attn_heads == self.num_kv_heads
self.attention_attrs["use_packed_matmul"] = self.ep != "dml"

# GQA + Rot.Emb. does not require `position ids` as input
if self.ep in {"cuda", "cpu"}:
if self.ep != "dml":
self.attention_attrs["use_rotemb_in_attn"] = True
self.input_names.remove("position_ids")

Expand All @@ -199,7 +203,7 @@ def __init__(self, config, io_dtype, onnx_dtype, ep, cache_dir, extra_options):
}

def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
config = GenerationConfig.from_pretrained(model_name_or_path, **extra_kwargs)
config = GenerationConfig.from_pretrained(model_name_or_path, use_auth_token=True, trust_remote_code=True, **extra_kwargs)
inputs = dict(zip(self.input_names, self.input_names))
inputs.update({
"past_key_names": "past_key_values.%d.key",
Expand All @@ -212,7 +216,7 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
"decoder": {
"session_options" : {
"log_id": "onnxruntime-genai",
"provider_options" : []
"provider_options" : [],
},
"filename": self.filename,
"head_size": self.head_size,
Expand Down Expand Up @@ -259,7 +263,7 @@ def make_genai_config(self, model_name_or_path, extra_kwargs, out_dir):
json.dump(genai_config, f, indent=4)

def save_processing(self, model_name_or_path, extra_kwargs, out_dir):
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, **extra_kwargs)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, use_auth_token=True, trust_remote_code=True, **extra_kwargs)
print(f"Saving processing files in {out_dir} for GenAI")
tokenizer.save_pretrained(out_dir)

Expand Down Expand Up @@ -563,11 +567,13 @@ def make_matmul_fp16_or_fp32(self, matmul, name, root_input, **kwargs):
# self.make_value_info(output, self.io_dtype, shape=['batch_size', 'sequence_length', self.hidden_size])

def make_packed_matmul(self, q_matmul, k_matmul, v_matmul, name, root_input, **kwargs):
# N = num_heads * head_size, H = hidden_size
# Combine 3 Matmuls of shape NxH into 1 packed MatMul of shape 3NxH
# Note: Packed MatMul is of shape 3NxH instead of Hx3N because `make_matmul` will apply a transpose before saving
N, H = q_matmul.shape
matmul = np.stack((q_matmul.transpose(), k_matmul.transpose(), v_matmul.transpose()), axis=1).reshape(H, 3*N).transpose()
# N_q = num_attention_heads * head_size, N_kv = num_key_value_heads * head_size, H = hidden_size
# Combine 3 MatMuls of shape N_q x H, N_kv x H, N_kv x H into 1 packed MatMul of shape (N_q+N_kv+N_kv)xH
# Note: Packed MatMul is of shape (N_q+N_kv+N_kv)xH instead of Hx(N_q+N_kv+N_kv) because `make_matmul` will
# apply a transpose before saving
N_q, H = q_matmul.shape
N_kv, _ = k_matmul.shape
matmul = np.concatenate([q_matmul, k_matmul, v_matmul], axis=0).reshape(N_q + N_kv + N_kv, H)
self.make_matmul(matmul, name, root_input, **kwargs)

def make_add_bias(self, add, name, root_input, **kwargs):
Expand Down Expand Up @@ -1001,7 +1007,7 @@ def make_attention(self, layer_id, attention, root_input, **kwargs):
self.make_rotary_embedding(attention.rotary_emb, k_rotary_name, root_input=k_input_to_attention, position_ids=kwargs.get("position_ids", "position_ids"))
k_input_to_attention = f"{k_rotary_name}/output_0"

# Make repeat KV nodes (TODO: remove once ORT supports GQA for CPU)
# Make repeat KV nodes (Note: `repeat_kv` needs to be kept since GroupQueryAttention isn't supported for FP32 CUDA)
past_k = f"past_key_values.{layer_id}.key"
past_v = f"past_key_values.{layer_id}.value"
present_k = f"present.{layer_id}.key"
Expand Down Expand Up @@ -1191,8 +1197,8 @@ def make_model(self, input_path):
self.layernorm_attrs["add_offset"] = 0 # add offset already done for GGUF models
else:
# Load PyTorch model
extra_kwargs = {"trust_remote_code": True} if os.path.exists(self.model_name_or_path) else {"num_hidden_layers": self.num_layers, "trust_remote_code": True} if "num_hidden_layers" in self.extra_options else {"cache_dir": self.cache_dir, "use_auth_token": True, "trust_remote_code": True}
model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, **extra_kwargs)
extra_kwargs = {} if os.path.exists(self.model_name_or_path) else {"num_hidden_layers": self.num_layers} if "num_hidden_layers" in self.extra_options else {"cache_dir": self.cache_dir}
model = AutoModelForCausalLM.from_pretrained(self.model_name_or_path, use_auth_token=True, trust_remote_code=True, **extra_kwargs)

# Loop through model and map each module to ONNX/ORT ops
self.layer_id = 0
Expand Down Expand Up @@ -1858,9 +1864,9 @@ def create_model(model_name, input_path, output_dir, precision, execution_provid
os.makedirs(cache_dir, exist_ok=True)

# Load model config
extra_kwargs = {"trust_remote_code": True} if os.path.isdir(input_path) else {"cache_dir": cache_dir, "use_auth_token": True, "trust_remote_code": True}
extra_kwargs = {} if os.path.isdir(input_path) else {"cache_dir": cache_dir}
hf_name = input_path if os.path.isdir(input_path) else model_name
config = AutoConfig.from_pretrained(hf_name, **extra_kwargs)
config = AutoConfig.from_pretrained(hf_name, use_auth_token=True, trust_remote_code=True, **extra_kwargs)

# Set input/output precision of ONNX model
io_dtype = TensorProto.FLOAT if precision in {"int8", "fp32"} or (precision == "int4" and execution_provider == "cpu") else TensorProto.FLOAT16
Expand Down Expand Up @@ -1980,12 +1986,11 @@ def get_args():
enable_cuda_graph = 1 : The model can use CUDA graph capture for CUDA execution provider. If enabled, all nodes being placed on the CUDA EP
is the prerequisite for the CUDA graph to be used correctly. It is not guaranteed that cuda graph be enabled as it depends on the model
and the graph structure.
enable_GQA_on_CPU = Enalbe G(Group)Query(Q)Attention(A) on CPU.
"""),
)

args = parser.parse_args()
print("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, INT4 CPU, INT4 CUDA")
print("Valid precision + execution provider combinations are: FP32 CPU, FP32 CUDA, FP16 CUDA, FP16 DML, INT4 CPU, INT4 CUDA, INT4 DML")
return args

if __name__ == '__main__':
Expand Down

0 comments on commit e2aa89e

Please sign in to comment.