Skip to content

Commit

Permalink
PaliGemma inference loop.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 29, 2024
1 parent dc2ac98 commit f2a1672
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 3 deletions.
2 changes: 1 addition & 1 deletion candle-examples/examples/paligemma/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ impl TextGeneration {
let start_pos = tokens.len().saturating_sub(context_size);
let ctxt = &tokens[start_pos..];
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?;
let logits = self.model.forward(&input, start_pos)?;
let logits = self.model.forward(&input)?;
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?;
let logits = if self.repeat_penalty == 1. {
logits
Expand Down
20 changes: 20 additions & 0 deletions candle-transformers/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,10 @@ impl Model {
})
}

pub fn embed_tokens(&self) -> &candle_nn::Embedding {
&self.embed_tokens
}

fn prepare_decoder_attention_mask(
&self,
b_size: usize,
Expand Down Expand Up @@ -400,6 +404,22 @@ impl Model {
.apply(&self.lm_head)
}

pub fn forward_embeds(
&mut self,
xs: &Tensor,
attn_mask: Option<&Tensor>,
seqlen_offset: usize,
) -> Result<Tensor> {
let (_, seq_len, _) = xs.dims3()?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attn_mask, seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
}

pub fn clear_kv_cache(&mut self) {
for layer in self.layers.iter_mut() {
layer.clear_kv_cache()
Expand Down
23 changes: 21 additions & 2 deletions candle-transformers/src/models/paligemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ impl Module for MultiModalProjector {

#[derive(Clone, Debug)]
pub struct Model {
pos: usize,
vision_tower: siglip::VisionModel,
multi_modal_projector: MultiModalProjector,
language_model: gemma::Model,
Expand All @@ -52,17 +53,35 @@ impl Model {
let multi_modal_projector = MultiModalProjector::new(cfg, vb.pp("multi_modal_projector"))?;
let language_model = gemma::Model::new(false, &cfg.text_config, vb.pp("language_model"))?;
Ok(Self {
pos: 0,
language_model,
vision_tower,
multi_modal_projector,
})
}

pub fn forward(&mut self, _input_ids: &Tensor, _pos: usize) -> Result<Tensor> {
todo!()
pub fn setup(&mut self, pixel_values: &Tensor, input_ids: &Tensor) -> Result<Tensor> {
self.clear_kv_cache();
let image_features = self
.vision_tower
.forward(pixel_values)?
.apply(&self.multi_modal_projector)?;
let image_features = crate::models::clip::div_l2_norm(&image_features)?;
let text_features = self.language_model.embed_tokens().forward(input_ids)?;
let input_embeds = Tensor::cat(&[image_features, text_features], 1)?;
self.pos = input_embeds.dim(1)?;
self.language_model.forward_embeds(&input_embeds, None, 0)
}

pub fn forward(&mut self, input_ids: &Tensor) -> Result<Tensor> {
let pos = self.pos;
let seq_len = input_ids.dim(1)?;
self.pos = pos + seq_len;
self.language_model.forward(input_ids, pos)
}

pub fn clear_kv_cache(&mut self) {
self.pos = 0;
self.language_model.clear_kv_cache()
}
}

0 comments on commit f2a1672

Please sign in to comment.