diff --git a/colossalai/inference/core/engine.py b/colossalai/inference/core/engine.py index d0d46d81bc0f..a1b54fa1c89a 100644 --- a/colossalai/inference/core/engine.py +++ b/colossalai/inference/core/engine.py @@ -466,6 +466,7 @@ def steps_spec_dec(self) -> List[Sequence]: self.k_cache[-1], # use kv cahces of the last layer self.v_cache[-1], batch.get_sequence_lengths(), + n_spec_tokens=self.n_spec_tokens, ) drafter_out = self.drafter.speculate( diff --git a/colossalai/inference/modeling/models/glide_llama.py b/colossalai/inference/modeling/models/glide_llama.py index 013b0f06185d..0ee78a303004 100644 --- a/colossalai/inference/modeling/models/glide_llama.py +++ b/colossalai/inference/modeling/models/glide_llama.py @@ -319,7 +319,8 @@ def forward( query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2) # for RoPE - cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len + 32) + position_ids = position_ids + glide_input.n_spec_tokens + cos, sin = self.rotary_emb(query_states, position_ids) query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids) query_states = query_states.transpose(1, 2) query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim) diff --git a/colossalai/inference/spec/struct.py b/colossalai/inference/spec/struct.py index 143f26d09a59..9b52437db2ab 100644 --- a/colossalai/inference/spec/struct.py +++ b/colossalai/inference/spec/struct.py @@ -46,6 +46,7 @@ class GlideInput: large_k_cache: torch.Tensor = None large_v_cache: torch.Tensor = None sequence_lengths: torch.Tensor = None + n_spec_tokens: int = 5 @property def glimpse_ready(self): diff --git a/examples/inference/llama/README.md b/examples/inference/llama/README.md index cde81a41d839..dae7f771cb44 100644 --- a/examples/inference/llama/README.md +++ b/examples/inference/llama/README.md @@ -43,5 +43,8 @@ colossalai run --nproc_per_node 2 llama_generation.py -m PATH_MODEL --drafter_mo If you want to try the GLIDE model (glide-vicuna7b) as the drafter model with vicuna-7B, you could provide the GLIDE model path or model card as drafter model and enable the feature by ```python +from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM +drafter_model = GlideLlamaForCausalLM.from_pretrained(drafter_model_path_or_name) +... engine.enable_spec_dec(drafter_model, use_glide_drafter=True) ```