Skip to content

Commit

Permalink
restrict rocm
Browse files Browse the repository at this point in the history
  • Loading branch information
chenfucn committed Jul 2, 2024
1 parent a77b042 commit 47893b3
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
8 changes: 8 additions & 0 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ namespace onnxruntime {
namespace contrib {
namespace cuda {

#ifndef USE_ROCM
template <>
Status MatMulNBits<MLFloat16>::PrepackedGemm(
cudaStream_t stream,
Expand All @@ -39,9 +40,12 @@ Status MatMulNBits<MLFloat16>::PrepackedGemm(
zero_points_ptr, zero_points_size,
Y->MutableData<MLFloat16>(), Y->Shape().Size());
}
#endif // !USE_ROCM

Check warning on line 43 in onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc

View workflow job for this annotation

GitHub Actions / Lint C++

[cpplint] reported by reviewdog 🐶 At least two spaces is best between code and comments [whitespace/comments] [2] Raw Output: onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.cc:43: At least two spaces is best between code and comments [whitespace/comments] [2]

template <typename T>
Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
using CudaT = typename ToCudaType<T>::MappedType;

const Tensor* a = ctx->Input<Tensor>(0);
const Tensor* b = ctx->Input<Tensor>(1);
const Tensor* scales = ctx->Input<Tensor>(2);
Expand All @@ -60,13 +64,17 @@ Status MatMulNBits<T>::ComputeInternal(OpKernelContext* ctx) const {
if (Y->Shape().Size() == 0) return Status::OK();

if (prepack_ > 0) {
#ifndef USE_ROCM
ORT_RETURN_IF(reorder_idx != nullptr,
"Internal Error: Prepacked gemm does not support reorder index. Fix the prepacking logic!");
ORT_RETURN_IF(zero_points != nullptr && zero_points->IsDataType<T>(),
"Internal Error: Prepacked gemm does not support zero points of type T. Fix the prepacking logic!");
return PrepackedGemm(
static_cast<cudaStream_t>(ctx->GetComputeStream()->GetHandle()),
static_cast<int>(helper.M()), a, b, scales, zero_points, Y);
#else
ORT_RETURN_IF(true, "Prepacked gemm is not supported for MatMulNBits op.");
#endif // !USE_ROCM
}

const auto* a_data = a->Data<T>();
Expand Down
7 changes: 4 additions & 3 deletions onnxruntime/contrib_ops/cuda/quantization/matmul_nbits.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,11 @@
namespace onnxruntime {
namespace contrib {
namespace cuda {
using namespace onnxruntime::cuda;

template <typename T>
class MatMulNBits final : public CudaKernel {
public:
using CudaT = typename ToCudaType<T>::MappedType;

MatMulNBits(const OpKernelInfo& info) : CudaKernel(info) {
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("K", &K_));
ORT_ENFORCE(Status::OK() == info.GetAttr<int64_t>("N", &N_));
Expand All @@ -34,15 +33,17 @@ class MatMulNBits final : public CudaKernel {
info.GetAttrOrDefault<int64_t>("prepacked", &prepack_, int64_t(0));
}

#ifndef USE_ROCM
Status PrepackedGemm([[maybe_unused]] cudaStream_t stream,
[[maybe_unused]] int M,
[[maybe_unused]] const Tensor* a,
[[maybe_unused]] const Tensor* b,
[[maybe_unused]] const Tensor* scales,
[[maybe_unused]] const Tensor* zero_points,
[[maybe_unused]] Tensor* Y) const {
ORT_THROW("Prepacked gemm is not supported for MatMulNBits op.");
return ORT_MAKE_STATUS(ONNXRUNTIME, NOT_IMPLEMENTED, "Prepacked gemm is not supported for MatMulNBits op.");
}
#endif // !USE_ROCM

Status ComputeInternal(OpKernelContext* context) const override;

Expand Down
4 changes: 2 additions & 2 deletions onnxruntime/core/optimizer/graph_transformer_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -395,13 +395,13 @@ InlinedVector<std::unique_ptr<GraphTransformer>> GenerateTransformers(
// while we can fuse more activation.
transformers.emplace_back(std::make_unique<ConvAddActivationFusion>(cpu_ep));

#ifdef USE_CUDA
#if defined(USE_CUDA) && !defined(USE_ROCM)
// Cuda weight prepacking.
auto* cuda_ep = execution_providers.Get(onnxruntime::kCudaExecutionProvider);
if (cuda_ep != nullptr) {
transformers.emplace_back(std::make_unique<GpuOpsPrepack>());
}
#endif // USE_CUDA
#endif // USE_CUDA && !USE_ROCM

#endif // !defined(DISABLE_CONTRIB_OPS)

Expand Down

0 comments on commit 47893b3

Please sign in to comment.