Skip to content

Commit

Permalink
[Fix] Fix spec-dec Glide LlamaModel for compatibility with transforme…
Browse files Browse the repository at this point in the history
…rs (#5837)

* fix glide llama model

* revise
  • Loading branch information
yuanheng-zhao authored Jun 19, 2024
1 parent fd1dc41 commit 7b249c7
Show file tree
Hide file tree
Showing 4 changed files with 7 additions and 1 deletion.
1 change: 1 addition & 0 deletions colossalai/inference/core/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,6 +466,7 @@ def steps_spec_dec(self) -> List[Sequence]:
self.k_cache[-1], # use kv cahces of the last layer
self.v_cache[-1],
batch.get_sequence_lengths(),
n_spec_tokens=self.n_spec_tokens,
)

drafter_out = self.drafter.speculate(
Expand Down
3 changes: 2 additions & 1 deletion colossalai/inference/modeling/models/glide_llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,8 @@ def forward(
query_states = query_states.view(bsz, -1, self.large_num_heads, self.large_head_dim).transpose(1, 2)

# for RoPE
cos, sin = self.rotary_emb(query_states, seq_len=kv_seq_len + 32)
position_ids = position_ids + glide_input.n_spec_tokens
cos, sin = self.rotary_emb(query_states, position_ids)
query_states = apply_single_rotary_pos_emb(query_states, cos, sin, position_ids)
query_states = query_states.transpose(1, 2)
query_states = query_states.reshape(-1, self.large_num_heads, self.large_head_dim)
Expand Down
1 change: 1 addition & 0 deletions colossalai/inference/spec/struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ class GlideInput:
large_k_cache: torch.Tensor = None
large_v_cache: torch.Tensor = None
sequence_lengths: torch.Tensor = None
n_spec_tokens: int = 5

@property
def glimpse_ready(self):
Expand Down
3 changes: 3 additions & 0 deletions examples/inference/llama/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,5 +43,8 @@ colossalai run --nproc_per_node 2 llama_generation.py -m PATH_MODEL --drafter_mo

If you want to try the GLIDE model (glide-vicuna7b) as the drafter model with vicuna-7B, you could provide the GLIDE model path or model card as drafter model and enable the feature by
```python
from colossalai.inference.modeling.models.glide_llama import GlideLlamaForCausalLM
drafter_model = GlideLlamaForCausalLM.from_pretrained(drafter_model_path_or_name)
...
engine.enable_spec_dec(drafter_model, use_glide_drafter=True)
```

0 comments on commit 7b249c7

Please sign in to comment.