diff --git a/tensorrt_llm/models/qwen/convert.py b/tensorrt_llm/models/qwen/convert.py index 8cfaf58e5..917b279c7 100644 --- a/tensorrt_llm/models/qwen/convert.py +++ b/tensorrt_llm/models/qwen/convert.py @@ -297,8 +297,8 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): q, k, v = torch.split(data, [local_dim, head_size, head_size], dim=-1) q_split = torch.split(q, q.shape[-1] // tp_size, dim=-1) - k_split = torch.split(k, q.shape[-1] // tp_size, dim=-1) - v_split = torch.split(v, q.shape[-1] // tp_size, dim=-1) + k_split = torch.split(k, k.shape[-1] // tp_size, dim=-1) + v_split = torch.split(v, v.shape[-1] // tp_size, dim=-1) return [ torch.concat((q_split[ii], k_split[ii], v_split[ii]), dim=-1) for ii in range(tp_size) @@ -318,8 +318,7 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): cur_weights = multi_query_split(original_weights, local_dim, head_size, tensor_parallel, rank) else: - cur_weights = torch.split(original_weights, - original_weights.shape[-1] // + cur_weights = torch.chunk(original_weights, tensor_parallel, dim=cat_dim)[rank] if is_qkv: @@ -370,8 +369,7 @@ def multi_query_split(data, local_dim, head_size, tp_size, cur_rank): cur_weights = multi_query_split(original_weights, local_dim, head_size, tensor_parallel, rank) else: - cur_weights = torch.split(original_weights, - original_weights.shape[-1] // + cur_weights = torch.chunk(original_weights, tensor_parallel, dim=cat_dim)[rank] if is_qkv: @@ -823,7 +821,7 @@ def convert_hf_qwen(hf_model, 1, intermediate_size // tensor_parallel ], rank=mapping.tp_rank, - cat_dim=-1)) + cat_dim=0)) else: weights.update( get_tllm_linear_weight(split_v, tllm_prex + 'mlp.proj.',