diff --git a/candle-examples/examples/pixtral/main.rs b/candle-examples/examples/pixtral/main.rs index 6e26d30128..2118e47b13 100644 --- a/candle-examples/examples/pixtral/main.rs +++ b/candle-examples/examples/pixtral/main.rs @@ -90,7 +90,6 @@ impl TextGeneration { let image_embeds = self.model.vision_tower.forward(&self.image)?; let image_embeds = self.model.multi_modal_projector.forward(&image_embeds)?; println!("generated image embeddings {image_embeds:?}"); - let image_embeds = image_embeds.to_dtype(self.model.dtype)?; for &t in tokens.iter() { if let Some(t) = self.tokenizer.next_token(t)? { print!("{t}") @@ -279,7 +278,7 @@ fn main() -> Result<()> { let device = candle_examples::device(args.cpu)?; let dtype = if device.supports_bf16() && !args.vision_only { - DType::BF16 + DType::F32 } else { DType::F32 }; @@ -303,10 +302,7 @@ fn main() -> Result<()> { &[0.26862954, 0.261_302_6, 0.275_777_1], )? }; - let image = image - .to_device(&device)? - .to_dtype(DType::F32)? - .unsqueeze(0)?; + let image = image.to_device(&device)?.to_dtype(dtype)?.unsqueeze(0)?; println!("loaded image with shape {:?}", image); let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };