-
Notifications
You must be signed in to change notification settings - Fork 24
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
linear_int4_kernel for XPU #1130
Merged
+438
−0
Merged
Changes from all commits
Commits
Show all changes
36 commits
Select commit
Hold shift + click to select a range
1e32bbc
Sync main into release/2.6 branch (#1117)
xytintel f312190
[Release-2.6] Fix bugs of `empty_xpu` and `soft_shrink` (#1139)
xytintel 7ecb0b1
[Release-2.6] Capture rrelu_with_noise noise mutation in compile (#1145)
xytintel 5410f51
contiguous layout for sycl int4 kernel
airMeng e9311a3
push without compile
sunjiweiswift e3eaffa
update linearkernel
sunjiweiswift 2a664af
fix some comiple error(not all)
sunjiweiswift 0156ba5
add sycl_ker_config_convention
sunjiweiswift a58afec
reg kernel for pytorch
sunjiweiswift f487b20
add yaml for int4mm
sunjiweiswift ce1c894
update yaml file
sunjiweiswift d61b198
Modified some review comments
sunjiweiswift d76a0ce
modify fun name
sunjiweiswift 870a3b5
autogen: _weight_int4pack_mm_with_scales_and_zeros.out
sunjiweiswift a9627f6
param int->int64_t(python int is int64)
sunjiweiswift 952ead9
use AT_DISPATCH_FLOATING_TYPES_AND
sunjiweiswift 93804f9
Keep the same name as pytorch's _weight_int4pack_mm
sunjiweiswift 9e50b68
modify UT for int4
sunjiweiswift 81a72f1
sync UT with pytoch UT(linalg)
sunjiweiswift a70df0a
col-major
sunjiweiswift c08382c
UT pass for B ones
sunjiweiswift 14bb4e0
update gemv
sunjiweiswift 70a3e13
fix scale and zp address
sunjiweiswift a590ad6
fix K large than 1024 UT
sunjiweiswift d6a2f3a
bug fix for FP16(BF16 maybe incorrect)
sunjiweiswift 27f18c2
save
sunjiweiswift 7f94b9b
Merge branch 'main' into fp_zp
sunjiweiswift 42c18e9
bugfix for Big Endian
sunjiweiswift d832050
Unify BF16 and FP16 Funtion
sunjiweiswift 8385f7e
fix compile warning
sunjiweiswift f44ed70
modify by review
sunjiweiswift 09696b1
Merge branch 'main' into fp_zp
sunjiweiswift ebe8c7c
Merge branch 'main' into fp_zp
sunjiweiswift ce6c16b
Merge branch 'main' into fp_zp
sunjiweiswift dacf3b9
Merge branch 'main' into fp_zp
sunjiweiswift 8a5c000
Merge branch 'main' into fp_zp
sunjiweiswift File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
|
||
#include <ATen/core/op_registration/adaption.h> | ||
#include <ATen/div_rtn.h> | ||
#include <ATen/native/TensorIterator.h> | ||
#include <torch/library.h> | ||
|
||
#include <ATen/native/xpu/sycl/LinearInt4.h> | ||
#include <comm/xpu_aten.h> | ||
|
||
namespace at::native { | ||
Tensor _weight_int4pack_mm_xpu( | ||
const Tensor& A, | ||
const Tensor& B, | ||
int64_t qGroupSize, | ||
const Tensor& qScaleAndZeros) { | ||
auto M = A.size(0); | ||
auto N = B.size(0); | ||
auto K = A.size(1); | ||
TORCH_CHECK( | ||
A.dtype() == kBFloat16 || A.dtype() == kHalf || A.dtype() == kFloat, | ||
__func__, | ||
" : expect A to be either 32-bit or 16-bit float tensor."); | ||
TORCH_CHECK(A.is_contiguous(), __func__, " : expect A to be contiguous."); | ||
TORCH_CHECK(A.dim() == 2, __func__, " : expect A to be 2D tensor."); | ||
|
||
TORCH_CHECK( | ||
B.dtype() == kInt || B.dtype() == kUInt32 || B.dtype() == kByte, | ||
__func__, | ||
" : expect B to be int32 or uint32 or uint8 tensor."); | ||
TORCH_CHECK(B.is_contiguous(), __func__, " : expect B to be contiguous."); | ||
TORCH_CHECK(B.dim() == 2, __func__, " : expect B to 2d tensor."); | ||
|
||
TORCH_CHECK( | ||
qGroupSize == 32 || qGroupSize == 64 || qGroupSize == 128 || | ||
qGroupSize == 256, | ||
__func__, | ||
": expect qGroupSize to be 32, 64, 128 or 256, got ", | ||
qGroupSize); | ||
|
||
TORCH_CHECK( | ||
qScaleAndZeros.dim() == 3 && qScaleAndZeros.size(0) == N && | ||
qScaleAndZeros.size(2) == 2, | ||
__func__, | ||
": expect qScaleAndZeros to be 3d tensor with sizes [", | ||
N, | ||
", :, 2]"); | ||
|
||
std::optional<Device> common_device = std::nullopt; | ||
c10::impl::check_and_update_common_device( | ||
common_device, A, "xpu::_weight_int4pack_mm", "A"); | ||
c10::impl::check_and_update_common_device( | ||
common_device, B, "xpu::_weight_int4pack_mm", "B"); | ||
c10::impl::check_and_update_common_device( | ||
common_device, | ||
qScaleAndZeros, | ||
"xpu::_weight_int4pack_mm", | ||
"qScaleAndZeros"); | ||
Tensor C = at::empty({M, N}, A.options()); | ||
|
||
at::native::xpu::linear_int4_kernel(A, B, qGroupSize, qScaleAndZeros, C); | ||
return C; | ||
} | ||
} // namespace at::native |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,247 @@ | ||
#include <ATen/native/xpu/sycl/LinearInt4.h> | ||
#include <comm/SYCLContext.h> | ||
|
||
namespace at::native::xpu { | ||
static inline int padto_le(int src, int padding) { | ||
return src / padding * padding; | ||
} | ||
|
||
static inline int64_t padto_le(int64_t src, int64_t padding) { | ||
return src / padding * padding; | ||
} | ||
|
||
static inline size_t padto_le(size_t src, int padding) { | ||
return src / size_t(padding) * size_t(padding); | ||
} | ||
|
||
template <typename scalar_t = sycl::half, int block_size = 32> | ||
struct LinearInt4KernelFunctor : public __SYCL_KER_CONFIG_CONVENTION__ { | ||
LinearInt4KernelFunctor( | ||
const scalar_t* A, | ||
const uint8_t* B, | ||
scalar_t* C, | ||
const scalar_t* ScaleAndZeros, | ||
int m, | ||
int n, | ||
int k, | ||
int lda, | ||
int ldb, | ||
int ldc) | ||
: A(A), | ||
B(B), | ||
C(C), | ||
ScaleAndZeros(ScaleAndZeros), | ||
m(m), | ||
n(n), | ||
k(k), | ||
lda(lda), | ||
ldb(ldb), | ||
ldc(ldc) {} | ||
void sycl_ker_config_convention(sycl::handler& cgh) {} | ||
|
||
void operator()(sycl::nd_item<1> it) const { | ||
int constexpr Unroll = 2; | ||
int constexpr SgSize = 16; | ||
int constexpr blocksize = block_size; | ||
using scalarx2_t = sycl::vec<scalar_t, 2>; | ||
|
||
if (k % (SgSize * 32 * Unroll) == 0) { | ||
int constexpr TileK = 32; | ||
int constexpr GroupK = SgSize * TileK; | ||
|
||
int g_idx = it.get_group(0); | ||
auto sg = it.get_sub_group(); | ||
int sg_id = sg.get_local_id()[0]; | ||
int g_n = g_idx; | ||
auto sptr = ScaleAndZeros + g_n * ldb * 2; | ||
auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1; | ||
auto bptr = B + g_n * k / 2; | ||
auto aptr = A; | ||
auto cptr = C + g_n; | ||
|
||
sycl::float2 tmpAcc = {0.f, 0.f}; | ||
for (int i = 0; i < k; i += GroupK * Unroll) { | ||
#pragma unroll | ||
for (int iu = 0; iu < Unroll; iu++) { | ||
uint8_t tmps8[TileK / 2]; | ||
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 = | ||
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2); | ||
int scale_offset = sg_id * (TileK / blocksize) * 2; | ||
int zp_offset = sg_id * (TileK / blocksize) * 2; | ||
scalar_t scale = *(sptr + scale_offset); | ||
scalar_t zero_point = *(zptr + zp_offset); | ||
#pragma unroll | ||
for (int ikk = 0; ikk < TileK; ikk += 2) { | ||
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK + ikk); | ||
scalarx2_t tmpB = { | ||
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8), | ||
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)}; | ||
auto tmpAmulB = tmpA * (tmpB * scale + zero_point); | ||
tmpAcc += {tmpAmulB[0], tmpAmulB[1]}; | ||
} | ||
sptr += (GroupK / blocksize) * 2; | ||
aptr += GroupK; | ||
bptr += GroupK / 2; | ||
} | ||
} | ||
sycl::float2 sum = {0.f, 0.f}; | ||
sum += sycl::reduce_over_group(sg, tmpAcc, sycl::plus<>()); | ||
if (sg_id == 0) { | ||
*cptr = static_cast<scalar_t>(sum[0] + sum[1]); | ||
} | ||
} else { // k % (SgSize * 32 * Unroll) != 0 | ||
int constexpr TileK = 32; | ||
int constexpr GroupK = SgSize * TileK; | ||
int k_body = padto_le(k, GroupK * Unroll); | ||
int constexpr TileK2 = 8; | ||
int constexpr GroupK2 = SgSize * TileK2; | ||
int k_body2 = padto_le(k, GroupK2 * Unroll); | ||
int g_idx = it.get_group(0); | ||
auto sg = it.get_sub_group(); | ||
int sg_id = sg.get_local_id()[0]; | ||
int g_n = g_idx; | ||
auto sptr = ScaleAndZeros + g_n * ldb * 2; | ||
auto zptr = ScaleAndZeros + g_n * ldb * 2 + 1; | ||
auto bptr = B + g_n * k / 2; | ||
auto aptr = A; | ||
auto cptr = C + g_n; | ||
sycl::float2 tmpAcc = {0.f, 0.f}; | ||
int i = 0; | ||
for (; i < k_body; i += GroupK * Unroll) { | ||
#pragma unroll | ||
for (int iu = 0; iu < Unroll; iu++) { | ||
uint8_t tmps8[TileK / 2]; | ||
*(sycl::vec<uint8_t, TileK / 2>*)tmps8 = | ||
*(sycl::vec<uint8_t, TileK / 2>*)(bptr + sg_id * TileK / 2); | ||
|
||
int scale_offset = sg_id * (TileK / blocksize) * 2; | ||
int zp_offset = sg_id * (TileK / blocksize) * 2; | ||
scalar_t scale = *(sptr + scale_offset); | ||
scalar_t zero_point = *(zptr + zp_offset); | ||
#pragma unroll | ||
for (int ikk = 0; ikk < TileK; ikk += 2) { | ||
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK + ikk); | ||
scalarx2_t tmpB = { | ||
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8), | ||
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)}; | ||
auto tmpAmulB = tmpA * (tmpB * scale + zero_point); | ||
tmpAcc += {tmpAmulB[0], tmpAmulB[1]}; | ||
} | ||
sptr += (GroupK / blocksize) * 2; | ||
aptr += GroupK; | ||
bptr += GroupK / 2; | ||
} | ||
} | ||
if (i + GroupK2 * Unroll < k_body2) { | ||
for (; i < k_body2; i += GroupK2 * Unroll) { | ||
#pragma unroll | ||
for (int iu = 0; iu < Unroll; iu++) { | ||
uint8_t tmps8[TileK2 / 2]; | ||
*(sycl::vec<uint8_t, TileK2 / 2>*)tmps8 = | ||
*(sycl::vec<uint8_t, TileK2 / 2>*)(bptr + sg_id * TileK2 / 2); | ||
|
||
int scale_offset = sg_id * (TileK2 / blocksize) * 2; | ||
int zp_offset = sg_id * (TileK2 / blocksize) * 2; | ||
scalar_t scale = *(sptr + scale_offset); | ||
scalar_t zero_point = *(zptr + zp_offset); | ||
#pragma unroll | ||
for (int ikk = 0; ikk < TileK2; ikk += 2) { | ||
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * TileK2 + ikk); | ||
scalarx2_t tmpB = { | ||
static_cast<int8_t>((tmps8[ikk / 2] & 0x0f) - 8), | ||
static_cast<int8_t>((tmps8[ikk / 2] >> 4) - 8)}; | ||
auto tmpAmulB = tmpA * (tmpB * scale + zero_point); | ||
tmpAcc += {tmpAmulB[0], tmpAmulB[1]}; | ||
} | ||
sptr += (GroupK2 / blocksize) * 2; | ||
aptr += GroupK2; | ||
bptr += GroupK2 / 2; | ||
} | ||
} | ||
} | ||
if (i + SgSize * 2 <= k) { | ||
for (; i < k; i += SgSize * 2) { | ||
uint8_t tmps8 = *(bptr + sg_id); | ||
scalarx2_t tmpB = { | ||
static_cast<int8_t>((tmps8 & 0x0f) - 8), | ||
static_cast<int8_t>((tmps8 >> 4) - 8)}; | ||
|
||
int scale_offset = sg_id * (2 / blocksize) * 2; | ||
int zp_offset = sg_id * (2 / blocksize) * 2; | ||
scalar_t scale = *(sptr + scale_offset); | ||
scalar_t zero_point = *(zptr + zp_offset); | ||
scalarx2_t tmpA = *(scalarx2_t*)(aptr + sg_id * 2); | ||
auto tmpAmulB = tmpA * (tmpB * scale + zero_point); | ||
tmpAcc += {tmpAmulB[0], tmpAmulB[1]}; | ||
sptr += (SgSize * 2 / blocksize) * 2; | ||
aptr += SgSize * 2; | ||
bptr += SgSize * 2 / 2; | ||
} | ||
} | ||
sycl::float2 sum = {0.f, 0.f}; | ||
sum += sycl::reduce_over_group(sg, tmpAcc, sycl::plus<>()); | ||
if (sg_id == 0) { | ||
*cptr = static_cast<scalar_t>(sum[0] + sum[1]); | ||
} | ||
} | ||
} | ||
|
||
private: | ||
const scalar_t* A; | ||
const uint8_t* B; | ||
scalar_t* C; | ||
const scalar_t* ScaleAndZeros; | ||
int m; | ||
int n; | ||
int k; | ||
int lda; | ||
int ldb; | ||
int ldc; | ||
}; | ||
|
||
void linear_int4_kernel( | ||
sunjiweiswift marked this conversation as resolved.
Show resolved
Hide resolved
|
||
const Tensor& A, | ||
const Tensor& B, | ||
int qGroupSize, | ||
const Tensor& qScaleAndZeros, | ||
Tensor& C) { | ||
auto& sycl_queue = at::xpu::getCurrentSYCLQueue(); | ||
int64_t m = A.size(0); | ||
int64_t n = C.size(1); | ||
int64_t k = A.size(1); | ||
int constexpr SgSize = 16; | ||
sycl::range<1> local_range{SgSize}; | ||
sycl::range<1> global_range{static_cast<size_t>(n) * SgSize}; | ||
AT_DISPATCH_REDUCED_FLOATING_TYPES( | ||
A.scalar_type(), "linear_int4_kernel", [&]() { | ||
using scalar_sycl_t = std::conditional_t< | ||
std::is_same_v<scalar_t, at::Half>, | ||
sycl::half, | ||
sycl::ext::oneapi::bfloat16>; | ||
|
||
const scalar_sycl_t* input_data = | ||
reinterpret_cast<scalar_sycl_t*>(A.data_ptr<scalar_t>()); | ||
uint8_t* weight_data = | ||
reinterpret_cast<uint8_t*>(B.data_ptr()); // int4x2 or int4x8 | ||
|
||
scalar_sycl_t* output_data = | ||
reinterpret_cast<scalar_sycl_t*>(C.data_ptr<scalar_t>()); | ||
scalar_sycl_t* scale_zeros_data = reinterpret_cast<scalar_sycl_t*>( | ||
qScaleAndZeros.data_ptr<scalar_t>()); | ||
LinearInt4KernelFunctor<scalar_sycl_t, 32> kfn = | ||
LinearInt4KernelFunctor<scalar_sycl_t, 32>( | ||
input_data, | ||
weight_data, | ||
output_data, | ||
scale_zeros_data, | ||
m, | ||
n, | ||
k, | ||
k, | ||
k / qGroupSize, | ||
n); | ||
sycl_kernel_submit(global_range, local_range, sycl_queue, kfn); | ||
}); | ||
} | ||
|
||
} // namespace at::native::xpu |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
#pragma once | ||
#include <ATen/native/TensorIterator.h> | ||
#include <comm/xpu_aten.h> | ||
|
||
namespace at::native::xpu { | ||
|
||
TORCH_XPU_API void linear_int4_kernel( | ||
const Tensor& input, | ||
const Tensor& weight, | ||
int qGroupSize, | ||
const Tensor& weight_scale_zero_point, | ||
Tensor& output); | ||
|
||
} // namespace at::native::xpu |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
move
uint8_t tmps8[TileK / 2];
outside the loop.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe we can do a little template trick to simply this piece of logic, have a template that handles all scernios and then pass corresponding args when called.
not a must to have, just a little trick.