Skip to content

Commit

Permalink
Bilinear works
Browse files Browse the repository at this point in the history
  • Loading branch information
yuslepukhin committed Feb 14, 2024
1 parent 2cfe280 commit 75042c7
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 230 deletions.
15 changes: 15 additions & 0 deletions onnxruntime/core/providers/cpu/tensor/upsample_antialias.h
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,8 @@ void ComputeInterpolationAtLevel1(int64_t num_channels, int64_t input_height, in
concurrency::ThreadPool* tp) {
const uint8_t* clip8_lookups = &p.GetClip8LookupTable()[640];

std::cout << "L1: ";

concurrency::ThreadPool::TrySimpleParallelFor(
tp, narrow<std::ptrdiff_t>(num_channels),
[&](std::ptrdiff_t c) {
Expand Down Expand Up @@ -286,6 +288,8 @@ void ComputeInterpolationAtLevel1(int64_t num_channels, int64_t input_height, in
output += (*Xdata_offset++) * (*weight_coeff++);
}

std::cout << " " << output;

if constexpr (is_8bit_v<InputType>) {
*Ydata_offset++ = static_cast<InputType>(clip8_lookups[output >> 22]);
} else if constexpr (std::is_same<InputType, int32_t>::value) {
Expand All @@ -296,6 +300,8 @@ void ComputeInterpolationAtLevel1(int64_t num_channels, int64_t input_height, in
}
}
});

std::cout << std::endl;
}

/**
Expand All @@ -322,6 +328,8 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in
const FilterParamsAntiAlias<AccumulateType>& p,
const FilterParamsBaseAntiAlias<AccumulateType>& p_dim,
concurrency::ThreadPool* tp) {
std::cout << "L2: ";

const uint8_t* clip8_lookups = &p.GetClip8LookupTable()[640];
// This condition is set for higher performance.
// Observed that TrySimpleParallelFor in dim num_channels is always have higher efficiency, so I would rather
Expand Down Expand Up @@ -357,6 +365,9 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in
output += *Xdata_offset * (*weight_coeff_start++);
Xdata_offset += output_width;
}

std::cout << " " << output;

if constexpr (is_8bit_v<InputType>) {
*Ydata_offset++ = static_cast<InputType>(clip8_lookups[output >> 22]);
} else if constexpr (std::is_same<InputType, int32_t>::value) {
Expand Down Expand Up @@ -403,6 +414,9 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in
output += *Xdata_offset * (*weight_coeff_start++);
Xdata_offset += output_width;
}

std::cout << " " << output;

if constexpr (is_8bit_v<InputType>) {
*Ydata_offset++ = static_cast<InputType>(clip8_lookups[output >> 22]);
} else if constexpr (std::is_same<InputType, int32_t>::value) {
Expand All @@ -414,6 +428,7 @@ void ComputeInterpolationAtLevel2(int64_t num_channels, int64_t input_height, in
}
});
}
std::cout << std::endl;
}

template <typename InputType, typename AccumulateType>
Expand Down
Loading

0 comments on commit 75042c7

Please sign in to comment.