From 2f7adb595da55ec1ec6fea634eb6987009509fba Mon Sep 17 00:00:00 2001 From: Kunshang Ji Date: Fri, 13 Dec 2024 17:41:52 +0800 Subject: [PATCH 01/14] use env var to control --- vllm/model_executor/models/llama.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 44d34a4e3f20a..710ecf01280f3 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -82,6 +82,9 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", ) + split_enable = bool(os.environ.get('VLLM_TP_SPLIT_ENABLE', '1')) + split_size = int(os.environ.get('VLLM_TP_SPLIT_SIZE', '2')) + split_threshold = int(os.environ.get('VLLM_TP_SPLIT_THRESHOLD', '128')) self.down_proj = RowParallelLinear( input_size=intermediate_size, output_size=hidden_size, From c73ff1920007ef15f0d8b8355e9471557f5024a1 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Sat, 14 Dec 2024 00:45:44 +0200 Subject: [PATCH 02/14] update strategy only split when tensor is big Signed-off-by: Chendi Xue --- vllm/model_executor/models/llama.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 710ecf01280f3..44d34a4e3f20a 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -82,9 +82,6 @@ def __init__( quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", ) - split_enable = bool(os.environ.get('VLLM_TP_SPLIT_ENABLE', '1')) - split_size = int(os.environ.get('VLLM_TP_SPLIT_SIZE', '2')) - split_threshold = int(os.environ.get('VLLM_TP_SPLIT_THRESHOLD', '128')) self.down_proj = RowParallelLinear( input_size=intermediate_size, output_size=hidden_size, From 94897d086cb2ef4c275a5a2b40adc560b6150001 Mon Sep 17 00:00:00 2001 From: Barak Goldberg Date: Mon, 25 Nov 2024 09:34:38 +0200 Subject: [PATCH 03/14] [SW-209737] prepare sin/cos buffers for rope outside model forward --- vllm/model_executor/models/llama.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 44d34a4e3f20a..a6b7ae591f789 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -476,6 +476,8 @@ def forward( residual = intermediate_tensors["residual"] if is_hpu: + for i in range(self.start_layer, self.end_layer): + self.layers[i].self_attn.rotary_emb.prepare_cos_sin(positions) import habana_frameworks.torch as htorch htorch.core.mark_step() From 8bf84bcbbaa29ec698d16245bc8c7c5520be9e94 Mon Sep 17 00:00:00 2001 From: Nir David Date: Tue, 22 Oct 2024 19:08:08 +0300 Subject: [PATCH 04/14] Split qkv --- vllm/config.py | 4 +- vllm/engine/arg_utils.py | 8 +++- vllm/engine/llm_engine.py | 22 ++++++++-- vllm/model_executor/layers/linear.py | 66 ++++++++++++++++++++++------ vllm/model_executor/models/llama.py | 30 ++++++++++--- 5 files changed, 106 insertions(+), 24 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 4e5c755055f1f..65ca86bf77cb1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -737,6 +737,7 @@ class CacheConfig: prefix caching enabled. enable_prefix_caching: Whether to enable prefix caching. cpu_offload_gb: Size of the CPU offload buffer in GiB. + split_qk_v: Whether to split qk and v calculations. """ def __init__( @@ -750,6 +751,7 @@ def __init__( sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, cpu_offload_gb: float = 0, + split_qk_v: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -760,7 +762,7 @@ def __init__( self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb - + self.split_qk_v = split_qk_v self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9f932c6f26eaa..fb286c505ae81 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -123,6 +123,7 @@ class EngineArgs: swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 + split_qk_v: bool = False max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_num_prefill_seqs: Optional[int] = None @@ -501,7 +502,11 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=None, help='If specified, ignore GPU profiling result and use this number' - ' of GPU blocks. Used for testing preemption.') + 'of GPU blocks. Used for testing preemption.') + parser.add_argument('--split-qk-v', + action='store_true', + default=EngineArgs.split_qk_v, + help='Whether to separate qk and v calculations.') parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, @@ -1050,6 +1055,7 @@ def create_engine_config(self, cache_dtype=self.kv_cache_dtype, is_attention_free=model_config.is_attention_free, num_gpu_blocks_override=self.num_gpu_blocks_override, + split_qk_v=self.split_qk_v, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching, cpu_offload_gb=self.cpu_offload_gb, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 34044b358faca..91a8d4dbc79f1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -249,8 +249,22 @@ def __init__( ) logger.info( - "Initializing an LLM engine (v%s) with config: %r," - "use_cached_outputs=%s, ", + "Initializing an LLM engine (v%s) with config: " + "model=%r, speculative_config=%r, tokenizer=%r, " + "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " + "override_neuron_config=%s, tokenizer_revision=%s, " + "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " + "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " + "pipeline_parallel_size=%d, " + "disable_custom_all_reduce=%s, quantization=%s, " + "weights_load_device=%s, enforce_eager=%s, kv_cache_dtype=%s, " + "quantization_param_path=%s, device_config=%s, " + "decoding_config=%r, observability_config=%r, " + "seed=%d, served_model_name=%s, " + "num_scheduler_steps=%d, chunked_prefill_enabled=%s " + "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " + "use_async_output_proc=%s, use_cached_outputs=%s, " + "mm_processor_kwargs=%s, pooler_config=%r, split_qk_v=%s)", VLLM_VERSION, vllm_config, use_cached_outputs, @@ -326,7 +340,9 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: "enforce_eager": self.model_config.enforce_eager, "disable_custom_all_reduce": - self.parallel_config.disable_custom_all_reduce, + parallel_config.disable_custom_all_reduce, + "split_qk_v": + cache_config.split_qk_v, }) if self.tokenizer: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3461acbb95ee9..218d435d2e9a1 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -669,7 +669,8 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = ""): + prefix: str = "", + split_qk_v: bool = False): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -686,14 +687,19 @@ def __init__(self, else: self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) self.num_kv_head_replicas = 1 + self.split_qk_v = split_qk_v + self.q_size = self.num_heads * self.head_size * tp_size + self.kv_size = self.num_kv_heads * self.head_size * tp_size input_size = self.hidden_size - output_size = (self.num_heads + - 2 * self.num_kv_heads) * tp_size * self.head_size self.output_sizes = [ - self.num_heads * self.head_size * tp_size, # q_proj - self.num_kv_heads * self.head_size * tp_size, # k_proj - self.num_kv_heads * self.head_size * tp_size, # v_proj + self.q_size, # q_proj ] + if split_qk_v: + output_size = (self.num_heads) * tp_size * self.head_size + else: + output_size = (self.num_heads + + 2 * self.num_kv_heads) * tp_size * self.head_size + self.output_sizes.append(self.kv_size) # v_proj super().__init__(input_size=input_size, output_size=output_size, @@ -704,6 +710,24 @@ def __init__(self, quant_config=quant_config, prefix=prefix) + if split_qk_v: + self.k_proj = ColumnParallelLinear(input_size=input_size, + output_size=self.kv_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix) + self.v_proj = ColumnParallelLinear(input_size=input_size, + output_size=self.kv_size, + bias=bias, + gather_output=False, + skip_bias_add=skip_bias_add, + params_dtype=params_dtype, + quant_config=quant_config, + prefix=prefix) + def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { "q": 0, @@ -913,15 +937,21 @@ def weight_loader(self, if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.num_heads * self.head_size), - "k": (self.num_heads * self.head_size, - self.num_kv_heads * self.head_size), - "v": - ((self.num_heads + self.num_kv_heads) * self.head_size, - self.num_kv_heads * self.head_size), - "total": - ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, - 0) } + if self.split_qk_v: + orig_qkv_offsets["total"] = ( + (self.num_heads) * self.head_size, + 0) + else: + orig_qkv_offsets["k"] = ( + (self.num_heads) * self.head_size, + self.num_kv_heads * self.head_size) + orig_qkv_offsets["v"] = ( + (self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size) + orig_qkv_offsets["total"] = ( + (self.num_heads + 2 * self.num_kv_heads) * + self.head_size, 0) shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( param, orig_qkv_offsets, loaded_shard_id) @@ -961,6 +991,14 @@ def weight_loader(self, assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) + def forward(self, input_): + q, output_bias = super().forward(input_) + if not self.split_qk_v: + return q, output_bias + k, _ = self.k_proj(input_) + v, _ = self.v_proj(input_) + return q, k, v, output_bias + class RowParallelLinear(LinearBase): """Linear layer with row parallelism. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index a6b7ae591f789..06f81237eefef 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -147,6 +147,7 @@ def __init__( self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings + self.split_qk_v = cache_config.split_qk_v self.qkv_proj = QKVParallelLinear( hidden_size=hidden_size, @@ -156,6 +157,7 @@ def __init__( bias=bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", + split_qk_v=self.split_qk_v, ) self.o_proj = RowParallelLinear( @@ -212,8 +214,12 @@ def forward( skip_seq_split: bool = False, **kwargs, ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.split_qk_v: + q, k, v, _ = self.qkv_proj(hidden_states) + else: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata, **kwargs) self.o_proj.skip_seq_split=skip_seq_split @@ -452,6 +458,13 @@ def __init__(self, make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + if is_hpu: + import os + self.config_hidden_layers = int( + os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) + + self.split_qk_v = cache_config.split_qk_v + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -500,11 +513,14 @@ def load_weights(self, weights: Iterable[Tuple[str, stacked_params_mapping = [ # (param_name, shard_name, shard_id) (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] + if self.split_qk_v: + stacked_params_mapping.append((".qkv_proj.v_proj", ".v_proj", "v")) + stacked_params_mapping.append((".qkv_proj.k_proj", ".k_proj", "k")) + else: + stacked_params_mapping.append((".qkv_proj", ".v_proj", "v")) params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: @@ -537,7 +553,11 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[name] weight_loader = param.weight_loader - weight_loader(param, loaded_weight, shard_id) + if self.split_qk_v and (shard_id == "v" or shard_id == "k") : + weight_loader(param, loaded_weight) + else: + weight_loader(param, loaded_weight, shard_id) + break else: # Skip loading extra bias for GPTQ models. From 388388bc964b94b31f8d1cfd055658a24d5f3bed Mon Sep 17 00:00:00 2001 From: Tianmu Li Date: Mon, 16 Dec 2024 03:15:28 +0200 Subject: [PATCH 05/14] WIP --- vllm/model_executor/models/llama.py | 60 ++++++++++++++++++++++------- 1 file changed, 46 insertions(+), 14 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 06f81237eefef..21600b357874f 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,6 +37,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, + ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -149,16 +150,43 @@ def __init__( self.max_position_embeddings = max_position_embeddings self.split_qk_v = cache_config.split_qk_v - self.qkv_proj = QKVParallelLinear( - hidden_size=hidden_size, - head_size=self.head_dim, - total_num_heads=self.total_num_heads, - total_num_kv_heads=self.total_num_kv_heads, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - split_qk_v=self.split_qk_v, - ) + if self.split_qk_v: + print("Using split qk_v") + self.q_proj = ColumnParallelLinear(input_size=self.hidden_size, + output_size=self.hidden_size, + bias=bias, + gather_output=False, + skip_bias_add=False, + params_dtype=None, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + self.k_proj = ColumnParallelLinear(input_size=self.hidden_size, + output_size=self.kv_size * tp_size, + bias=bias, + gather_output=False, + skip_bias_add=False, + params_dtype=None, + quant_config=quant_config, + prefix=f"{prefix}.k_proj") + self.v_proj = ColumnParallelLinear(input_size=self.hidden_size, + output_size=self.kv_size * tp_size, + bias=bias, + gather_output=False, + skip_bias_add=False, + params_dtype=None, + quant_config=quant_config, + prefix=f"{prefix}.v_proj") + else: + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj", + split_qk_v=self.split_qk_v, + ) self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, @@ -215,7 +243,10 @@ def forward( **kwargs, ) -> torch.Tensor: if self.split_qk_v: - q, k, v, _ = self.qkv_proj(hidden_states) + # q, k, v, _ = self.qkv_proj(hidden_states) + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) else: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], @@ -512,14 +543,15 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), (".gate_up_proj", ".gate_proj", 0), (".gate_up_proj", ".up_proj", 1), ] if self.split_qk_v: - stacked_params_mapping.append((".qkv_proj.v_proj", ".v_proj", "v")) - stacked_params_mapping.append((".qkv_proj.k_proj", ".k_proj", "k")) + pass + # stacked_params_mapping.append((".qkv_proj.v_proj", ".v_proj", "v")) + # stacked_params_mapping.append((".qkv_proj.k_proj", ".k_proj", "k")) else: + stacked_params_mapping.append((".qkv_proj", ".q_proj", "q")) stacked_params_mapping.append((".qkv_proj", ".v_proj", "v")) params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() From d4eb9a7df0517450005e960bd3a1e4d35056e534 Mon Sep 17 00:00:00 2001 From: Tianmu Li Date: Mon, 16 Dec 2024 03:50:09 +0200 Subject: [PATCH 06/14] Resolve rebase issues --- vllm/engine/llm_engine.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 91a8d4dbc79f1..6175c2d37530d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -340,9 +340,9 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: "enforce_eager": self.model_config.enforce_eager, "disable_custom_all_reduce": - parallel_config.disable_custom_all_reduce, + self.parallel_config.disable_custom_all_reduce, "split_qk_v": - cache_config.split_qk_v, + self.cache_config.split_qk_v, }) if self.tokenizer: From 217c58a1a0324cdbb0181164cca9b9a0e124dd41 Mon Sep 17 00:00:00 2001 From: Tianmu Li Date: Mon, 16 Dec 2024 04:51:44 +0200 Subject: [PATCH 07/14] Remove print --- vllm/model_executor/models/llama.py | 1 - 1 file changed, 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 21600b357874f..9e71c00bbb427 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -151,7 +151,6 @@ def __init__( self.split_qk_v = cache_config.split_qk_v if self.split_qk_v: - print("Using split qk_v") self.q_proj = ColumnParallelLinear(input_size=self.hidden_size, output_size=self.hidden_size, bias=bias, From a76a6b8d5994dc4994d42e2365b0f77a5396ad26 Mon Sep 17 00:00:00 2001 From: Tianmu Li Date: Wed, 18 Dec 2024 01:48:15 +0200 Subject: [PATCH 08/14] Remove old split_qk_v implementation --- vllm/model_executor/layers/linear.py | 65 +++++++--------------------- 1 file changed, 15 insertions(+), 50 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 218d435d2e9a1..d778596154bb7 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -669,8 +669,7 @@ def __init__(self, skip_bias_add: bool = False, params_dtype: Optional[torch.dtype] = None, quant_config: Optional[QuantizationConfig] = None, - prefix: str = "", - split_qk_v: bool = False): + prefix: str = ""): self.hidden_size = hidden_size self.head_size = head_size self.total_num_heads = total_num_heads @@ -687,19 +686,16 @@ def __init__(self, else: self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) self.num_kv_head_replicas = 1 - self.split_qk_v = split_qk_v self.q_size = self.num_heads * self.head_size * tp_size self.kv_size = self.num_kv_heads * self.head_size * tp_size input_size = self.hidden_size self.output_sizes = [ self.q_size, # q_proj ] - if split_qk_v: - output_size = (self.num_heads) * tp_size * self.head_size - else: - output_size = (self.num_heads + - 2 * self.num_kv_heads) * tp_size * self.head_size - self.output_sizes.append(self.kv_size) # v_proj + + output_size = (self.num_heads + + 2 * self.num_kv_heads) * tp_size * self.head_size + self.output_sizes.append(self.kv_size) # v_proj super().__init__(input_size=input_size, output_size=output_size, @@ -710,24 +706,6 @@ def __init__(self, quant_config=quant_config, prefix=prefix) - if split_qk_v: - self.k_proj = ColumnParallelLinear(input_size=input_size, - output_size=self.kv_size, - bias=bias, - gather_output=False, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - prefix=prefix) - self.v_proj = ColumnParallelLinear(input_size=input_size, - output_size=self.kv_size, - bias=bias, - gather_output=False, - skip_bias_add=skip_bias_add, - params_dtype=params_dtype, - quant_config=quant_config, - prefix=prefix) - def _get_shard_offset_mapping(self, loaded_shard_id: str): shard_offset_mapping = { "q": 0, @@ -938,20 +916,16 @@ def weight_loader(self, orig_qkv_offsets = { "q": (0, self.num_heads * self.head_size), } - if self.split_qk_v: - orig_qkv_offsets["total"] = ( - (self.num_heads) * self.head_size, - 0) - else: - orig_qkv_offsets["k"] = ( - (self.num_heads) * self.head_size, - self.num_kv_heads * self.head_size) - orig_qkv_offsets["v"] = ( - (self.num_heads + self.num_kv_heads) * self.head_size, - self.num_kv_heads * self.head_size) - orig_qkv_offsets["total"] = ( - (self.num_heads + 2 * self.num_kv_heads) * - self.head_size, 0) + + orig_qkv_offsets["k"] = ( + (self.num_heads) * self.head_size, + self.num_kv_heads * self.head_size) + orig_qkv_offsets["v"] = ( + (self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size) + orig_qkv_offsets["total"] = ( + (self.num_heads + 2 * self.num_kv_heads) * + self.head_size, 0) shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( param, orig_qkv_offsets, loaded_shard_id) @@ -991,15 +965,6 @@ def weight_loader(self, assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) - def forward(self, input_): - q, output_bias = super().forward(input_) - if not self.split_qk_v: - return q, output_bias - k, _ = self.k_proj(input_) - v, _ = self.v_proj(input_) - return q, k, v, output_bias - - class RowParallelLinear(LinearBase): """Linear layer with row parallelism. From e95ea97a1f3b0f8fd8555fb912bca8df03802f25 Mon Sep 17 00:00:00 2001 From: Tianmu Li Date: Wed, 18 Dec 2024 03:15:29 +0200 Subject: [PATCH 09/14] Remove additional prints --- vllm/engine/llm_engine.py | 18 ++---------------- 1 file changed, 2 insertions(+), 16 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 6175c2d37530d..e94bc510a61de 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -249,22 +249,8 @@ def __init__( ) logger.info( - "Initializing an LLM engine (v%s) with config: " - "model=%r, speculative_config=%r, tokenizer=%r, " - "skip_tokenizer_init=%s, tokenizer_mode=%s, revision=%s, " - "override_neuron_config=%s, tokenizer_revision=%s, " - "trust_remote_code=%s, dtype=%s, max_seq_len=%d, " - "download_dir=%r, load_format=%s, tensor_parallel_size=%d, " - "pipeline_parallel_size=%d, " - "disable_custom_all_reduce=%s, quantization=%s, " - "weights_load_device=%s, enforce_eager=%s, kv_cache_dtype=%s, " - "quantization_param_path=%s, device_config=%s, " - "decoding_config=%r, observability_config=%r, " - "seed=%d, served_model_name=%s, " - "num_scheduler_steps=%d, chunked_prefill_enabled=%s " - "multi_step_stream_outputs=%s, enable_prefix_caching=%s, " - "use_async_output_proc=%s, use_cached_outputs=%s, " - "mm_processor_kwargs=%s, pooler_config=%r, split_qk_v=%s)", + "Initializing an LLM engine (v%s) with config: %r," + "use_cached_outputs=%s, ", VLLM_VERSION, vllm_config, use_cached_outputs, From f6bcffbe10fab73947f5e1a7827cc08ee8a97f5e Mon Sep 17 00:00:00 2001 From: Tianmu Li Date: Wed, 18 Dec 2024 08:05:46 +0200 Subject: [PATCH 10/14] Remove old code --- vllm/model_executor/models/llama.py | 83 +++++++++++++++++------------ 1 file changed, 49 insertions(+), 34 deletions(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 9e71c00bbb427..456442cebcbbd 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -76,13 +76,30 @@ def __init__( split_size: int = 2 ) -> None: super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[intermediate_size] * 2, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - ) + self.split_gate_up = True + if self.split_gate_up: + self.gate_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_proj", + ) + self.up_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj" + ) + else: + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) self.down_proj = RowParallelLinear( input_size=intermediate_size, output_size=hidden_size, @@ -98,8 +115,11 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x, skip_seq_split=False): - x, _ = self.gate_up_proj(x) - x = self.act_fn(x) + # if self.split_gate_up: + x = nn.functional.silu(self.gate_proj(x)[0]) * self.up_proj(x)[0] + # else: + # x, _ = self.gate_up_proj(x) + # x = self.act_fn(x) self.down_proj.skip_seq_split=skip_seq_split x, _ = self.down_proj(x) return x @@ -183,8 +203,7 @@ def __init__( total_num_kv_heads=self.total_num_kv_heads, bias=bias, quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - split_qk_v=self.split_qk_v, + prefix=f"{prefix}.qkv_proj" ) self.o_proj = RowParallelLinear( @@ -241,15 +260,15 @@ def forward( skip_seq_split: bool = False, **kwargs, ) -> torch.Tensor: - if self.split_qk_v: - # q, k, v, _ = self.qkv_proj(hidden_states) - q, _ = self.q_proj(hidden_states) - k, _ = self.k_proj(hidden_states) - v, _ = self.v_proj(hidden_states) - else: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - dim=-1) + # if self.split_qk_v: + # q, k, v, _ = self.qkv_proj(hidden_states) + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + # else: + # qkv, _ = self.qkv_proj(hidden_states) + # q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + # dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata, **kwargs) self.o_proj.skip_seq_split=skip_seq_split @@ -488,12 +507,8 @@ def __init__(self, make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) - if is_hpu: - import os - self.config_hidden_layers = int( - os.getenv('VLLM_CONFIG_HIDDEN_LAYERS', '1')) - self.split_qk_v = cache_config.split_qk_v + self.split_gate_up = True def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -519,8 +534,6 @@ def forward( residual = intermediate_tensors["residual"] if is_hpu: - for i in range(self.start_layer, self.end_layer): - self.layers[i].self_attn.rotary_emb.prepare_cos_sin(positions) import habana_frameworks.torch as htorch htorch.core.mark_step() @@ -542,16 +555,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), + # (".gate_up_proj", ".gate_proj", 0), + # (".gate_up_proj", ".up_proj", 1), ] - if self.split_qk_v: - pass - # stacked_params_mapping.append((".qkv_proj.v_proj", ".v_proj", "v")) - # stacked_params_mapping.append((".qkv_proj.k_proj", ".k_proj", "k")) - else: + if not self.split_qk_v: stacked_params_mapping.append((".qkv_proj", ".q_proj", "q")) + stacked_params_mapping.append((".qkv_proj", ".k_proj", "k")) stacked_params_mapping.append((".qkv_proj", ".v_proj", "v")) + + if not self.split_gate_up: + stacked_params_mapping.append((".gate_up_proj", ".gate_proj", 0)) + stacked_params_mapping.append((".gate_up_proj", ".up_proj", 1)) + params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: From 4d508b8fbed0650019f7c777fb7874f6366f0795 Mon Sep 17 00:00:00 2001 From: Tianmu Li Date: Thu, 19 Dec 2024 07:06:26 +0200 Subject: [PATCH 11/14] Cleanup --- vllm/config.py | 3 +++ vllm/engine/arg_utils.py | 6 +++++ vllm/engine/llm_engine.py | 2 ++ vllm/model_executor/layers/linear.py | 27 +++++++++---------- vllm/model_executor/models/llama.py | 40 +++++++++++++++------------- 5 files changed, 45 insertions(+), 33 deletions(-) diff --git a/vllm/config.py b/vllm/config.py index 65ca86bf77cb1..62a7c07b363c1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -738,6 +738,7 @@ class CacheConfig: enable_prefix_caching: Whether to enable prefix caching. cpu_offload_gb: Size of the CPU offload buffer in GiB. split_qk_v: Whether to split qk and v calculations. + split_gate_up: Whether to split gate and up calculations. """ def __init__( @@ -752,6 +753,7 @@ def __init__( enable_prefix_caching: bool = False, cpu_offload_gb: float = 0, split_qk_v: bool = False, + split_gate_up: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -763,6 +765,7 @@ def __init__( self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb self.split_qk_v = split_qk_v + self.split_gate_up = split_gate_up self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index fb286c505ae81..fc062e253f40d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -124,6 +124,7 @@ class EngineArgs: cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 split_qk_v: bool = False + split_gate_up: bool = False max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_num_prefill_seqs: Optional[int] = None @@ -507,6 +508,10 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: action='store_true', default=EngineArgs.split_qk_v, help='Whether to separate qk and v calculations.') + parser.add_argument('--split-gate-up', + action='store_true', + default=EngineArgs.split_gate_up, + help='Whether to separate gate and up calculations.') parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, @@ -1056,6 +1061,7 @@ def create_engine_config(self, is_attention_free=model_config.is_attention_free, num_gpu_blocks_override=self.num_gpu_blocks_override, split_qk_v=self.split_qk_v, + split_gate_up=self.split_gate_up, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching, cpu_offload_gb=self.cpu_offload_gb, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index e94bc510a61de..e7bb20063af79 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -329,6 +329,8 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.parallel_config.disable_custom_all_reduce, "split_qk_v": self.cache_config.split_qk_v, + "split_gate_up": + self.cache_config.split_gate_up, }) if self.tokenizer: diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index d778596154bb7..3ae2883c33d31 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -692,10 +692,13 @@ def __init__(self, self.output_sizes = [ self.q_size, # q_proj ] - output_size = (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size - self.output_sizes.append(self.kv_size) # v_proj + self.output_sizes = [ + self.num_heads * self.head_size * tp_size, # q_proj + self.num_kv_heads * self.head_size * tp_size, # k_proj + self.num_kv_heads * self.head_size * tp_size, # v_proj + ] super().__init__(input_size=input_size, output_size=output_size, @@ -915,20 +918,16 @@ def weight_loader(self, if use_bitsandbytes_4bit: orig_qkv_offsets = { "q": (0, self.num_heads * self.head_size), + "k": (self.num_heads * self.head_size, + self.num_kv_heads * self.head_size), + "v": + ((self.num_heads + self.num_kv_heads) * self.head_size, + self.num_kv_heads * self.head_size), + "total": + ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, + 0) } - orig_qkv_offsets["k"] = ( - (self.num_heads) * self.head_size, - self.num_kv_heads * self.head_size) - orig_qkv_offsets["v"] = ( - (self.num_heads + self.num_kv_heads) * self.head_size, - self.num_kv_heads * self.head_size) - orig_qkv_offsets["total"] = ( - (self.num_heads + 2 * self.num_kv_heads) * - self.head_size, 0) - shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( - param, orig_qkv_offsets, loaded_shard_id) - param_data = param_data.narrow(output_dim, shard_offset, shard_size) if loaded_shard_id == "q": diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 456442cebcbbd..03136f1dbebd3 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -73,10 +73,11 @@ def __init__( bias: bool = False, prefix: str = "", do_split: bool = False, - split_size: int = 2 + split_size: int = 2, + split_gate_up: bool = False ) -> None: super().__init__() - self.split_gate_up = True + self.split_gate_up = split_gate_up if self.split_gate_up: self.gate_proj = ColumnParallelLinear( input_size=hidden_size, @@ -115,11 +116,11 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x, skip_seq_split=False): - # if self.split_gate_up: - x = nn.functional.silu(self.gate_proj(x)[0]) * self.up_proj(x)[0] - # else: - # x, _ = self.gate_up_proj(x) - # x = self.act_fn(x) + if self.split_gate_up: + x = nn.functional.silu(self.gate_proj(x)[0]) * self.up_proj(x)[0] + else: + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) self.down_proj.skip_seq_split=skip_seq_split x, _ = self.down_proj(x) return x @@ -260,15 +261,15 @@ def forward( skip_seq_split: bool = False, **kwargs, ) -> torch.Tensor: - # if self.split_qk_v: - # q, k, v, _ = self.qkv_proj(hidden_states) - q, _ = self.q_proj(hidden_states) - k, _ = self.k_proj(hidden_states) - v, _ = self.v_proj(hidden_states) - # else: - # qkv, _ = self.qkv_proj(hidden_states) - # q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], - # dim=-1) + if self.split_qk_v: + # q, k, v, _ = self.qkv_proj(hidden_states) + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + else: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata, **kwargs) self.o_proj.skip_seq_split=skip_seq_split @@ -330,7 +331,8 @@ def __init__( bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", do_split=do_split, - split_size=split_size + split_size=split_size, + split_gate_up=cache_config.split_gate_up ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -508,7 +510,7 @@ def __init__(self, ["hidden_states", "residual"], config.hidden_size)) self.split_qk_v = cache_config.split_qk_v - self.split_gate_up = True + self.split_gate_up = cache_config.split_gate_up def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -599,7 +601,7 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[name] weight_loader = param.weight_loader - if self.split_qk_v and (shard_id == "v" or shard_id == "k") : + if self.split_qk_v and (shard_id == "v" or shard_id == "k" or shard_id == "q") : weight_loader(param, loaded_weight) else: weight_loader(param, loaded_weight, shard_id) From e011f91eb0485b3bda068552c73e77fb273a76ce Mon Sep 17 00:00:00 2001 From: Tianmu Li Date: Thu, 19 Dec 2024 20:58:31 +0200 Subject: [PATCH 12/14] Remove unneeded changes --- vllm/model_executor/layers/linear.py | 10 ++++------ vllm/model_executor/models/llama.py | 4 ---- 2 files changed, 4 insertions(+), 10 deletions(-) diff --git a/vllm/model_executor/layers/linear.py b/vllm/model_executor/layers/linear.py index 3ae2883c33d31..3461acbb95ee9 100644 --- a/vllm/model_executor/layers/linear.py +++ b/vllm/model_executor/layers/linear.py @@ -686,14 +686,9 @@ def __init__(self, else: self.num_kv_heads = divide(self.total_num_kv_heads, tp_size) self.num_kv_head_replicas = 1 - self.q_size = self.num_heads * self.head_size * tp_size - self.kv_size = self.num_kv_heads * self.head_size * tp_size input_size = self.hidden_size - self.output_sizes = [ - self.q_size, # q_proj - ] output_size = (self.num_heads + - 2 * self.num_kv_heads) * tp_size * self.head_size + 2 * self.num_kv_heads) * tp_size * self.head_size self.output_sizes = [ self.num_heads * self.head_size * tp_size, # q_proj self.num_kv_heads * self.head_size * tp_size, # k_proj @@ -927,6 +922,8 @@ def weight_loader(self, ((self.num_heads + 2 * self.num_kv_heads) * self.head_size, 0) } + shard_size, shard_offset = adjust_bitsandbytes_4bit_shard( + param, orig_qkv_offsets, loaded_shard_id) param_data = param_data.narrow(output_dim, shard_offset, shard_size) @@ -964,6 +961,7 @@ def weight_loader(self, assert param_data.shape == loaded_weight.shape param_data.copy_(loaded_weight) + class RowParallelLinear(LinearBase): """Linear layer with row parallelism. diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 03136f1dbebd3..89fcfbe97b84c 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -601,10 +601,6 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[name] weight_loader = param.weight_loader - if self.split_qk_v and (shard_id == "v" or shard_id == "k" or shard_id == "q") : - weight_loader(param, loaded_weight) - else: - weight_loader(param, loaded_weight, shard_id) break else: From 69ffcb273e25247be342394a5bb7e4318012de72 Mon Sep 17 00:00:00 2001 From: Tianmu Li Date: Thu, 19 Dec 2024 23:16:00 +0200 Subject: [PATCH 13/14] Typo --- vllm/model_executor/models/llama.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 89fcfbe97b84c..e631161cada89 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -601,7 +601,7 @@ def load_weights(self, weights: Iterable[Tuple[str, param = params_dict[name] weight_loader = param.weight_loader - + weight_loader(param, loaded_weight, shard_id) break else: # Skip loading extra bias for GPTQ models. From b19dbf8fa9e34cce8b1c64836c5a190ac70c9a23 Mon Sep 17 00:00:00 2001 From: Tianmu Li Date: Wed, 25 Dec 2024 03:32:29 +0200 Subject: [PATCH 14/14] seq_len=512 prefill w/a --- vllm/model_executor/models/llama.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index e631161cada89..65e70c082ee18 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -78,6 +78,7 @@ def __init__( ) -> None: super().__init__() self.split_gate_up = split_gate_up + self.hidden_size = hidden_size if self.split_gate_up: self.gate_proj = ColumnParallelLinear( input_size=hidden_size, @@ -116,6 +117,10 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x, skip_seq_split=False): + batch_size = x.size(0) + seq_len = x.size(1) + if (seq_len*batch_size)%512==0: + x = x.view(-1,512,self.hidden_size) if self.split_gate_up: x = nn.functional.silu(self.gate_proj(x)[0]) * self.up_proj(x)[0] else: @@ -123,6 +128,8 @@ def forward(self, x, skip_seq_split=False): x = self.act_fn(x) self.down_proj.skip_seq_split=skip_seq_split x, _ = self.down_proj(x) + if (seq_len*batch_size)%512==0: + x = x.view(batch_size,seq_len,self.hidden_size) return x