From 668348ed6456d027b915b4a7b463aa5c7c896f09 Mon Sep 17 00:00:00 2001 From: Nicolas Hug Date: Mon, 30 Oct 2023 08:01:28 +0000 Subject: [PATCH] PSRoiPool: SymInt support + meta-implem (#8062) --- torchvision/_meta_registrations.py | 34 +++++++++++ .../csrc/ops/autograd/ps_roi_pool_kernel.cpp | 56 +++++++++---------- torchvision/csrc/ops/ps_roi_pool.cpp | 45 ++++++++++++++- torchvision/csrc/ops/ps_roi_pool.h | 19 +++++++ 4 files changed, 124 insertions(+), 30 deletions(-) diff --git a/torchvision/_meta_registrations.py b/torchvision/_meta_registrations.py index 82ef7751f18..15513e538f5 100644 --- a/torchvision/_meta_registrations.py +++ b/torchvision/_meta_registrations.py @@ -126,6 +126,40 @@ def meta_roi_pool_backward( return grad.new_empty((batch_size, channels, height, width)) +@register_meta("ps_roi_pool") +def meta_ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width): + torch._check(rois.size(1) == 5, lambda: "rois must have shape as Tensor[K, 5]") + torch._check( + input.dtype == rois.dtype, + lambda: ( + "Expected tensor for input to have the same type as tensor for rois; " + f"but type {input.dtype} does not equal {rois.dtype}" + ), + ) + channels = input.size(1) + torch._check( + channels % (pooled_height * pooled_width) == 0, + "input channels must be a multiple of pooling height * pooling width", + ) + num_rois = rois.size(0) + out_size = (num_rois, channels // (pooled_height * pooled_width), pooled_height, pooled_width) + return input.new_empty(out_size), torch.empty(out_size, device="meta", dtype=torch.int32) + + +@register_meta("_ps_roi_pool_backward") +def meta_ps_roi_pool_backward( + grad, rois, channel_mapping, spatial_scale, pooled_height, pooled_width, batch_size, channels, height, width +): + torch._check( + grad.dtype == rois.dtype, + lambda: ( + "Expected tensor for grad to have the same type as tensor for rois; " + f"but type {grad.dtype} does not equal {rois.dtype}" + ), + ) + return grad.new_empty((batch_size, channels, height, width)) + + @torch._custom_ops.impl_abstract("torchvision::nms") def meta_nms(dets, scores, iou_threshold): torch._check(dets.dim() == 2, lambda: f"boxes should be a 2d tensor, got {dets.dim()}D") diff --git a/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp b/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp index ddc37262382..39b83819f94 100644 --- a/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp +++ b/torchvision/csrc/ops/autograd/ps_roi_pool_kernel.cpp @@ -15,15 +15,15 @@ class PSROIPoolFunction : public torch::autograd::Function { const torch::autograd::Variable& input, const torch::autograd::Variable& rois, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { + c10::SymInt pooled_height, + c10::SymInt pooled_width) { ctx->saved_data["spatial_scale"] = spatial_scale; ctx->saved_data["pooled_height"] = pooled_height; ctx->saved_data["pooled_width"] = pooled_width; - ctx->saved_data["input_shape"] = input.sizes(); + ctx->saved_data["input_shape"] = input.sym_sizes(); at::AutoDispatchBelowADInplaceOrView g; - auto result = - ps_roi_pool(input, rois, spatial_scale, pooled_height, pooled_width); + auto result = ps_roi_pool_symint( + input, rois, spatial_scale, pooled_height, pooled_width); auto output = std::get<0>(result); auto channel_mapping = std::get<1>(result); @@ -40,18 +40,18 @@ class PSROIPoolFunction : public torch::autograd::Function { auto saved = ctx->get_saved_variables(); auto rois = saved[0]; auto channel_mapping = saved[1]; - auto input_shape = ctx->saved_data["input_shape"].toIntList(); - auto grad_in = detail::_ps_roi_pool_backward( + auto input_shape = ctx->saved_data["input_shape"].toList(); + auto grad_in = detail::_ps_roi_pool_backward_symint( grad_output[0], rois, channel_mapping, ctx->saved_data["spatial_scale"].toDouble(), - ctx->saved_data["pooled_height"].toInt(), - ctx->saved_data["pooled_width"].toInt(), - input_shape[0], - input_shape[1], - input_shape[2], - input_shape[3]); + ctx->saved_data["pooled_height"].toSymInt(), + ctx->saved_data["pooled_width"].toSymInt(), + input_shape[0].get().toSymInt(), + input_shape[1].get().toSymInt(), + input_shape[2].get().toSymInt(), + input_shape[3].get().toSymInt()); return { grad_in, @@ -72,14 +72,14 @@ class PSROIPoolBackwardFunction const torch::autograd::Variable& rois, const torch::autograd::Variable& channel_mapping, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { at::AutoDispatchBelowADInplaceOrView g; - auto grad_in = detail::_ps_roi_pool_backward( + auto grad_in = detail::_ps_roi_pool_backward_symint( grad, rois, channel_mapping, @@ -105,8 +105,8 @@ std::tuple ps_roi_pool_autograd( const at::Tensor& input, const at::Tensor& rois, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width) { + c10::SymInt pooled_height, + c10::SymInt pooled_width) { auto result = PSROIPoolFunction::apply( input, rois, spatial_scale, pooled_height, pooled_width); @@ -118,12 +118,12 @@ at::Tensor ps_roi_pool_backward_autograd( const at::Tensor& rois, const at::Tensor& channel_mapping, double spatial_scale, - int64_t pooled_height, - int64_t pooled_width, - int64_t batch_size, - int64_t channels, - int64_t height, - int64_t width) { + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { return PSROIPoolBackwardFunction::apply( grad, rois, diff --git a/torchvision/csrc/ops/ps_roi_pool.cpp b/torchvision/csrc/ops/ps_roi_pool.cpp index c9f64661033..92469d5e380 100644 --- a/torchvision/csrc/ops/ps_roi_pool.cpp +++ b/torchvision/csrc/ops/ps_roi_pool.cpp @@ -20,6 +20,19 @@ std::tuple ps_roi_pool( return op.call(input, rois, spatial_scale, pooled_height, pooled_width); } +std::tuple ps_roi_pool_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width) { + C10_LOG_API_USAGE_ONCE("torchvision.csrc.ops.ps_roi_pool.ps_roi_pool"); + static auto op = c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::ps_roi_pool", "") + .typed(); + return op.call(input, rois, spatial_scale, pooled_height, pooled_width); +} + namespace detail { at::Tensor _ps_roi_pool_backward( @@ -50,13 +63,41 @@ at::Tensor _ps_roi_pool_backward( width); } +at::Tensor _ps_roi_pool_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width) { + static auto op = + c10::Dispatcher::singleton() + .findSchemaOrThrow("torchvision::_ps_roi_pool_backward", "") + .typed(); + return op.call( + grad, + rois, + channel_mapping, + spatial_scale, + pooled_height, + pooled_width, + batch_size, + channels, + height, + width); +} + } // namespace detail TORCH_LIBRARY_FRAGMENT(torchvision, m) { m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, int pooled_height, int pooled_width) -> (Tensor, Tensor)")); + "torchvision::ps_roi_pool(Tensor input, Tensor rois, float spatial_scale, SymInt pooled_height, SymInt pooled_width) -> (Tensor, Tensor)")); m.def(TORCH_SELECTIVE_SCHEMA( - "torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, int pooled_height, int pooled_width, int batch_size, int channels, int height, int width) -> Tensor")); + "torchvision::_ps_roi_pool_backward(Tensor grad, Tensor rois, Tensor channel_mapping, float spatial_scale, SymInt pooled_height, SymInt pooled_width, SymInt batch_size, SymInt channels, SymInt height, SymInt width) -> Tensor")); } } // namespace ops diff --git a/torchvision/csrc/ops/ps_roi_pool.h b/torchvision/csrc/ops/ps_roi_pool.h index 20c2511e7aa..4a3cc54e0e5 100644 --- a/torchvision/csrc/ops/ps_roi_pool.h +++ b/torchvision/csrc/ops/ps_roi_pool.h @@ -13,6 +13,13 @@ VISION_API std::tuple ps_roi_pool( int64_t pooled_height, int64_t pooled_width); +VISION_API std::tuple ps_roi_pool_symint( + const at::Tensor& input, + const at::Tensor& rois, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width); + namespace detail { at::Tensor _ps_roi_pool_backward( @@ -27,6 +34,18 @@ at::Tensor _ps_roi_pool_backward( int64_t height, int64_t width); +at::Tensor _ps_roi_pool_backward_symint( + const at::Tensor& grad, + const at::Tensor& rois, + const at::Tensor& channel_mapping, + double spatial_scale, + c10::SymInt pooled_height, + c10::SymInt pooled_width, + c10::SymInt batch_size, + c10::SymInt channels, + c10::SymInt height, + c10::SymInt width); + } // namespace detail } // namespace ops