diff --git a/src/utils.rs b/src/utils.rs index 207bcfb..8a27007 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -317,7 +317,7 @@ pub fn preprocess(image: &Tensor, square_size: i64) -> Tensor { let (_, height, width) = image.size3().unwrap(); let (uw, uh) = square64(square_size, width, height); let scaled_image = if height == uh && width == uw { - let out = Tensor::ones(image.size(), (Kind::Int, Device::Cpu)); + let out = Tensor::ones(image.size(), (Kind::Uint8, Device::Cpu)); image.clone(&out) } else { tch::vision::image::resize(&image, uw, uh).expect("can't resize image") @@ -355,10 +355,10 @@ mod test { #[test] fn test_preprocess() { - let image = Tensor::from_slice(&[ + let image_data: [u8; 24] = [ 1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28, - ]) - .reshape([3, 2, 4]); + ]; + let image = Tensor::from_slice(&image_data).reshape([3, 2, 4]); let bg = preprocess(&image, 4); let expect_bg = Tensor::from_slice(&[ 114, 114, 114, 114, 1, 2, 3, 4, 5, 6, 7, 8, 114, 114, 114, 114, 114, 114, 114, 114, 11,