Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use graph optimizer for gpu tensor prepack #19814

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions cmake/onnxruntime_optimizer.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,10 @@ onnxruntime_add_static_library(onnxruntime_optimizer ${onnxruntime_optimizer_src

onnxruntime_add_include_to_target(onnxruntime_optimizer onnxruntime_common onnxruntime_framework onnx onnx_proto ${PROTOBUF_LIB} flatbuffers::flatbuffers Boost::mp11 safeint_interface)
target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT})

# using optimizer as cuda prepacking, so extra headers are needed
target_include_directories(onnxruntime_optimizer PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)

if (onnxruntime_ENABLE_TRAINING)
target_include_directories(onnxruntime_optimizer PRIVATE ${ORTTRAINING_ROOT})
if (onnxruntime_ENABLE_TRAINING_TORCH_INTEROP)
Expand Down
5 changes: 5 additions & 0 deletions cmake/onnxruntime_providers_cuda.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@

include(cutlass)
target_include_directories(${target} PRIVATE ${cutlass_SOURCE_DIR}/include ${cutlass_SOURCE_DIR}/examples ${cutlass_SOURCE_DIR}/tools/util/include)
target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT}/core/mickey)

target_include_directories(${target} PRIVATE ${ONNXRUNTIME_ROOT} ${CMAKE_CURRENT_BINARY_DIR} ${eigen_INCLUDE_DIRS} ${TVM_INCLUDES}
PUBLIC ${CUDAToolkit_INCLUDE_DIRS})
Expand Down Expand Up @@ -284,6 +285,10 @@
endif()
config_cuda_provider_shared_module(onnxruntime_providers_cuda)

# TODO only needed in DEBUG builds, need cmake expert advice on how to do that
set_source_files_properties(${ONNXRUNTIME_ROOT}/contrib_ops/cuda/quantization/matmul_nbits.cu PROPERTIES COMPILE_FLAGS " -Wno-unknown-pragmas ")
set_source_files_properties(${ONNXRUNTIME_ROOT}/contrib_ops/cuda/quantization/matmul_nbits.cc PROPERTIES COMPILE_FLAGS " -Wno-unknown-pragmas ")

