Skip to content

Commit

Permalink
Fix Dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Feb 14, 2024
1 parent 75042c7 commit 0bbd3c4
Showing 1 changed file with 66 additions and 57 deletions.
123 changes: 66 additions & 57 deletions onnxruntime/core/providers/cuda/tensor/resize_antialias_impl.cu
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ __global__ void ComputeInterpolationAtLevel1(
const fast_divmod div_output_width,
int32_t window_size,
bool use_extrapolation, float extrapolation_value,
const uint8_t* clip8_lookups,
const uint8_t* clip8_table,
const int64_t* bound_data,
std::tuple<int64_t*, int64_t*> outof_bounds_buffers,
const AccumType* weight_coefficients,
Expand All @@ -162,30 +162,30 @@ __global__ void ComputeInterpolationAtLevel1(
int output_y, output_x;
div_output_width.divmod(id, output_y, output_x);

// if (use_extrapolation) {
// const auto* w_outof_bounds = std::get<1>(outof_bounds_buffers);
// // Extrapolate along the w dimension
// if (w_outof_bounds[static_cast<ptrdiff_t>(output_x)] != -1) {
// T* Ydata_offset = Ydata + output_y * output_width;
//#ifdef CPU_TESTING
// assert(w_outof_bounds[static_cast<ptrdiff_t>(output_x)] == output_x);
//#endif
// Ydata_offset[static_cast<ptrdiff_t>(output_x)] = static_cast<T>(extrapolation_value);
// return;
// }
//
// // Extrapolate along the y dimension
// const auto* y_outof_bounds = std::get<0>(outof_bounds_buffers);
// if (y_outof_bounds[static_cast<ptrdiff_t>(output_y)] != -1) {
//#ifdef CPU_TESTING
// assert(y_outof_bounds[static_cast<ptrdiff_t>(output_y)] == output_y);
//#endif
// T* Ydata_offset = Ydata + output_y * output_width;
// Ydata_offset[static_cast<ptrdiff_t>(output_x)] = static_cast<T>(extrapolation_value);
// return;
// }
// /// XXX Add z dimension test
// }
// if (use_extrapolation) {
// const auto* w_outof_bounds = std::get<1>(outof_bounds_buffers);
// // Extrapolate along the w dimension
// if (w_outof_bounds[static_cast<ptrdiff_t>(output_x)] != -1) {
// T* Ydata_offset = Ydata + output_y * output_width;
// #ifdef CPU_TESTING
// assert(w_outof_bounds[static_cast<ptrdiff_t>(output_x)] == output_x);
// #endif
// Ydata_offset[static_cast<ptrdiff_t>(output_x)] = static_cast<T>(extrapolation_value);
// return;
// }
//
// // Extrapolate along the y dimension
// const auto* y_outof_bounds = std::get<0>(outof_bounds_buffers);
// if (y_outof_bounds[static_cast<ptrdiff_t>(output_y)] != -1) {
// #ifdef CPU_TESTING
// assert(y_outof_bounds[static_cast<ptrdiff_t>(output_y)] == output_y);
// #endif
// T* Ydata_offset = Ydata + output_y * output_width;
// Ydata_offset[static_cast<ptrdiff_t>(output_x)] = static_cast<T>(extrapolation_value);
// return;
// }
// /// XXX Add z dimension test
// }

auto* Ydata_offset = Ydata + output_width * output_y + output_x;
const auto* bound = bound_data;
Expand All @@ -200,13 +200,18 @@ __global__ void ComputeInterpolationAtLevel1(
const auto* Xdata_offset = Xdata + input_width * output_y + xmin;

for (; xmin < xmax; ++xmin) {
// This cast is needed when we deal with half
output += static_cast<AccumType>((*Xdata_offset++)) * (*weight_coeff++);
if constexpr (std::is_same<T, half>::value) {
// This cast is needed when we deal with half
output += static_cast<AccumType>((*Xdata_offset++)) * (*weight_coeff++);
} else {
output += (*Xdata_offset++) * (*weight_coeff++);
}
}

std::cout << " " << output;
std::cout << output << ' ';

if constexpr (onnxruntime::is_8bit_v<T>) {
const uint8_t* clip8_lookups = &clip8_table[640];
*Ydata_offset = static_cast<T>(clip8_lookups[output >> 22]);
} else if constexpr (std::is_same<T, int32_t>::value) {
*Ydata_offset = static_cast<int32_t>(std::round(output));
Expand All @@ -225,7 +230,7 @@ __global__ void ComputeInterpolationAtLevel2(
const fast_divmod div_output_width,
int32_t window_size,
bool use_extrapolation, float extrapolation_value,
const uint8_t* clip8_lookups,
const uint8_t* clip8_table,
const int64_t* bound_data,
std::tuple<int64_t*, int64_t*> outof_bounds_buffers,
const AccumType* weight_coefficients,
Expand All @@ -243,29 +248,29 @@ __global__ void ComputeInterpolationAtLevel2(
int output_y, output_x;
div_output_width.divmod(id, output_y, output_x);

// if (use_extrapolation) {
// const auto* w_outof_bounds = std::get<1>(outof_bounds_buffers);
// // Extrapolate along the w dimension
// if (w_outof_bounds[static_cast<ptrdiff_t>(output_x)] != -1) {
// T* Ydata_offset = Ydata + output_y * output_width;
//#ifdef CPU_TESTING
// assert(w_outof_bounds[static_cast<ptrdiff_t>(output_x)] == output_x);
//#endif
// Ydata_offset[static_cast<ptrdiff_t>(output_x)] = static_cast<T>(extrapolation_value);
// return;
// }
//
// // Extrapolate along the y dimension
// const auto* y_outof_bounds = std::get<0>(outof_bounds_buffers);
// if (y_outof_bounds[static_cast<ptrdiff_t>(output_y)] != -1) {
//#ifdef CPU_TESTING
// assert(y_outof_bounds[static_cast<ptrdiff_t>(output_y)] == output_y);
//#endif
// T* Ydata_offset = Ydata + output_y * output_width;
// Ydata_offset[static_cast<ptrdiff_t>(output_x)] = static_cast<T>(extrapolation_value);
// return;
// }
// }
// if (use_extrapolation) {
// const auto* w_outof_bounds = std::get<1>(outof_bounds_buffers);
// // Extrapolate along the w dimension
// if (w_outof_bounds[static_cast<ptrdiff_t>(output_x)] != -1) {
// T* Ydata_offset = Ydata + output_y * output_width;
// #ifdef CPU_TESTING
// assert(w_outof_bounds[static_cast<ptrdiff_t>(output_x)] == output_x);
// #endif
// Ydata_offset[static_cast<ptrdiff_t>(output_x)] = static_cast<T>(extrapolation_value);
// return;
// }
//
// // Extrapolate along the y dimension
// const auto* y_outof_bounds = std::get<0>(outof_bounds_buffers);
// if (y_outof_bounds[static_cast<ptrdiff_t>(output_y)] != -1) {
// #ifdef CPU_TESTING
// assert(y_outof_bounds[static_cast<ptrdiff_t>(output_y)] == output_y);
// #endif
// T* Ydata_offset = Ydata + output_y * output_width;
// Ydata_offset[static_cast<ptrdiff_t>(output_x)] = static_cast<T>(extrapolation_value);
// return;
// }
// }

auto* Ydata_offset = Ydata + output_width * output_y + output_x;
const auto* bound = bound_data;
Expand All @@ -279,14 +284,19 @@ __global__ void ComputeInterpolationAtLevel2(
const auto* Xdata_offset = Xdata + ymin * output_width + output_x;

for (; ymin < ymax; ++ymin) {
output += static_cast<AccumType>((*Xdata_offset)) * (*weight_coeff);
if constexpr (std::is_same<T, half>::value) {
// We cast to AccumType to resolve ambiguous call to operator* for half in CUDA
output += static_cast<AccumType>((*Xdata_offset)) * (*weight_coeff++);
} else {
output += (*Xdata_offset) * (*weight_coeff++);
}
Xdata_offset += input_width;
weight_coeff++;
}

std::cout << ", " << output;
std::cout << output << ' ';

if constexpr (onnxruntime::is_8bit_v<T>) {
const uint8_t* clip8_lookups = &clip8_table[640];
*Ydata_offset = static_cast<T>(clip8_lookups[output >> 22]);
} else if constexpr (std::is_same<T, int32_t>::value) {
*Ydata_offset = static_cast<int32_t>(std::round(output));
Expand Down Expand Up @@ -594,7 +604,7 @@ __global__ void _SetupTrilinerarUpsampleFilterAntiAlias(
const auto scale = 1.f / inv_scale;
const float inv_scale_1 = (scale >= 1.0f) ? 1.0f / scale : 1.0f;

std::cout << "Rscale: " << inv_scale_1 << " Scale: " << scale << " Scaled Support: " << scaled_support << " Window Size: " << window_size
std::cout << "\nRscale: " << inv_scale_1 << " Scale: " << scale << " Scaled Support: " << scaled_support << " Window Size: " << window_size
<< " Input Size: " << input_size << " Output Size: " << output_size << " Inv Scale: " << inv_scale
<< " roi_start: " << roi_start << " roi_end " << roi_end;
}
Expand Down Expand Up @@ -855,7 +865,6 @@ void ResizeAntiAliasImpl(

std::cout << std::endl;


CUDA_CALL_THROW(cudaMemcpyAsync(output_data, host_output_buffer.get(),
N * sizeof(T), cudaMemcpyHostToDevice, stream));
#else
Expand Down

0 comments on commit 0bbd3c4

Please sign in to comment.