Skip to content

Commit

Permalink
Support embedding model gte-Qwen1.5-7B-instruct (#2190)
Browse files Browse the repository at this point in the history
* Support embedding model gte-Qwen1.5-7B-instruct

This is a text embedding model based on Qwen2. They share same
model architecture except the last MLP module. This commit brings in
minimal modification of the old Qwen2 implementation to support both
models.

An example is provided, and had been verified according to the official
PyTorch implementation.

* Avoid doing the 'last-token filtering' based on the absence of attention mask.

---------

Co-authored-by: Laurent <[email protected]>
  • Loading branch information
yinguobing and LaurentMazare authored May 16, 2024
1 parent bdaa342 commit 349c3e8
Show file tree
Hide file tree
Showing 4 changed files with 260 additions and 16 deletions.
19 changes: 19 additions & 0 deletions candle-examples/examples/gte-qwen/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# gte-Qwen1.5-7B-instruct

gte-Qwen1.5-7B-instruct is a variant of the GTE embedding model family.

- [Model card](https://huggingface.co/Alibaba-NLP/gte-Qwen1.5-7B-instruct) on the HuggingFace Hub.
- [Technical report](https://arxiv.org/abs/2308.03281) *Towards General Text Embeddings with Multi-stage Contrastive Learning*


## Running the example

Automatically download the model from the HuggingFace hub:
```bash
$ cargo run --example gte-qwen --release
```

or, load the model from a local directory:
```bash
cargo run --example gte-qwen --release --features cuda -- --local-repo /path/to/gte_Qwen1.5-7B-instruct/
```
178 changes: 178 additions & 0 deletions candle-examples/examples/gte-qwen/main.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,178 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

#[cfg(feature = "accelerate")]
extern crate accelerate_src;

use anyhow::{Error as E, Result};
use clap::Parser;

use candle_transformers::models::qwen2::{Config, Model};

use candle::{DType, Tensor};
use candle_nn::VarBuilder;
use hf_hub::{api::sync::Api, Repo, RepoType};
use tokenizers::{
utils::padding::{PaddingDirection, PaddingParams, PaddingStrategy},
Tokenizer,
};

// gte-Qwen1.5-7B-instruct use EOS token as padding token
const EOS_TOKEN: &str = "<|endoftext|>";
const EOS_TOKEN_ID: u32 = 151643;

#[derive(Parser, Debug)]
#[command(author, version, about, long_about = None)]
struct Args {
/// Run on CPU rather than on GPU.
#[arg(long)]
cpu: bool,

/// Enable tracing (generates a trace-timestamp.json file).
#[arg(long)]
tracing: bool,

#[arg(long, default_value = "Alibaba-NLP/gte-Qwen1.5-7B-instruct")]
model_id: String,

#[arg(long, default_value = "main")]
revision: String,

#[arg(long)]
local_repo: Option<String>,
}

#[derive(Debug)]
struct ConfigFiles {
pub config: std::path::PathBuf,
pub tokenizer: std::path::PathBuf,
pub weights: Vec<std::path::PathBuf>,
}

// Loading the model from the HuggingFace Hub. Network access is required.
fn load_from_hub(model_id: &str, revision: &str) -> Result<ConfigFiles> {
let api = Api::new()?;
let repo = api.repo(Repo::with_revision(
model_id.to_string(),
RepoType::Model,
revision.to_string(),
));
Ok(ConfigFiles {
config: repo.get("config.json")?,
tokenizer: repo.get("tokenizer.json")?,
weights: candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?,
})
}

// Loading the model from a local directory.
fn load_from_local(local_path: &str) -> Result<ConfigFiles> {
let local_path = std::path::PathBuf::from(local_path);
let weight_path = local_path.join("model.safetensors.index.json");
let json: serde_json::Value = serde_json::from_str(&std::fs::read_to_string(weight_path)?)?;
let weight_map = match json.get("weight_map") {
Some(serde_json::Value::Object(map)) => map,
Some(_) => panic!("`weight map` is not a map"),
None => panic!("`weight map` not found"),
};
let mut safetensors_files = std::collections::HashSet::new();
for value in weight_map.values() {
safetensors_files.insert(
value
.as_str()
.expect("Weight files should be parsed as strings"),
);
}
let safetensors_paths = safetensors_files
.iter()
.map(|v| local_path.join(v))
.collect::<Vec<_>>();
Ok(ConfigFiles {
config: local_path.join("config.json"),
tokenizer: local_path.join("tokenizer.json"),
weights: safetensors_paths,
})
}

fn main() -> Result<()> {
use tracing_chrome::ChromeLayerBuilder;
use tracing_subscriber::prelude::*;

let args = Args::parse();
let _guard = if args.tracing {
let (chrome_layer, guard) = ChromeLayerBuilder::new().build();
tracing_subscriber::registry().with(chrome_layer).init();
Some(guard)
} else {
None
};

// Fetch the model. Do this offline if local path provided.
println!("Fetching model files...");
let start = std::time::Instant::now();
let config_files = match args.local_repo {
Some(local_path) => load_from_local(&local_path)?,
None => load_from_hub(&args.model_id, &args.revision)?,
};
println!("Model file retrieved in {:?}", start.elapsed());

// Inputs will be padded to the longest sequence in the batch.
let padding = PaddingParams {
strategy: PaddingStrategy::BatchLongest,
direction: PaddingDirection::Left,
pad_to_multiple_of: None,
pad_id: EOS_TOKEN_ID,
pad_type_id: 0,
pad_token: String::from(EOS_TOKEN),
};

// Tokenizer setup
let mut tokenizer = Tokenizer::from_file(config_files.tokenizer).map_err(E::msg)?;
tokenizer.with_padding(Some(padding));

// Model initialization
let device = candle_examples::device(args.cpu)?;
let dtype = if device.is_cuda() {
DType::BF16
} else {
DType::F32
};
let config: Config = serde_json::from_slice(&std::fs::read(config_files.config)?)?;
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&config_files.weights, dtype, &device)? };
let mut model = Model::new(&config, vb)?;
println!("Model loaded in {:?}", start.elapsed());

// Encode the queries and the targets
let instruct = "Instruct: Given a web search query, retrieve relevant passages that answer the query\nQuery: ";
let documents = vec![
format!("{instruct}how much protein should a female eat{EOS_TOKEN}"),
format!("{instruct}summit define{EOS_TOKEN}"),
format!("As a general guideline, the CDC's average requirement of protein for women ages 19 to 70 is 46 grams per day. But, as you can see from this chart, you'll need to increase that if you're expecting or training for a marathon. Check out the chart below to see how much protein you should be eating each day.{EOS_TOKEN}"),
format!("Definition of summit for English Language Learners. : 1 the highest point of a mountain : the top of a mountain. : 2 the highest level. : 3 a meeting or series of meetings between the leaders of two or more governments.{EOS_TOKEN}"),
];
let encoded = tokenizer.encode_batch(documents, true).map_err(E::msg)?;
let tokens: Vec<&[u32]> = encoded.iter().map(|x| x.get_ids()).collect();
let tokens = Tensor::new(tokens, &device)?;
let mask: Vec<&[u32]> = encoded.iter().map(|x| x.get_attention_mask()).collect();
let mask = Tensor::new(mask, &device)?;

// Inference
let start_gen = std::time::Instant::now();
let logits = model.forward(&tokens, 0, Some(&mask))?;

// Extract the last hidden states as embeddings since inputs are padded left.
let (_, seq_len, _) = logits.dims3()?;
let embd = logits
.narrow(1, seq_len - 1, 1)?
.squeeze(1)?
.to_dtype(DType::F32)?;

// Calculate the relativity scores. Note the embeddings should be normalized.
let norm = embd.broadcast_div(&embd.sqr()?.sum_keepdim(1)?.sqrt()?)?;
let scores = norm.narrow(0, 0, 2)?.matmul(&norm.narrow(0, 2, 2)?.t()?)?;

// Print the results
println!("Embedding done in {:?}", start_gen.elapsed());
println!("Scores: {:?}", scores.to_vec2::<f32>()?);

Ok(())
}
2 changes: 1 addition & 1 deletion candle-examples/examples/qwen/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ extern crate accelerate_src;
use anyhow::{Error as E, Result};
use clap::Parser;

use candle_transformers::models::qwen2::{Config as ConfigBase, Model as ModelBase};
use candle_transformers::models::qwen2::{Config as ConfigBase, ModelForCausalLM as ModelBase};
use candle_transformers::models::qwen2_moe::{Config as ConfigMoe, Model as ModelMoe};

use candle::{DType, Device, Tensor};
Expand Down
77 changes: 62 additions & 15 deletions candle-transformers/src/models/qwen2.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::models::with_tracing::{linear, linear_no_bias, Linear, RmsNorm};
use candle::{DType, Device, Module, Result, Tensor, D};
use candle::{DType, Device, IndexOp, Module, Result, Tensor, D};
use candle_nn::{Activation, VarBuilder};
use std::sync::Arc;

Expand Down Expand Up @@ -250,7 +250,6 @@ pub struct Model {
embed_tokens: candle_nn::Embedding,
layers: Vec<DecoderLayer>,
norm: RmsNorm,
lm_head: Linear,
sliding_window: usize,
device: Device,
dtype: DType,
Expand All @@ -269,19 +268,17 @@ impl Model {
layers.push(layer)
}
let norm = RmsNorm::new(cfg.hidden_size, cfg.rms_norm_eps, vb_m.pp("norm"))?;
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
Ok(Self {
embed_tokens,
layers,
norm,
lm_head,
sliding_window: cfg.sliding_window,
device: vb.device().clone(),
dtype: vb.dtype(),
})
}

fn prepare_decoder_attention_mask(
fn prepare_causal_attention_mask(
&self,
b_size: usize,
tgt_len: usize,
Expand All @@ -301,7 +298,7 @@ impl Model {
.collect();
let mask = Tensor::from_slice(&mask, (tgt_len, tgt_len), &self.device)?;
let mask = if seqlen_offset > 0 {
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), DType::F32, &self.device)?;
let mask0 = Tensor::zeros((tgt_len, seqlen_offset), self.dtype, &self.device)?;
Tensor::cat(&[&mask0, &mask], D::Minus1)?
} else {
mask
Expand All @@ -310,21 +307,42 @@ impl Model {
.to_dtype(self.dtype)
}

pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
fn prepare_attention_mask(&self, attn_mask: &Tensor) -> Result<Tensor> {
let (b_sz, sql_len) = attn_mask.dims2()?;
let mut mask: Vec<Tensor> = vec![];
for b in 0..b_sz {
mask.push(attn_mask.i((b, ..))?.expand((1, 1, sql_len, sql_len))?);
}
let mask = Tensor::cat(&mask, 0)?;
let on_true = mask.zeros_like()?.to_dtype(self.dtype)?;
let on_false = Tensor::new(f32::NEG_INFINITY, &self.device)?
.broadcast_as(mask.shape())?
.to_dtype(self.dtype)?;
mask.where_cond(&on_true, &on_false)
}

pub fn forward(
&mut self,
input_ids: &Tensor,
seqlen_offset: usize,
attn_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b_size, seq_len) = input_ids.dims2()?;
let attention_mask = if seq_len <= 1 {
None
} else {
let mask = self.prepare_decoder_attention_mask(b_size, seq_len, seqlen_offset)?;
Some(mask)
let attention_mask: Option<Tensor> = match attn_mask {
Some(mask) => Some(self.prepare_attention_mask(mask)?),
None => {
if seq_len <= 1 {
None
} else {
Some(self.prepare_causal_attention_mask(b_size, seq_len, seqlen_offset)?)
}
}
};
let mut xs = self.embed_tokens.forward(input_ids)?;
for layer in self.layers.iter_mut() {
xs = layer.forward(&xs, attention_mask.as_ref(), seqlen_offset)?
}
xs.narrow(1, seq_len - 1, 1)?
.apply(&self.norm)?
.apply(&self.lm_head)
xs.apply(&self.norm)
}

pub fn clear_kv_cache(&mut self) {
Expand All @@ -333,3 +351,32 @@ impl Model {
}
}
}

#[derive(Debug, Clone)]
pub struct ModelForCausalLM {
base_model: Model,
lm_head: Linear,
}

impl ModelForCausalLM {
pub fn new(cfg: &Config, vb: VarBuilder) -> Result<Self> {
let lm_head = linear_no_bias(cfg.hidden_size, cfg.vocab_size, vb.pp("lm_head"))?;
let base_model = Model::new(cfg, vb)?;
Ok(Self {
base_model,
lm_head,
})
}

pub fn forward(&mut self, input_ids: &Tensor, seqlen_offset: usize) -> Result<Tensor> {
let (_b_size, seq_len) = input_ids.dims2()?;
self.base_model
.forward(input_ids, seqlen_offset, None)?
.narrow(1, seq_len - 1, 1)?
.apply(&self.lm_head)
}

pub fn clear_kv_cache(&mut self) {
self.base_model.clear_kv_cache()
}
}

0 comments on commit 349c3e8

Please sign in to comment.