install(TARGETS onnxruntime_providers_cuda
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
Expand Down
5 changes: 3 additions & 2 deletions include/onnxruntime/core/optimizer/graph_transformer_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@

namespace onnxruntime {
class IExecutionProvider;
class ExecutionProviders;

namespace optimizer_utils {

Expand All @@ -48,7 +49,7 @@ std::unique_ptr<RuleBasedGraphTransformer> GenerateRuleBasedGraphTransformer(
InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
TransformerLevel level,
const SessionOptions& session_options,
const IExecutionProvider& execution_provider /*required by constant folding*/,
const ExecutionProviders& execution_providers /* cpu ep required by constant folding*/,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});

#endif // !defined(ORT_MINIMAL_BUILD)
Expand Down Expand Up @@ -77,7 +78,7 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformersForMinimalB
TransformerLevel level,
const SessionOptions& session_options,
const SatApplyContextVariant& apply_context,
const IExecutionProvider& cpu_execution_provider,
const ExecutionProviders& execution_providers,
const InlinedHashSet<std::string>& rules_and_transformers_to_disable = {});

#endif // !defined(ORT_MINIMAL_BUILD) || defined(ORT_EXTENDED_MINIMAL_BUILD)
Expand Down
50 changes: 41 additions & 9 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,31 @@
namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;

template<>
Status MatMulNBits<MLFloat16>::PrepackedGemm(
cudaStream_t stream,
const Tensor* a,
const Tensor* b,
const Tensor* scales,
const Tensor* zero_points,
Tensor* Y) const {
int64_t M = a->Shape()[0];
uint8_t const* zero_points_ptr = nullptr;
size_t zero_points_size = 0;
if (zero_points != nullptr) {
zero_points_ptr = zero_points->Data<uint8_t>();
zero_points_size = zero_points->Shape().Size();
}

return blkq4_fp16_gemm_sm80_dispatch<MLFloat16>(
int(block_size_), column_wise_quant_blk_, int(M), int(N_), int(K_), stream,

Check warning on line 35 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc:35: Using deprecated casting style. Use static_cast<int>(...) instead [readability/casting] [4]
a->Data<MLFloat16>(), a->Shape().Size(),
b->Data<uint8_t>(), b->Shape().Size(),
scales->Data<MLFloat16>(), scales->Shape().Size(),
zero_points_ptr, zero_points_size,
Y->MutableData<MLFloat16>(), Y->Shape().Size());
}

template <typename T>
Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
Expand All @@ -24,14 +48,6 @@
const Tensor* zero_points = ctx->Input<Tensor>(3);
const Tensor* reorder_idx = ctx->Input<Tensor>(4);

const auto* a_data = a->Data<T>();
const uint8_t* blob_data = b->Data<uint8_t>();
const auto* scales_data = scales->Data<T>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data<int32_t>();

typedef typename ToCudaType<T>::MappedType CudaT;

constexpr bool transa = false;
constexpr bool transb = true;
MatMulComputeHelper helper;
Expand All @@ -43,6 +59,22 @@
// Bail out early if the output is going to be empty
if (Y->Shape().Size() == 0) return Status::OK();

if (prepack_ > 0){

Check warning on line 62 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Missing space before { [whitespace/braces] [5] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc:62: Missing space before { [whitespace/braces] [5]
ORT_RETURN_IF(reorder_idx != nullptr,
"Internal Error: Prepacked gemm does not support reorder index. Fix the prepacking logic!");
ORT_RETURN_IF(zero_points != nullptr && zero_points->IsDataType<T>(),
"Internal Error: Prepacked gemm does not support zero points of type T. Fix the prepacking logic!");
return PrepackedGemm(
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()),
a, b, scales, zero_points, Y);
}

const auto* a_data = a->Data<T>();
const uint8_t* blob_data = b->Data<uint8_t>();
const auto* scales_data = scales->Data<T>();
const auto* zero_points_data = zero_points == nullptr ? nullptr : zero_points->DataRaw();
const auto* reorder_idx_data = reorder_idx == nullptr ? nullptr : reorder_idx->Data<int32_t>();

bool is_4bit_done = (reorder_idx_data == nullptr) &&
(!zero_points || !zero_points->IsDataType<T>()) &&
TryMatMul4Bits(
Expand Down
223 changes: 221 additions & 2 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
#include "core/providers/cuda/cuda_common.h"
#include "matmul_nbits.cuh"

using namespace onnxruntime::cuda;
using namespace cub;
#include "blk_q4/f16_gemm_sm80.h"

namespace onnxruntime {
namespace contrib {
Expand Down Expand Up @@ -348,6 +347,226 @@
int shared_mem_per_block,
cudaStream_t stream);

/**
* @brief Helper function to run the GEMM kernel for 4bits quantized gemm on SM80.
* Only support fp16 for now.
*/
template<
typename ElementT,
int block_size,
bool column_wise_blocking,
bool small_m,
bool has_offsets>
Status blkq4_gemm_sm80(int m, int n, int k, cudaStream_t stream,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this can be supported by sm86, sm89?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it should

gsl::span<ElementT const> a,
gsl::span<uint8_t const> weights,
gsl::span<ElementT const> scales,
gsl::span<uint8_t const> offsets,
gsl::span<ElementT> output) {
static_assert(std::is_same<ElementT, half>::value
|| std::is_same<ElementT, MLFloat16>::value
|| std::is_same<ElementT, cutlass::half_t>::value,
"Only support fp16 for now");
using ElementDequant = cutlass::half_t;
using QuantBlocking =
typename std::conditional<column_wise_blocking,
cutlass::MatrixShape<block_size, 1>,
cutlass::MatrixShape<1, block_size>>::type;

using GemmRunner = onnxruntime::cuda::BlkQ4F16GemmImpl<ElementDequant, QuantBlocking, small_m, has_offsets>;

using ElementAccumulator = typename GemmRunner::ElementAccumulator;
using ElementComputeEpilogue = typename GemmRunner::ElementComputeEpilogue;
using ElementOutput = typename GemmRunner::ElementOutput;
using ElementW = typename GemmRunner::ElementW;
using ElementWPack = typename GemmRunner::ElementWPack;
using ElementQScale = typename GemmRunner::ElementQScale;
using ElementQOffset = typename GemmRunner::ElementQOffset;

using LayoutInputA = typename GemmRunner::LayoutInputA;
using LayoutOutput = typename GemmRunner::LayoutOutput;
using LayoutInputWPack = typename GemmRunner::LayoutInputWPack;
using LayoutInputQScale = typename GemmRunner::LayoutInputQScale;

const cutlass::gemm::GemmCoord problem_size = {m, n, k};

ORT_RETURN_IF_NOT(a.size_bytes() == m * k * sizeof(ElementDequant), "Activation tensor size is not correct");
cutlass::TensorRef<ElementDequant const, LayoutInputA> ref_a(
reinterpret_cast<ElementDequant const *>(a.data()),
LayoutInputA::packed({m, k}));

ORT_RETURN_IF_NOT(weights.size_bytes() == k/2 * n/2 * sizeof(ElementWPack), "weights size is not correct");
cutlass::TensorRef<ElementWPack const, LayoutInputWPack> ref_W(
reinterpret_cast<ElementWPack const *>(weights.data()),
LayoutInputWPack::packed({k/2, n/2}));

ORT_RETURN_IF_NOT(scales.size_bytes() == (k/QuantBlocking::kRow) * (n/QuantBlocking::kColumn) * sizeof(ElementQScale),
"scales size is not correct");
cutlass::TensorRef<ElementQScale const, LayoutInputQScale> ref_scales(
reinterpret_cast<ElementQScale const *>(scales.data()),
LayoutInputQScale::packed({k/QuantBlocking::kRow, n/QuantBlocking::kColumn}));

ORT_RETURN_IF_NOT(output.size_bytes() == m * n * sizeof(ElementOutput), "output size is not correct");
cutlass::TensorRef<ElementOutput, LayoutOutput> ref_output(
reinterpret_cast<ElementOutput *>(output.data()),
LayoutOutput::packed({m, n}));

// run GEMM
cutlass::Status status;
if constexpr (has_offsets) {
ORT_RETURN_IF_NOT(offsets.size_bytes() == (k/QuantBlocking::kRow) * (n/QuantBlocking::kColumn) * sizeof(ElementQOffset),

Check warning on line 417 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu:417: Lines should be <= 120 characters long [whitespace/line_length] [2]
"offsets size is not correct");
cutlass::TensorRef<ElementQOffset const, LayoutInputQScale> ref_offsets(
reinterpret_cast<ElementQOffset const *>(offsets.data()),
LayoutInputQScale::packed({k/QuantBlocking::kRow, n/QuantBlocking::kColumn}));
status = GemmRunner::run(
stream, problem_size, ref_a, ref_W, ref_scales, ref_offsets,
ref_output, ref_output);
} else {
status = GemmRunner::run(
stream, problem_size, ref_a, ref_W, ref_scales,
ref_output, ref_output);
}
ORT_RETURN_IF_NOT(status == cutlass::Status::kSuccess, "Kernel execution failed: ", cutlassGetStatusString(status));
return Status::OK();
}

template<typename ElementT>
Status
blkq4_fp16_gemm_sm80_dispatch(
int block_size, bool column_wise_blocking, int m, int n, int k, cudaStream_t stream,
ElementT const* a_ptr, size_t a_size,
uint8_t const* weights_ptr, size_t weights_size,
ElementT const* scales_ptr, size_t scales_size,
uint8_t const* offsets_ptr, size_t offsets_size,
ElementT* output_ptr, size_t output_size) {
auto a = gsl::make_span(a_ptr, a_size);
auto weights = gsl::make_span(weights_ptr, weights_size);
auto scales = gsl::make_span(scales_ptr, scales_size);
auto offsets = gsl::make_span(offsets_ptr, offsets_size);
auto output = gsl::make_span(output_ptr, output_size);

switch (block_size)
{

Check warning on line 450 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 { should almost always be at the end of the previous line [whitespace/braces] [4] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu:450: { should almost always be at the end of the previous line [whitespace/braces] [4]
case 16:
if (column_wise_blocking) {
if (m > 16) {
if (offsets.empty())
return blkq4_gemm_sm80<ElementT, 16, true, false, false>(m, n, k, stream, a, weights, scales, offsets, output);

Check warning on line 455 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu:455: Lines should be <= 120 characters long [whitespace/line_length] [2]
else
return blkq4_gemm_sm80<ElementT, 16, true, false, true>(m, n, k, stream, a, weights, scales, offsets, output);
} else {
if (offsets.empty())
return blkq4_gemm_sm80<ElementT, 16, true, true, false>(m, n, k, stream, a, weights, scales, offsets, output);
else
return blkq4_gemm_sm80<ElementT, 16, true, true, true>(m, n, k, stream, a, weights, scales, offsets, output);
}
} else {
if (m > 16) {
if (offsets.empty())
return blkq4_gemm_sm80<ElementT, 16, false, false, false>(m, n, k, stream, a, weights, scales, offsets, output);

Check warning on line 467 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu:467: Lines should be <= 120 characters long [whitespace/line_length] [2]
else
return blkq4_gemm_sm80<ElementT, 16, false, false, true>(m, n, k, stream, a, weights, scales, offsets, output);

Check warning on line 469 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu:469: Lines should be <= 120 characters long [whitespace/line_length] [2]
} else {
if (offsets.empty())
return blkq4_gemm_sm80<ElementT, 16, false, true, false>(m, n, k, stream, a, weights, scales, offsets, output);

Check warning on line 472 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu:472: Lines should be <= 120 characters long [whitespace/line_length] [2]
else
return blkq4_gemm_sm80<ElementT, 16, false, true, true>(m, n, k, stream, a, weights, scales, offsets, output);
}
}
break;

case 32:
if (column_wise_blocking) {
if (m > 16) {
if (offsets.empty())
return blkq4_gemm_sm80<ElementT, 32, true, false, false>(m, n, k, stream, a, weights, scales, offsets, output);

Check warning on line 483 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu:483: Lines should be <= 120 characters long [whitespace/line_length] [2]
else
return blkq4_gemm_sm80<ElementT, 32, true, false, true>(m, n, k, stream, a, weights, scales, offsets, output);
} else {
if (offsets.empty())
return blkq4_gemm_sm80<ElementT, 32, true, true, false>(m, n, k, stream, a, weights, scales, offsets, output);
else
return blkq4_gemm_sm80<ElementT, 32, true, true, true>(m, n, k, stream, a, weights, scales, offsets, output);
}
} else {
if (m > 16) {
if (offsets.empty())
return blkq4_gemm_sm80<ElementT, 32, false, false, false>(m, n, k, stream, a, weights, scales, offsets, output);

Check warning on line 495 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 Lines should be <= 120 characters long [whitespace/line_length] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cu:495: Lines should be <= 120 characters long [whitespace/line_length] [2]
else
return blkq4_gemm_sm80<ElementT, 32, false, false, true>(m, n, k, stream, a, weights, scales, offsets, output);
} else {
if (offsets.empty())
return blkq4_gemm_sm80<ElementT, 32, false, true, false>(m, n, k, stream, a, weights, scales, offsets, output);
else
return blkq4_gemm_sm80<ElementT, 32, false, true, true>(m, n, k, stream, a, weights, scales, offsets, output);
}
}
break;

case 64:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would you support case=128? which is used widely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's difficult for this kernel. I am working on another version which hopefully can support that.

if (column_wise_blocking) {
if (m > 16) {
if (offsets.empty())
return blkq4_gemm_sm80<ElementT, 64, true, false, false>(m, n, k, stream, a, weights, scales, offsets, output);
else
return blkq4_gemm_sm80<ElementT, 64, true, false, true>(m, n, k, stream, a, weights, scales, offsets, output);
} else {
if (offsets.empty())
return blkq4_gemm_sm80<ElementT, 64, true, true, false>(m, n, k, stream, a, weights, scales, offsets, output);
else
return blkq4_gemm_sm80<ElementT, 64, true, true, true>(m, n, k, stream, a, weights, scales, offsets, output);
}
} else {
if (m > 16) {
if (offsets.empty())
return blkq4_gemm_sm80<ElementT, 64, false, false, false>(m, n, k, stream, a, weights, scales, offsets, output);
else
return blkq4_gemm_sm80<ElementT, 64, false, false, true>(m, n, k, stream, a, weights, scales, offsets, output);
} else {
if (offsets.empty())
return blkq4_gemm_sm80<ElementT, 64, false, true, false>(m, n, k, stream, a, weights, scales, offsets, output);
else
return blkq4_gemm_sm80<ElementT, 64, false, true, true>(m, n, k, stream, a, weights, scales, offsets, output);
}
}
break;
}

return ORT_MAKE_STATUS(ONNXRUNTIME, FAIL, "Unsupported block size: ", block_size);
}

template
Status blkq4_fp16_gemm_sm80_dispatch<half>(
int block_size,
bool column_wise_blocking,
int m, int n, int k, cudaStream_t stream,
half const* a_ptr, size_t a_size,
uint8_t const* weights_ptr, size_t weights_size,
half const* scales_ptr, size_t scales_size,
uint8_t const* offsets_ptr, size_t offsets_size,
half* output_ptr, size_t output_size);

template
Status blkq4_fp16_gemm_sm80_dispatch<cutlass::half_t>(
int block_size,
bool column_wise_blocking,
int m, int n, int k, cudaStream_t stream,
cutlass::half_t const* a_ptr, size_t a_size,
uint8_t const* weights_ptr, size_t weights_size,
cutlass::half_t const* scales_ptr, size_t scales_size,
uint8_t const* offsets_ptr, size_t offsets_size,
cutlass::half_t* output_ptr, size_t output_size);

template
Status blkq4_fp16_gemm_sm80_dispatch<onnxruntime::MLFloat16>(
int block_size, bool column_wise_blocking, int m, int n, int k, cudaStream_t stream,
onnxruntime::MLFloat16 const* a_ptr, size_t a_size,
uint8_t const* weights_ptr, size_t weights_size,
onnxruntime::MLFloat16 const* scales_ptr, size_t scales_size,
uint8_t const* offsets_ptr, size_t offsets_size,
onnxruntime::MLFloat16* output_ptr, size_t output_size);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
11 changes: 11 additions & 0 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,17 @@ bool TryMatMul4Bits(
int shared_mem_per_block,
cudaStream_t stream);

template <typename ElementT>
Status blkq4_fp16_gemm_sm80_dispatch(
int block_size,
bool column_wise_blocking,
int m, int n, int k, cudaStream_t stream,
ElementT const* a_ptr, size_t a_size,
uint8_t const* weights_ptr, size_t weights_size,
ElementT const* scales_ptr, size_t scales_size,
uint8_t const* offsets_ptr, size_t offsets_size,
ElementT* output_ptr, size_t output_size);

} // namespace cuda
} // namespace contrib
} // namespace onnxruntime
Loading
Loading