Skip to content

Commit

Permalink
Revert NeuralSpeed code for x64 MatMulNBits (#19382)
Browse files Browse the repository at this point in the history
### Description
<!-- Describe your changes. -->
Revert PR#19016 #19016
Revert PR#17669 #17669
  • Loading branch information
luoyu-intel authored Feb 7, 2024
1 parent 75f0631 commit 0d10c7f
Show file tree
Hide file tree
Showing 11 changed files with 0 additions and 1,023 deletions.
10 changes: 0 additions & 10 deletions cgmanifests/generated/cgmanifest.json
Original file line number Diff line number Diff line change
Expand Up @@ -202,16 +202,6 @@
"comments": "mp11"
}
},
{
"component": {
"type": "git",
"git": {
"commitHash": "c11386eb632eec7c1c2aa323142f73519f946e2a",
"repositoryUrl": "https://github.com/intel/neural-speed.git"
},
"comments": "neural_speed"
}
},
{
"component": {
"type": "git",
Expand Down
12 changes: 0 additions & 12 deletions cmake/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,6 @@ option(onnxruntime_USE_QNN "Build with QNN support" OFF)
option(onnxruntime_USE_SNPE "Build with SNPE support" OFF)
option(onnxruntime_USE_RKNPU "Build with RKNPU support" OFF)
option(onnxruntime_USE_DNNL "Build with DNNL support" OFF)
option(onnxruntime_USE_NEURAL_SPEED "Build with Neural Speed support" OFF)
option(onnxruntime_USE_JSEP "Build with JavaScript implemented kernels support" OFF)
option(onnxruntime_BUILD_UNIT_TESTS "Build ONNXRuntime unit tests" ON)
option(onnxruntime_BUILD_CSHARP "Build C# library" OFF)
Expand Down Expand Up @@ -902,10 +901,6 @@ function(onnxruntime_set_compile_flags target_name)
target_compile_definitions(${target_name} PRIVATE ENABLE_ATEN)
endif()

if(USE_NEURAL_SPEED)
target_compile_definitions(${target_name} PRIVATE ORT_NEURAL_SPEED)
endif()

set_target_properties(${target_name} PROPERTIES COMPILE_WARNING_AS_ERROR ON)
if (onnxruntime_USE_CUDA)
# Suppress a "conversion_function_not_usable" warning in gsl/span
Expand Down Expand Up @@ -1193,13 +1188,6 @@ if (onnxruntime_USE_DNNL)
add_compile_definitions(DNNL_OPENMP)
endif()

if (onnxruntime_USE_NEURAL_SPEED AND NOT onnxruntime_MINIMAL_BUILD)
include(neural_speed)
if (USE_NEURAL_SPEED)
list(APPEND onnxruntime_EXTERNAL_LIBRARIES neural_speed::bestla)
endif()
endif()

# TVM EP
if (onnxruntime_USE_TVM)
if (NOT TARGET tvm)
Expand Down
1 change: 0 additions & 1 deletion cmake/deps.txt
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ microsoft_gsl;https://github.com/microsoft/GSL/archive/refs/tags/v4.0.0.zip;cf36
microsoft_wil;https://github.com/microsoft/wil/archive/refs/tags/v1.0.230629.1.zip;e4a542a323c070376f7c2d1973d0f7ddbc1d2fa5
mimalloc;https://github.com/microsoft/mimalloc/archive/refs/tags/v2.1.1.zip;d5ee7d34223d0567892db5179849939c8769dc41
mp11;https://github.com/boostorg/mp11/archive/refs/tags/boost-1.82.0.zip;9bc9e01dffb64d9e0773b2e44d2f22c51aace063
neural_speed;https://github.com/intel/neural-speed/archive/refs/tags/bestlav0.1.1.zip;65b0f7a0d04f72f0d5a8d48af70f0366f2ab3939
onnx;https://github.com/onnx/onnx/archive/refs/tags/v1.15.0.zip;54c3f960a0541c5d8d3e60c2933e11f5d3688a11
#use the commit of supporting all the plugins and TRT 8.6-GA (https://github.com/onnx/onnx-tensorrt/commit/0462dc31ae78f48744b6141ae376df1f96d3f459)
onnx_tensorrt;https://github.com/onnx/onnx-tensorrt/archive/a43ce67187bab219520fd80f21af8bbd4354bc8c.zip;572535aefef477050f86744dfab1fef840198035
Expand Down
15 changes: 0 additions & 15 deletions cmake/external/neural_speed.cmake

This file was deleted.

15 changes: 0 additions & 15 deletions cmake/onnxruntime_providers_cpu.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -60,15 +60,6 @@ if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/aten_ops/aten_op_executor.cc"
)
endif()
set(onnxruntime_cpu_neural_speed_srcs
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_wrapper.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_defs.h"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.cc"
"${ONNXRUNTIME_ROOT}/contrib_ops/cpu/quantization/neural_speed_gemm.h"
)
if(NOT USE_NEURAL_SPEED)
list(REMOVE_ITEM onnxruntime_cpu_contrib_ops_srcs ${onnxruntime_cpu_neural_speed_srcs})
endif()
# add using ONNXRUNTIME_ROOT so they show up under the 'contrib_ops' folder in Visual Studio
source_group(TREE ${ONNXRUNTIME_ROOT} FILES ${onnxruntime_cpu_contrib_ops_srcs})
list(APPEND onnxruntime_providers_src ${onnxruntime_cpu_contrib_ops_srcs})
Expand Down Expand Up @@ -153,12 +144,6 @@ if (HAS_BITWISE_INSTEAD_OF_LOGICAL)
target_compile_options(onnxruntime_providers PRIVATE "-Wno-bitwise-instead-of-logical")
endif()

if(NOT onnxruntime_DISABLE_CONTRIB_OPS)
if(USE_NEURAL_SPEED)
onnxruntime_add_include_to_target(onnxruntime_providers neural_speed::bestla)
endif()
endif()

if (MSVC)
target_compile_options(onnxruntime_providers PRIVATE "/bigobj")
# if(NOT CMAKE_SIZEOF_VOID_P EQUAL 8)
Expand Down
144 changes: 0 additions & 144 deletions onnxruntime/contrib_ops/cpu/quantization/matmul_nbits.cc
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,6 @@
#include "core/providers/cpu/math/matmul_helper.h"
#include "core/providers/common.h"

#ifdef ORT_NEURAL_SPEED
#include "contrib_ops/cpu/quantization/neural_speed_gemm.h"
#endif

namespace onnxruntime {
namespace contrib {

Expand All @@ -23,16 +19,6 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level
static_cast<int64_t>(CompMostAccurate),
static_cast<int64_t>(CompLeastAccurate));

#if defined(ORT_NEURAL_SPEED)

ORT_UNUSED_PARAMETER(nbits);
ORT_UNUSED_PARAMETER(block_size);

// Neural Speed APIs already expect a minimum accuracy level so just use the given value.
return accuracy_level;

#else // defined(ORT_NEURAL_SPEED)

// Find a supported accuracy level that is not less accurate than the one given.
// CompMostAccurate is always supported with the fallback implementation.
// Note: A higher numeric accuracy level value means lower accuracy, so the comparison order is reversed.
Expand All @@ -45,8 +31,6 @@ int64_t GetAccuracyLevel(size_t nbits, size_t block_size, int64_t accuracy_level
}

return effective_accuracy_level;

#endif // defined(ORT_NEURAL_SPEED)
}
} // namespace

