-
Notifications
You must be signed in to change notification settings - Fork 8
/
main.rs
60 lines (50 loc) · 1.8 KB
/
main.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
use usls::{models::Clip, DataLoader, Options};
fn main() -> Result<(), Box<dyn std::error::Error>> {
// visual
let options_visual = Options::default().with_model("clip/visual-base-dyn.onnx")?;
// textual
let options_textual = Options::default()
.with_model("clip/textual-base-dyn.onnx")?
.with_tokenizer("clip/tokenizer.json")?;
// build model
let mut model = Clip::new(options_visual, options_textual)?;
// texts
let texts = vec![
"A photo of a dinosaur ".to_string(),
"A photo of a cat".to_string(),
"A photo of a dog".to_string(),
"几个胡萝卜".to_string(),
"There are some playing cards on a striped table cloth".to_string(),
"There is a doll with red hair and a clock on a table".to_string(),
"Some people holding wine glasses in a restaurant".to_string(),
];
let feats_text = model.encode_texts(&texts)?; // [n, ndim]
// load image
let dl = DataLoader::new("./examples/clip/images")?.build()?;
// loop
for (images, paths) in dl {
let feats_image = model.encode_images(&images).unwrap();
// use image to query texts
let matrix = match feats_image.embedding() {
Some(x) => x.dot2(feats_text.embedding().unwrap())?,
None => continue,
};
// summary
for i in 0..paths.len() {
let probs = &matrix[i];
let (id, &score) = probs
.iter()
.enumerate()
.reduce(|max, x| if x.1 > max.1 { x } else { max })
.unwrap();
println!(
"({:?}%) {} => {} ",
score * 100.0,
paths[i].display(),
&texts[id]
);
println!("{:?}\n", probs);
}
}
Ok(())
}