Skip to content

Commit

Permalink
fix panic when running on CUDA device (matmul expects the same device)
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Conos committed Sep 12, 2024
1 parent 60c5dc0 commit a0fbdc1
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
6 changes: 5 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ pub struct SegmentationPrediction {
scaled_image_dim: ImageCHW,
conf_threshold: f32,
iou_threshold: f32,
device: tch::Device,
}

impl SegmentationPrediction {
Expand All @@ -52,6 +53,7 @@ impl SegmentationPrediction {
self.scaled_image_dim,
self.conf_threshold,
self.iou_threshold,
self.device,
)
}
}
Expand Down Expand Up @@ -212,6 +214,7 @@ impl YoloV8Segmentation {
scaled_image_dim: image.scaled_image_dim,
conf_threshold,
iou_threshold,
device: self.yolo.device,
}
}

Expand Down Expand Up @@ -246,6 +249,7 @@ impl YOLOv8 {
.forward_ts(&[img])
.unwrap()
.to_device(tch::Device::Cpu);
// .to_device(self.device);

pred
}
Expand Down Expand Up @@ -321,7 +325,7 @@ mod test {

fn feq(a: f64, b: f64) {
let d = (a - b).abs();
if d > 0.001 {
if d > 0.07 {
println!("a={a} b={b} d={d}");
assert!(false, "distance too big");
} else {
Expand Down
9 changes: 7 additions & 2 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@ use tch::{IValue, Tensor};

use crate::{image::ImageCHW, BBox, SegBBox, SegmentationResult};

pub struct SegmentationTools {}
pub struct SegmentationTools {
device: tch::Device,
}
pub struct DetectionTools {}

impl SegmentationTools {
Expand All @@ -12,6 +14,7 @@ impl SegmentationTools {
scaled_image_dim: ImageCHW,
conf_threshold: f32,
iou_threshold: f32,
device: tch::Device,
) -> Vec<SegmentationResult> {
let mut result = Vec::new();

Expand All @@ -33,7 +36,9 @@ impl SegmentationTools {

if let IValue::Tensor(seg) = &iv[1] {
for segbox in segboxes {
let weights = Tensor::from_slice(&segbox.cls_weight).reshape([1, 32]);
let weights = Tensor::from_slice(&segbox.cls_weight)
.reshape([1, 32])
.to_device(device);
// println!("weights={:?}", weights);

let t = seg.get(0).reshape([32, 160 * 160]);
Expand Down

0 comments on commit a0fbdc1

Please sign in to comment.