Expand All @@ -61,17 +45,6 @@ class MatMulNBits final : public OpKernel {
accuracy_level_{GetAccuracyLevel(nbits_, block_size_, info.GetAttr<int64_t>("accuracy_level"))} {
ORT_ENFORCE(nbits_ == 4,
"Only 4b quantization is supported for MatMulNBits op, additional bits support is planned.");
#ifdef ORT_NEURAL_SPEED
const Tensor* tensor_B = nullptr;
const Tensor* tensor_scale = nullptr;
const Tensor* tensor_zero_point = nullptr;
bool B_constant = info.TryGetConstantInput(1, &tensor_B);
bool scale_constant = info.TryGetConstantInput(2, &tensor_scale);
bool zero_point_constant = info.TryGetConstantInput(3, &tensor_zero_point);
is_asym_ = info.GetInputCount() >= 4;
all_constant_ = B_constant && scale_constant;
all_constant_ = is_asym_ ? all_constant_ && zero_point_constant : all_constant_;
#endif
}

Status Compute(OpKernelContext* context) const override;
Expand All @@ -92,68 +65,13 @@ class MatMulNBits final : public OpKernel {
const bool column_wise_quant_{true};
IAllocatorUniquePtr<void> packed_b_;
size_t packed_b_size_{0};

#if defined(ORT_NEURAL_SPEED)

bool is_asym_{false};
bool all_constant_{false};

#endif // defined(ORT_NEURAL_SPEED)
};

Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;

#if defined(ORT_NEURAL_SPEED)

if (!all_constant_) {
return Status::OK();
}
MLAS_THREADPOOL* pool = NULL;
if (nbits_ != 4) {
return Status::OK();
}
auto comp_type = static_cast<NS_SQNBIT_COMPUTE_TYPE>(accuracy_level_);
auto nbits = static_cast<int>(nbits_);
if (input_idx == 1) {
packed_b_size_ = NSNBitsGemmPackBSize(N_, K_, block_size_, nbits, is_asym_, comp_type);
if (packed_b_size_ == 0) return Status::OK();
auto qptr = tensor.Data<uint8_t>();
packed_b_ = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size_, true);
std::memset(packed_b_.get(), 0, packed_b_size_);
NSNBitsGemmPackB(packed_b_.get(), qptr, nullptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, false,
comp_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}
if (input_idx == 2 && packed_b_ != nullptr) {
auto sptr = tensor.Data<float>();
NSNBitsGemmPackB(packed_b_.get(), nullptr, sptr, nullptr, N_, K_, K_, block_size_, nbits, is_asym_, !is_asym_,
comp_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}
if (input_idx == 3 && packed_b_ != nullptr) {
auto zptr = tensor.Data<uint8_t>();
NSNBitsGemmPackB(packed_b_.get(), nullptr, nullptr, zptr, N_, K_, K_, block_size_, nbits, is_asym_, is_asym_,
comp_type, pool);
if (prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
prepacked_weights->buffer_sizes_.push_back(packed_b_size_);
}
is_packed = true;
}

#else // defined(ORT_NEURAL_SPEED)

if (input_idx == 1) {
const auto compute_type = static_cast<MLAS_SQNBIT_GEMM_COMPUTE_TYPE>(accuracy_level_);
if (!MlasIsSQNBitGemmAvailable(nbits_, block_size_, compute_type)) {
Expand All @@ -173,40 +91,18 @@ Status MatMulNBits::PrePack(const Tensor& tensor, int input_idx, /*out*/ Allocat
is_packed = true;
}

#endif // defined(ORT_NEURAL_SPEED)

return Status::OK();
}

Status MatMulNBits::UseSharedPrePackedBuffers(std::vector<BufferUniquePtr>& prepacked_buffers, int input_idx,
/*out*/ bool& used_shared_buffers) {
used_shared_buffers = false;

#if defined(ORT_NEURAL_SPEED)

// Pack three tensors into one buffer
if (input_idx == 1) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
if (input_idx == 2) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}
if (input_idx == 3) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}

#else // defined(ORT_NEURAL_SPEED)

if (input_idx == 1) {
used_shared_buffers = true;
packed_b_ = std::move(prepacked_buffers[0]);
}

#endif // defined(ORT_NEURAL_SPEED)

return Status::OK();
}

Expand All @@ -216,46 +112,6 @@ Status MatMulNBits::Compute(OpKernelContext* ctx) const {
const Tensor* a = ctx->Input<Tensor>(0);
const auto* a_data = a->Data<float>();

#if defined(ORT_NEURAL_SPEED)

if (packed_b_) {
TensorShape b_shape({static_cast<int64_t>(N_), static_cast<int64_t>(K_)});

MatMulComputeHelper helper;
ORT_RETURN_IF_ERROR(helper.Compute(a->Shape(), b_shape, false, true));

Tensor* y = ctx->Output(0, helper.OutputShape());

// Bail out early if the output is going to be empty
if (y->Shape().Size() == 0) return Status::OK();

auto* y_data = y->MutableData<float>();

const size_t max_len = helper.OutputOffsets().size();
const size_t M = static_cast<size_t>(helper.M());
const size_t N = static_cast<size_t>(helper.N());
const size_t K = static_cast<size_t>(helper.K());
const size_t lda = helper.Lda(false);
std::vector<NS_SQNBITS_GEMM_DATA_PACKED_PARAMS> gemm_params(max_len);
AllocatorPtr allocator;
auto status = ctx->GetTempSpaceAllocator(&allocator);
ORT_RETURN_IF_ERROR(status);
for (size_t i = 0; i < max_len; i++) {
gemm_params[i].A = a_data + helper.LeftOffsets()[i];
gemm_params[i].lda = lda;
gemm_params[i].B = packed_b_.get();
gemm_params[i].C = y_data + helper.OutputOffsets()[i];
gemm_params[i].ldc = N;
}
auto ws_size = NSSQNBitsGemmBatchWorkspaceSize(M, N, K, max_len, gemm_params.data());
// workspace for activation process(dynamic quantization and others)
auto ws_ptr = IAllocator::MakeUniquePtr<int8_t>(allocator, ws_size);
NSSQNBitsGemmBatchPackedB(M, N, K, max_len, gemm_params.data(), ws_ptr.get(), thread_pool);
return Status::OK();
}

#endif // defined(ORT_NEURAL_SPEED)

const Tensor* scales = ctx->Input<Tensor>(2);
const Tensor* zero_points = ctx->Input<Tensor>(3);
const auto* scales_data = scales->Data<float>();
Expand Down
45 changes: 0 additions & 45 deletions onnxruntime/contrib_ops/cpu/quantization/neural_speed_defs.h

This file was deleted.

Loading

0 comments on commit 0d10c7f

Please sign in to comment.