Skip to content

Commit

Permalink
[Refactor] Refactor code to avoid potential nested queue submit issue…
Browse files Browse the repository at this point in the history
…. (#2695)
  • Loading branch information
cboss6 authored and mini-goel committed Dec 18, 2024
1 parent 8decc58 commit 66dd673
Showing 1 changed file with 19 additions and 16 deletions.
35 changes: 19 additions & 16 deletions itex/core/kernels/gpu/matmul_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -349,26 +349,29 @@ void LaunchBmmCustomKernel(OpKernelContext* ctx, const T* A, const T* B, T* C,
sycl::range<3> local{1, BS_X, BS_Y};
Tensor A_offset_tensor, B_offset_tensor;

if (src_dims > 3 && is_bcast_required) {
const std::vector<int64_t>& x_batch_indices = bcast.x_batch_indices();
const std::vector<int64_t>& y_batch_indices = bcast.y_batch_indices();
OP_REQUIRES_OK(ctx,
ctx->allocate_temp(DataTypeToEnum<int64_t>::value,
TensorShape({bs}), &A_offset_tensor));
OP_REQUIRES_OK(ctx,
ctx->allocate_temp(DataTypeToEnum<int64_t>::value,
TensorShape({bs}), &B_offset_tensor));
stream
->memcpy(GetTensorBuffer<int64_t>(&A_offset_tensor),
x_batch_indices.data(), bs * sizeof(int64_t))
.wait();
stream
->memcpy(GetTensorBuffer<int64_t>(&B_offset_tensor),
y_batch_indices.data(), bs * sizeof(int64_t))
.wait();
}

stream->submit([&](sycl::handler& cgh) {
LocalAcc<T> Asub(sycl::range<2>{c_M * BS_X, TILE_K}, cgh);
LocalAcc<T> Bsub(sycl::range<2>{TILE_K, c_P * BS_Y}, cgh);
if (src_dims > 3 && is_bcast_required) {
const std::vector<int64_t>& x_batch_indices = bcast.x_batch_indices();
const std::vector<int64_t>& y_batch_indices = bcast.y_batch_indices();
OP_REQUIRES_OK(ctx,
ctx->allocate_temp(DataTypeToEnum<int64_t>::value,
TensorShape({bs}), &A_offset_tensor));
OP_REQUIRES_OK(ctx,
ctx->allocate_temp(DataTypeToEnum<int64_t>::value,
TensorShape({bs}), &B_offset_tensor));
stream
->memcpy(GetTensorBuffer<int64_t>(&A_offset_tensor),
x_batch_indices.data(), bs * sizeof(int64_t))
.wait();
stream
->memcpy(GetTensorBuffer<int64_t>(&B_offset_tensor),
y_batch_indices.data(), bs * sizeof(int64_t))
.wait();
BatchMatMulWithBcastKernel<T, c_M, c_P, BS_X, BS_Y, TILE_K, TILE_AB> task(
A, B, C, bs, M, N, P, Asub, Bsub, adj_A, adj_B,
static_cast<int64_t*>(GetTensorBuffer<int64_t>(&A_offset_tensor)),
Expand Down

0 comments on commit 66dd673

Please sign in to comment.