Skip to content

Commit

Permalink
Add the pixtral config.
Browse files Browse the repository at this point in the history
  • Loading branch information
LaurentMazare committed Sep 30, 2024
1 parent e8ea0d1 commit 7486cff
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 5 deletions.
15 changes: 12 additions & 3 deletions candle-examples/examples/pixtral/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;

use candle_transformers::models::pixtral::vision_model::{Config, Model};
use candle_transformers::models::pixtral::{vision_model, Config};

use candle::{DType, Module};
use candle_nn::VarBuilder;
Expand Down Expand Up @@ -53,6 +53,9 @@ struct Args {
#[arg(long)]
tokenizer_file: Option<String>,

#[arg(long)]
config_file: Option<String>,

#[arg(long)]
weight_files: Option<String>,

Expand Down Expand Up @@ -126,7 +129,13 @@ fn main() -> Result<()> {
} else {
DType::F32
};
let config = Config::pixtral_12b_2409();
let config: Config = match args.config_file {
Some(config_file) => serde_json::from_slice(&std::fs::read(config_file)?)?,
None => {
let config_file = repo.get("config.json")?;
serde_json::from_slice(&std::fs::read(config_file)?)?
}
};
let image = if args.image.ends_with(".safetensors") {
match candle::safetensors::load(&args.image, &device)?.remove("img") {
None => anyhow::bail!("no img tensor in {}", args.image),
Expand All @@ -144,7 +153,7 @@ fn main() -> Result<()> {
println!("loaded image with shape {:?}", image);
let start = std::time::Instant::now();
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? };
let model = Model::new(&config, vb.pp("vision_tower"))?;
let model = vision_model::Model::new(&config.vision_config, vb.pp("vision_tower"))?;
println!("loaded the model in {:?}", start.elapsed());
let embs = model.forward(&image)?;
println!("EMBS\n{embs}");
Expand Down
10 changes: 10 additions & 0 deletions candle-transformers/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,19 +4,29 @@ use candle::{DType, Device, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;

fn default_num_attention_heads() -> usize {
32
}

fn default_use_flash_attn() -> bool {
false
}

fn default_hidden_act() -> candle_nn::Activation {
candle_nn::Activation::Silu
}

#[derive(Debug, Clone, PartialEq, serde::Deserialize)]
pub struct Config {
pub vocab_size: usize,
pub hidden_size: usize,
pub intermediate_size: usize,
pub num_hidden_layers: usize,
#[serde(default = "default_num_attention_heads")]
pub num_attention_heads: usize,
pub head_dim: Option<usize>,
pub num_key_value_heads: usize,
#[serde(default = "default_hidden_act")]
pub hidden_act: Activation,
pub max_position_embeddings: usize,
pub rms_norm_eps: f64,
Expand Down
2 changes: 1 addition & 1 deletion candle-transformers/src/models/pixtral/llava.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ pub struct Config {
pub projector_hidden_act: candle_nn::Activation,
pub text_config: mistral::Config,
pub vision_config: vision_model::Config,
pub image_token_size: usize,
pub image_token_index: usize,
pub image_seq_length: usize,
}

Expand Down
30 changes: 29 additions & 1 deletion candle-transformers/src/models/pixtral/vision_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,41 @@ fn default_act() -> candle_nn::Activation {
candle_nn::Activation::Gelu
}

fn default_hidden_size() -> usize {
1024
}

fn default_intermediate_size() -> usize {
4096
}

fn default_num_channels() -> usize {
3
}

fn default_num_hidden_layers() -> usize {
24
}

fn default_num_attention_heads() -> usize {
16
}

#[derive(serde::Deserialize, Debug, Clone)]
pub struct Config {
#[serde(default = "default_hidden_size")]
pub hidden_size: usize,
#[serde(default = "default_num_channels")]
pub num_channels: usize,
pub image_size: usize,
pub patch_size: usize,
pub rope_theta: f64,
#[serde(default = "default_intermediate_size")]
pub intermediate_size: usize,
#[serde(default = "default_num_hidden_layers")]
pub num_hidden_layers: usize,
pub head_dim: Option<usize>,
#[serde(default = "default_num_attention_heads")]
pub num_attention_heads: usize,
#[serde(default = "default_act")]
pub hidden_act: candle_nn::Activation,
Expand All @@ -30,13 +56,15 @@ impl Config {
intermediate_size: 4096,
num_hidden_layers: 24,
num_attention_heads: 16,
head_dim: None,
// Default
hidden_act: candle_nn::Activation::Gelu,
}
}

fn head_dim(&self) -> usize {
self.hidden_size / self.num_attention_heads
self.head_dim
.unwrap_or(self.hidden_size / self.num_attention_heads)
}
}

Expand Down

0 comments on commit 7486cff

Please sign in to comment.