Skip to content

Commit

Permalink
Integrate sbgemm kernel into cpu provider MatMul operator
Browse files Browse the repository at this point in the history
sbgemm kernel is integrated under fastmath mode environment variable.
It's disabled by default, set the below env to enable it.
ORT_USE_FASTMATH_MODE = 1
  • Loading branch information
snadampal committed Sep 11, 2023
1 parent 4c511ef commit 97a9826
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 3 deletions.
59 changes: 56 additions & 3 deletions onnxruntime/core/providers/cpu/math/matmul.cc
Original file line number Diff line number Diff line change
Expand Up @@ -126,15 +126,64 @@ Status MatMul<T>::Compute(OpKernelContext* ctx) const {
return Status::OK();
}

bool GemmPackBbf16(AllocatorPtr& alloc,
const Tensor& tensor_b,
bool trans_b,
IAllocatorUniquePtr<void>& packed_b,
size_t& packed_b_size,
TensorShape& b_shape) {
// Only handle the common case of a 2D weight matrix. Additional matrices
// could be handled by stacking the packed buffers.
if (tensor_b.Shape().NumDimensions() != 2) {
return false;
}

b_shape = tensor_b.Shape();

const size_t K = trans_b ? static_cast<size_t>(b_shape[1]) : static_cast<size_t>(b_shape[0]);
const size_t N = trans_b ? static_cast<size_t>(b_shape[0]) : static_cast<size_t>(b_shape[1]);

packed_b_size = MlasSBGemmPackBSize(N, K);
if (packed_b_size == 0) {
return false;
}

packed_b = IAllocator::MakeUniquePtr<void>(alloc, packed_b_size, true);
auto* packed_b_data = packed_b.get();

// Initialize memory to 0 as there could be some padding associated with pre-packed
// buffer memory and we don not want it uninitialized and generate different hashes
// if and when we try to cache this pre-packed buffer for sharing between sessions.
memset(packed_b_data, 0, packed_b_size);
MlasSBGemmPackB(trans_b ? CblasTrans : CblasNoTrans,
N,
K,
tensor_b.Data<float>(),
trans_b ? K : N,
packed_b_data);
return true;
}


Status MatMul<float>::PrePack(const Tensor& tensor, int input_idx, /*out*/ AllocatorPtr alloc,
/*out*/ bool& is_packed,
/*out*/ PrePackedWeights* prepacked_weights) {
is_packed = false;

TensorShape b_shape = tensor.Shape();

const size_t dim1 = static_cast<size_t>(b_shape[0]);
const size_t dim2 = static_cast<size_t>(b_shape[1]);

// only pack Matrix B
if (input_idx == 1) {
size_t packed_b_size;
is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_);
if (use_fastmath_mode && (trans_b_attr_== 0) && ((dim1*dim2) >= fastmath_mode_kernelsize_threshold)) {
is_packed = GemmPackBbf16(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_);
} else {
is_packed = GemmPackBFp32(alloc, tensor, trans_b_attr_ != 0, packed_b_, packed_b_size, b_shape_);
}

bool share_prepacked_weights = (prepacked_weights != nullptr);
if (is_packed && share_prepacked_weights) {
prepacked_weights->buffers_.push_back(std::move(packed_b_));
Expand Down Expand Up @@ -199,9 +248,13 @@ Status MatMul<float>::Compute(OpKernelContext* ctx) const {
data[i].alpha = alpha_attr_;
data[i].beta = 0.0f;
}
MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans,
if (use_fastmath_mode && !trans_b && ((N*K) >= fastmath_mode_kernelsize_threshold)) {
MlasSBGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans,
M, N, K, data.data(), max_len, thread_pool);

} else {
MlasGemmBatch(trans_a ? CblasTrans : CblasNoTrans, trans_b ? CblasTrans : CblasNoTrans,
M, N, K, data.data(), max_len, thread_pool);
}
return Status::OK();
}

Expand Down
16 changes: 16 additions & 0 deletions onnxruntime/core/providers/cpu/math/matmul.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,14 @@ class MatMul final : public OpKernel {
Status Compute(OpKernelContext* context) const override;
};

static inline bool ort_use_fastmath_mode() {
static auto value = [&] {
const char* ptr = std::getenv("ORT_USE_FASTMATH_MODE");
return ptr != nullptr ? std::atoi(ptr) : 0;
}();
return value;
}

template <>
class MatMul<float> final : public OpKernel {
public:
Expand All @@ -27,6 +35,8 @@ class MatMul<float> final : public OpKernel {
info.GetAttrOrDefault<int64_t>("transBatchB", &trans_batch_b_attr, 0);
trans_batch_a_ = trans_batch_a_attr != 0;
trans_batch_b_ = trans_batch_b_attr != 0;

use_fastmath_mode = ort_use_fastmath_mode();
}

Status PrePack(const Tensor& tensor, int input_idx, AllocatorPtr alloc,
Expand All @@ -48,6 +58,12 @@ class MatMul<float> final : public OpKernel {
int64_t trans_b_attr_;
bool trans_batch_a_;
bool trans_batch_b_;

//fastmath mode state
bool use_fastmath_mode;
// sbgemm kernel is implemented as 8x8 blocks with weights pre-packed to 4 blocks of 4x2
// so a minimum of 32 elements is defined to outweigh the additional prepacking overhead
const size_t fastmath_mode_kernelsize_threshold = 32;
};

} // namespace onnxruntime

0 comments on commit 97a9826

Please sign in to comment.