diff --git a/onnxruntime/core/providers/cpu/tensor/affine_grid.cc b/onnxruntime/core/providers/cpu/tensor/affine_grid.cc index 12ef99ab75e55..df25a7a233685 100644 --- a/onnxruntime/core/providers/cpu/tensor/affine_grid.cc +++ b/onnxruntime/core/providers/cpu/tensor/affine_grid.cc @@ -71,44 +71,39 @@ void generate_base_grid_3d(int64_t D, int64_t H, int64_t W, bool align_corners, } template -struct AffineGridGenerator2D { - void operator()(const Tensor* theta, const Eigen::Matrix& base_grid_transposed, int64_t batch_num, int64_t H, int64_t W, Tensor* grid) const { - const Eigen::StorageOptions option = Eigen::RowMajor; - auto theta_batch_offset = batch_num * 2 * 3; - const T* theta_data = theta->Data() + theta_batch_offset; - const Eigen::Matrix theta_R{{theta_data[0], theta_data[1]}, {theta_data[3], theta_data[4]}}; - const Eigen::Array theta_T(theta_data[2], theta_data[5]); - - auto grid_batch_offset = batch_num * H * W * 2; - T* grid_data = grid->MutableData() + grid_batch_offset; - Eigen::Map> grid_matrix(grid_data, narrow(H * W), 2); - grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose(); - } -}; +void affine_grid_generator_2d (const Tensor* theta, const Eigen::Matrix& base_grid_transposed, int64_t batch_num, int64_t H, int64_t W, Tensor* grid) { + const Eigen::StorageOptions option = Eigen::RowMajor; + auto theta_batch_offset = batch_num * 2 * 3; + const T* theta_data = theta->Data() + theta_batch_offset; + const Eigen::Matrix theta_R{{theta_data[0], theta_data[1]}, {theta_data[3], theta_data[4]}}; + const Eigen::Array theta_T(theta_data[2], theta_data[5]); + + auto grid_batch_offset = batch_num * H * W * 2; + T* grid_data = grid->MutableData() + grid_batch_offset; + Eigen::Map> grid_matrix(grid_data, narrow(H * W), 2); + grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose(); +} template -struct AffineGridGenerator3D { - void operator()(const Tensor* theta, const Eigen::Matrix& base_grid_transposed, int64_t batch_num, int64_t D, int64_t H, int64_t W, Tensor* grid) { - const Eigen::StorageOptions option = Eigen::RowMajor; - auto theta_batch_offset = batch_num * 3 * 4; - const T* theta_data = theta->Data() + theta_batch_offset; - const Eigen::Matrix theta_R{ - {theta_data[0], theta_data[1], theta_data[2]}, - {theta_data[4], theta_data[5], theta_data[6]}, - {theta_data[8], theta_data[9], theta_data[10]}}; - const Eigen::Array theta_T(theta_data[3], theta_data[7], theta_data[11]); - - auto grid_batch_offset = batch_num * D * H * W * 3; - T* grid_data = grid->MutableData() + grid_batch_offset; - Eigen::Map> grid_matrix(grid_data, narrow(D * H * W), 3); - grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose(); - } -}; +void affine_grid_generator_3d (const Tensor* theta, const Eigen::Matrix& base_grid_transposed, int64_t batch_num, int64_t D, int64_t H, int64_t W, Tensor* grid) { + const Eigen::StorageOptions option = Eigen::RowMajor; + auto theta_batch_offset = batch_num * 3 * 4; + const T* theta_data = theta->Data() + theta_batch_offset; + const Eigen::Matrix theta_R{ + {theta_data[0], theta_data[1], theta_data[2]}, + {theta_data[4], theta_data[5], theta_data[6]}, + {theta_data[8], theta_data[9], theta_data[10]}}; + const Eigen::Array theta_T(theta_data[3], theta_data[7], theta_data[11]); + + auto grid_batch_offset = batch_num * D * H * W * 3; + T* grid_data = grid->MutableData() + grid_batch_offset; + Eigen::Map> grid_matrix(grid_data, narrow(D * H * W), 3); + grid_matrix = ((theta_R * base_grid_transposed).array().colwise() + theta_T).matrix().transpose(); +} template Status AffineGrid::Compute(OpKernelContext* context) const { const Tensor* theta = context->Input(0); - const auto elem_type = theta->GetElementType(); const TensorShape& theta_shape = theta->Shape(); if (theta_shape.NumDimensions() != 3) { return ORT_MAKE_STATUS(ONNXRUNTIME, INVALID_ARGUMENT, "AffineGrid : Input theta tensor dimension is not 3"); @@ -128,9 +123,8 @@ Status AffineGrid::Compute(OpKernelContext* context) const { generate_base_grid_2d(H, W, align_corners_, base_grid); Eigen::Matrix base_grid_transposed = base_grid.transpose(); - std::function fn = [elem_type, theta, base_grid_transposed, H, W, grid](ptrdiff_t batch_num) { - utils::MLTypeCallDispatcher t_disp(elem_type); - t_disp.Invoke(theta, base_grid_transposed, batch_num, H, W, grid); + std::function fn = [theta, base_grid_transposed, H, W, grid](ptrdiff_t batch_num) { + affine_grid_generator_2d(theta, base_grid_transposed, batch_num, H, W, grid); }; concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow(N), std::move(fn), 0); @@ -144,9 +138,8 @@ Status AffineGrid::Compute(OpKernelContext* context) const { generate_base_grid_3d(D, H, W, align_corners_, base_grid); Eigen::Matrix base_grid_transposed = base_grid.transpose(); - std::function fn = [elem_type, theta, base_grid_transposed, D, H, W, grid](ptrdiff_t batch_num) { - utils::MLTypeCallDispatcher t_disp(elem_type); - t_disp.Invoke(theta, base_grid_transposed, batch_num, D, H, W, grid); + std::function fn = [theta, base_grid_transposed, D, H, W, grid](ptrdiff_t batch_num) { + affine_grid_generator_3d(theta, base_grid_transposed, batch_num, D, H, W, grid); }; concurrency::ThreadPool::TryBatchParallelFor(context->GetOperatorThreadPool(), narrow(N), std::move(fn), 0);