Skip to content

Commit

Permalink
Merge pull request #413 from robertknight/detr-finetunes
Browse files Browse the repository at this point in the history
Support other DETR-based models in the DETR example
  • Loading branch information
robertknight authored Nov 24, 2024
2 parents 6117201 + f932f20 commit 39dc4c7
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions rten-examples/src/detr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ struct Args {
annotated_image: Option<String>,
min_size: Option<u32>,
max_size: Option<u32>,
labels: Option<Vec<String>>,
threshold: f32,
}

fn parse_args() -> Result<Args, lexopt::Error> {
Expand All @@ -21,8 +23,10 @@ fn parse_args() -> Result<Args, lexopt::Error> {
let mut values = VecDeque::new();
let mut parser = lexopt::Parser::from_env();
let mut annotated_image = None;
let mut labels = None;
let mut min_size = None;
let mut max_size = None;
let mut threshold = 0.5;

while let Some(arg) = parser.next()? {
match arg {
Expand All @@ -39,6 +43,11 @@ Options:
Annotate image with bounding boxes and save to <path>
--labels <string>
Comma-separated list of class labels. For Hugging Face models see the
\"id2label\" map in the model's config.json file.
--max <length>
Set the maximum length of the longest side of the model's input. Defaults to
Expand All @@ -51,6 +60,10 @@ Options:
Set the minimum length of the shortest side of the model's input. Defaults to
`800 / 1333` * `<longest side or 1333>`.
--threshold <value>
Probability threshold for object detections.
",
bin_name = parser.bin_name().unwrap_or("detr")
);
Expand All @@ -59,12 +72,24 @@ Options:
Long("annotate") => {
annotated_image = Some(parser.value()?.string()?);
}
Long("labels") => {
let label_list = parser.value()?.string()?;
labels = Some(
label_list
.split(',')
.map(|lb| lb.trim().to_string())
.collect(),
);
}
Long("max") => {
max_size = Some(parser.value()?.parse()?);
}
Long("min") => {
min_size = Some(parser.value()?.parse()?);
}
Long("threshold") => {
threshold = parser.value()?.parse::<f32>()?.clamp(0., 1.);
}
_ => return Err(arg.unexpected()),
}
}
Expand All @@ -76,8 +101,10 @@ Options:
model,
image,
annotated_image,
labels,
min_size,
max_size,
threshold,
};

Ok(args)
Expand Down Expand Up @@ -160,7 +187,11 @@ fn rescaled_size(

// Labels obtained from `id2label` map in
// https://huggingface.co/facebook/detr-resnet-50/blob/main/config.json.
const LABELS: &[&str] = &[
//
// The original model uses class 0 to represent no object. Some other
// DETR-based models use the maximum class ID to represent "no object" instead.
const DEFAULT_LABELS: &[&str] = &[
"N/A",
"person",
"bicycle",
"car",
Expand Down Expand Up @@ -253,6 +284,9 @@ const LABELS: &[&str] = &[
"toothbrush",
];

/// Labels which represent no detection.
const NO_OBJECT_LABELS: &[&str] = &["N/A"];

/// Detect objects in images using DETR [^1].
///
/// The DETR model [^2] can be obtained from Hugging Face and converted to this
Expand Down Expand Up @@ -343,13 +377,23 @@ fn main() -> Result<(), Box<dyn Error>> {
painter.set_stroke_width(stroke_width);
}

let labels = args
.labels
.unwrap_or(DEFAULT_LABELS.iter().map(|lb| lb.to_string()).collect());

for obj in 0..n_objects {
let cls = classes[[0, obj]] as usize;
let prob = probs[[0, obj, cls]];

let Some(label) = LABELS.get(cls - 1) else {
let Some(label) = labels.get(cls) else {
continue;
};
if NO_OBJECT_LABELS.contains(&label.as_str()) {
continue;
}
if prob < args.threshold {
continue;
}

let [center_x, center_y, width, height] = boxes.slice([0, obj]).to_array();

Expand Down

0 comments on commit 39dc4c7

Please sign in to comment.