Skip to content

Commit

Permalink
correct dtype is u8
Browse files Browse the repository at this point in the history
  • Loading branch information
Michal Conos committed Sep 12, 2024
1 parent 48d88a9 commit b65fa7d
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ pub fn preprocess(image: &Tensor, square_size: i64) -> Tensor {
let (_, height, width) = image.size3().unwrap();
let (uw, uh) = square64(square_size, width, height);
let scaled_image = if height == uh && width == uw {
let out = Tensor::ones(image.size(), (Kind::Int, Device::Cpu));
let out = Tensor::ones(image.size(), (Kind::Uint8, Device::Cpu));
image.clone(&out)
} else {
tch::vision::image::resize(&image, uw, uh).expect("can't resize image")
Expand Down Expand Up @@ -355,10 +355,10 @@ mod test {

#[test]
fn test_preprocess() {
let image = Tensor::from_slice(&[
let image_data: [u8; 24] = [
1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28,
])
.reshape([3, 2, 4]);
];
let image = Tensor::from_slice(&image_data).reshape([3, 2, 4]);
let bg = preprocess(&image, 4);
let expect_bg = Tensor::from_slice(&[
114, 114, 114, 114, 1, 2, 3, 4, 5, 6, 7, 8, 114, 114, 114, 114, 114, 114, 114, 114, 11,
Expand Down

0 comments on commit b65fa7d

Please sign in to comment.