diff --git a/vllm/config.py b/vllm/config.py index 4e5c755055f1f..62a7c07b363c1 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -737,6 +737,8 @@ class CacheConfig: prefix caching enabled. enable_prefix_caching: Whether to enable prefix caching. cpu_offload_gb: Size of the CPU offload buffer in GiB. + split_qk_v: Whether to split qk and v calculations. + split_gate_up: Whether to split gate and up calculations. """ def __init__( @@ -750,6 +752,8 @@ def __init__( sliding_window: Optional[int] = None, enable_prefix_caching: bool = False, cpu_offload_gb: float = 0, + split_qk_v: bool = False, + split_gate_up: bool = False, ) -> None: self.block_size = block_size self.gpu_memory_utilization = gpu_memory_utilization @@ -760,7 +764,8 @@ def __init__( self.sliding_window = sliding_window self.enable_prefix_caching = enable_prefix_caching self.cpu_offload_gb = cpu_offload_gb - + self.split_qk_v = split_qk_v + self.split_gate_up = split_gate_up self._verify_args() self._verify_cache_dtype() self._verify_prefix_caching() diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 9f932c6f26eaa..fc062e253f40d 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -123,6 +123,8 @@ class EngineArgs: swap_space: float = 4 # GiB cpu_offload_gb: float = 0 # GiB gpu_memory_utilization: float = 0.90 + split_qk_v: bool = False + split_gate_up: bool = False max_num_batched_tokens: Optional[int] = None max_num_seqs: int = 256 max_num_prefill_seqs: Optional[int] = None @@ -501,7 +503,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: type=int, default=None, help='If specified, ignore GPU profiling result and use this number' - ' of GPU blocks. Used for testing preemption.') + 'of GPU blocks. Used for testing preemption.') + parser.add_argument('--split-qk-v', + action='store_true', + default=EngineArgs.split_qk_v, + help='Whether to separate qk and v calculations.') + parser.add_argument('--split-gate-up', + action='store_true', + default=EngineArgs.split_gate_up, + help='Whether to separate gate and up calculations.') parser.add_argument('--max-num-batched-tokens', type=int, default=EngineArgs.max_num_batched_tokens, @@ -1050,6 +1060,8 @@ def create_engine_config(self, cache_dtype=self.kv_cache_dtype, is_attention_free=model_config.is_attention_free, num_gpu_blocks_override=self.num_gpu_blocks_override, + split_qk_v=self.split_qk_v, + split_gate_up=self.split_gate_up, sliding_window=model_config.get_sliding_window(), enable_prefix_caching=self.enable_prefix_caching, cpu_offload_gb=self.cpu_offload_gb, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 34044b358faca..e7bb20063af79 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -327,6 +327,10 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer: self.model_config.enforce_eager, "disable_custom_all_reduce": self.parallel_config.disable_custom_all_reduce, + "split_qk_v": + self.cache_config.split_qk_v, + "split_gate_up": + self.cache_config.split_gate_up, }) if self.tokenizer: diff --git a/vllm/model_executor/models/llama.py b/vllm/model_executor/models/llama.py index 44d34a4e3f20a..65e70c082ee18 100644 --- a/vllm/model_executor/models/llama.py +++ b/vllm/model_executor/models/llama.py @@ -37,6 +37,7 @@ from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, + ColumnParallelLinear, RowParallelLinear) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig @@ -72,16 +73,35 @@ def __init__( bias: bool = False, prefix: str = "", do_split: bool = False, - split_size: int = 2 + split_size: int = 2, + split_gate_up: bool = False ) -> None: super().__init__() - self.gate_up_proj = MergedColumnParallelLinear( - input_size=hidden_size, - output_sizes=[intermediate_size] * 2, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.gate_up_proj", - ) + self.split_gate_up = split_gate_up + self.hidden_size = hidden_size + if self.split_gate_up: + self.gate_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_proj", + ) + self.up_proj = ColumnParallelLinear( + input_size=hidden_size, + output_size=intermediate_size, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.up_proj" + ) + else: + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.gate_up_proj", + ) self.down_proj = RowParallelLinear( input_size=intermediate_size, output_size=hidden_size, @@ -97,10 +117,19 @@ def __init__( self.act_fn = SiluAndMul() def forward(self, x, skip_seq_split=False): - x, _ = self.gate_up_proj(x) - x = self.act_fn(x) + batch_size = x.size(0) + seq_len = x.size(1) + if (seq_len*batch_size)%512==0: + x = x.view(-1,512,self.hidden_size) + if self.split_gate_up: + x = nn.functional.silu(self.gate_proj(x)[0]) * self.up_proj(x)[0] + else: + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) self.down_proj.skip_seq_split=skip_seq_split x, _ = self.down_proj(x) + if (seq_len*batch_size)%512==0: + x = x.view(batch_size,seq_len,self.hidden_size) return x @@ -147,16 +176,43 @@ def __init__( self.scaling = self.head_dim**-0.5 self.rope_theta = rope_theta self.max_position_embeddings = max_position_embeddings - - self.qkv_proj = QKVParallelLinear( - hidden_size=hidden_size, - head_size=self.head_dim, - total_num_heads=self.total_num_heads, - total_num_kv_heads=self.total_num_kv_heads, - bias=bias, - quant_config=quant_config, - prefix=f"{prefix}.qkv_proj", - ) + self.split_qk_v = cache_config.split_qk_v + + if self.split_qk_v: + self.q_proj = ColumnParallelLinear(input_size=self.hidden_size, + output_size=self.hidden_size, + bias=bias, + gather_output=False, + skip_bias_add=False, + params_dtype=None, + quant_config=quant_config, + prefix=f"{prefix}.q_proj") + self.k_proj = ColumnParallelLinear(input_size=self.hidden_size, + output_size=self.kv_size * tp_size, + bias=bias, + gather_output=False, + skip_bias_add=False, + params_dtype=None, + quant_config=quant_config, + prefix=f"{prefix}.k_proj") + self.v_proj = ColumnParallelLinear(input_size=self.hidden_size, + output_size=self.kv_size * tp_size, + bias=bias, + gather_output=False, + skip_bias_add=False, + params_dtype=None, + quant_config=quant_config, + prefix=f"{prefix}.v_proj") + else: + self.qkv_proj = QKVParallelLinear( + hidden_size=hidden_size, + head_size=self.head_dim, + total_num_heads=self.total_num_heads, + total_num_kv_heads=self.total_num_kv_heads, + bias=bias, + quant_config=quant_config, + prefix=f"{prefix}.qkv_proj" + ) self.o_proj = RowParallelLinear( input_size=self.total_num_heads * self.head_dim, @@ -212,8 +268,15 @@ def forward( skip_seq_split: bool = False, **kwargs, ) -> torch.Tensor: - qkv, _ = self.qkv_proj(hidden_states) - q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) + if self.split_qk_v: + # q, k, v, _ = self.qkv_proj(hidden_states) + q, _ = self.q_proj(hidden_states) + k, _ = self.k_proj(hidden_states) + v, _ = self.v_proj(hidden_states) + else: + qkv, _ = self.qkv_proj(hidden_states) + q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], + dim=-1) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v, kv_cache, attn_metadata, **kwargs) self.o_proj.skip_seq_split=skip_seq_split @@ -275,7 +338,8 @@ def __init__( bias=getattr(config, "mlp_bias", False), prefix=f"{prefix}.mlp", do_split=do_split, - split_size=split_size + split_size=split_size, + split_gate_up=cache_config.split_gate_up ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) @@ -452,6 +516,9 @@ def __init__(self, make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size)) + self.split_qk_v = cache_config.split_qk_v + self.split_gate_up = cache_config.split_gate_up + def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) @@ -497,12 +564,18 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> Set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) - (".qkv_proj", ".q_proj", "q"), - (".qkv_proj", ".k_proj", "k"), - (".qkv_proj", ".v_proj", "v"), - (".gate_up_proj", ".gate_proj", 0), - (".gate_up_proj", ".up_proj", 1), + # (".gate_up_proj", ".gate_proj", 0), + # (".gate_up_proj", ".up_proj", 1), ] + if not self.split_qk_v: + stacked_params_mapping.append((".qkv_proj", ".q_proj", "q")) + stacked_params_mapping.append((".qkv_proj", ".k_proj", "k")) + stacked_params_mapping.append((".qkv_proj", ".v_proj", "v")) + + if not self.split_gate_up: + stacked_params_mapping.append((".gate_up_proj", ".gate_proj", 0)) + stacked_params_mapping.append((".gate_up_proj", ".up_proj", 1)) + params_dict = dict(self.named_parameters()) loaded_params: Set[str] = set() for name, loaded_weight in weights: