Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

linear_int4_kernel for XPU #1130

Merged
merged 36 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
1e32bbc
Sync main into release/2.6 branch (#1117)
xytintel Nov 22, 2024
f312190
[Release-2.6] Fix bugs of `empty_xpu` and `soft_shrink` (#1139)
xytintel Dec 3, 2024
7ecb0b1
[Release-2.6] Capture rrelu_with_noise noise mutation in compile (#1145)
xytintel Dec 5, 2024
5410f51
contiguous layout for sycl int4 kernel
airMeng Nov 22, 2024
e9311a3
push without compile
sunjiweiswift Nov 26, 2024
e3eaffa
update linearkernel
sunjiweiswift Nov 28, 2024
2a664af
fix some comiple error(not all)
sunjiweiswift Nov 28, 2024
0156ba5
add sycl_ker_config_convention
sunjiweiswift Nov 28, 2024
a58afec
reg kernel for pytorch
sunjiweiswift Nov 29, 2024
f487b20
add yaml for int4mm
sunjiweiswift Nov 29, 2024
ce1c894
update yaml file
sunjiweiswift Dec 3, 2024
d61b198
Modified some review comments
sunjiweiswift Dec 3, 2024
d76a0ce
modify fun name
sunjiweiswift Dec 9, 2024
870a3b5
autogen: _weight_int4pack_mm_with_scales_and_zeros.out
sunjiweiswift Dec 10, 2024
a9627f6
param int->int64_t(python int is int64)
sunjiweiswift Dec 10, 2024
952ead9
use AT_DISPATCH_FLOATING_TYPES_AND
sunjiweiswift Dec 10, 2024
93804f9
Keep the same name as pytorch's _weight_int4pack_mm
sunjiweiswift Dec 11, 2024
9e50b68
modify UT for int4
sunjiweiswift Dec 11, 2024
81a72f1
sync UT with pytoch UT(linalg)
sunjiweiswift Dec 12, 2024
a70df0a
col-major
sunjiweiswift Dec 12, 2024
c08382c
UT pass for B ones
sunjiweiswift Dec 13, 2024
14bb4e0
update gemv
sunjiweiswift Dec 16, 2024
70a3e13
fix scale and zp address
sunjiweiswift Dec 17, 2024
a590ad6
fix K large than 1024 UT
sunjiweiswift Dec 18, 2024
d6a2f3a
bug fix for FP16(BF16 maybe incorrect)
sunjiweiswift Dec 18, 2024
27f18c2
save
sunjiweiswift Dec 20, 2024
7f94b9b
Merge branch 'main' into fp_zp
sunjiweiswift Dec 20, 2024
42c18e9
bugfix for Big Endian
sunjiweiswift Dec 20, 2024
d832050
Unify BF16 and FP16 Funtion
sunjiweiswift Dec 20, 2024
8385f7e
fix compile warning
sunjiweiswift Dec 20, 2024
f44ed70
modify by review
sunjiweiswift Dec 23, 2024
09696b1
Merge branch 'main' into fp_zp
sunjiweiswift Dec 24, 2024
ebe8c7c
Merge branch 'main' into fp_zp
sunjiweiswift Dec 25, 2024
ce6c16b
Merge branch 'main' into fp_zp
sunjiweiswift Jan 2, 2025
dacf3b9
Merge branch 'main' into fp_zp
sunjiweiswift Jan 3, 2025
8a5c000
Merge branch 'main' into fp_zp
sunjiweiswift Jan 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 63 additions & 0 deletions src/ATen/native/xpu/LinearInt4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@

#include <ATen/core/op_registration/adaption.h>
#include <ATen/div_rtn.h>
#include <ATen/native/TensorIterator.h>
#include <torch/library.h>

#include <ATen/native/xpu/sycl/LinearInt4.h>
#include <comm/xpu_aten.h>

namespace at::native {
Tensor _weight_int4pack_mm_xpu(
const Tensor& A,
const Tensor& B,
int64_t qGroupSize,
const Tensor& qScaleAndZeros) {
auto M = A.size(0);
auto N = B.size(0);
auto K = A.size(1);
TORCH_CHECK(
A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat,
__func__,
" : expect A to be either 32-bit or 16-bit float tensor.");
TORCH_CHECK(A.is_contiguous(), __func__, " : expect A to be contiguous.");
TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor.");

TORCH_CHECK(
B.dtype() == kInt || B.dtype() == kUInt32 || B.dtype() == kByte,
__func__,
" : expect B to be int32 or uint32 or uint8 tensor.");
TORCH_CHECK(B.is_contiguous(), __func__, " : expect B to be contiguous.");
TORCH_CHECK(B.dim() == 2, __func__, " : expect B to 2d tensor.");

TORCH_CHECK(
qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 ||
qGroupSize == 256,
__func__,
": expect qGroupSize to be 32, 64, 128 or 256, got ",
qGroupSize);

TORCH_CHECK(
qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(0) == N &&
qScaleAndZeros.size(2) == 2,
__func__,
": expect qScaleAndZeros to be 3d tensor with sizes [",
N,
", :, 2]");

std::optional<Device> common_device = std::nullopt;
c10::impl::check_and_update_common_device(
common_device, A, "xpu::_weight_int4pack_mm", "A");
c10::impl::check_and_update_common_device(
common_device, B, "xpu::_weight_int4pack_mm", "B");
c10::impl::check_and_update_common_device(
common_device,
qScaleAndZeros,
"xpu::_weight_int4pack_mm",
"qScaleAndZeros");
Tensor C = at::empty({M, N}, A.options());

at::native::xpu::linear_int4_kernel(A, B, qGroupSize, qScaleAndZeros, C);
return C;
}
} // namespace at::native
247 changes: 247 additions & 0 deletions src/ATen/native/xpu/sycl/LinearInt4.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,247 @@
#include <ATen/native/xpu/sycl/LinearInt4.h>
#include <comm/SYCLContext.h>

namespace at::native::xpu {
static inline int padto_le(int src, int padding) {
return src / padding * padding;
}

static inline int64_t padto_le(int64_t src, int64_t padding) {
return src / padding * padding;
}

static inline size_t padto_le(size_t src, int padding) {
return src / size_t(padding) * size_t(padding);
}

template <typename scalar_t = sycl::half, int block_size = 32>
struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ {
LinearInt4KernelFunctor(
const scalar_t* A,
const uint8_t* B,
scalar_t* C,
const scalar_t* ScaleAndZeros,
int m,
int n,
int k,
int lda,
int ldb,
int ldc)
: A(A),
B(B),
C(C),
ScaleAndZeros(ScaleAndZeros),
m(m),
n(n),
k(k),
lda(lda),
ldb(ldb),
ldc(ldc) {}
void sycl_ker_config_convention(sycl::handler& cgh) {}

void operator()(sycl::nd_item<1> it) const {
int constexpr Unroll = 2;
int constexpr SgSize = 16;
int constexpr blocksize = block_size;
using scalarx2_t = sycl::vec<scalar_t, 2>;

if (k % (SgSize * 32 * Unroll) == 0) {
int constexpr TileK = 32;
int constexpr GroupK = SgSize * TileK;

int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = ScaleAndZeros + g_n * ldb * 2;
auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;

sycl::float2 tmpAcc = {0.f, 0.f};
for (int i = 0; i < k; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move uint8_t tmps8[TileK / 2]; outside the loop.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we can do a little template trick to simply this piece of logic, have a template that handles all scernios and then pass corresponding args when called.

template <typename scalar_t, int SgSize, int TileK, int Unroll>
void tinygemm_kernel(...)

if (k % (SgSize * 32 * Unroll) == 0) {
  // use tinygemm_kernel<...>
else {
  // use tinygemm_kernel<...>
}

not a must to have, just a little trick.

*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);
int scale_offset = sg_id * (TileK / blocksize) * 2;
int zp_offset = sg_id * (TileK / blocksize) * 2;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK + ikk);
scalarx2_t tmpB = {
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
auto tmpAmulB = tmpA * (tmpB * scale + zero_point);
tmpAcc += {tmpAmulB[0], tmpAmulB[1]};
}
sptr += (GroupK / blocksize) * 2;
aptr += GroupK;
bptr += GroupK / 2;
}
}
sycl::float2 sum = {0.f, 0.f};
sum += sycl::reduce_over_group(sg, tmpAcc, sycl::plus<>());
if (sg_id == 0) {
*cptr = static_cast<scalar_t>(sum[0] + sum[1]);
}
} else { // k % (SgSize * 32 * Unroll) != 0
int constexpr TileK = 32;
int constexpr GroupK = SgSize * TileK;
int k_body = padto_le(k, GroupK * Unroll);
int constexpr TileK2 = 8;
int constexpr GroupK2 = SgSize * TileK2;
int k_body2 = padto_le(k, GroupK2 * Unroll);
int g_idx = it.get_group(0);
auto sg = it.get_sub_group();
int sg_id = sg.get_local_id()[0];
int g_n = g_idx;
auto sptr = ScaleAndZeros + g_n * ldb * 2;
auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1;
auto bptr = B + g_n * k / 2;
auto aptr = A;
auto cptr = C + g_n;
sycl::float2 tmpAcc = {0.f, 0.f};
int i = 0;
for (; i < k_body; i += GroupK * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK / 2];
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2);

int scale_offset = sg_id * (TileK / blocksize) * 2;
int zp_offset = sg_id * (TileK / blocksize) * 2;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
#pragma unroll
for (int ikk = 0; ikk < TileK; ikk += 2) {
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK + ikk);
scalarx2_t tmpB = {
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
auto tmpAmulB = tmpA * (tmpB * scale + zero_point);
tmpAcc += {tmpAmulB[0], tmpAmulB[1]};
}
sptr += (GroupK / blocksize) * 2;
aptr += GroupK;
bptr += GroupK / 2;
}
}
if (i + GroupK2 * Unroll < k_body2) {
for (; i < k_body2; i += GroupK2 * Unroll) {
#pragma unroll
for (int iu = 0; iu < Unroll; iu++) {
uint8_t tmps8[TileK2 / 2];
*(sycl::vec<uint8_t, TileK2 / 2>*)tmps8 =
*(sycl::vec<uint8_t, TileK2 / 2>*)(bptr + sg_id * TileK2 / 2);

int scale_offset = sg_id * (TileK2 / blocksize) * 2;
int zp_offset = sg_id * (TileK2 / blocksize) * 2;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
#pragma unroll
for (int ikk = 0; ikk < TileK2; ikk += 2) {
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK2 + ikk);
scalarx2_t tmpB = {
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8),
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)};
auto tmpAmulB = tmpA * (tmpB * scale + zero_point);
tmpAcc += {tmpAmulB[0], tmpAmulB[1]};
}
sptr += (GroupK2 / blocksize) * 2;
aptr += GroupK2;
bptr += GroupK2 / 2;
}
}
}
if (i + SgSize * 2 <= k) {
for (; i < k; i += SgSize * 2) {
uint8_t tmps8 = *(bptr + sg_id);
scalarx2_t tmpB = {
static_cast<int8_t>((tmps8 & 0x0f) - 8),
static_cast<int8_t>((tmps8 >> 4) - 8)};

int scale_offset = sg_id * (2 / blocksize) * 2;
int zp_offset = sg_id * (2 / blocksize) * 2;
scalar_t scale = *(sptr + scale_offset);
scalar_t zero_point = *(zptr + zp_offset);
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * 2);
auto tmpAmulB = tmpA * (tmpB * scale + zero_point);
tmpAcc += {tmpAmulB[0], tmpAmulB[1]};
sptr += (SgSize * 2 / blocksize) * 2;
aptr += SgSize * 2;
bptr += SgSize * 2 / 2;
}
}
sycl::float2 sum = {0.f, 0.f};
sum += sycl::reduce_over_group(sg, tmpAcc, sycl::plus<>());
if (sg_id == 0) {
*cptr = static_cast<scalar_t>(sum[0] + sum[1]);
}
}
}

