Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable splitting qkv and gate_up #662

Merged
merged 14 commits into from
Jan 7, 2025
7 changes: 6 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,8 @@ class CacheConfig:
prefix caching enabled.
enable_prefix_caching: Whether to enable prefix caching.
cpu_offload_gb: Size of the CPU offload buffer in GiB.
split_qk_v: Whether to split qk and v calculations.
split_gate_up: Whether to split gate and up calculations.
"""

def __init__(
Expand All @@ -750,6 +752,8 @@ def __init__(
sliding_window: Optional[int] = None,
enable_prefix_caching: bool = False,
cpu_offload_gb: float = 0,
split_qk_v: bool = False,
split_gate_up: bool = False,
) -> None:
self.block_size = block_size
self.gpu_memory_utilization = gpu_memory_utilization
Expand All @@ -760,7 +764,8 @@ def __init__(
self.sliding_window = sliding_window
self.enable_prefix_caching = enable_prefix_caching
self.cpu_offload_gb = cpu_offload_gb

self.split_qk_v = split_qk_v
self.split_gate_up = split_gate_up
self._verify_args()
self._verify_cache_dtype()
self._verify_prefix_caching()
Expand Down
14 changes: 13 additions & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,8 @@ class EngineArgs:
swap_space: float = 4 # GiB
cpu_offload_gb: float = 0 # GiB
gpu_memory_utilization: float = 0.90
split_qk_v: bool = False
split_gate_up: bool = False
max_num_batched_tokens: Optional[int] = None
max_num_seqs: int = 256
max_num_prefill_seqs: Optional[int] = None
Expand Down Expand Up @@ -501,7 +503,15 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=int,
default=None,
help='If specified, ignore GPU profiling result and use this number'
' of GPU blocks. Used for testing preemption.')
'of GPU blocks. Used for testing preemption.')
parser.add_argument('--split-qk-v',
action='store_true',
default=EngineArgs.split_qk_v,
help='Whether to separate qk and v calculations.')
parser.add_argument('--split-gate-up',
action='store_true',
default=EngineArgs.split_gate_up,
help='Whether to separate gate and up calculations.')
parser.add_argument('--max-num-batched-tokens',
type=int,
default=EngineArgs.max_num_batched_tokens,
Expand Down Expand Up @@ -1050,6 +1060,8 @@ def create_engine_config(self,
cache_dtype=self.kv_cache_dtype,
is_attention_free=model_config.is_attention_free,
num_gpu_blocks_override=self.num_gpu_blocks_override,
split_qk_v=self.split_qk_v,
split_gate_up=self.split_gate_up,
sliding_window=model_config.get_sliding_window(),
enable_prefix_caching=self.enable_prefix_caching,
cpu_offload_gb=self.cpu_offload_gb,
Expand Down
4 changes: 4 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,10 @@ def get_tokenizer_for_seq(sequence: Sequence) -> AnyTokenizer:
self.model_config.enforce_eager,
"disable_custom_all_reduce":
self.parallel_config.disable_custom_all_reduce,
"split_qk_v":
self.cache_config.split_qk_v,
"split_gate_up":
self.cache_config.split_gate_up,
})

if self.tokenizer:
Expand Down
129 changes: 101 additions & 28 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (MergedColumnParallelLinear,
QKVParallelLinear,
ColumnParallelLinear,
RowParallelLinear)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationConfig
Expand Down Expand Up @@ -72,16 +73,35 @@ def __init__(
bias: bool = False,
prefix: str = "",
do_split: bool = False,
split_size: int = 2
split_size: int = 2,
split_gate_up: bool = False
) -> None:
super().__init__()
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.split_gate_up = split_gate_up
self.hidden_size = hidden_size
if self.split_gate_up:
self.gate_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_proj",
)
self.up_proj = ColumnParallelLinear(
input_size=hidden_size,
output_size=intermediate_size,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.up_proj"
)
else:
self.gate_up_proj = MergedColumnParallelLinear(
input_size=hidden_size,
output_sizes=[intermediate_size] * 2,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.gate_up_proj",
)
self.down_proj = RowParallelLinear(
input_size=intermediate_size,
output_size=hidden_size,
Expand All @@ -97,10 +117,19 @@ def __init__(
self.act_fn = SiluAndMul()

def forward(self, x, skip_seq_split=False):
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
batch_size = x.size(0)
seq_len = x.size(1)
if (seq_len*batch_size)%512==0:
x = x.view(-1,512,self.hidden_size)
if self.split_gate_up:
x = nn.functional.silu(self.gate_proj(x)[0]) * self.up_proj(x)[0]
else:
x, _ = self.gate_up_proj(x)
x = self.act_fn(x)
self.down_proj.skip_seq_split=skip_seq_split
x, _ = self.down_proj(x)
if (seq_len*batch_size)%512==0:
x = x.view(batch_size,seq_len,self.hidden_size)
return x


Expand Down Expand Up @@ -147,16 +176,43 @@ def __init__(
self.scaling = self.head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings

self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj",
)
self.split_qk_v = cache_config.split_qk_v

if self.split_qk_v:
self.q_proj = ColumnParallelLinear(input_size=self.hidden_size,
output_size=self.hidden_size,
bias=bias,
gather_output=False,
skip_bias_add=False,
params_dtype=None,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.k_proj = ColumnParallelLinear(input_size=self.hidden_size,
output_size=self.kv_size * tp_size,
bias=bias,
gather_output=False,
skip_bias_add=False,
params_dtype=None,
quant_config=quant_config,
prefix=f"{prefix}.k_proj")
self.v_proj = ColumnParallelLinear(input_size=self.hidden_size,
output_size=self.kv_size * tp_size,
bias=bias,
gather_output=False,
skip_bias_add=False,
params_dtype=None,
quant_config=quant_config,
prefix=f"{prefix}.v_proj")
else:
self.qkv_proj = QKVParallelLinear(
hidden_size=hidden_size,
head_size=self.head_dim,
total_num_heads=self.total_num_heads,
total_num_kv_heads=self.total_num_kv_heads,
bias=bias,
quant_config=quant_config,
prefix=f"{prefix}.qkv_proj"
)

self.o_proj = RowParallelLinear(
input_size=self.total_num_heads * self.head_dim,
Expand Down Expand Up @@ -212,8 +268,15 @@ def forward(
skip_seq_split: bool = False,
**kwargs,
) -> torch.Tensor:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
if self.split_qk_v:
# q, k, v, _ = self.qkv_proj(hidden_states)
q, _ = self.q_proj(hidden_states)
k, _ = self.k_proj(hidden_states)
v, _ = self.v_proj(hidden_states)
else:
qkv, _ = self.qkv_proj(hidden_states)
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size],
dim=-1)
q, k = self.rotary_emb(positions, q, k)
attn_output = self.attn(q, k, v, kv_cache, attn_metadata, **kwargs)
self.o_proj.skip_seq_split=skip_seq_split
Expand Down Expand Up @@ -275,7 +338,8 @@ def __init__(
bias=getattr(config, "mlp_bias", False),
prefix=f"{prefix}.mlp",
do_split=do_split,
split_size=split_size
split_size=split_size,
split_gate_up=cache_config.split_gate_up
)
self.input_layernorm = RMSNorm(config.hidden_size,
eps=config.rms_norm_eps)
Expand Down Expand Up @@ -452,6 +516,9 @@ def __init__(self,
make_empty_intermediate_tensors_factory(
["hidden_states", "residual"], config.hidden_size))

self.split_qk_v = cache_config.split_qk_v
self.split_gate_up = cache_config.split_gate_up

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)

Expand Down Expand Up @@ -497,12 +564,18 @@ def load_weights(self, weights: Iterable[Tuple[str,
torch.Tensor]]) -> Set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
(".qkv_proj", ".q_proj", "q"),
(".qkv_proj", ".k_proj", "k"),
(".qkv_proj", ".v_proj", "v"),
(".gate_up_proj", ".gate_proj", 0),
(".gate_up_proj", ".up_proj", 1),
# (".gate_up_proj", ".gate_proj", 0),
# (".gate_up_proj", ".up_proj", 1),
]
if not self.split_qk_v:
stacked_params_mapping.append((".qkv_proj", ".q_proj", "q"))
stacked_params_mapping.append((".qkv_proj", ".k_proj", "k"))
stacked_params_mapping.append((".qkv_proj", ".v_proj", "v"))

if not self.split_gate_up:
stacked_params_mapping.append((".gate_up_proj", ".gate_proj", 0))
stacked_params_mapping.append((".gate_up_proj", ".up_proj", 1))

params_dict = dict(self.named_parameters())
loaded_params: Set[str] = set()
for name, loaded_weight in weights:
Expand Down
Loading