private:
const scalar_t* A;
const uint8_t* B;
scalar_t* C;
const scalar_t* ScaleAndZeros;
int m;
int n;
int k;
int lda;
int ldb;
int ldc;
};

void linear_int4_kernel(
sunjiweiswift marked this conversation as resolved.
Show resolved Hide resolved
const Tensor& A,
const Tensor& B,
int qGroupSize,
const Tensor& qScaleAndZeros,
Tensor& C) {
auto& sycl_queue = at::xpu::getCurrentSYCLQueue();
int64_t m = A.size(0);
int64_t n = C.size(1);
int64_t k = A.size(1);
int constexpr SgSize = 16;
sycl::range<1> local_range{SgSize};
sycl::range<1> global_range{static_cast<size_t>(n) * SgSize};
AT_DISPATCH_REDUCED_FLOATING_TYPES(
A.scalar_type(), "linear_int4_kernel", [&]() {
using scalar_sycl_t = std::conditional_t<
std::is_same_v<scalar_t, at::Half>,
sycl::half,
sycl::ext::oneapi::bfloat16>;

const scalar_sycl_t* input_data =
reinterpret_cast<scalar_sycl_t*>(A.data_ptr<scalar_t>());
uint8_t* weight_data =
reinterpret_cast<uint8_t*>(B.data_ptr()); // int4x2 or int4x8

scalar_sycl_t* output_data =
reinterpret_cast<scalar_sycl_t*>(C.data_ptr<scalar_t>());
scalar_sycl_t* scale_zeros_data = reinterpret_cast<scalar_sycl_t*>(
qScaleAndZeros.data_ptr<scalar_t>());
LinearInt4KernelFunctor<scalar_sycl_t, 32> kfn =
LinearInt4KernelFunctor<scalar_sycl_t, 32>(
input_data,
weight_data,
output_data,
scale_zeros_data,
m,
n,
k,
k,
k / qGroupSize,
n);
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn);
});
}

} // namespace at::native::xpu
14 changes: 14 additions & 0 deletions src/ATen/native/xpu/sycl/LinearInt4.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
#pragma once
#include <ATen/native/TensorIterator.h>
#include <comm/xpu_aten.h>

namespace at::native::xpu {

TORCH_XPU_API void linear_int4_kernel(
const Tensor& input,
const Tensor& weight,
int qGroupSize,
const Tensor& weight_scale_zero_point,
Tensor& output);

} // namespace at::native::xpu
Loading